TensorFlow入门案例
入门小案例,分别是回归模型建立和mnist数据集的模型建立
1、回归案例:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import pandas as pd #1===================================================================
'''
总结:
1、这个整体比较简单,就是先构造整个图,然后再代入进去计算,注意设置的学习率
'''
#构造数据
x_data = np.float32(np.random.rand(2,100))
y_data = np.dot([1,2],x_data) + 4 #构造线性模型
b = tf.Variable(tf.zeros([1]))
w = tf.Variable(tf.random_uniform([1,2]))
y = tf.matmul(w,x_data) + b #最小化方差
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss) #初始化变量,并启动图
sess = tf.Session()
sess.run(tf.global_variables_initializer()) #拟合平面
for step in range(201):
sess.run(train)
if step % 20 ==0:
print(step,sess.run(w),sess.run(b)) # 200 [[1.0000179 2.0000112]] [3.9999843]
2、MNIST普通神经网络
'''
总结一下需要注意的点:
1、x是[None,784] * w是[784,10] + [10] 这里的10会自动广播至[None,10]行,最后的结果就是[None,10]
2、softmax之后行列数不变,每一行求和都是1,哪个位置概率最大就是那个位置的值
3、交叉熵的公式是-y实际*log(y_hat)
4、arg_max(y,1)每一行的最大值下标,0的话是每一列的最大值下标,这个怎么跟concat不一样,感觉很乱
4、tf.equal(tf.arg_max(y,1),tf.arg(y-,1)),两个样本的最大标算出来一共是[None,1][None,1],然后匹配是否一致,一致就是0,不一致就是1,得到[None,1]
5、再求均值,求均值之前需要转float32格式
6、注意后面计算准确率的时候,传入的变量值是x_test,y_test,这里的y_test回去直接是y_,而计算y时候需要用到x,w,这里x就变成了x_test,不是前面的batc_x了,记住啊
7、使用mnist.trian.next_batch(100)来每次传入100个值
''' #加载数据
mnist = input_data.read_data_sets('mnist_data/',one_hot = True)
x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels #定义变量
x = tf.placeholder(tf.float32,shape=[None,784])
y_= tf.placeholder(tf.float32,shape=[None,10])
w = tf.Variable(tf.truncated_normal([784,10]))
b = tf.Variable(tf.constant(0.01,shape=[10]))
y = tf.nn.softmax(tf.matmul(x,w)+b) #定义损失
loss = -tf.reduce_mean(y_*tf.log(y))
optimizer = tf.train.GradientDescentOptimizer(0.2)
train = optimizer.minimize(loss) #初始化变量
sess = tf.Session()
sess.run(tf.global_variables_initializer()) #验证正确率
accuracy_rate = tf.equal(tf.arg_max(y,1),tf.arg_max(y_,1))
accuracy = tf.reduce_mean(tf.cast(accuracy_rate,'float32')) #训练
for step in range(20000):
batch_x,batch_y = mnist.train.next_batch(100)
sess.run(train,feed_dict={x:batch_x,y_:batch_y})
if step % 1000 == 0:
print(step,sess.run(accuracy,feed_dict={x:x_test,y_:y_test})) #最终的正确率大概是在91%
3、MNIST卷积神经网络
'''
总结一下:
1、首先要弄清楚整个卷积神经网络的构图,分别是卷积池化》卷积池化》全连接》输出
2、样本是一个长784维的变量,要reshape成为28,28的图片才能做卷积,
3、第一次用32个5,5的卷积核来处理,由于样本是灰度图,所以是[5,5,1,32]所以样本从28,28变成了28,28,32的数据,这里因为步长[1,1,1,1]
4、然后再用[1,2,2,1]大小的pool处理卷积,步长是[1,2,2,1]所以[28,28,32] 变成了[14,14,32]
5、顺序是先做卷积,在套relu,在做pooling
6、同理第二个卷积是这么做的,用[5,5,32,64]大小的权重把刚刚pooling之后的[14,14,32]的变成了[14,14,64]的,再做pooling变成[7,7,64]
7、刚刚前面都是针对一个图片做的,其实到最后是[None,7,7,64],
8、全连接层之前需要把图片重新恢复掉,因为现在是[None,7,7,64],所以reshape[-1,7*7*64],-1的意思就是None
9、reshape之后是[None,3136]太大了,全连接处转成1024,所以用的w是[7*7*64,1024]
10、最后用softmax转成[None,10]就可以了
'''
#加载数据
x_data = mnist.train.images
y_data = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels #权重函数
def weights(shape):
initial = tf.truncated_normal(shape,stddev=0.1,dtype=tf.float32)
return tf.Variable(initial) #偏置项
def bias(shape):
initial = tf.constant(0.1,shape=shape,dtype=tf.float32)
return tf.Variable(initial) #输入值
xs = tf.placeholder(tf.float32,shape=[None,784])
ys = tf.placeholder(tf.float32,shape=[None,10])
x_images = tf.reshape(xs,[-1,28,28,1]) #第一层卷积
#con_1
w_con1 = weights([5,5,1,32])
b_con1 = bias([32])
h_con1 = tf.nn.conv2d(x_images,w_con1,[1,1,1,1],padding='SAME')
h_relu1 = tf.nn.relu(h_con1 + b_con1)
#pool1
h_pool1 = tf.nn.max_pool(h_relu1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') #第二层卷积
#con2
w_con2 = weights([5,5,32,64])
b_con2 = bias([64])
h_con2 = tf.nn.conv2d(h_pool1,w_con2,strides=[1,1,1,1],padding='SAME')
h_relu2 = tf.nn.relu(h_con2)
#pool2
h_pool2 = tf.nn.max_pool(h_relu2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') #全连接层
w_fc1 = weights([7*7*64,1024])
b_fc1 = bias([1024])
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1) + b_fc1) #drop_out
keep_pro = tf.placeholder(dtype=tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1,keep_prob=keep_pro) #输出层
w_fc2 = weights([1024,10])
b_fc2 = bias([10])
h_fc2 = tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2) + b_fc2) #损失函数
loss = -tf.reduce_mean(ys*tf.log(h_fc2))
train = tf.train.AdamOptimizer(1e-4).minimize(loss)
#初始化变量
sess.run(tf.global_variables_initializer()) #计算误差
accuracy = tf.equal(tf.arg_max(ys,1),tf.arg_max(h_fc2,1))
accuracy = tf.reduce_mean(tf.cast(accuracy,tf.float32)) #开始训练
for step in range(5000):
batch_x,batch_y = mnist.train.next_batch(100)
sess.run(train,feed_dict={xs:batch_x,ys:batch_y,keep_pro:0.8})
if step % 100 == 0 :
print(step,sess.run(accuracy,feed_dict={xs:mnist.test.images,ys:mnist.test.labels,keep_pro:1}))
训练的好慢啊,只有CPU真的这么慢吗。。。。还是我写的代码有问题,不过正确率确实好高,跑一会就97%了
TensorFlow入门案例的更多相关文章
- 资源 | 数十种TensorFlow实现案例汇集:代码+笔记
选自 Github 机器之心编译 参与:吴攀.李亚洲 这是使用 TensorFlow 实现流行的机器学习算法的教程汇集.本汇集的目标是让读者可以轻松通过案例深入 TensorFlow. 这些案例适合那 ...
- TensorFlow 入门之手写识别(MNIST) softmax算法
TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...
- 数十种TensorFlow实现案例汇集:代码+笔记(转)
转:https://www.jiqizhixin.com/articles/30dc6dd9-39cd-406b-9f9e-041f5cbf1d14 这是使用 TensorFlow 实现流行的机器学习 ...
- 数十种TensorFlow实现案例汇集:代码+笔记
这是使用 TensorFlow 实现流行的机器学习算法的教程汇集.本汇集的目标是让读者可以轻松通过案例深入 TensorFlow. 这些案例适合那些想要清晰简明的 TensorFlow 实现案例的初学 ...
- TensorFlow 入门之手写识别(MNIST) softmax算法 二
TensorFlow 入门之手写识别(MNIST) softmax算法 二 MNIST Fly softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...
- SpringMVC入门案例及请求流程图(关于处理器或视图解析器或处理器映射器等的初步配置)
SpringMVC简介:SpringMVC也叫Spring Web mvc,属于表现层的框架.Spring MVC是Spring框架的一部分,是在Spring3.0后发布的 Spring结构图 Spr ...
- SpringMvc核心流程以及入门案例的搭建
1.什么是SpringMvc Spring MVC属于SpringFrameWork的后续产品,已经融合在Spring Web Flow里面.Spring 框架提供了构建 Web 应用程序的全功能 M ...
- Struts2第一个入门案例
一.如何获取Struts2,以及Struts2资源包的目录结构的了解 Struts的官方地址为http://struts.apache.org 在他的主页当中,我们可以通过左侧的Apache ...
- MyBatis入门案例、增删改查
一.MyBatis入门案例: ①:引入jar包 ②:创建实体类 Dept,并进行封装 ③ 在Src下创建大配置mybatis-config.xml <?xml version="1.0 ...
随机推荐
- hibernate事务隔离机制
事务的基本概念 ACID:A是atomicity(原子性),C是consistency(一致性),I是isolation(隔离性),D是durability(持久性) 事务隔离级别从低到高: 读取未提 ...
- pod优先级与抢占测试
# kubectl describe node k8s-n2Name: k8s-n2Roles: <none>Labels: ...
- conductor元数据定义
Task Definition conductor维护工作任务类型的注册表. 必须在工作流中使用之前注册任务类型. 例如: { "name": "encode_task& ...
- 获取iframe内的元素
$("#iframeID").contents().find("#index_p") 2获取父窗体的值 $('#father', parent.document ...
- Spring Cloud feign
Spring Cloud feign使用 前言 环境准备 应用模块 应用程序 应用启动 feign特性 综上 1. 前言 我们在前一篇文章中讲了一些我使用过的一些http的框架 服务间通信之Http框 ...
- Mysql操作日志
任何一种数据库中,都有各种各样的日志.MySQL也不例外,在Mysql中有4种不同的日志.分别错误日志.二进制日志.查询日志和慢查询日志.这些日志记录着Mysql数据库不同方面的踪迹.下文将介绍这4种 ...
- Idea多个module下maven的pom.xml失效的问题
今天在Idea中配置spring-cloud时,配置了两个module,结果其中一个module的pom.xml失效了.. 解决方法: 1.点击Idea右侧的Maven Project 2.点击&qu ...
- mysql基本的增删改查和条件语句
增 insert into 表名(列名,列名......) values("test1",23),("test2",23),("test3" ...
- My97DatePicker 日历控件
My97DatePicker 是一款非常强大的日历控件,使用也非常简单,也能修改源码,牛逼我就不吹了,自己用用看 使用 1.引入 <script language="javascrip ...
- access数据库收缩(压缩)
一般是因为表中有大量没用的数据,把没用的数据全部删除 菜单栏的“工具”——“数据库实用工具”——“压缩和修复数据库” OK啦