1. 了解引入的需要神经网络解决的问题
  2. 学习用神经网络的基本结构、表达方式和编程实现
  3. 学习训练神经网络的基本方法

三好学生成绩问题

总分 = 德育分 * 60% + 智育分 * 60% + 体育分 * 60%

假设家长不知道这个规则,已知:

  • 学校一定是以德育分、智育分和体育分三项分数的总分来确定三好学生的
  • 计算总分时,三项分数应该有各自的权重系数
  • 各自孩子的三项分数都已经知道,总分也已经知道

经过家长们的分析,只有三项分数各自乘以的权重系数是未知的。问题演变成求解方程:w1x + w2y + w3z = A 中的三个 w 即权重。其中 x、y、z、A 分别对应几位学生的德育分、智育分、体育分和总分。

两个方程式解三个未知数无法求解:

90w1 + 80w2 + 70w3 = 85

98
w1 + 95w2 + 87w3 = 96

搭建对应的网络神经

神经网络模型图的一般约定:

  • 神经网络图一般包含一个输入层、一个或多个隐藏层,以及一个输出层
  • 输入层是描逑输入数据的形态的(输入节点)
  • 隐藏层是描迒神经网络模型结构中最重要的部分隐藏层可以有多个;每一层有一个或多个神经元(神经元节点/节点);每个节点接收上层的数据并进行运算向下层输出数据(计算操作/操作)
  • 输出层一般是神经网络的最后一层,包含一个或多个输出节点

神经网络的代码:

import tensorflow as tf

x1 = tf.placeholder(dtype = tf.float32)
x2 = tf.placeholder(dtype = tf.float32)
x3 = tf.placeholder(dtype = tf.float32) w1 = tf.Variable(0.1, dtype = tf.float32)
w2 = tf.Variable(0.1, dtype = tf.float32)
w3 = tf.Variable(0.1, dtype = tf.float32) n1 = x1 * w1
n2 = x2 * w2
n3 = x3 * w3 y = n1 + n2 + n3 sess = tf.Session()
init = tf.global_variable_initializer() sess.run(init) result = sess.run([x1, x2, x3, w1, w2, w3, y], feed_dict={x1: 90, x2: 80, x3: 70})
print(result)
x1 = tf.placeholder(dtype = tf.float32)
x2 = tf.placeholder(dtype = tf.float32)
x3 = tf.placeholder(dtype = tf.float32)

通过 tf.placeholder 定义三个占位符(placeholder),作为神经网络的输入节点,来准备分别接收德育、智育、体育三门分数作为神经网络的输入。dtype 是 data type 的缩写,dtype = tf.float3 是命令参数,tf.float32 代表 32 位小数。

w1 = tf.Variable(0.1, dtype = tf.float32)
w2 = tf.Variable(0.1, dtype = tf.float32)
w3 = tf.Variable(0.1, dtype = tf.float32)

通过 tf.Variable() 定义三个可变参数。

n1 = x1 * w1
n2 = x2 * w2
n3 = x3 * w3

n1、n2、n3 是三个隐藏层节点,实际上是他们的计算算式。

y = n1 + n2 + n3

定义输出节点 y,也就是总分的计算公式(加权求和)。至此,神经网络模型的定义完成。

sess = tf.Session()

定义神经网络的会话对象

init = tf.global_variable_initializer()

tf.global_variable_initializer() 返回专门用于初始化可变参数的对象。

sess.run(init)

初始化所有的可变参数。

result = sess.run([x1, x2, x3, w1, w2, w3, y], feed_dict={x1: 90, x2: 80, x3: 70})
print(result)

[x1, x2, x3, w1, w2, w3, y] 为要查看的结果项,feed_dict={x1: 90, x2: 80, x3: 70} 为输入的数据。输入三门分数运行神经网络并获得该神经网络输出的节点值。

运行代码,查看结果:

根据随意设置的可变参数初始值计算出的输出结果正确,证明搭建的神经网络可以运行,但不能真正投入使用,存在一定误差。

如果你使用了 TensorFlow 2.x 上述代码中可能存在兼容问题,但是可以通过更改部分代码解决:

代码
# import tensorflow as tf
import tensorflow.compat.v1 as tf tf.disable_v2_behavior() x1 = tf.placeholder(dtype = tf.float32)
x2 = tf.placeholder(dtype = tf.float32)
x3 = tf.placeholder(dtype = tf.float32) w1 = tf.Variable(0.1, dtype = tf.float32)
w2 = tf.Variable(0.1, dtype = tf.float32)
w3 = tf.Variable(0.1, dtype = tf.float32) n1 = x1 * w1
n2 = x2 * w2
n3 = x3 * w3 y = n1 + n2 + n3 sess = tf.Session()
# init = tf.global_variable_initializer()
init = tf.compat.v1.global_variables_initializer() sess.run(init) result = sess.run([x1, x2, x3, w1, w2, w3, y], feed_dict={x1: 90, x2: 80, x3: 70})
print(result)

训练神经网络

神经网络在投入使用前,都要经过训练(train)的过程才能有准确的输出。

  • 神经网络训练时一定要有训练数据
  • 有监督学习中,训练数据中的每一条是由一组输入值和一个目标值组成的
  • 目标值就是根据这一组输入数值应该得到的 “准答案”
  • 般来说,训练数据越多、离散性(覆盖面)越强越好

x1 = tf.placeholder(dtype = tf.float32)
x2 = tf.placeholder(dtype = tf.float32)
x3 = tf.placeholder(dtype = tf.float32) yTrain = tf.placeholder(dtype = tf.float32)

给神经网络增加一个输入项 —— 目标值 yTrain,用来表示正确的总分结果。增加误差函数 loss,优化器 optimizer 和训练对象 train

y = n1 + n2 + n3

loss = tf.abs(y - yTrain)
optimizer = tf.train.RMSPropOptimizer(0.001)
train = optimizer.minimize(loss)

tf.abs 函数用于取绝对值:计算结果 y 与目标值 yTrain 之间的误差。使用 RMSProp 优化器其中参数是学习率。optimizer.minimize 让优化器按照把 loss 最小化的原则来调整可变参数。

“误差函数”(又叫损失函数)用于让神经网络来判断当前网络的计算结果与目标值(也就是标准答案)相差多少。“训练对象”被神经网络用于控制训练的方式,常见的训练的方式是设法使误差函数的计算值越来越小。

result = sess.run([train, x1, x2, x3, w1, w2, w3, y, yTrain], feed_dict={x1: 90, x2: 80, x3: 70, yTrain: 96})
print(result)

训练两次并查看输出结果,注意与前面的区别:训练时要在 sess.run 函数的第一个参数中添加 train 这个训练对象;在 feed_dict 参数中多指定了 Train 的数值。

代码
# import tensorflow as tf
import tensorflow.compat.v1 as tf tf.disable_v2_behavior() x1 = tf.placeholder(dtype = tf.float32)
x2 = tf.placeholder(dtype = tf.float32)
x3 = tf.placeholder(dtype = tf.float32) yTrain = tf.placeholder(dtype = tf.float32) w1 = tf.Variable(0.1, dtype = tf.float32)
w2 = tf.Variable(0.1, dtype = tf.float32)
w3 = tf.Variable(0.1, dtype = tf.float32) n1 = x1 * w1
n2 = x2 * w2
n3 = x3 * w3 y = n1 + n2 + n3 loss = tf.abs(y - yTrain)
optimizer = tf.train.RMSPropOptimizer(0.001)
train = optimizer.minimize(loss) sess = tf.Session()
# init = tf.global_variable_initializer()
init = tf.compat.v1.global_variables_initializer() sess.run(init) result = sess.run([train, x1, x2, x3, w1, w2, w3, y, yTrain], feed_dict={x1: 90, x2: 80, x3: 70, yTrain: 96})
print(result)

w1、w2、w3 和计算结果 y 已经开始有了变化。

循环进行多次训练:

for i in range(5000):
result = sess.run([train, x1, x2, x3, w1, w2, w3, y, yTrain], feed_dict={x1: 90, x2: 80, x3: 70, yTrain: 85})
print(result) result = sess.run([train, x1, x2, x3, w1, w2, w3, y, yTrain], feed_dict={x1: 98, x2: 95, x3: 87, yTrain: 96})
print(result)
代码
# import tensorflow as tf
import tensorflow.compat.v1 as tf tf.disable_v2_behavior() x1 = tf.placeholder(dtype = tf.float32)
x2 = tf.placeholder(dtype = tf.float32)
x3 = tf.placeholder(dtype = tf.float32) yTrain = tf.placeholder(dtype = tf.float32) w1 = tf.Variable(0.1, dtype = tf.float32)
w2 = tf.Variable(0.1, dtype = tf.float32)
w3 = tf.Variable(0.1, dtype = tf.float32) n1 = x1 * w1
n2 = x2 * w2
n3 = x3 * w3 y = n1 + n2 + n3 loss = tf.abs(y - yTrain)
optimizer = tf.train.RMSPropOptimizer(0.001)
train = optimizer.minimize(loss) sess = tf.Session()
# init = tf.global_variables_initializer()
init = tf.compat.v1.global_variables_initializer() sess.run(init) for i in range(5000):
result = sess.run([train, x1, x2, x3, w1, w2, w3, y, yTrain], feed_dict={x1: 90, x2: 80, x3: 70, yTrain: 85})
print(result) result = sess.run([train, x1, x2, x3, w1, w2, w3, y, yTrain], feed_dict={x1: 98, x2: 95, x3: 87, yTrain: 96})
print(result)

w1、w2、w3 已经非常接近于预期的 0.6、0.3、0.1,y 也非常接近目标值。

CH3 初识 TensorFlow的更多相关文章

  1. 初识TensorFlow

    在前边几期的文章中,笔者已经用TensorFlow进行的一些基础性的探索工作,想必大家对TensorFlow框架也是非常的好奇,本着发扬雷锋精神,笔者将详细的阐述TensorFlow框架的基本用法,并 ...

  2. 机器学习之路: 初识tensorflow 第一个程序

    计算图 tensorflow是一个通过计算图的形式来表示计算的编程系统tensorflow中每一个计算都是计算图上的一个节点节点之间的边描述了计算之间的依赖关系 张量 tensor张量可以简单理解成多 ...

  3. 初识 ❤ TensorFlow |【一见倾心】

    说明

  4. Tensorflow 安装 和 初识

    Windows中 Anaconda,Tensorflow 和 Pycharm的安装和配置   https://blog.csdn.net/zhuiqiuzhuoyue583/article/detai ...

  5. TensorFlow学习(1)-初识

    初识TensorFlow 一.术语潜知 深度学习:深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法. 深度学 ...

  6. TensorFlow 基础概念

    初识TensorFlow,看了几天教程后有些无聊,决定写些东西,来夯实一下基础,提供些前进动力. 一.Session.run()和Tensor.eval()的区别: 最主要的区别就是可以使用sess. ...

  7. TensorFlow学习(1)

    初识TensorFlow 一.术语潜知 深度学习:深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法. 深度学 ...

  8. TensorFlow从入门到入坑(1)

    初识TensorFlow 一.术语潜知 深度学习:深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法. 深度学 ...

  9. 语音识别(LSTM+CTC)

    完整版请微信关注“大数据技术宅” 序言:语音识别作为人工智能领域重要研究方向,近几年发展迅猛,其中RNN的贡献尤为突出.RNN设计的目的就是让神经网络可以处理序列化的数据.本文笔者将陪同小伙伴们一块儿 ...

  10. 大数据利器Hive

    序言:在大数据领域存在一个现象,那就是组件繁多,粗略估计一下轻松超过20种.如果你是初学者,瞬间就会蒙圈,不知道力往哪里使.那么,为什么会出现这种现象呢?在本文的开头笔者就简单的阐述一下这种现象出现的 ...

随机推荐

  1. 面试题:java Runnable与Callable 的区别

    相同点 都是接口:(废话,当然是接口了) 都可用来编写多线程程序: 都需要调用Thread.start()启动线程. Callable是类似于Runnable的接口,实现Callable接口的类和实现 ...

  2. SQL语句between and边界问题

       BETWEEN AND 需要两个参数,即范围的起始值a和终止值b,而且要求a<b.如果字段值在指定的[闭区间[a,b]]内,则这些记录被返回:否则,记录不会被返回. 字段值可以是数值.文本 ...

  3. K8s集群中的DNS服务(CoreDNS)详解

    概述 官网文档:https://kubernetes.io/zh-cn/docs/concepts/services-networking/dns-pod-service/ 在 Kubernetes( ...

  4. MongoDB可视化工具

    简单说明 这里使用mongodb的过程中,可以通过mongo shell或者mongo的可视化工具进行连接. mongo shell连接 # 使用root用户登录mongo mongodb@p8lnp ...

  5. 你应该懂的AI 大模型(五)之 LangChain 之 LCEL

    本文 对<LangChain>一文中的 Chain 与 LCEL 部分的示例进行详细的展示. 先回顾下 在LangChain框架中,Chain(链) 和 LCEL(LangChain Ex ...

  6. SAP的PI日志查看工具

    被很多人吐槽的SAP PI能坚挺的活下来,真是不容易... SXI_MONITOR是PI的标准的消息查看器,如果又权限的话,甚至可以做自定义字段的查询增强(如果对单据创建接口,增加单号...速度杠杠的 ...

  7. 详解鸿蒙Next仓颉开发语言中的全屏模式

    大家好,今天跟大家分享一下仓颉开发语言中的全屏模式. 和ArkTS一样,仓颉的新建项目默认是非全屏模式的,如果你的应用颜色比较丰富,就会发现屏幕上方和底部的留白,这是应用自动避让了屏幕上方摄像头区域和 ...

  8. 简单的php奥运倒计时牌

    1 <?php 2 3 date_default_timezone_set ( "Asia/Shanghai" ); 4 $kaimu = mktime ( 4, 0, 0, ...

  9. 那些年拿过的shell之adminer

    扫敏感文件扫到一个adminer 第三次遇到了,先看版本4.2.5比较低可以利用mysql服务端读客户端文件漏洞(高版本修复了). 通过报错得到这个站是linux.虚拟主机.thinkphp3.绝对路 ...

  10. SQL Server 2008~2022版本序列号/密钥/激活码 汇总

    SQL Server 2008~2022版本序列号/密钥/激活码 汇总 - 重庆熊猫 - 博客园 (cnblogs.com) SQL Server 2022# Enterprise: J4V48-P8 ...