前面在mnist中使用了三个非线性层来增加模型复杂度,并通过最小化损失函数来更新参数,下面实用最底层的方式即张量进行前向传播(暂不采用层的概念)。

主要注意点如下:

  · 进行梯度运算时,tensorflow只对tf.Variable类型的变量进行记录,而不对tf.Tensor或者其他类型的变量记录

  · 进行梯度更新时,如果采用赋值方法更新即w1=w1+x的形式,那么所得的w1是tf.Tensor类型的变量,所以要采用原地更新的方式即assign_sub函数,或者再次使用tf.Variable包起来(不推荐)

代码如下:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
import os os.environ['TF_CPP_MIN_LOG_LEVEL']='' # x:[60k,28,28]
# y:[60k]
(x,y),_=datasets.mnist.load_data() x = tf.convert_to_tensor(x,dtype=tf.float32)/255.0
y = tf.convert_to_tensor(y,dtype=tf.int32) print(x.shape,y.shape,x.dtype,y.dtype)
print(tf.reduce_min(x),tf.reduce_max(x))
print(tf.reduce_min(y),tf.reduce_max(y)) train_db=tf.data.Dataset.from_tensor_slices((x,y)).batch(128)
train_iter=iter(train_db)
sample=next(train_iter)
print('batch:',sample[0].shape,sample[1].shape) # [b,784]=>[b,256]=>[b,128]=>[b,10]
# w shape[dim_in,dim_out] b shape[dim_out]
w1 = tf.Variable(tf.random.truncated_normal([784,256],stddev=0.1))
b1 = tf.Variable(tf.zeros([256])) w2 = tf.Variable(tf.random.truncated_normal([256,128],stddev=0.1))
b2 = tf.Variable(tf.zeros([128])) w3 = tf.Variable(tf.random.truncated_normal([128,10],stddev=0.1))
b3 = tf.Variable(tf.zeros([10])) # 设置学习率
lr = 0.001
for epoch in range(10): # 对数据集迭代
for step,(x,y) in enumerate(train_db):
# x:[128,28,28] y:[128]
x = tf.reshape(x,[-1,28*28]) with tf.GradientTape() as tape: # tape只会跟踪tf.Variable
# x:[b,28*28]
# [b,784]@[784,256]+[256]=>[b,256]+[256]
h1 = x@w1 + b1
h1 = tf.nn.relu(h1) # 去线性化
h2 = h1@w2 + b2
h2 = tf.nn.relu(h2) # 去线性化
out = h2@w3 + b3 # 计算损失
y_onehot = tf.one_hot(y,depth=10)
# mse = mean(sum(y-out)^2)
loss = tf.square(y_onehot - out)
# mean:scalar
loss = tf.reduce_mean(loss) # 计算梯度
grads = tape.gradient(loss,[w1,b1,w2,b2,w3,b3])
# w1 = w1 -lr * w1_grad
w1.assign_sub(lr * grads[0]) # 原地更新
b1.assign_sub(lr * grads[1])
w2.assign_sub(lr * grads[2])
b2.assign_sub(lr * grads[3])
w3.assign_sub(lr * grads[4])
b3.assign_sub(lr * grads[5]) if step % 100 == 0:
print('epoch = ',epoch,'step =',step,',loss =',float(loss))

效果如下:

前向传播和反向传播实战(Tensor)的更多相关文章

  1. 机器学习(ML)八之正向传播、反向传播和计算图,及数值稳定性和模型初始化

    正向传播 正向传播的计算图 通常绘制计算图来可视化运算符和变量在计算中的依赖关系.下图绘制了本节中样例模型正向传播的计算图,其中左下角是输入,右上角是输出.可以看到,图中箭头方向大多是向右和向上,其中 ...

  2. 小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播

    下面要说的基本都是<动手学深度学习>这本花书上的内容,图也采用的书上的 首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望) ...

  3. caffe中 softmax 函数的前向传播和反向传播

    1.前向传播: template <typename Dtype> void SoftmaxLayer<Dtype>::Forward_cpu(const vector< ...

  4. caffe中的前向传播和反向传播

    caffe中的网络结构是一层连着一层的,在相邻的两层中,可以认为前一层的输出就是后一层的输入,可以等效成如下的模型 可以认为输出top中的每个元素都是输出bottom中所有元素的函数.如果两个神经元之 ...

  5. BP原理 - 前向计算与反向传播实例

    Outline 前向计算 反向传播 很多事情不是需要聪明一点,而是需要耐心一点,踏下心来认真看真的很简单的. 假设有这样一个网络层: 第一层是输入层,包含两个神经元i1 i2和截距b1: 第二层是隐含 ...

  6. 反向传播算法(前向传播、反向传播、链式求导、引入delta)

    参考链接: 一文搞懂反向传播算法

  7. Tensorflow笔记——神经网络图像识别(一)前反向传播,神经网络八股

      第一讲:人工智能概述       第三讲:Tensorflow框架         前向传播: 反向传播: 总的代码: #coding:utf-8 #1.导入模块,生成模拟数据集 import t ...

  8. BP神经网络反向传播之计算过程分解(详细版)

    摘要:本文先从梯度下降法的理论推导开始,说明梯度下降法为什么能够求得函数的局部极小值.通过两个小例子,说明梯度下降法求解极限值实现过程.在通过分解BP神经网络,详细说明梯度下降法在神经网络的运算过程, ...

  9. 深度学习与CV教程(4) | 神经网络与反向传播

    作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/37 本文地址:http://www.showmeai.tech/article-det ...

随机推荐

  1. [python]getpass模块

    python3的input函数不能隐藏用户输入,可以用getpass模块的getpass方法获取用户输入的时候用于隐藏显示密码. *需要注意的是该方法在IDE中看不到隐藏效果,在内置IDLE中会有Ge ...

  2. CCF_ 201312-3_最大的矩形

    遍历数组中每一元素,左右延伸得出宽度. #include<iostream> #include<cstdio> using namespace std; int main() ...

  3. StackExchange.Redis 之 hash 类型示例

    StackExchange.Redis 的组件封装示例网上有很多,自行百度搜索即可. 这里只演示如何使用Hash类型操作数据: // 在 hash 中存入或修改一个值 并设置order_hashkey ...

  4. sqlserver install on linux chapter two

    The previous chapter is tell us how to install sqlerver on linuix Today, we will see how to make it ...

  5. 在centos6.3下安装php的Xdebug

    首先下载一个xdebug http://www.xdebug.org/docs/ 官网上有windos版本和linux源码版本的,我们下载一个源码包xdebug-2.2.5.tgz 然后进入安装流程 ...

  6. rabbit MQ 消息队列

    为什么会需要消息队列(MQ)? 一.消息队列概述消息队列中间件是分布式系统中重要的组件,主要解决应用解耦,异步消息,流量削锋等问题,实现高性能,高可用,可伸缩和最终一致性架构.目前使用较多的消息队列有 ...

  7. 在Linux实例上自动安装并运行VNC Server

    #!/bin/bash ######################################### #Function: install vnc server #Usage: bash ins ...

  8. 【重新整理】log4j 2的使用

    一 概述 1.1 日志框架 日志接口(slf4j) slf4j是对所有日志框架制定的一种规范.标准.接口,并不是一个框架的具体的实现,因为接口并不能独立使用,需要和具体的日志框架实现配合使用(如log ...

  9. 如何通过adb command 完成自动SD卡升级?

    如何通过adb command 完成自动SD卡升级? 原创 2014年09月09日 10:50:57 2746 通过adb 命令的方式,免去了按powerkey+volumeup进入menu sele ...

  10. Shiro -- (一)简介

    简介: Apache Shiro 是一个强大易用的 Java 安全框架,提供了认证.授权.加密和会话管理等功能,对于任何一个应用程序,Shiro 都可以提供全面的安全管理服务.并且相对于其他安全框架, ...