import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt #使用numpy生成200个随机点
x_data = np.linspace(-0.5,0.5,200)[:,np.newaxis]
noise = np.random.normal(0,0.02,x_data.shape)
y_data = np.square(x_data) + noise #定义两个placeholder
x = tf.placeholder(tf.float32,[None,1])
y = tf.placeholder(tf.float32,[None,1]) #定义神经网络中间层
Weights_L1 = tf.Variable(tf.random_normal([1,10]))
biases_L1 = tf.Variable(tf.zeros([1,10]))
Wx_plus_b_L1 = tf.matmul(x,Weights_L1) + biases_L1
L1 = tf.nn.tanh(Wx_plus_b_L1) #定义神经网络输出层
Weights_L2 = tf.Variable(tf.random_normal([10,1]))
biases_L2 = tf.Variable(tf.zeros([1,1]))
Wx_plus_b_L2 = tf.matmul(L1,Weights_L2) + biases_L2
prediction = tf.nn.tanh(Wx_plus_b_L2) #二次代价函数
loss = tf.reduce_mean(tf.square(y - prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) with tf.Session() as sess:
#变量初始化
sess.run(tf.global_variables_initializer())
for _ in range(2000):
sess.run(train_step,feed_dict={x:x_data,y:y_data}) #获得预测值
prediction_value = sess.run(prediction,feed_dict={x:x_data})
#画图
plt.figure()
plt.scatter(x_data,y_data)
plt.plot(x_data,prediction_value,'r-',lw = 5)
plt.show()

得到结果:

——

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) #每个批次的大小
batch_size = 100
#计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size #定义两个placeholder
x = tf.placeholder(tf.float32,[None,784]) #图片
y = tf.placeholder(tf.float32,[None,10]) #标签 #创建一个简单的神经网络
w = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,w)+b) #二次代价函数
loss = tf.reduce_mean(tf.square(y-prediction))
#梯度下降算法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) #初始化变量
init = tf.global_variables_initializer() #结果存放在一个bool类型的列表中,argmax()返回一维张量中最大值所在的位置,equal函数判断两者是否相等
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
#求准确率,cast:将bool型转换为0,1然后求平均正好算到是准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) with tf.Session() as sess:
sess.run(init)
for epoch in range(21):
for batch in range(batch_size):#next_batch:不断获取下一组数据
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Iter "+str(epoch)+",Testing Accuracy "+str(acc))

训练结果:

Iter 0,Testing Accuracy 0.657
Iter 1,Testing Accuracy 0.7063
Iter 2,Testing Accuracy 0.7458
Iter 3,Testing Accuracy 0.7876
Iter 4,Testing Accuracy 0.8203
Iter 5,Testing Accuracy 0.8406
Iter 6,Testing Accuracy 0.8492
Iter 7,Testing Accuracy 0.8586
Iter 8,Testing Accuracy 0.8642
Iter 9,Testing Accuracy 0.8678
Iter 10,Testing Accuracy 0.8709
Iter 11,Testing Accuracy 0.8736
Iter 12,Testing Accuracy 0.8751
Iter 13,Testing Accuracy 0.8784
Iter 14,Testing Accuracy 0.88
Iter 15,Testing Accuracy 0.8816
Iter 16,Testing Accuracy 0.883
Iter 17,Testing Accuracy 0.8841
Iter 18,Testing Accuracy 0.8847
Iter 19,Testing Accuracy 0.8855
Iter 20,Testing Accuracy 0.8878

可见,通过多次的训练,识别的准确值不断提高

——

tensorflow入门(二)的更多相关文章

  1. 48、tensorflow入门二,线性模型的拟合

    import tensorflow as tf import numpy as np#生成2维的100个0-1的随机数 x_data = np.float32(np.random.rand(2,100 ...

  2. TensorFlow 入门之手写识别(MNIST) softmax算法 二

    TensorFlow 入门之手写识别(MNIST) softmax算法 二 MNIST Fly softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

  3. #tensorflow入门(1)

    tensorflow入门(1) 关于 TensorFlow TensorFlow™ 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库.节点(Nodes)在图中表示数学操 ...

  4. TensorFlow入门(五)多层 LSTM 通俗易懂版

    欢迎转载,但请务必注明原文出处及作者信息. @author: huangyongye @creat_date: 2017-03-09 前言: 根据我本人学习 TensorFlow 实现 LSTM 的经 ...

  5. 1 TensorFlow入门笔记之基础架构

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  6. TensorFlow 入门之手写识别(MNIST) 数据处理 一

    TensorFlow 入门之手写识别(MNIST) 数据处理 一 MNIST Fly softmax回归 准备数据 解压 与 重构 手写识别入门 MNIST手写数据集 图片以及标签的数据格式处理 准备 ...

  7. TensorFlow 入门 | iBooker·ApacheCN

    原文:Getting Started with TensorFlow 协议:CC BY-NC-SA 4.0 自豪地采用谷歌翻译 不要担心自己的形象,只关心如何实现目标.--<原则>,生活原 ...

  8. 【原创】NIO框架入门(二):服务端基于MINA2的UDP双向通信Demo演示

    前言 NIO框架的流行,使得开发大并发.高性能的互联网服务端成为可能.这其中最流行的无非就是MINA和Netty了,MINA目前的主要版本是MINA2.而Netty的主要版本是Netty3和Netty ...

  9. (转)TensorFlow 入门

        TensorFlow 入门 本文转自:http://www.jianshu.com/p/6766fbcd43b9 字数3303 阅读904 评论3 喜欢5 CS224d-Day 2: 在 Da ...

  10. Swift语法基础入门二(数组, 字典, 字符串)

    Swift语法基础入门二(数组, 字典, 字符串) 数组(有序数据的集) *格式 : [] / Int / Array() let 不可变数组 var 可变数组 注意: 不需要改变集合的时候创建不可变 ...

随机推荐

  1. vue react angular对比

    1.数据绑定 1)vue 把一个普通对象传给Vued的data选项,Vue会遍历此对象的所有属性,并使用Object.defineProperty将这些属性全部转为getter/setter.Obje ...

  2. Jquery each循环用法小结

    var str = res.ZhaoPian; var piclist = str.substring(0, str.length - 1).split(','); $.each(piclist, f ...

  3. XDU 1001 又是苹果(状态压缩)

    #include<cstdio> #include<cstring> ; using namespace std; int r[maxn],c[maxn]; char pic[ ...

  4. vue指令与组件

    参考http://blog.csdn.net/lioldamon/article/details/74058222?utm_source=itdadao&utm_medium=referral ...

  5. topcoder SRM712 Div1 LR

    题目: Problem Statement      We have a cyclic array A of length n. For each valid i, element i-1 the l ...

  6. 写入Csv

    //定义文件输出流  FILE *f; f = fopen("a.csv" , "wb"); fprintf(f,"aaa,23,sdf\n" ...

  7. maven 介绍(一)

    本文内容主要摘自:http://www.konghao.org/index 内部视频                              http://www.ibm.com/developer ...

  8. oracle中length、lengthb、substr、substrb用法小结

    我记得我曾经在开发form的时候犯过这样一个错误,对于form中的某个字段,对应于数据库中某张表的字段,假设在数据库中这个字段一般也就用到20个汉字的长度,后来我在开发form的时候,设置item类型 ...

  9. win7系统下查看端口的占用情况以及如何删除端口进程

    经常在本地测试开发使用tomcat的时候容易报端口占用的情况,比如我要查看8080端口的使用情况 1.按如下操作,输入 cmd 回车 2.在doc窗口中输入命令    netstat -ano | f ...

  10. firewall 防火墙相关

    修改配置文件: /etc/sysconfig/network-scripts/ifcfg-ens33 文件 ONBOOT=no 改为yes 然后重启  service network restart ...