前面在mnist中使用了三个非线性层来增加模型复杂度,并通过最小化损失函数来更新参数,下面实用最底层的方式即张量进行前向传播(暂不采用层的概念)。

主要注意点如下:

  · 进行梯度运算时,tensorflow只对tf.Variable类型的变量进行记录,而不对tf.Tensor或者其他类型的变量记录

  · 进行梯度更新时,如果采用赋值方法更新即w1=w1+x的形式,那么所得的w1是tf.Tensor类型的变量,所以要采用原地更新的方式即assign_sub函数,或者再次使用tf.Variable包起来(不推荐)

代码如下:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
import os os.environ['TF_CPP_MIN_LOG_LEVEL']='' # x:[60k,28,28]
# y:[60k]
(x,y),_=datasets.mnist.load_data() x = tf.convert_to_tensor(x,dtype=tf.float32)/255.0
y = tf.convert_to_tensor(y,dtype=tf.int32) print(x.shape,y.shape,x.dtype,y.dtype)
print(tf.reduce_min(x),tf.reduce_max(x))
print(tf.reduce_min(y),tf.reduce_max(y)) train_db=tf.data.Dataset.from_tensor_slices((x,y)).batch(128)
train_iter=iter(train_db)
sample=next(train_iter)
print('batch:',sample[0].shape,sample[1].shape) # [b,784]=>[b,256]=>[b,128]=>[b,10]
# w shape[dim_in,dim_out] b shape[dim_out]
w1 = tf.Variable(tf.random.truncated_normal([784,256],stddev=0.1))
b1 = tf.Variable(tf.zeros([256])) w2 = tf.Variable(tf.random.truncated_normal([256,128],stddev=0.1))
b2 = tf.Variable(tf.zeros([128])) w3 = tf.Variable(tf.random.truncated_normal([128,10],stddev=0.1))
b3 = tf.Variable(tf.zeros([10])) # 设置学习率
lr = 0.001
for epoch in range(10): # 对数据集迭代
for step,(x,y) in enumerate(train_db):
# x:[128,28,28] y:[128]
x = tf.reshape(x,[-1,28*28]) with tf.GradientTape() as tape: # tape只会跟踪tf.Variable
# x:[b,28*28]
# [b,784]@[784,256]+[256]=>[b,256]+[256]
h1 = x@w1 + b1
h1 = tf.nn.relu(h1) # 去线性化
h2 = h1@w2 + b2
h2 = tf.nn.relu(h2) # 去线性化
out = h2@w3 + b3 # 计算损失
y_onehot = tf.one_hot(y,depth=10)
# mse = mean(sum(y-out)^2)
loss = tf.square(y_onehot - out)
# mean:scalar
loss = tf.reduce_mean(loss) # 计算梯度
grads = tape.gradient(loss,[w1,b1,w2,b2,w3,b3])
# w1 = w1 -lr * w1_grad
w1.assign_sub(lr * grads[0]) # 原地更新
b1.assign_sub(lr * grads[1])
w2.assign_sub(lr * grads[2])
b2.assign_sub(lr * grads[3])
w3.assign_sub(lr * grads[4])
b3.assign_sub(lr * grads[5]) if step % 100 == 0:
print('epoch = ',epoch,'step =',step,',loss =',float(loss))

效果如下:

前向传播和反向传播实战(Tensor)的更多相关文章

  1. 机器学习(ML)八之正向传播、反向传播和计算图,及数值稳定性和模型初始化

    正向传播 正向传播的计算图 通常绘制计算图来可视化运算符和变量在计算中的依赖关系.下图绘制了本节中样例模型正向传播的计算图,其中左下角是输入,右上角是输出.可以看到,图中箭头方向大多是向右和向上,其中 ...

  2. 小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播

    下面要说的基本都是<动手学深度学习>这本花书上的内容,图也采用的书上的 首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望) ...

  3. caffe中 softmax 函数的前向传播和反向传播

    1.前向传播: template <typename Dtype> void SoftmaxLayer<Dtype>::Forward_cpu(const vector< ...

  4. caffe中的前向传播和反向传播

    caffe中的网络结构是一层连着一层的,在相邻的两层中,可以认为前一层的输出就是后一层的输入,可以等效成如下的模型 可以认为输出top中的每个元素都是输出bottom中所有元素的函数.如果两个神经元之 ...

  5. BP原理 - 前向计算与反向传播实例

    Outline 前向计算 反向传播 很多事情不是需要聪明一点,而是需要耐心一点,踏下心来认真看真的很简单的. 假设有这样一个网络层: 第一层是输入层,包含两个神经元i1 i2和截距b1: 第二层是隐含 ...

  6. 反向传播算法(前向传播、反向传播、链式求导、引入delta)

    参考链接: 一文搞懂反向传播算法

  7. Tensorflow笔记——神经网络图像识别(一)前反向传播,神经网络八股

      第一讲:人工智能概述       第三讲:Tensorflow框架         前向传播: 反向传播: 总的代码: #coding:utf-8 #1.导入模块,生成模拟数据集 import t ...

  8. BP神经网络反向传播之计算过程分解(详细版)

    摘要:本文先从梯度下降法的理论推导开始,说明梯度下降法为什么能够求得函数的局部极小值.通过两个小例子,说明梯度下降法求解极限值实现过程.在通过分解BP神经网络,详细说明梯度下降法在神经网络的运算过程, ...

  9. 深度学习与CV教程(4) | 神经网络与反向传播

    作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/37 本文地址:http://www.showmeai.tech/article-det ...

随机推荐

  1. PHP Ajax 跨域问题最佳解决方案 【摘自菜鸟教程】

    PHP Ajax 跨域问题最佳解决方案 分类 编程技术 http://www.runoob.com/w3cnote/php-ajax-cross-border.html 本文通过设置Access-Co ...

  2. JAVA 调用控件开发

    最近homoloCzh有个小伙伴接到一个需求说是把一个c# 写的具备扫描.调阅等功能 winfrom 影像控件嵌入到java Swing当中,让小伙伴很苦恼啊,从年前一直研究到年后,期间用了很多种方法 ...

  3. python笔记22(面向对象课程四)

    今日内容 讲作业 栈 顺序查找 可迭代对象 约束 + 异常 反射 内容详细 1.作业 1.1 代码从上到下执行 print('你好') def func(): pass func() class Fo ...

  4. vue学习(一)项目搭建

    首先需要配置node和npm,如果没有安装的话,百度一下安装教程. 如果感觉npm下载速度慢,可以使用淘宝镜像cnpm,链接地址: http://npm.taobao.org/ 安装cnpm npm ...

  5. 《Web渗透与漏洞挖掘》第一章 安全知识

    漏洞:漏洞是指一个系统存在的弱点或缺陷,系统对特定威胁攻击或危险时间的敏感性,或进行攻击威胁的可能性.漏洞可能来自应用软件或操作系统设计时的缺陷或编码时的错误,也可能来自业务交互处理过程中的设计缺陷或 ...

  6. 反弹shell备忘录

    反弹shell备忘录 简单理解,通常是我们主动发起请求,去访问服务器(某个IP的某个端口),比如我们常访问的web服务器:http(https)://ip:80,这是因为在服务器上面开启了80端口的监 ...

  7. React之深入了解虚拟DOM

    JS在手机中也可运行,React Native通过将虚拟DOM转换为底层的原生组件,即可在手机端运行,在浏览器运行时,只需要将虚拟DOM转换为真实DOM即可运行,虚拟DOM就是将真实DOM转换为JS对 ...

  8. Nginx是什么 ? 能干嘛 ?

    学习博客:https://blog.csdn.net/forezp/article/details/87887507 学习博客:https://blog.csdn.net/qq_29677867/ar ...

  9. centos7搭建SVN并配置使用http方式访问SVN服务器

    一.检查SVN是否安装 centos7系统自带SVN # rpm -qa subversion [root@localhost ~]# rpm -qa subversion subversion--. ...

  10. 4,ZooKeeper原理

    1,ZooKeeper概述 ··· 作用:     · ZooKeeper是为分布式应用程序提供的一个分布式开源协调框架,是Hadoop和Hbase的重要组件:     · 主要用于解决分布式集群中应 ...