PyTorch深度学习实践——处理多维特征的输入
处理多维特征的输入
课程来源:PyTorch深度学习实践——河北工业大学
《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili
这一讲介绍输入为多维数据时的分类。
一个数据集示例如下:
由于使用的是多维的数据,因此模型中的x和y都应该变为向量的形式,变为如下式子:
而下方针对多维数据的式子中的一部分可以使用矩阵相乘的方式表示:
w_1\\
.\\
.\\
.\\
w_8
\end{bmatrix}+b)
\]
由于我们使用的是mini-batch的计算方式,因此计算的形式如下:
\hat y^{(1)}\\
.\\
.\\
.\\
\hat y^{(N)}
\end{bmatrix}=\sigma
\begin{bmatrix}
z^{(1)}\\
.\\
.\\
.\\
z^{(N)}
\end{bmatrix}
\]
其中z的计算方式如下:
w_1\\
.\\
.\\
.\\
w_8
\end{bmatrix}+b
\]
为了利用并行计算进行优化,因此将计算改为矩阵运算如下:
z^{(1)}\\
.\\
.\\
.\\
z^{(N)}
\end{bmatrix}=
\begin{bmatrix}
x_1^{(1)}...x_8^{(1)}\\
.\\
.\\
.\\
x_1^{(N)}...x_8^{(N)}
\end{bmatrix}
\begin{bmatrix}
w_1\\
.\\
.\\
.\\
w_8
\end{bmatrix}+b
\]
由于我们想将神经网络的层数增加几层,不是只用一层来预测,因此模型使用主要部分代码示例如下:
线性层的使用:
self.linear1 = torch.nn.Linear(8, 6)
注:叠加线性层每两层之间一定要加入非线性层,否则没有意义。
非线性层的使用:
x = self.sigmoid(self.linear1(x))
一般而言,神经网络中的隐层越多,中间神经元越多学习能力越强,但是过拟合的可能性也越大。
一个简单的神经网络的模型如下图:
代码如下:
import torch
import numpy as np
import matplotlib.pyplot as plt
##1. Prepare Dataset
xy = np.loadtxt('diabetes.csv.gz', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:,:-1])
y_data = torch.from_numpy(xy[:, [-1]])
loss_list=[]
epoch_list=[]
##2. Define Model
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
##定义了三层线性层
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
##定义激活函数,除了sigmoid也有其他的如self.activate = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
##处理单元(线性层+非线性变化层),三层,用同一个变量x(每一层处理的结果都传递到下一层)
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
model = Model()
##3. Construct Loss and Optimizer
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
##4. Training Cycle
for epoch in range(10000):
##Forward
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
loss_list.append(loss.item())
epoch_list.append(epoch)
print(epoch, loss.item())
# Backward
optimizer.zero_grad()
loss.backward()
# Update
optimizer.step()
print(epoch,loss)
plt.plot(epoch_list, loss_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()
注:上述代码没有实现mini-batch的训练模式,还是使用全部输入,一次性训练的结果。
PyTorch深度学习实践——处理多维特征的输入的更多相关文章
- PyTorch深度学习实践——反向传播
反向传播 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili 目录 反向传播 笔记 作业 笔记 在之前课程中介绍的线性 ...
- PyTorch深度学习实践——多分类问题
多分类问题 目录 多分类问题 Softmax 在Minist数据集上实现多分类问题 作业 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩 ...
- PyTorch深度学习实践-Overview
Overview 1.PyTorch简介 PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序.它主要由Facebookd的人工智能小组开发,不仅能够 实现强 ...
- 深度学习实践系列(2)- 搭建notMNIST的深度神经网络
如果你希望系统性的了解神经网络,请参考零基础入门深度学习系列,下面我会粗略的介绍一下本文中实现神经网络需要了解的知识. 什么是深度神经网络? 神经网络包含三层:输入层(X).隐藏层和输出层:f(x) ...
- 深度学习实践系列(3)- 使用Keras搭建notMNIST的神经网络
前期回顾: 深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型 深度学习实践系列(2)- 搭建notMNIST的深度神经网络 在第二篇系列中,我们使用了TensorFlow搭建了第一个深度 ...
- 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码
PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...
- 【PyTorch深度学习60分钟快速入门 】Part1:PyTorch是什么?
0x00 PyTorch是什么? PyTorch是一个基于Python的科学计算工具包,它主要面向两种场景: 用于替代NumPy,可以使用GPU的计算力 一种深度学习研究平台,可以提供最大的灵活性 ...
- 【PyTorch深度学习】学习笔记之PyTorch与深度学习
第1章 PyTorch与深度学习 深度学习的应用 接近人类水平的图像分类 接近人类水平的语音识别 机器翻译 自动驾驶汽车 Siri.Google语音和Alexa在最近几年更加准确 日本农民的黄瓜智能分 ...
- PyTorch 60 分钟入门教程:PyTorch 深度学习官方入门中文教程
什么是 PyTorch? PyTorch 是一个基于 Python 的科学计算包,主要定位两类人群: NumPy 的替代品,可以利用 GPU 的性能进行计算. 深度学习研究平台拥有足够的灵活性和速度 ...
随机推荐
- Centos配置yum本地源最简单的办法
有关centos配置yum本地源的方法 一.前提 先连接镜像 然后在命令行输入如下命令 mount /dev/sr0 /mnt cd /etc/yum.repos.d/ ls 之后会看到如下的界面 二 ...
- 使用Hot Chocolate和.NET 6构建GraphQL应用(4) —— 实现Query映射功能
系列导航 使用Hot Chocolate和.NET 6构建GraphQL应用文章索引 需求 在上一篇文章使用Hot Chocolate和.NET 6构建GraphQL应用(3) -- 实现Query基 ...
- WTM多租户改造
首先简单说下多租户的几种实现方式 多租户(Multi-Tenant ),即多个租户共用一个实例,租户的数据既有隔离又有共享,说到底是要解决数据存储的问题. 常用的数据存储方式有三种. 方案一:独立数据 ...
- 【免杀技术】Tomcat内存马-Filter
Tomcat内存马-Filter型 什么是内存马?为什么要有内存马?什么又是Filter型内存马?这些问题在此就不做赘述 Filter加载流程分析 tomcat启动后正常情况下对于Filter的处理过 ...
- git命令行-新建分支与已提交分支合并
例如要将A分支的一个commit合并到B分支: 首先切换到A分支 git checkout A git log 找出要合并的commit ID : 例如 325d41 然后切换到B分支上 git ch ...
- Atcoder ARC-063
ARC063(2020.7.16) A \(A\) 题如果洛谷评分很低就不看了. B 可以发现一定是选择在一个地方全部买完然后在之后的一个地方全部卖完,那么我们就只需要即一个后缀最大值就可以计算答案了 ...
- java.lang.IllegalArgumentException: Failed to register servlet with name 'dispatcher'.Check if there is another servlet registered under the same name
前言 一年前接手了一个项目,项目始终无法运行,不管咋样,都无法处理,最近,在一次尝试中,终于成功处理了. 含义 意思很明显了,注册了一个相同的dispatcher,可是找了很久,没有相同的Contro ...
- 自定义 RestTemplate 异常处理 (转)
转自:https://ethendev.github.io/2018/11/06/RestTemplate-error-handler/ 一些 API 的报错信息通过 Response 的 body返 ...
- Java中float、double、long类型变量赋值添加f、d、L尾缀问题
展开1. 添加尾缀说明 我们知道Java在变量赋值的时候,其中float.double.long数据类型变量,需要在赋值直接量后面分别添加f或F.d或D.l或L尾缀来说明. 其中,long类型最好以 ...
- 【转】Python 并行分布式框架 Celery
原文链接:https://blog.csdn.net/freeking101/article/details/74707619 Celery 官网:http://www.celeryproject.o ...