%matplotlib inline
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt d = 1
n = 200
X = torch.rand(n,d) #200*1, batch * feature_dim
#y = 3*torch.sin(X) + 5* torch.cos(X**2)
y = 4 * torch.sin(np.pi * X) * torch.cos(6*np.pi*X**2) #注意这里hid_dim 设置是超参数(如果太小,效果就不好),使用tanh还是relu效果也不同,优化器自选
hid_dim_1 = 128
hid_dim_2 = 32
d_out = 1 model = nn.Sequential(nn.Linear(d,hid_dim_1),
nn.Tanh(),
nn.Linear(hid_dim_1, hid_dim_2),
nn.Tanh(),
nn.Linear(hid_dim_2, d_out)
)
loss_func = nn.MSELoss()
optim = torch.optim.SGD(model.parameters(), 0.05) epochs = 6000
print("epoch\t loss\t")
for i in range(epochs):
y_hat = model(X)
loss = loss_func(y_hat, y)
optim.zero_grad()
loss.backward()
optim.step()
if((i+1)%100 == 0):
print("{}\t {:.5f}".format(i+1,loss.item())) #这个地方容易出错,测试时不要用原来的x,因为原来的x不是从小到达排序,导致x在连线时会混乱,所以要用np.linspace重新来构造
test_x = torch.tensor(np.linspace(0,1,50), dtype = torch.float32).reshape(-1,1)
final_y = model(test_x)
plt.scatter(X,y)
plt.plot(test_x.detach(),final_y.detach(),"r") #不使用detach会报错
print("over")
epoch	 loss
100 3.84844
200 3.83552
300 3.78960
400 3.64596
500 3.43755
600 3.17153
700 2.59001
800 2.21228
900 1.87939
1000 1.55716
1100 1.41315
1200 1.26750
1300 1.05869
1400 0.91269
1500 0.81320
1600 0.74047
1700 0.67874
1800 0.61939
1900 0.56204
2000 0.51335
2100 0.47797
2200 0.45317
2300 0.43151
2400 0.40505
2500 0.37628
2600 0.34879
2700 0.32457
2800 0.30431
2900 0.28866
3000 0.30260
3100 0.26200
3200 0.30286
3300 0.25229
3400 0.21422
3500 0.22737
3600 0.22905
3700 0.19909
3800 0.24601
3900 0.17733
4000 0.22905
4100 0.15704
4200 0.21570
4300 0.14141
4400 0.14657
4500 0.14609
4600 0.11998
4700 0.12598
4800 0.10871
4900 0.08616
5000 0.18319
5100 0.08111
5200 0.08213
5300 0.11087
5400 0.06879
5500 0.07235
5600 0.11281
5700 0.06817
5800 0.08423
5900 0.06886
6000 0.06301

3、pytorch实现最基础的MLP网络的更多相关文章

  1. 07_利用pytorch的nn工具箱实现LeNet网络

    07_利用pytorch的nn工具箱实现LeNet网络 目录 一.引言 二.定义网络 三.损失函数 四.优化器 五.数据加载和预处理 六.Hub模块简介 七.总结 pytorch完整教程目录:http ...

  2. 你必须了解的基础的 Linux 网络命令

    Linux 基础网络命令列表 我在计算机网络课程上使用 FreeBSD,不过这些 UNIX 命令应该也能在 Linux 上同样工作. 连通性 ping <host>:发送 ICMP ech ...

  3. JAVA基础知识之网络编程——-网络基础(Java的http get和post请求,多线程下载)

    本文主要介绍java.net下为网络编程提供的一些基础包,InetAddress代表一个IP协议对象,可以用来获取IP地址,Host name之类的信息.URL和URLConnect可以用来访问web ...

  4. 基础的 Linux 网络命令,你值得拥有

    导读 有抱负的 Linux 系统管理员和 Linux 狂热者必须知道的.最重要的.而且基础的 Linux 网络命令合集.在 It's FOSS 我们并非每天都谈论 Linux 的"命令行方面 ...

  5. 黑马程序员:Java基础总结----GUI&网络&IO综合开发

    黑马程序员:Java基础总结 GUI&网络&IO综合开发   ASP.Net+Android+IO开发 . .Net培训 .期待与您交流! 网络架构 C/S:Client/Server ...

  6. Java基础教程:网络编程

    Java基础教程:网络编程 基础 Socket与ServerSocket Socket又称"套接字",网络上的两个程序通过一个双向的通信连接实现数据的交换,这个连接的一端称为一个s ...

  7. Linux基础入门之网络属性配置

    Linux基础入门之网络属性配置 摘要 Linux网络属性配置,最根本的就是ip和子网掩码(netmask),子网掩码是用来让本地主机来判断通信目标是否是本地网络内主机的,从而采取不同的通信机制. L ...

  8. 【转载】[基础知识]【网络编程】TCP/IP

    转自http://mc.dfrobot.com.cn/forum.php?mod=viewthread&tid=27043 [基础知识][网络编程]TCP/IP iooops  胖友们楼主我又 ...

  9. 10个基础的linux网络和监控命令

    配置zookeeper集群时,需要查看本机ip,输入命令 hostname -i   就会只显示主机ip, 下边搜了一篇常用的    命令,闲的时候多敲敲命令,以便用的时候再找! 我下面列出来的10个 ...

随机推荐

  1. spring处理静态资源方式

    1. <mvc:default-servlet-handler/>default-servlet-handler在SpringMVC上下文定义一个org.springframework.w ...

  2. 主动关闭 time-wait 2msl 处理

    先上传后面整理 /* * This routine is called by the ICMP module when it gets some * sort of error condition. ...

  3. 栈(Stack)和队列(Queue)是两种操作受限的线性表。

    (线性表:线性表是一种线性结构,它是一个含有n≥0个结点的有限序列,同一个线性表中的数据元素数据类型相同并且满足"一对一"的逻辑关系. "一对一"的逻辑关系指的 ...

  4. linux中/etc/passwd和/etc/shadow文件说明

    /etc/passwd是用来存储登陆用户信息: [root@localhost test]# cat /etc/passwd root:x:0:0:root:/root:/bin/bash bin:x ...

  5. 12.java设计模式之代理模式

    基本介绍: 代理模式(Proxy)为一个对象提供一个替身,以控制对这个对象的访问.即通过代理对象访问目标对象.这样做的好处是:可以在目标对象实现的基础上,增强额外的功能操作,即扩展目标对象的功能,想在 ...

  6. 字符串匹配—KMP算法

    KMP算法是一种改进的字符串匹配算法,由D.E.Knuth,J.H.Morris和V.R.Pratt提出的,因此人们称它为克努特-莫里斯-普拉特操作(简称KMP算法).KMP算法的核心是利用匹配失败后 ...

  7. sqlilab less32-less37

    less-32 过滤了单引号,双引号,斜杠,同时设置数据库为GBK编码,可以考虑宽字节注入, 当设置gbk编码后,遇到连续两个字节,都符合gbk取值范围,会自动解析为一个汉字.用脚本来测试下哪些符合 ...

  8. 小程序后端获取openid (php实例)

    小程序获取openid 首先,小程序授权登录的时候,前端就会获取到code 而后端接收到了code之后,就可以向微信发起请求,获取用户的openid代码如下: <?php $code = $_R ...

  9. 使用SpringBoot进行优雅的数据验证

    JSR-303 规范 在程序进行数据处理之前,对数据进行准确性校验是我们必须要考虑的事情.尽早发现数据错误,不仅可以防止错误向核心业务逻辑蔓延,而且这种错误非常明显,容易发现解决. JSR303 规范 ...

  10. RSA脚本环境配置-攻防世界-OldDriver

    [Crypto] 题目链接 [RSA算法解密] 审题分析 首先拿到一个压缩包,解压得到文件enc.txt. 先不用去管其他,第一眼enc马上联想到 RSA解密.接着往下看 [{"c" ...