%matplotlib inline
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt d = 1
n = 200
X = torch.rand(n,d) #200*1, batch * feature_dim
#y = 3*torch.sin(X) + 5* torch.cos(X**2)
y = 4 * torch.sin(np.pi * X) * torch.cos(6*np.pi*X**2) #注意这里hid_dim 设置是超参数(如果太小,效果就不好),使用tanh还是relu效果也不同,优化器自选
hid_dim_1 = 128
hid_dim_2 = 32
d_out = 1 model = nn.Sequential(nn.Linear(d,hid_dim_1),
nn.Tanh(),
nn.Linear(hid_dim_1, hid_dim_2),
nn.Tanh(),
nn.Linear(hid_dim_2, d_out)
)
loss_func = nn.MSELoss()
optim = torch.optim.SGD(model.parameters(), 0.05) epochs = 6000
print("epoch\t loss\t")
for i in range(epochs):
y_hat = model(X)
loss = loss_func(y_hat, y)
optim.zero_grad()
loss.backward()
optim.step()
if((i+1)%100 == 0):
print("{}\t {:.5f}".format(i+1,loss.item())) #这个地方容易出错,测试时不要用原来的x,因为原来的x不是从小到达排序,导致x在连线时会混乱,所以要用np.linspace重新来构造
test_x = torch.tensor(np.linspace(0,1,50), dtype = torch.float32).reshape(-1,1)
final_y = model(test_x)
plt.scatter(X,y)
plt.plot(test_x.detach(),final_y.detach(),"r") #不使用detach会报错
print("over")
epoch	 loss
100 3.84844
200 3.83552
300 3.78960
400 3.64596
500 3.43755
600 3.17153
700 2.59001
800 2.21228
900 1.87939
1000 1.55716
1100 1.41315
1200 1.26750
1300 1.05869
1400 0.91269
1500 0.81320
1600 0.74047
1700 0.67874
1800 0.61939
1900 0.56204
2000 0.51335
2100 0.47797
2200 0.45317
2300 0.43151
2400 0.40505
2500 0.37628
2600 0.34879
2700 0.32457
2800 0.30431
2900 0.28866
3000 0.30260
3100 0.26200
3200 0.30286
3300 0.25229
3400 0.21422
3500 0.22737
3600 0.22905
3700 0.19909
3800 0.24601
3900 0.17733
4000 0.22905
4100 0.15704
4200 0.21570
4300 0.14141
4400 0.14657
4500 0.14609
4600 0.11998
4700 0.12598
4800 0.10871
4900 0.08616
5000 0.18319
5100 0.08111
5200 0.08213
5300 0.11087
5400 0.06879
5500 0.07235
5600 0.11281
5700 0.06817
5800 0.08423
5900 0.06886
6000 0.06301

3、pytorch实现最基础的MLP网络的更多相关文章

  1. 07_利用pytorch的nn工具箱实现LeNet网络

    07_利用pytorch的nn工具箱实现LeNet网络 目录 一.引言 二.定义网络 三.损失函数 四.优化器 五.数据加载和预处理 六.Hub模块简介 七.总结 pytorch完整教程目录:http ...

  2. 你必须了解的基础的 Linux 网络命令

    Linux 基础网络命令列表 我在计算机网络课程上使用 FreeBSD,不过这些 UNIX 命令应该也能在 Linux 上同样工作. 连通性 ping <host>:发送 ICMP ech ...

  3. JAVA基础知识之网络编程——-网络基础(Java的http get和post请求,多线程下载)

    本文主要介绍java.net下为网络编程提供的一些基础包,InetAddress代表一个IP协议对象,可以用来获取IP地址,Host name之类的信息.URL和URLConnect可以用来访问web ...

  4. 基础的 Linux 网络命令,你值得拥有

    导读 有抱负的 Linux 系统管理员和 Linux 狂热者必须知道的.最重要的.而且基础的 Linux 网络命令合集.在 It's FOSS 我们并非每天都谈论 Linux 的"命令行方面 ...

  5. 黑马程序员:Java基础总结----GUI&网络&IO综合开发

    黑马程序员:Java基础总结 GUI&网络&IO综合开发   ASP.Net+Android+IO开发 . .Net培训 .期待与您交流! 网络架构 C/S:Client/Server ...

  6. Java基础教程:网络编程

    Java基础教程:网络编程 基础 Socket与ServerSocket Socket又称"套接字",网络上的两个程序通过一个双向的通信连接实现数据的交换,这个连接的一端称为一个s ...

  7. Linux基础入门之网络属性配置

    Linux基础入门之网络属性配置 摘要 Linux网络属性配置,最根本的就是ip和子网掩码(netmask),子网掩码是用来让本地主机来判断通信目标是否是本地网络内主机的,从而采取不同的通信机制. L ...

  8. 【转载】[基础知识]【网络编程】TCP/IP

    转自http://mc.dfrobot.com.cn/forum.php?mod=viewthread&tid=27043 [基础知识][网络编程]TCP/IP iooops  胖友们楼主我又 ...

  9. 10个基础的linux网络和监控命令

    配置zookeeper集群时,需要查看本机ip,输入命令 hostname -i   就会只显示主机ip, 下边搜了一篇常用的    命令,闲的时候多敲敲命令,以便用的时候再找! 我下面列出来的10个 ...

随机推荐

  1. MSSQL 高并发下生成连续不重复的订单号

    参考: https://www.cnblogs.com/h-change/p/6699683.html 这里在数据库层面生成的,经测试确实不会重复. 附上自己修改后的版本,这里也可以预先生成一年的记录 ...

  2. 经典c程序100例==81--90

    [程序81] 题目:809*??=800*??+9*??+1 其中??代表的两位数,8*??的结果为两位数,9*??的结果为3位数.求??代表的两位数,及809*??后的结果. 1.程序分析: 2.程 ...

  3. 关于mybatisPlus一些坑,当条件为null时

    1.TStaffDepart 属性有值是才匹配条件,会报错,相当于mybatis if 判断 eg:TStaffDepart staffDepart = new TStaffDepart();staf ...

  4. 部署sftp服务

    部署sftp服务有风险,可能造成ssh无法连接到服务器,因此写个脚本定时覆盖一下,保证ssh可以正常使用. 创建数据目录并赋权,创建账号密码,修改ssh文件. * mkdir /sftp groupa ...

  5. webug第十五关:什么?图片上传不了?

    第十五关:什么?图片上传不了? 直接上传php一句话失败,将content type改为图片 成功

  6. linux qt 5.12.6 编译mysql驱动

    环境:ubuntu 18.4 x64.qt 5.12.6 问题:安装后是没有mysql的驱动的 解决过程: 各种搜索,先后安装了mysql mysql-client,mysql-server,和各种l ...

  7. 公式编辑器MathType之入门攻略

    许多时候在工作.学习,尤其是写文献时,需要在Word文档中输入较多公式,简单的公式或符号,可以借助Word自带的公式编辑器,但是,遇到较多并且复杂的公式,该如何高效解决呢?其实可以借助一款强大的公式编 ...

  8. 用思维导图软件MindManager整理假期

    今天带大家使用MindManager2020软件构建出2020年的节假日思维导图. 既然是做2020年的节假日思维导图,那么有个MindManager技巧就是,关于这一类思维导图我们都可以选择时间线导 ...

  9. java中String类的使用

    一.Strng类的概念 String类在我们开发中经常使用,在jdk1.8版本之前(包括1.8),String类的底层是一个char类型的数组,1.8版本之后是byte类型的数组,正是因为String ...

  10. 自学linux——1.VMware的安装及VM下centos的安装

    1.CentOS下载 网址:https://www.centos.org/download/ 网盘:https://pan.baidu.com/s/1HrtK6xNig6KC8oh6O-6fyg 提取 ...