TensorFlow运行方式。加载数据、定义超参数,构建网络,训练模型,评估模型、预测。

构造一个满足一元二次函数y=ax^2+b原始数据,构建最简单神经网络,包含输入层、隐藏层、输出层。TensorFlow学习隐藏层、输出层weights、biases。观察训练次数增加,损失值变化。

生成、加载数据。方程y=x^2-0.5。构造满足方程的x、y。加入不满足方程噪声点。

import tensor flow as tf
import bumpy as np
# 构造满中一元二次方程的函数
x_data = np.linspace(-1,1,300)[:,np.newaxis] # 构建起300个点,分布在-1到1区间,用np生成等差数列,300个点的一维数组转换为300x1的二维数组
noise = np.random.normal(0, 0.05, x_data.shape) # 加入噪声点,与x_data维度一致,拟合均值0、方差0.05正态分布
y_data = np.square(x_data) - 0.5 + noise # y = x^2 - 0.5 + 噪声

定义x、y占位符作输入神经网络变量。

xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])

构建网络模型。

构建一个隐藏层,一个输出层。输入参数4个变量,输入数据、输入数据维度、输出数据维度、激活函数。每层向量化处理(y = weights*x +biases),激活函数非线性化处理,输出数据。定义隐藏层、输出层:

def add_layer(inputs, in_size, out_size, activation_function=None):
# 构建权重:in_size*out_size 大小的矩阵
weights = tf.Variable(tf.random_normal([in_size, out_size]))
# 构建偏置:1 * out_size矩阵
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
# 矩阵相乘
Wx_plus_b = tf.matmul(inputs, weights) + biases
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
return outputs #得到输出数据
# 构建隐藏层,假设隐藏层有10个神经元
h1 = add_layer(xs, 1, 20, activation_function=tf.nn.relu)
# 构建输出层,假设输出层和输入层一样,有1个神经元
prediction = add_layer(h1, 20, 1, activation_function=None)

构建损失函数,计算输出层预测值、真实值间误差。二者差的平方求和再取平均。梯度下降法,以0.1效率最小化损失。

loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

训练模型。训练1000次,每50次输出训练损失值。

init = tf.global_variables_initializer() # 初始化所有变量
sess = tf.Session()
sess.run(init)

for i in range(1000): # 训练1000次
sess.run(train_step, feed_dict = (xs: x_data, ys: y_data)
if i % 50 == 0: #每50次打印出一次损失值
print(sets.run(loss, feed_dict={xs: x_data, ys: y_data}))

训练权重值,模型拟合y = x^2-0.5的系数1和-0.5。损失值越来越小,训练参数越来越逼近目标结果。评估模型,学习系数weights、biase前向传播后和真值y = x^2-0.5结果系数比较,根据相近程度计算准确率。

超参数设定。hyper-parameters,机器学习模型框架参数。手动设定、不断试错。

学习率(learning rate),设置越大,训练时间越短,速度越快。设置越小,训练准确度越高。可变学习率,训练过程记录最桂准确率,连续n轮(epoch)没达到最佳准确率,认为准确率不再提高,停止训练,early stopping,no_improvement-in-n规则。学习率减半,再满足时再减半。逐渐接近最优解,学习率越小,准确度越高。

mini-batch大小。每批大小决定权重更新规则。整批样本梯度全部计算完,才求平均值,更新权重。批次越大训练速度越快,利用矩阵、线性代数库加速,权重更新频率低。批次越小,训练速度越慢。结合机器硬件性能与数据集大小设定。

正则项系数(regularization parameter,λ)。凭经验。复杂网络出现明显过拟合(训练数据准确率高,测试数据准确率下降)。一开始设0,确定好学习率,再给λ设值,根据准确率精细调整。

参考资料:
《TensorFlow技术解析与实战》

欢迎推荐上海机器学习工作机会,我的微信:qingxingfengzi

学习笔记TF055:TensorFlow神经网络简单实现一元二次函数的更多相关文章

  1. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  2. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  3. [转载]SharePoint 2013搜索学习笔记之搜索构架简单概述

    Sharepoint搜索引擎主要由6种组件构成,他们分别是爬网组件,内容处理组件,分析处理组件,索引组件,查询处理组件,搜索管理组件.可以将这6种组件分别部署到Sharepoint场内的多个服务器上, ...

  4. OGG学习笔记03-单向复制简单故障处理

    OGG学习笔记03-单向复制简单故障处理 环境:参考:OGG学习笔记02-单向复制配置实例 实验目的:了解OGG简单故障的基本处理思路. 1. 故障现象 故障现象:启动OGG源端的extract进程, ...

  5. QML学习笔记(六)- 简单计时器和定时器

    做一个简单的qml计时器和定时器,左键触发计时,右键触发定时 GitHub:八至 作者:狐狸家的鱼 本文链接:QML学习笔记(六)- 简单计时器和定时器 左键点击按钮,触发计时器,中键可以暂停计时,同 ...

  6. CNN学习笔记:卷积神经网络

    CNN学习笔记:卷积神经网络 卷积神经网络 基本结构 卷积神经网络是一种层次模型,其输入是原始数据,如RGB图像.音频等.卷积神经网络通过卷积(convolution)操作.汇合(pooling)操作 ...

  7. [DL学习笔记]从人工神经网络到卷积神经网络_3_使用tensorflow搭建CNN来分类not_MNIST数据(有一些问题)

    3:用tensorflow搭个神经网络出来 为什么用tensorflow呢,应为谷歌是亲爹啊,虽然有些人说caffe更适合图像啊mxnet效率更高等等,但爸爸就是爸爸,Android都能那么火,一个道 ...

  8. [DL学习笔记]从人工神经网络到卷积神经网络_1_神经网络和BP算法

    前言:这只是我的一个学习笔记,里边肯定有不少错误,还希望有大神能帮帮找找,由于是从小白的视角来看问题的,所以对于初学者或多或少会有点帮助吧. 1:人工全连接神经网络和BP算法 <1>:人工 ...

  9. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

随机推荐

  1. UML类图中最重要的几种类关系及其表示

    阅读UML图最常见到的类与类之间的关系有如下几种: 1.依赖关系 依赖关系是指一个类在计算时,应用了“另一个类”类型的参数,这种关系是偶然.临时.弱的. UML类图中,依赖关系用带单箭头的虚线表示,即 ...

  2. python机器可读数据-XML

    XML XML是一门标记语言.也就是说,它具有包含格式化数据的文档结构. XML文档本质上只是格式特殊的数据文件. 在XML文件中有两个位置可以保存数据值:2个标签之间,标签的属性. 导入XML数据 ...

  3. Ubuntu 18.04安装VNC远程登录

    reference: https://blog.csdn.net/bluewhalerobot/article/details/73649353 https://community.bwbot.org ...

  4. python的迭代器

    迭代器 我们已经知道,可以直接作用于for循环的数据类型有以下几种: 一类是集合数据类型,如list.tuple.dict.set.str等: 一类是generator,包括生成器和带yield的ge ...

  5. java基础知识—数组

    1.数组:是一个变量,存储相同数据类型的一组数据. 2.数据的优点:减少代码量.易查找. 3.数组的使用步骤: 1)声明数组:int scores []: 2)开辟空间:scores = new in ...

  6. Linux文件打包与解压缩

    一.文件打包和解压缩 常用的压缩包文件格式.在 Windows 上我们最常见的不外乎这三种*.zip,*.rar,*.7z后缀的压缩文件,而在 Linux 上面常见常用的除了以上这三种外,还有*.gz ...

  7. 使用Babel将单独的js文件 中的 ES6转码为ES5

      如果你并没有接触过ES6,当你看到下面的代码时,肯定是有点懵逼的(这是什么鬼?心中一万头神兽奔腾而过),但是你没看错,这就是ES6.不管你看不看它,它都在这里. 1 2 3 4 5 6 7 8 9 ...

  8. pytest自动化1:兼容unittest代码实例

    初级版本 源码: #!/usr/bin/env python # -*- coding:utf-8 -*- #使用车管家的接口实际调用类函数 import unittest from urllib i ...

  9. python定时脚本判断服务器内存

    经常我们会发现服务器跑着跑着内存使用率达到了百分之八九十,或者有时候直接挂掉,在我们还没定位是哪块代码有问题导致内存占用很大的时候,可以先写个定时脚本,当服务器内存使用率达到一定值的时候,就重启一起服 ...

  10. shell版的nginx安装

    #!/bin/bash # Name:Centos 6.4 安装nginx1.8.1 # Date:-- # Author:qifei@meizu.com Home=$(cd ``;pwd) #这个命 ...