import torch
from torch import nn
from torch.nn import functional as F
from torch import optim import torchvision
from matplotlib import pyplot as plt # 小工具 def plot_curve(data):
fig = plt.figure()
plt.plot(range(len(data)),data,color='blue')
plt.legend(['value'],loc='upper right')
plt.xlabel('step')
plt.tlabel('value')
plt.show() def plot_image(img,label,name):
fig = plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt,tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')
plt.title("{}:{}".format(name,label[i].item()))
plt.xticks([])
plt.xticks([]) plt.show() def one_hot(label,depth = 10):
out = torch.zeros(label.size(0),depth)
idx = torch.LongTensor(label).view(-1,1)
out.scatter_(dim=1,index=idx,value=1)
return out # 一次加载多少图片
batch_size = 512
# step1. load dataset 数据加载
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MINST('mnist_data',train=True,download=True,
transform=torchvision.transforms.Compose([
torchvision.transfroms.ToTensor(), torchvision.transfroms.Normalize(
(0.1307,),(0.3081,))
])),
batch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MINST('mnist_data/',train=False,download=True,
transform=torchvision.transforms.Compose([
torchvision.transfroms.ToTensor(),
torchvision.transfroms.Normalize(
(0.1307,),(0.3081,))
])),
batch_size=batch_size,shuffle=False) # 网络创建
class Net(nn.Module): def __init__(self):
super(Net,self).__init__() #xw+b
self.fc1 = nn.Linear(28*28,256)
self.fc2 = nn.Linear(256,64)
self.fc3 = nn.Linear(64,10) def forward(self,x):
# x:[batch_size,1,28,28]
# h1 = relu(xw1+b1)
x = F.relu(self.fc1(x))
# h1 = relu(h1w2+b2)
x = F.relu(self.fc2(x))
# h3 = h2w3+b3
x = self.fc3(x) return x net = Net()
# [w1,b1,w2,b1,w3,b3]
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9) train_loss = [] # 训练
for epoch in range(3): for batch_idx,(x,y) in enumerate(train_loader): # x: [b,1,28,28], y:[512]
# [b,1,28,28]-->[b,feature]
x = x.view(x.size(0),28*28)
# --> [b,10]
out = net(x)
# --> [b,10]
y_onehot = one_hot(y)
# loss = mse(out,y_onehot)
loss = F.mse_loss(out,y_onehot)
# 清零梯度
optimizer.zero_grad()
# 计算梯度
loss.backward()
#w' = w - lr*grad 更新梯度
optimizer.step() train_loss.append(loss.item()) if batch_idx % 10 == 0:
print(epoch,batch_idx,loss.item()) plot_curve(train_loss) # 得到一个比较好的 [w1,b1,w2,b1,w3,b3] # 验证准确率
total_correct = 0
for x,y in test_loader"
x = x.view(x.size(0),28*28)
out = net(x)
# out: [b,10] --> pred: [b]
pred = out.argmax(dim = 1)
correct = pred.eq(y).sum().float().item()
total_correct += correct total_num = len(test_loader.dataset)
acc = total_correct / total_num
print('test acc:',acc) # 直观显示验证
x,y = next(iter(test_loader))
out = net(x.view(x.size(0),28*28))
pred = out.argmax(dim = 1)
plot_image(x,pred,'test')

龙良曲pytorch学习笔记_03的更多相关文章

  1. Pytorch学习笔记(二)---- 神经网络搭建

    记录如何用Pytorch搭建LeNet-5,大体步骤包括:网络的搭建->前向传播->定义Loss和Optimizer->训练 # -*- coding: utf-8 -*- # Al ...

  2. Pytorch学习笔记(一)---- 基础语法

    书上内容太多太杂,看完容易忘记,特此记录方便日后查看,所有基础语法以代码形式呈现,代码和注释均来源与书本和案例的整理. # -*- coding: utf-8 -*- # All codes and ...

  3. 【pytorch】pytorch学习笔记(一)

    原文地址:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 什么是pytorch? pytorch是一个基于p ...

  4. 【深度学习】Pytorch 学习笔记

    目录 Pytorch Leture 05: Linear Rregression in the Pytorch Way Logistic Regression 逻辑回归 - 二分类 Lecture07 ...

  5. Pytorch学习笔记(一)——简介

    一.Tensor Tensor是Pytorch中重要的数据结构,可以认为是一个高维数组.Tensor可以是一个标量.一维数组(向量).二维数组(矩阵)或者高维数组等.Tensor和numpy的ndar ...

  6. [PyTorch 学习笔记] 1.3 张量操作与线性回归

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson1/linear_regression.py 张量的操作 拼 ...

  7. [PyTorch 学习笔记] 1.1 PyTorch 简介与安装

    PyTorch 的诞生 2017 年 1 月,FAIR(Facebook AI Research)发布了 PyTorch.PyTorch 是在 Torch 基础上用 python 语言重新打造的一款深 ...

  8. [PyTorch 学习笔记] 1.4 计算图与动态图机制

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson1/computational_graph.py 计算图 深 ...

  9. [PyTorch 学习笔记] 2.2 图片预处理 transforms 模块机制

    PyTorch 的数据增强 我们在安装PyTorch时,还安装了torchvision,这是一个计算机视觉工具包.有 3 个主要的模块: torchvision.transforms: 里面包括常用的 ...

随机推荐

  1. VC windows 多网卡情况下 获取当前网卡ip地址

    参考 代码如下 记录下以后用得到或者能帮到有需要的朋友 #include <iostream> #include <WinSock2.h> #include <Iphlp ...

  2. 洛谷$P4316$ 绿豆蛙的归宿 期望

    正解:期望 解题报告: 传送门! 看懂题目还是挺水的$(bushi$ 三个方法,但因为题目太水了懒得一一介绍了,,,反正都是期望,,,$so$随便港个最简单的趴$QwQ$ 直接考虑每条边的贡献,就会是 ...

  3. 基于Redis的分布式锁和Redlock算法

    1 前言 前面写了4篇Redis底层实现和工程架构相关文章,感兴趣的读者可以回顾一下: Redis面试热点之底层实现篇-1 Redis面试热点之底层实现篇-2 Redis面试热点之工程架构篇-1 Re ...

  4. Redis实战 | 持久化、主从复制特性和故障处理思路

    前言 前面两篇我们了解了Redis的安装.Redis最常用的5种数据类型.本篇总结下Redis的持久化.主从复制特性,以及Redis服务挂了之后的一些处理思路. 前期回顾传送门: Linux下安装Re ...

  5. mac使用python识别图形验证码

    前言 最近在研究验证码相关的操作,所以准备记录下安装以及使用的过程.虽然之前对验证码的破解有所了解的,但是之前都是简单使用之后就不用了,没有记录一个详细的过程,所以后面再用起来也要重新从网上查找资料比 ...

  6. MinIO 搭建使用

    MinIO简介¶ MinIO 是一款基于Go语言的高性能对象存储服务,在Github上已有19K+Star.它采用了Apache License v2.0开源协议,非常适合于存储大容量非结构化的数据, ...

  7. 洛谷P3292 [SCOI2016]幸运数字 线性基+倍增

    P3292 [SCOI2016]幸运数字 传送门 题目描述 A 国共有 n 座城市,这些城市由 n-1 条道路相连,使得任意两座城市可以互达,且路径唯一.每座城市都有一个幸运数字,以纪念碑的形式矗立在 ...

  8. Spring Cloud 如何动态刷新 Git 仓库配置?

    有时候在配置中心有些参数是需要修改的,这时候如何不重启而达到实时生效的效果呢? 本文基于以下讲解: Spring Cloud Greenwich.SR3 Spring Boot 2.1.7.RELEA ...

  9. H5录音音频可视化-实时波形频谱绘制、频率直方图

    这段时间给GitHub Recorder开源库添加了两个新的音频可视化功能,比以前单一的动态波形显示丰富了好多(下图后两行是不是比第一行看起来丰满些):趁热打铁写了一个音频可视化相关扩展测试代码,下面 ...

  10. 【Linux】---Linux系统下各种常用命令总结

    在Linux系统下,“万物皆文件”,之所以强调在强调这个概念,是因为很多人已经习惯了win系统下找找点点得那种方式和思维,因此总是会觉得linux系统下很多指令既复杂又难记.其实都是一样得东西,只是w ...