参考: https://github.com/Iallen520/lhy_DL_Hw/blob/master/PyTorch_Introduction.ipynb

模拟一个回归模型,y = X * w + 随机数

如,y : n*1矩阵, X : n*2矩阵 , w : 1,2矩阵

设置true_w = [[-1.0], [2.0]] , 随机初始化w 比如[[1.0], [0.0]],目标是拟合出正确的w

代码如下:

#自己写一个test case
import torch d = 2
n = 50
X = torch.randn(n,d)
true_w = torch.tensor([[-1.0], [2.0]])
y = X @ true_w + torch.randn(n,1) * 0.1
print('X shape', X.shape)
print('y shape', y.shape)
print('w shape', true_w.shape) print(X.shape)
print(true_w.shape) #w = torch.rand(2,1, requires_grad = True)
w = torch.tensor([[1.],[0]], requires_grad= True)
print("w:", w) print('iter,\tloss,\tw')
for i in range(20):
loss = torch.norm(y - torch.matmul(X,w))**2 / n loss.backward() w.data = w.data - 0.1 * w.grad print('{},\t{:.2f},\t{}'.format(i, loss.item(), w.view(2).detach().numpy())) w.grad.zero_() print('\ntrue w\t\t', true_w.view(2).numpy())
print('estimated w\t', w.view(2).detach().numpy())
X shape torch.Size([50, 2])
y shape torch.Size([50, 1])
w shape torch.Size([2, 1])
torch.Size([50, 2])
torch.Size([2, 1])
w: tensor([[1.],
[0.]], requires_grad=True)
iter, loss, w
0, 6.20, [0.7062541 0.32114884]
1, 4.45, [0.45446268 0.59016734]
2, 3.20, [0.2387764 0.81564814]
3, 2.30, [0.05413117 1.0047431 ]
4, 1.65, [-0.10385066 1.1634098 ]
5, 1.19, [-0.23894812 1.2966142 ]
6, 0.86, [-0.3544196 1.4084985]
7, 0.62, [-0.4530713 1.5025208]
8, 0.45, [-0.5373176 1.5815694]
9, 0.32, [-0.60923356 1.6480589 ]
10, 0.24, [-0.6706013 1.7040086]
11, 0.17, [-0.7229501 1.7511086]
12, 0.13, [-0.76759106 1.7907746 ]
13, 0.09, [-0.8056477 1.8241924]
14, 0.07, [-0.8380821 1.8523566]
15, 0.05, [-0.86571765 1.8761011 ]
16, 0.04, [-0.8892586 1.8961263]
17, 0.03, [-0.90930706 1.91302 ]
18, 0.02, [-0.9263775 1.9272763]
19, 0.02, [-0.9409093 1.9393103] true w [-1. 2.]
estimated w [-0.9409093 1.9393103]

出错点:

1.初始化tensor时,要为float,否则容易报错

2. 初始化时设置 requires_grad = True

3. 定义的loss要在for循坏之内

4. 要用w.data不用w,否则报错,可能和pytorch初始化有关

5. w.grad.zero_()  注意这个写法

1、pytorch写的第一个Linear模型(原始版,不调用nn.Modules模块)的更多相关文章

  1. 2、pytorch——Linear模型(最基础版,理解框架,背诵记忆)(调用nn.Modules模块)

    #define y = X @ w import torch from torch import nn #第一模块,数据初始化 n = 100 X = torch.rand(n,2) true_w = ...

  2. Online Coding开发模式 (通过在线配置实现一个表模型的增删改查功能,无需写任何代码)

    JEECG 智能开发平台. 开发模式由代码生成器转变为Online Coding模式                      (通过在线配置实现一个表模型的增删改查功能,无需一行代码,支持用户自定义 ...

  3. GAN实战笔记——第三章第一个GAN模型:生成手写数字

    第一个GAN模型-生成手写数字 一.GAN的基础:对抗训练 形式上,生成器和判别器由可微函数表示如神经网络,他们都有自己的代价函数.这两个网络是利用判别器的损失记性反向传播训练.判别器努力使真实样本输 ...

  4. Pytorch写CNN

    用Pytorch写了两个CNN网络,数据集用的是FashionMNIST.其中CNN_1只有一个卷积层.一个全连接层,CNN_2有两个卷积层.一个全连接层,但训练完之后的准确率两者差不多,且CNN_1 ...

  5. pytorch入门2.1构建回归模型初体验(模型构建)

    pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...

  6. 详解Pytorch中的网络构造,模型save和load,.pth权重文件解析

    转载:https://zhuanlan.zhihu.com/p/53927068 https://blog.csdn.net/wangdongwei0/article/details/88956527 ...

  7. 【小白学PyTorch】1 搭建一个超简单的网络

    文章目录: 目录 1 任务 2 实现思路 3 实现过程 3.1 引入必要库 3.2 创建训练集 3.3 搭建网络 3.4 设置优化器 3.5 训练网络 3.6 测试 1 任务 首先说下我们要搭建的网络 ...

  8. 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...

  9. Anaconda+django写出第一个web app(二)

    今天开始建立App中的第一个Model,命名为Tutorial. Model的定义在main文件夹下的models.py中通过类进行,我们希望Tutorial这个model包含三个属性:标题.内容和发 ...

随机推荐

  1. CV 履历 格式

    CV 指的是 "Curriculum Vitae" Curriculum vitae 在拉丁语中的意思是"生命的故事" CV 经常被称为 "Resum ...

  2. Spring Boot 2.4.0 正式发布!全新的配置处理机制,拥抱云原生!

    2020年11月12日,Spring官方发布了Spring Boot 2.4.0 GA的公告. 在这个版本中增加了大量的新特性和改进,下面我们一起看看在这个重要版本中都有哪些值得关注的内容! 更新内容 ...

  3. socket connect tcp_v4_connect

    tcp_v4_connect /* This will initiate an outgoing connection. tcp_v4_connect函数初始化一个对外的连接请求,创建一个SYN包并发 ...

  4. [LeetCode题解]206. 反转链表 | 迭代 + 递归

    方法一:迭代 解题思路 遍历过程,同时反转,这里需要一个指针 pre 要保存前一个节点. 代码 /** * Definition for singly-linked list. * public cl ...

  5. Linux权限位(含特殊权限位s s t) 及chown\chmod命令使用

    1.普通权限位 ls –l查看文件的属性 [root@oldboy ~]# ls -l -rw-------. 1 root root 1073 Mar 4 22:08 anaconda-ks.cfg ...

  6. 通过一道CTF学习HTTP协议请求走私

    HTTP请求走私 HTTP请求走私 HTTP请求走私是针对于服务端处理一个或者多个接收http请求序列的方式,进行绕过安全机制,实施未授权访问一种攻击手段,获取敏感信息,并直接危害其他用户. 请求走私 ...

  7. 面试官:你说你精通SpringBoot,你给我说一下类的自动装配吧

    ## 剖析@SpringBootApplication注解 创建一个SpringBoot工程后,SpringBoot会为用户提供一个Application类,该类负责项目的启动: ```@Spring ...

  8. 兄弟萌,这份SpringMVC框架学习笔记真的建议反复看,写的太细了

    概述 是Spring为展现层提供的基于MVC设计理念的Web框架,通过一套MVC注解,让POJO成为处理请求的控制器,而无需实现任何接口 支持REST风格的URL请求 采用松散耦合的可插拔组件结构,比 ...

  9. Markdown的应知应会

    Markdown介绍 什么是Markdown Markdown是一种纯文本.轻量级的标记语言,常用作文本编辑器使用.和记事本.notepad++相比,Markdown可以进行排版:和Word相比,Ma ...

  10. hashmap(有空可以看看算法这本书中对于这部分的实现,很有道理)

    //转载:https://baijiahao.baidu.com/s?id=1618550070727689060&wfr=spider&for=pc 1.为什么用HashMap? H ...