pytorch入门2.x构建回归模型系列:

pytorch入门2.0构建回归模型初体验(数据生成)

pytorch入门2.1构建回归模型初体验(模型构建)

pytorch入门2.2构建回归模型初体验(开始训练)

经过上面两个部分,我们完成了数据生成、网络结构定义,下面我们终于可以小试牛刀,训练模型了!

首先,我们先定义一些训练时要用到的参数:

EPOCH = 1000  # 就是要把数据用几遍
LR = 0.1 # 优化器的学习率,类似爬山的时候应该迈多大的步子。
BATCH_SIZE=50

其次,按照定义的模型类实例化一个网络:

if torch.cuda.is_available():  # 检查机器是否支持GPU计算,如果支持GPU计算,那么就用GPU啦,快!
model = LinearRegression().cuda() # 这里的这个.cuda操作就是把模型放到GPU上
else:
model = LinearRegression() # 如果不支持,那么用cpu也可以哦
# 定义损失函数,要有个函数让模型的输出知道他做的对、还是错,对到什么程度或者错到什么程度,这就是损失函数。
loss_fun = nn.MSELoss() # loss function
# 定义优化器,就是告诉模型,改如何优化内部的参数、还有该迈多大的步子(学习率LR)。
optimizer = torch.optim.SGD(model.parameters(), lr=LR) # opimizer

下面终于可以开始训练了,但是训练之前解释一下EPOCH,比如我们有300个样本,训练的时候我们不会把300个样本放到模型里面训练一遍,就停止了。即在模型中我们每个样本不会只用一次,而是会使用多次。这300个样本到底要用多少次呢,就是EPOCH的值的意义。

for epoch in range(EPOCH):
# 此处类似前面实例化模型是,我们把模型放到GPU上来跑道理是一样的。此处,我们要把变量放到GPU上,跑的快!如果不行, 那就放到CPU上吧。
# 其中x是输入数据,y是训练集的groundtruth。为什么要有y呢?因为我们要知道我们算的对不对,到底有多对(这里由损失函数控制)
if torch.cuda.is_available():
x = Variable(x_train).cuda()
y = Variable(y_train).cuda()
else:
x = Variable(x_train)
y = Variable(y_train)
# 我们把x丢进模型,得到输出y。哇,是不是好简单,这样我们就得到结果了呢?但是不要高兴的太早,我们只是把输入数据放到一个啥都不懂(参数没有训练)的模型中,得到的结果肯定不准啊。不准的结果怎么办,看下一步。
out = model(x)
# 拿到模型输出的结果,我们就要看看模型算的准不准,就是计算损失函数了。
loss = loss_fun(out,y)
# 好了好了,我已经知道模型算的准不准了,那么就该让模型自己去朝着好的方向优化了。模型,你已经是个大孩子了,应该会自己优化的。
optimizer.zero_grad() # 在优化之前,我们首先要清空优化器的梯度。因为每次循环都要靠这个优化器呢,不能翻旧账,就只算这次我们怎么优化。
loss.backward() # 优化开始,首先,我们要把算出来的误差、损失倒着传回去。(是你们这些模块给我算的这个值,现在这个值有错误,错了这么多,返回给你们,你们自己看看自己错哪了) optimizer.step() # 按照优化器的方式,一步一步优化吧。 if (epoch+1)%100==0: # 中间每循环100次,偷偷看看结果咋样。
print('Epoch[{}/{}],loss:{:.6f}'.format(epoch+1,EPOCH,loss.data.item()))

上面我们训练了1000(EPOCH=1000)次,应该差不多了。是时候看看训练的咋样啦!其实我们已经知道训练的咋样了,就是上面输出的损失值,只不过是在训练集上的。

下面我们就要看看在测试集上表现咋样呢?

model.eval()  # 开启模型的测试模式
# 拿到测试集中x的值,放到GPU上
if torch.cuda.is_available():
x = x_test.cuda()
#通过把x的值输入模型,得到预测结果
predict = model(x)
# 那预测结果的值取出来,因为预测结果是封装好的,现在h只要它的值。
predict = predict.cpu().data.numpy()
#画个图看看,到底拟合成啥样了?
plt.plot(x.cpu().numpy(),y_test.cpu().numpy(),'ro',label='original data')
plt.plot(sorted(x.cpu().numpy()),sorted(predict),label='fitting line')
plt.show()

看看图,结果还凑合吧,要想结果更好需要进一步对模型的结构、超参数进行设置,我们之后在学。

到此为止,我们用pytorch就已经建立完,并且训练完一个线性回归模型了,我们可以回顾下,多看几遍,仔细回想一下这里面到底发生了什么。

pytorch入门2.2构建回归模型初体验(开始训练)的更多相关文章

  1. pytorch入门2.0构建回归模型初体验(数据生成)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  2. pytorch入门2.1构建回归模型初体验(模型构建)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  3. 2,turicreate入门 - 一个简单的回归模型

    turicreate入门系列文章目录 1,turicreate入门 - jupyter & turicreate安装 2,turicreate入门 - 一个简单的回归模型 3,turicrea ...

  4. cucumber java从入门到精通(1)初体验

    cucumber java从入门到精通(1)初体验 cucumber在ruby环境下表现让人惊叹,作为BDD框架的先驱,cucumber后来被移植到了多平台,有cucumber-js以及我们今天要介绍 ...

  5. python--爬虫入门(七)urllib库初体验以及中文编码问题的探讨

    python系列均基于python3.4环境 ---------@_@? --------------------------------------------------------------- ...

  6. 【小白学PyTorch】18 TF2构建自定义模型

    [机器学习炼丹术]的炼丹总群已经快满了,要加入的快联系炼丹兄WX:cyx645016617 参考目录: 目录 1 创建自定义网络层 2 创建一个完整的CNN 2.1 keras.Model vs ke ...

  7. 3,turicreate入门 - 优化回归模型,使得预测更准确

    turicreate入门系列文章目录 1,turicreate入门 - jupyter & turicreate安装 2,turicreate入门 - 一个简单的回归模型 3,turicrea ...

  8. 吴裕雄 python 神经网络——TensorFlow实现回归模型训练预测MNIST手写数据集

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...

  9. weka实际操作--构建分类、回归模型

    weka提供了几种处理数据的方式,其中分类和回归是平时用到最多的,也是非常容易理解的,分类就是在已有的数据基础上学习出一个分类函数或者构造出一个分类模型.这个函数或模型能够把数据集中地映射到某个给定的 ...

随机推荐

  1. 关于web标准

    从我去年接触前端,到现在,我似乎都没有特意去研究过web标准.我只知道传统上推崇结构样式行为分离,js.css.html各司其职, 不推荐在文档的节点上写类似<p onclick=“fn()”& ...

  2. Hbase-二级索引 Hbase+Hbase-indexer+solr (CDH)

    最近一段时间工作涉及到hbase sql查询和可视化展示的工作,hbase作为列存储,数据单一为二进制数组,本身就不擅长sql查询:而且有hive来作为补充作为sql查询和存储,但是皮皮虾需要低延迟的 ...

  3. J2EE项目分类管理中,提交表单数据是二进制形式时,对数据的修改失败。category赋值失败。

    原因: 在条件判断时,对字符串的比较进行了错误比较. 解决方法: A==B,比较的是两个字符串是否是同一个对象. A.equal(B),比较的是两个字符串内容是否相同. 出现错误是用了第一种比较,应该 ...

  4. PHP cookie基本操作

    PHP中Cookie的使用---添加/更新/删除/获取Cookie 及 自动填写该用户的用户名和密码和判断是否第一次登陆 什么是cookie 服务器在客户端保存用户的信息,比如登录名,密码等 这些数据 ...

  5. 这次终于可以愉快的进行 appium 自动化测试了

    appium 是进行 app 自动化测试非常成熟的一套框架.但是因为 appium 设计到的安装内容比较多,很多同学入门都跪在了环境安装的部分.本篇讲述 appium 安卓环境的搭建,希望让更多童鞋轻 ...

  6. 像宝石一样的Java原子类

    十五年前,多处理器系统是高度专业化的系统,通常耗资数十万美元(其中大多数具有两到四个处理器). 如今,多处理器系统既便宜又丰富,几乎主流的微处理器都内置了对多处理器的支持,很多能够支持数十或数百个处理 ...

  7. JavaScript变量语法扩展

    1.更新变量 一个变量被重新赋值后,它原有的值会被覆盖,变量值将会以最后一次赋值为准. 2.同时声明多个变量 var   age = 18 , address ='火影村' , gz = 2000 ; ...

  8. 《计算机网络》课程笔记 (Ch03-运输层)

    为运行在不同主机上的应用进程之间提供逻辑通信功能. 将应用层报文切分为块,然后加上运输层首部,形成报文段,交付给网络层. 多路复用与多路分解 将网络层提供的主机到主机交付服务延伸到进程到进程交付服务. ...

  9. 关于如何查看论文是否被SCI或者EI收录

    最好的方法,在高校图书馆网站上进行查询. 另外还有就是去对应网站查询: SCI:https://apps.webofknowledge.com/UA_GeneralSearch_input.do?pr ...

  10. Excel常用小方法

    Excel快捷键 Excel中处理工作表的快捷键 插入新工作表 Shift+F11或Alt+Shift+F1 移动到工作簿中的下一张工作表 Ctrl+PageDown 移动到工作簿中的上一张工作表 C ...