pytorch学习笔记(7)--线性层
(一)Liner Layers线性层
b 是偏移量bias



代码输入:
import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader dataset = torchvision.datasets.CIFAR10("../dataset", train=False, transform=torchvision.transforms.ToTensor(), download=False)
dataloader = DataLoader(dataset, batch_size=64) class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.linear1 = Linear(196608, 10) def forward(self, input):
output = self.linear1(input)
return output tudui = Tudui() for data in dataloader:
imgs, target = data
print(imgs.shape)
output = torch.reshape(imgs, (1, 1, 1, -1))
print(output.shape)
output = tudui(output)
print(output.shape)
输出:
torch.Size([64, 3, 32, 32])
torch.Size([1, 1, 1, 196608])
torch.Size([1, 1, 1, 10])
改为 flatten 类似“平铺”:
import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader dataset = torchvision.datasets.CIFAR10("../dataset", train=False, transform=torchvision.transforms.ToTensor(), download=False)
dataloader = DataLoader(dataset, batch_size=64) class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.linear1 = Linear(196608, 10) def forward(self, input):
output = self.linear1(input)
return output tudui = Tudui() for data in dataloader:
imgs, target = data
print(imgs.shape)
# flatten
output = torch.flatten(imgs)
print(output.shape)
输出:
torch.Size([64, 3, 32, 32])
torch.Size([196608])
图形图像方面Module:

pytorch学习笔记(7)--线性层的更多相关文章
- [PyTorch 学习笔记] 3.3 池化层、线性层和激活函数层
本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson3/nn_layers_others.py 这篇文章主要介绍 ...
- [PyTorch 学习笔记] 3.2 卷积层
本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson3/nn_layers_convolution.py 这篇文 ...
- 【pytorch】pytorch学习笔记(一)
原文地址:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 什么是pytorch? pytorch是一个基于p ...
- [PyTorch 学习笔记] 4.1 权值初始化
本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/grad_vanish_explod.py 在搭建好网络 ...
- PyTorch学习笔记6--案例2:PyTorch神经网络(MNIST CNN)
上一节中,我们使用autograd的包来定义模型并求导.本节中,我们将使用torch.nn包来构建神经网络. 一个nn.Module包含各个层和一个forward(input)方法,该方法返回outp ...
- 【深度学习】Pytorch 学习笔记
目录 Pytorch Leture 05: Linear Rregression in the Pytorch Way Logistic Regression 逻辑回归 - 二分类 Lecture07 ...
- TensorFlow 深度学习笔记 从线性分类器到深度神经网络
转载请注明作者:梦里风林 Github工程地址:https://github.com/ahangchen/GDLnotes 欢迎star,有问题可以到Issue区讨论 官方教程地址 视频/字幕下载 L ...
- Pytorch学习笔记(二)---- 神经网络搭建
记录如何用Pytorch搭建LeNet-5,大体步骤包括:网络的搭建->前向传播->定义Loss和Optimizer->训练 # -*- coding: utf-8 -*- # Al ...
- Pytorch学习笔记(一)---- 基础语法
书上内容太多太杂,看完容易忘记,特此记录方便日后查看,所有基础语法以代码形式呈现,代码和注释均来源与书本和案例的整理. # -*- coding: utf-8 -*- # All codes and ...
- 学习笔记TF014:卷积层、激活函数、池化层、归一化层、高级层
CNN神经网络架构至少包含一个卷积层 (tf.nn.conv2d).单层CNN检测边缘.图像识别分类,使用不同层类型支持卷积层,减少过拟合,加速训练过程,降低内存占用率. TensorFlow加速所有 ...
随机推荐
- 085_JS Promise
js对undefined的处理 https://www.liaoxuefeng.com/wiki/001434446689867b27157e896e74d51a89c25cc8b43bdb3000 ...
- react toolkit 异步请求之后调取其他函数
在slice切片文件中,页面dispatch执行action之后,异步请求完成后调取另外一个异步请求,要在异步请求的 第二个参数添加 thunkAPI,调取thunkAPI的dispatch方法即 ...
- [Swift] 在 OC 工程中,创建和使用 Swift
1.我们创建了一个 Objective-C 的工程,叫做 playGround. 2.首先,我们需要在 工程的 Build Settings,找到 如中所示的项目,并将 Defines Module ...
- element ui form 表单 校验upload是否有上传
查到资料 可以绑定在一个多选上,校验此绑定的值 1 <el-form-item label="上传图片" prop="双向绑定值"> 2 <e ...
- WPF-UI框架MahApps.Metro使用教程
参考教程:https://www.shuzhiduo.com/A/xl561ZaoJr/ 一,MahApps.Metro安装 1,项目中引用"MahApps.Metro.dll"[ ...
- 由于CVE-2020-10770漏洞,k8s集群升级keycloak 8.0.0-->15.0.0(由于15.0.0最近又出安全漏洞,升级为16.0.0)
背景 前段时间项目组用到的8.0.0版本的keycloak被安全部门同事扫出来一个中危漏洞: A flaw was found in Keycloak, where it is possible to ...
- nginx(一) の 入门解析
OSI 模型的前三层 应用层: 每一个应用程序自定义的协议 表示层: 数据的压缩与解压缩.图片的编码与解码 会话层: 会话管理(session) 和 网络验证 .包括断点续传和服务器验证用户登录等.比 ...
- AWT+Swing区别
AWT 是Abstract Window ToolKit (抽象窗口工具包)的缩写,这个工具包提供了一套与本地图形界面进行交互的接口.AWT 中的图形函数与操作系统所提供的图形函数之间有着一一对应的关 ...
- Install MySQL wsl1
To install MySQL on WSL (ie. Ubuntu) env Ubuntu 22.04.1 LTS mysql Ver 8.0.32-0ubuntu0.22.04.2 for Li ...
- 前后端分离 基于session的验证码功能实现
前后端分离 基于session的验证码功能实现 1.后端代码 1.1 SessionContextUtils 用于获取session import javax.servlet.http.HttpSes ...