一、神经网络训练

# file     : train.py
# time : 2022/8/11 上午10:03
# function :
import torchvision.datasets
from model import *
from torch.utils.data import DataLoader # DataSet
train_data = torchvision.datasets.CIFAR10("../dataset", train=True, transform=torchvision.transforms.ToTensor(), download=False)
test_data = torchvision.datasets.CIFAR10("../dataset", train=False, transform=torchvision.transforms.ToTensor(), download=False) # len长度
train_data_size = len(train_data)
test_data_size = len(test_data) print(format(train_data_size))
# ctrl+D
print(format(test_data_size)) # liyong DataLoader
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64) # create
tudui = Tudui() #loss
loss_fn = nn.CrossEntropyLoss() #
# learning_rate = 0.01
# 1e-2 = 1 x (10)^(-2) = 1/100 = 0.01
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate) # shezhi
total_train_step = 0
#
total_test_step = 0
#
epoch = 10 for i in range(epoch):
print("------------第{} 轮训练开始----------- ".format(i+1)) #
for data in train_dataloader:
imgs, targets = data
output = tudui(imgs)
loss = loss_fn(output, targets)
#
optimizer.zero_grad()
loss.backward()
optimizer.step() total_train_step = total_train_step + 1
print("训练次数:{}, loss:{}".format(total_train_step, loss.item))

pytorch学习笔记(10)--完整的模型训练(待完善)的更多相关文章

  1. 《C++ Primer Plus》学习笔记10

    <C++ Primer Plus>学习笔记10 <<<<<<<<<<<<<<<<<&l ...

  2. golang学习笔记10 beego api 用jwt验证auth2 token 获取解码信息

    golang学习笔记10 beego api 用jwt验证auth2 token 获取解码信息 Json web token (JWT), 是为了在网络应用环境间传递声明而执行的一种基于JSON的开放 ...

  3. Android:日常学习笔记(10)———使用LitePal操作数据库

    Android:日常学习笔记(10)———使用LitePal操作数据库 引入LitePal 什么是LitePal LitePal是一款开源的Android数据库框架,采用了对象关系映射(ORM)的模式 ...

  4. JavaScript:学习笔记(10)——XMLHttpRequest对象

    JavaScript:学习笔记(10)——XMLHttpRequest对象 XHR对象 使用XMLHttpRequest (XHR)对象可以与服务器交互.您可以从URL获取数据,而无需让整个的页面刷新 ...

  5. WebGL three.js学习笔记 加载外部模型以及Tween.js动画

    WebGL three.js学习笔记 加载外部模型以及Tween.js动画 本文的程序实现了加载外部stl格式的模型,以及学习了如何把加载的模型变为一个粒子系统,并使用Tween.js对该粒子系统进行 ...

  6. thinkphp学习笔记10—看不懂的路由规则

    原文:thinkphp学习笔记10-看不懂的路由规则 路由这部分貌似在实际工作中没有怎么设计过,只是在用默认的设置,在手册里面看到部分,艰涩难懂. 1.路由定义 要使用路由功能需要支持PATH_INF ...

  7. SQL反模式学习笔记10 取整错误

    目标:使用小数取代整数 反模式:使用Float类型 根据IEEE754标识,float类型使用二进制格式编码实数数据. 缺点:(1)舍入的必要性: 并不是所有的十进制中描述的信息都能使用二进制存储,处 ...

  8. Spring MVC 学习笔记10 —— 实现简单的用户管理(4.3)用户登录显示全局异常信息

    </pre>Spring MVC 学习笔记10 -- 实现简单的用户管理(4.3)用户登录--显示全局异常信息<p></p><p></p>& ...

  9. ArcGIS案例学习笔记-批量裁剪地理模型

    ArcGIS案例学习笔记-批量裁剪地理模型 联系方式:谢老师,135-4855-4328,xiexiaokui#qq.com 功能:空间数据的批量裁剪 优点:1.批量裁剪:任意多个目标数据,去裁剪任意 ...

  10. Hadoop学习笔记(10) ——搭建源码学习环境

    Hadoop学习笔记(10) ——搭建源码学习环境 上一章中,我们对整个hadoop的目录及源码目录有了一个初步的了解,接下来计划深入学习一下这头神象作品了.但是看代码用什么,难不成gedit?,单步 ...

随机推荐

  1. 03初识MapReduce

    初识MapReduce 一.什么是MapReduce MapReduce是一种编程范式,它借助Map将一个大任务分解成多个小任务,再借助Reduce归并Map的结果.MapReduce虽然原理很简单, ...

  2. 图文并茂手把手教你How to copy files or directory in nodejs npm scripts编写脚本用npm或者node命令复制文件

    每天都要开心哦~~~ 今天来个双语文档 先放出来官方文档 https://www.npmjs.com/package/copyfiles 先来说一下npm 执行的方式 1.首先,进入项目目录,下载依赖 ...

  3. Quartz.Net源码Example之Quartz.Examples.AspNetCore

    Quartz.Examples.AspNetCore ​ .NetCore的Web系统,后台主要执行多个触发器任务,前台展示所有触发器信息和正在执行的作业的相关信息,还可以通过访问health-UI来 ...

  4. 【学习笔记】XR872 Audio 驱动框架分析

    Xradio Sdk 的 Audio 驱动框架和 Linux 的 ASOC 驱动框架非常相似,只不过简化了很多. 驱动和芯片之间的关系图 下面的 SOC 表示的是 XR872 芯片,这里以 AC107 ...

  5. 计算机网络基础06-Email应用

    1 构成组件 邮件客户端 邮件服务器 SMTP协议 Simple Mail Transfer Protocol 1.1 邮件客户端 读写Email消息 和服务器交互,收发消息 1.2 邮件服务器 邮箱 ...

  6. OPENMP FOR CONSTRUCT GUIDED 调度方式实现原理和源码分析

    OPENMP FOR CONSTRUCT GUIDED 调度方式实现原理和源码分析 前言 在本篇文章当中主要给大家介绍在 OpenMP 当中 guided 调度方式的实现原理.这个调度方式其实和 dy ...

  7. elasticsearch之日期类型有点怪

    一.Date类型简介 elasticsearch通过JSON格式来承载数据的,而JSON中是没有Date对应的数据类型的,但是elasticsearch可以通过以下三种方式处理JSON承载的Date数 ...

  8. 5.安装&卸载子应用 投票

    另起一个新的Django项目 djangoProject_poll_test ........ 把.tar.gz包下载到某个路径 D:\此电脑下分类\桌面\django-polls\dist\djan ...

  9. Javascript中0.1+0.2===0.3?怎么解决这个问题?

    一.问题分析 计算机存储以二进制的方式,而0.1 在二进制中是无限循环的一个数字,所以会出现裁剪,精度丢失会出现,0.100000000000000002 === 0.1,0.200000000000 ...

  10. 使用云服务器配置MariaDB环境,Navicat远程连接一直出错误代码 "2002 - Can't connect to server on '' (10060)"

    使用腾讯云或者阿里云的服务器配置MariaDB数据库环境的时候,用Navicat远程连接在Centos7的Linux上配置MariaDB数据库环境的时候一直出错误代码 "2002 - Can ...