tensorflow入门(二)
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt #使用numpy生成200个随机点
x_data = np.linspace(-0.5,0.5,200)[:,np.newaxis]
noise = np.random.normal(0,0.02,x_data.shape)
y_data = np.square(x_data) + noise #定义两个placeholder
x = tf.placeholder(tf.float32,[None,1])
y = tf.placeholder(tf.float32,[None,1]) #定义神经网络中间层
Weights_L1 = tf.Variable(tf.random_normal([1,10]))
biases_L1 = tf.Variable(tf.zeros([1,10]))
Wx_plus_b_L1 = tf.matmul(x,Weights_L1) + biases_L1
L1 = tf.nn.tanh(Wx_plus_b_L1) #定义神经网络输出层
Weights_L2 = tf.Variable(tf.random_normal([10,1]))
biases_L2 = tf.Variable(tf.zeros([1,1]))
Wx_plus_b_L2 = tf.matmul(L1,Weights_L2) + biases_L2
prediction = tf.nn.tanh(Wx_plus_b_L2) #二次代价函数
loss = tf.reduce_mean(tf.square(y - prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) with tf.Session() as sess:
#变量初始化
sess.run(tf.global_variables_initializer())
for _ in range(2000):
sess.run(train_step,feed_dict={x:x_data,y:y_data}) #获得预测值
prediction_value = sess.run(prediction,feed_dict={x:x_data})
#画图
plt.figure()
plt.scatter(x_data,y_data)
plt.plot(x_data,prediction_value,'r-',lw = 5)
plt.show()
得到结果:

——
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次的大小
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32,[None,784]) #图片
y = tf.placeholder(tf.float32,[None,10]) #标签 #创建一个简单的神经网络
w = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,w)+b) #二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
#梯度下降算法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个bool类型的列表中,argmax()返回一维张量中最大值所在的位置,equal函数判断两者是否相等
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#求准确率,cast:将bool型转换为0,1然后求平均正好算到是准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) with tf.Session() as sess:
sess.run(init)
for epoch in range(21):
for batch in range(batch_size):#next_batch:不断获取下一组数据
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Iter "+str(epoch)+",Testing Accuracy "+str(acc))
训练结果:
Iter 0,Testing Accuracy 0.657
Iter 1,Testing Accuracy 0.7063
Iter 2,Testing Accuracy 0.7458
Iter 3,Testing Accuracy 0.7876
Iter 4,Testing Accuracy 0.8203
Iter 5,Testing Accuracy 0.8406
Iter 6,Testing Accuracy 0.8492
Iter 7,Testing Accuracy 0.8586
Iter 8,Testing Accuracy 0.8642
Iter 9,Testing Accuracy 0.8678
Iter 10,Testing Accuracy 0.8709
Iter 11,Testing Accuracy 0.8736
Iter 12,Testing Accuracy 0.8751
Iter 13,Testing Accuracy 0.8784
Iter 14,Testing Accuracy 0.88
Iter 15,Testing Accuracy 0.8816
Iter 16,Testing Accuracy 0.883
Iter 17,Testing Accuracy 0.8841
Iter 18,Testing Accuracy 0.8847
Iter 19,Testing Accuracy 0.8855
Iter 20,Testing Accuracy 0.8878
可见,通过多次的训练,识别的准确值不断提高
——
tensorflow入门(二)的更多相关文章
- 48、tensorflow入门二,线性模型的拟合
import tensorflow as tf import numpy as np#生成2维的100个0-1的随机数 x_data = np.float32(np.random.rand(2,100 ...
- TensorFlow 入门之手写识别(MNIST) softmax算法 二
TensorFlow 入门之手写识别(MNIST) softmax算法 二 MNIST Fly softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...
- #tensorflow入门(1)
tensorflow入门(1) 关于 TensorFlow TensorFlow™ 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库.节点(Nodes)在图中表示数学操 ...
- TensorFlow入门(五)多层 LSTM 通俗易懂版
欢迎转载,但请务必注明原文出处及作者信息. @author: huangyongye @creat_date: 2017-03-09 前言: 根据我本人学习 TensorFlow 实现 LSTM 的经 ...
- 1 TensorFlow入门笔记之基础架构
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
- TensorFlow 入门之手写识别(MNIST) 数据处理 一
TensorFlow 入门之手写识别(MNIST) 数据处理 一 MNIST Fly softmax回归 准备数据 解压 与 重构 手写识别入门 MNIST手写数据集 图片以及标签的数据格式处理 准备 ...
- TensorFlow 入门 | iBooker·ApacheCN
原文:Getting Started with TensorFlow 协议:CC BY-NC-SA 4.0 自豪地采用谷歌翻译 不要担心自己的形象,只关心如何实现目标.--<原则>,生活原 ...
- 【原创】NIO框架入门(二):服务端基于MINA2的UDP双向通信Demo演示
前言 NIO框架的流行,使得开发大并发.高性能的互联网服务端成为可能.这其中最流行的无非就是MINA和Netty了,MINA目前的主要版本是MINA2.而Netty的主要版本是Netty3和Netty ...
- (转)TensorFlow 入门
TensorFlow 入门 本文转自:http://www.jianshu.com/p/6766fbcd43b9 字数3303 阅读904 评论3 喜欢5 CS224d-Day 2: 在 Da ...
- Swift语法基础入门二(数组, 字典, 字符串)
Swift语法基础入门二(数组, 字典, 字符串) 数组(有序数据的集) *格式 : [] / Int / Array() let 不可变数组 var 可变数组 注意: 不需要改变集合的时候创建不可变 ...
随机推荐
- Codeforces Round #425 (Div. 2) C - Strange Radiation
地址:http://codeforces.com/contest/832/problem/C 题目: C. Strange Radiation time limit per test 3 second ...
- netbeans通过wsdl生成webservice的UTF8问题
在netbeans通过wsdl方式生成的webservice,打开类文件时,提示无法通过UTF-8打开. 这是因为默认生成的文件不是UTF-8格式的,解决方案如下: 1.打开netbeans的安装目录 ...
- iOS & Android APP crash保护机制
一.背景 还在码代码,码好再BB... 二.思路 三.解决方案 四.注意点 五.开源项目 github:https://github.com/qiyer/QYCrashProtector
- Appium移动自动化
一. 安装node.js 因为Appium是使用nodejs实现的,所以node是解释器,首先需要确认安装好 官网下载node.js:https://nodejs.org/en/download/ 安 ...
- 国光大力推荐(安利)Deepin15.4
简介 深度操作系统15.4 Beta(deepin15.4)相比deepin15.3来看,外观上要更加优雅.现在还在内测中,相信不就官网就会发布正式版.小子昨天下午删了我的windows10,特意来尝 ...
- 为什么iterator,foreach遍历时不能进行remove操作?除了一种情况可以这样(特殊情况)?
Exception in thread "main" java.util.ConcurrentModificationException 并发修改异常引发的思考! 1 foreac ...
- SQL Server 对XML数据类型的SQL语句总结
--创建XMLTable , ) primary key, XMLCol xml); go ------------------------------------------------------ ...
- TypeScript 3.3来了!快看看有什么新功能
翻译:疯狂的技术宅原文:https://github.com/Microsoft/TypeScript/wiki/What's-new-in-TypeScript 本文首发微信公众号:jingchen ...
- 20145302张薇《Java程序设计》实验一报告
20145302 <Java程序设计>实验一:Java开发环境的熟悉 实验内容 使用dos命令行编译.运行简单的Java程序: 使用IDEA编辑.编译.运行.调试Java程序. 1.命令行 ...
- Android系统启动过程[典☆☆☆]
Android系统启动过程 首先Android框架架构图:(来自网上,我觉得这张图看起来很清晰) Linux内核启动之后就到Android Init进程,进而启动Android相关的服务和应用. 启动 ...