TensorFlow是谷歌推出的深度学习平台,目前在各大深度学习平台中使用的最广泛。

一、安装命令

pip3 install -U tensorflow --default-timeout=1800 -i https://mirrors.ustc.edu.cn/pypi/web/simple

上面是不支持GPU的版本,支持GPU版本的安装命令如下

pip3 install -U tensorflow-gpu --default-timeout=1800 -i https://mirrors.ustc.edu.cn/pypi/web/simple

https://mirrors.ustc.edu.cn/pypi/web/simple 是国内的镜像,安装速度更快

二、基本数据类型

TensorFlow 中最基本的单位是常量(Constant)、变量(Variable)和占位符(Placeholder)。常量在定义后它的值和维度不可变,变量在定义后它的值可变而维度不可变。在神经网络中,变量一般可作为存储权重和其他信息的矩阵,常量可作为存储超参数或其他结构信息的变量。

三、使用TensorFlow进行机器学习的基本流程

● 准备样本数据(训练样本、验证样本、测试样本)

● 定义节点准备接收数据

● 设计神经网络:隐藏层和输出层

● 定义损失函数loss

● 选择优化器(optimizer) 使 loss 达到最小

● 对所有变量进行初始化,通过 sess.run optimizer,迭代N次进行学习

下面的示意图是所有 TensorFlow 机器学习模型所遵循的构建流程,即构建计算图、把数据输入张量、更新权重变量并返回输出值。

在第一步使用 TensorFlow 构建计算图中,需要构建整个模型的架构。例如在神经网络模型中,需要从输入层开始构建整个神经网络的架构,包括隐藏层的数量、每一层神经元的数量、层级之间连接的情况与权重、整个网络中每个神经元使用的激活函数等。此外,还需要配置整个训练、验证与测试的过程。例如在神经网络中,定义整个正向传播的过程与参数并设定学习率、正则化率和批量大小等各类超参数。

第二步将训练数据或测试数据等输送到模型中,TensorFlow 在这一步中一般需要打开一个会话(Session)来执行参数初始化和输送数据等任务。例如在计算机视觉中,需要随机初始化整个模型参数数值,并将图像成批(图像数等于批量大小)地输送到定义好的卷积神经网络中。

第三步更新权重并获取返回值,控制训练过程与获得最终的预测结果。

TensorFlow 线性回归示例

线性回归模型如下图所示

其中「×」为数据点,找到一条直线最好地拟合这些数据点,这条直线和数据点之间的距离即损失函数,所以我们希望找到一条能令损失函数值最小的直线。以下是使用 TensorFlow 构建线性回归的简单示例。

1、构建目标函数(即直线)

目标函数即 H(x)=Wx+b,其中 x 是特征向量、W是特征向量中每个元素对应的权重(Weight)、b(Bias)是偏置项。

# 用来训练模型的样本数据

x_train = [1, 2, 3]

y_train = [1, 2, 3]

W = tf.Variable(tf.random_normal([1]), name='weight')

b = tf.Variable(tf.random_normal([1]), name='bias')

# hypothesis函数 XW+b

hypothesis = x_train * W + b

如上所示定义了 y=wx+b 的运算,即需要拟合的一条直线。

2、构建损失函数

下面构建损失函数,即各数据点到该直线的距离,这里构建的损失函数是均方误差函数:

该函数表明根据数据点预测的值和该数据点真实值之间的距离,代码实现:

# 代价/损失 函数

cost = tf.reduce_mean(tf.square(hypothesis - y_train))

其中 tf.square() 是取某个数的平方, tf.reduce_mean() 是取均值。

3、采用梯度下降更新权重

α是学习速率(Learning rate),控制学习速度,需要调节的超参数。

# 最小化

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)

train = optimizer.minimize(cost)

为了寻找能拟合数据的最好直线,需要最小化损失函数,即数据与直线之间的距离,采用梯度下降算法。

4、 运行计算图开始训练模型

# 打开一个会话Session

sess = tf.Session()

# 初始化变量

sess.run(tf.global_variables_initializer())

# 迭代

for step in range(2000):

sess.run(train)

if step % 200 == 0:

print(step, sess.run(cost), sess.run(W), sess.run(b))

上面的代码打开了一个会话并执行变量初始化和输送数据。

5、完整的实现代码

6、某次训练时的输出

step(s): 0001 cost = 0.595171

step(s): 0201 cost = 0.002320

step(s): 0401 cost = 0.000886

step(s): 0601 cost = 0.000338

step(s): 0801 cost = 0.000129

step(s): 1001 cost = 0.000049

step(s): 1201 cost = 0.000019

step(s): 1401 cost = 0.000007

step(s): 1601 cost = 0.000003

step(s): 1801 cost = 0.000001

四、简单小结

本文简述了使用TensorFlow训练模型的过程,无论设计多么复杂的神经网络,都可以参考以上过程,当然在实际生产中还需要考虑很多因素,比如:样本数据的收集、样本数据的预处理、模型的选择和神经网络的设计、过拟合/欠拟合问题、梯度消失/膨胀问题、超参数的设置、是否需要GPU和分布式加快训练速度等等。

在设计深度网络时需要注意每层神经元的维度,这个地方容易出错,特别是层数深、每层神经元数量多的复杂神经网络, 参见《介绍一个快速确定神经网络模型中各层矩阵维度的方法》。

ps:目前多个厂家都推出了机器学习公有平台,一般都会支持TensorFlow,在公有平台上学习AI算法比自己搭建平台省心。

TensorFlow简要教程及线性回归算法示例的更多相关文章

  1. 使用TensorFlow v2库实现线性回归

    使用TensorFlow v2库实现线性回归 此示例使用简单方法来更好地理解训练过程背后的所有机制 from __future__ import absolute_import, division, ...

  2. Tensorflow学习教程------读取数据、建立网络、训练模型,小巧而完整的代码示例

    紧接上篇Tensorflow学习教程------tfrecords数据格式生成与读取,本篇将数据读取.建立网络以及模型训练整理成一个小样例,完整代码如下. #coding:utf-8 import t ...

  3. 《BI那点儿事》Microsoft 线性回归算法

    Microsoft 线性回归算法是 Microsoft 决策树算法的一种变体,有助于计算依赖变量和独立变量之间的线性关系,然后使用该关系进行预测.该关系采用的表示形式是最能代表数据序列的线的公式.例如 ...

  4. [机器学习Lesson 2]代价函数之线性回归算法

    本章内容主要是介绍:单变量线性回归算法(Linear regression with one variable) 1. 线性回归算法(linear regression) 1.1 预测房屋价格 该问题 ...

  5. Tensorflow快餐教程(1) - 30行代码搞定手写识别

    版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/lusing/article/details ...

  6. TensorFlow v2.0实现Word2Vec算法

    使用TensorFlow v2.0实现Word2Vec算法计算单词的向量表示,这个例子是使用一小部分维基百科文章来训练的. 更多信息请查看论文: Mikolov, Tomas et al. " ...

  7. scikit-learn 线性回归算法库小结

    scikit-learn对于线性回归提供了比较多的类库,这些类库都可以用来做线性回归分析,本文就对这些类库的使用做一个总结,重点讲述这些线性回归算法库的不同和各自的使用场景. 线性回归的目的是要得到输 ...

  8. CGContextRef使用简要教程

    CGContextRef使用简要教程 Graphics Context是图形上下文,也可以理解为一块画布,我们可以在上面进行绘画操作,绘制完成后,将画布放到我们的view中显示即可,view看作是一个 ...

  9. Yii Framework 开发教程Zii组件-Tabs示例

    有关Yii Tab类: http://www.yiichina.com/api/CTabView http://www.yiichina.com/api/CJuiTabs http://blog.cs ...

随机推荐

  1. UOJ#172. 【WC2016】论战捆竹竿

    传送门 首先这个题目显然就是先求出所有的 \(border\),问题转化成一个可行性背包的问题 一个方法就是同余类最短路,裸跑 \(30\) 分,加优化 \(50\) 分 首先有个性质 \(borde ...

  2. [HAOI2009]逆序对数列(加强)

    ZJL 的妹子序列 暴力就是 \(\Theta(n\times m)\) 如果 \(n,m \le 10^5\) ? 考虑问题的转换,设 \(a_i\) 表示 \(i\) 小的在它后面的数的个数 \( ...

  3. P1025[SCOI2009]游戏

    windy学会了一种游戏.对于1到N这N个数字,都有唯一且不同的1到N的数字与之对应.最开始windy把数字按 顺序1,2,3,……,N写一排在纸上.然后再在这一排下面写上它们对应的数字.然后又在新的 ...

  4. svg矢量图标在html中的使用, (知识点:1.通过h5中的css实现点击变色,2.一个svg文件包含多个图标)

    svg矢量文件体积小,不变形,比传统的png先进,比现在流行的icon-font灵活.然而在使用过程中还是遇到了很多坑.今天花了一天时间把经验整理出来,以供后来者借鉴.如果您从本文收益,请留言mark ...

  5. 03_netty实现聊天室功能

    [概述] 聊天室主要由两块组成:聊天服务器端(ChatRoomServer)和聊天客户端(ChatClient). [ 聊天服务器(ChatRoomServer)功能概述 ] 1.监听所有客户端的接入 ...

  6. Linux 虚拟机配置-network is unreachable

    配置虚拟机时,遇到network is unreachable,根据网上找来的方法处理,最终自己试过,成功修改的方法在这里记录一下: 修改虚拟机的网络适配器:桥接,复制物理机网络 vim /etc/s ...

  7. 绛河 初识WCF5

    然后我们在<Client>中添加一个终结点,这个是客户端的终结点,我们前面曾经提过,通信实际上发生在两个终结点间,客户端也有个终结点,然而请求总是从客户端首先发起,所以终结点地址应该填写为 ...

  8. CAGradientLayer渐变颜色动画

    CAGradientLayer渐变颜色动画 或许你用过CAGradientLayer,你知道他是用于渐变颜色的,但你是否直到,CAGradientLayer的渐变颜色是可以动画的哦. 源码: // / ...

  9. Python学习---函数的学习1209[all]

    1.基础函数 2.高阶函数 3.递归函数 4.内置函数 5.匿名函数和闭包

  10. TCP/IP 协议图--TCP/IP 基础

    1. TCP/IP 的具体含义 从字面意义上讲,有人可能会认为 TCP/IP 是指 TCP 和 IP 两种协议.实际生活当中有时也确实就是指这两种协议.然而在很多情况下,它只是利用 IP 进行通信时所 ...