一、神经网络训练

# file     : train.py
# time : 2022/8/11 上午10:03
# function :
import torchvision.datasets
from model import *
from torch.utils.data import DataLoader # DataSet
train_data = torchvision.datasets.CIFAR10("../dataset", train=True, transform=torchvision.transforms.ToTensor(), download=False)
test_data = torchvision.datasets.CIFAR10("../dataset", train=False, transform=torchvision.transforms.ToTensor(), download=False) # len长度
train_data_size = len(train_data)
test_data_size = len(test_data) print(format(train_data_size))
# ctrl+D
print(format(test_data_size)) # liyong DataLoader
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64) # create
tudui = Tudui() #loss
loss_fn = nn.CrossEntropyLoss() #
# learning_rate = 0.01
# 1e-2 = 1 x (10)^(-2) = 1/100 = 0.01
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate) # shezhi
total_train_step = 0
#
total_test_step = 0
#
epoch = 10 for i in range(epoch):
print("------------第{} 轮训练开始----------- ".format(i+1)) #
for data in train_dataloader:
imgs, targets = data
output = tudui(imgs)
loss = loss_fn(output, targets)
#
optimizer.zero_grad()
loss.backward()
optimizer.step() total_train_step = total_train_step + 1
print("训练次数:{}, loss:{}".format(total_train_step, loss.item))

pytorch学习笔记(10)--完整的模型训练(待完善)的更多相关文章

  1. 《C++ Primer Plus》学习笔记10

    <C++ Primer Plus>学习笔记10 <<<<<<<<<<<<<<<<<&l ...

  2. golang学习笔记10 beego api 用jwt验证auth2 token 获取解码信息

    golang学习笔记10 beego api 用jwt验证auth2 token 获取解码信息 Json web token (JWT), 是为了在网络应用环境间传递声明而执行的一种基于JSON的开放 ...

  3. Android:日常学习笔记(10)———使用LitePal操作数据库

    Android:日常学习笔记(10)———使用LitePal操作数据库 引入LitePal 什么是LitePal LitePal是一款开源的Android数据库框架,采用了对象关系映射(ORM)的模式 ...

  4. JavaScript:学习笔记(10)——XMLHttpRequest对象

    JavaScript:学习笔记(10)——XMLHttpRequest对象 XHR对象 使用XMLHttpRequest (XHR)对象可以与服务器交互.您可以从URL获取数据,而无需让整个的页面刷新 ...

  5. WebGL three.js学习笔记 加载外部模型以及Tween.js动画

    WebGL three.js学习笔记 加载外部模型以及Tween.js动画 本文的程序实现了加载外部stl格式的模型,以及学习了如何把加载的模型变为一个粒子系统,并使用Tween.js对该粒子系统进行 ...

  6. thinkphp学习笔记10—看不懂的路由规则

    原文:thinkphp学习笔记10-看不懂的路由规则 路由这部分貌似在实际工作中没有怎么设计过,只是在用默认的设置,在手册里面看到部分,艰涩难懂. 1.路由定义 要使用路由功能需要支持PATH_INF ...

  7. SQL反模式学习笔记10 取整错误

    目标:使用小数取代整数 反模式:使用Float类型 根据IEEE754标识,float类型使用二进制格式编码实数数据. 缺点:(1)舍入的必要性: 并不是所有的十进制中描述的信息都能使用二进制存储,处 ...

  8. Spring MVC 学习笔记10 —— 实现简单的用户管理(4.3)用户登录显示全局异常信息

    </pre>Spring MVC 学习笔记10 -- 实现简单的用户管理(4.3)用户登录--显示全局异常信息<p></p><p></p>& ...

  9. ArcGIS案例学习笔记-批量裁剪地理模型

    ArcGIS案例学习笔记-批量裁剪地理模型 联系方式:谢老师,135-4855-4328,xiexiaokui#qq.com 功能:空间数据的批量裁剪 优点:1.批量裁剪:任意多个目标数据,去裁剪任意 ...

  10. Hadoop学习笔记(10) ——搭建源码学习环境

    Hadoop学习笔记(10) ——搭建源码学习环境 上一章中,我们对整个hadoop的目录及源码目录有了一个初步的了解,接下来计划深入学习一下这头神象作品了.但是看代码用什么,难不成gedit?,单步 ...

随机推荐

  1. C Primer Plus 5.11 編程練習

    /*C Primer Plus (5.10) 9*/ 1 #include<stdio.h> 2 #define G 103 3 int main() 4 { 5 char ch=96; ...

  2. 通俗易懂angular搭建

  3. DLL的两种加载方式

    案例简述 在某项目中,需要使用两个不同版本的HCNetSDK库,我们通常使用的静态加载DLL的方式不能满足该需求,故用到动态加载DLL的方式. 背景技术及术语解释 静态加载:也称隐式调用,指在运行程序 ...

  4. Java 进阶P-1.1+P-1.2

    用类制造对象 对象与类 对象是实体,需要被创建,可以为我们做事情 类是规范,根据类的定义来创建对象 对象=属性+服务 数据:属性或状态 操作:函数 从这些例子可以看出来 class是提供服务的,数据是 ...

  5. Unity项目优化——Web版

    Unity项目优化--Web版 大家好,这是小黑第一次写文章(哈哈哈哈哈,好激动),我好好的写,有不对的地方多多指出. 首先呢是版本介绍,不过好像版本对于优化没有影响: 不过还是要告诉大家我用的版本: ...

  6. 学习python的编程语言

    前言 那么多编程语言,为什么学python 易于学习,是所有编程语言当中最容易学习的 没有最好的语言,只有最合适的语言 第一章 python基础 1. 课程整体介绍 课程整体介绍 python编程基础 ...

  7. 【分析笔记】DW7888 马达驱动芯片待机模式漏电流过高的问题

    发现问题 客户反馈说我们的硬件关机漏电流很大,但是拔掉电池之后再上电(仍处于关机状态)就会恢复为 16~20uA 左右.这让我也讶异,因为亲自测试过,漏电流只有 MCU 的休眠电流 16~20uA 左 ...

  8. java基础:方法

    方法 方法是解决一类问题的步骤的有序组合 包含于类/对象中 设计原则 方法的原子性:一个方法只实现一个功能 定义与调用 方法的组成: 方法的调用 若方法返回值为空 System.out.println ...

  9. Python中的函数定义中的斜杠/和星号*

    Python中的函数定义中的斜杠/和星号* 示例 看一段代码  def say_hello(name,age=18):     print(f'你好!我是{name},今年我{age}啦.') say ...

  10. spring中Utils工具类注入问题

    使用工具类的时候,我们想在static修饰的方法中,通过注入来调用其他方法,这里就存在问题. 第一:普通工具类是不在spring的管理下,spring不会依赖注入 第二:即便使用@Autowired完 ...