TF之RNN:基于顺序的RNN分类案例对手写数字图片mnist数据集实现高精度预测—Jason niu
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) lr=0.001
training_iters=100000
batch_size=128 n_inputs=28
n_steps=28
n_hidden_units=128
n_classes=10 x=tf.placeholder(tf.float32, [None,n_steps,n_inputs])
y=tf.placeholder(tf.float32, [None,n_classes]) weights ={
'in':tf.Variable(tf.random_normal([n_inputs,n_hidden_units])),
'out':tf.Variable(tf.random_normal([n_hidden_units,n_classes])),
}
biases ={
'in':tf.Variable(tf.constant(0.1,shape=[n_hidden_units,])),
'out':tf.Variable(tf.constant(0.1,shape=[n_classes,])),
} def RNN(X,weights,biases): X=tf.reshape(X,[-1,n_inputs])
X_in=tf.matmul(X,weights['in'])+biases['in']
X_in=tf.reshape(X_in,[-1,n_steps,n_hidden_units])
lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units,forget_bias=1.0,state_is_tuple=True)
__init__state=lstm_cell.zero_state(batch_size, dtype=tf.float32)
outputs,states=tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=__init__state,time_major=False) outputs=tf.unpack(tf.transpose(outputs, [1,0,2]))
results=tf.matmul(outputs[-1],weights['out'])+biases['out']
return results pred =RNN(x,weights,biases)
cost =tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
train_op=tf.train.AdamOptimizer(lr).minimize(cost) correct_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))
with tf.Session() as sess:
sess.run(init)
step=0
while step*batch_size < training_iters:
batch_xs,batch_ys=mnist.train.next_batch(batch_size)
batch_xs=batch_xs.reshape([batch_size,n_steps,n_inputs])
sess.run([train_op],feed_dict={
x:batch_xs,
y:batch_ys,})
if step%20==0:
print(sess.run(accuracy,feed_dict={
x:batch_xs,
y:batch_ys,}))
step+=1
TF之RNN:基于顺序的RNN分类案例对手写数字图片mnist数据集实现高精度预测—Jason niu的更多相关文章
- 基于OpenCV的KNN算法实现手写数字识别
基于OpenCV的KNN算法实现手写数字识别 一.数据预处理 # 导入所需模块 import cv2 import numpy as np import matplotlib.pyplot as pl ...
- TF之RNN:matplotlib动态演示之基于顺序的RNN回归案例实现高效学习逐步逼近余弦曲线—Jason niu
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt BATCH_START = 0 TIME_STEP ...
- 用Keras搭建神经网络 简单模版(四)—— RNN Classifier 循环神经网络(手写数字图片识别)
# -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) from keras.datasets import mnist fro ...
- 机器学习框架ML.NET学习笔记【4】多元分类之手写数字识别
一.问题与解决方案 通过多元分类算法进行手写数字识别,手写数字的图片分辨率为8*8的灰度图片.已经预先进行过处理,读取了各像素点的灰度值,并进行了标记. 其中第0列是序号(不参与运算).1-64列是像 ...
- 机器学习框架ML.NET学习笔记【5】多元分类之手写数字识别(续)
一.概述 上一篇文章我们利用ML.NET的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断.思路很简单,就是 ...
- TF之RNN:TensorBoard可视化之基于顺序的RNN回归案例实现蓝色正弦虚线预测红色余弦实线—Jason niu
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt BATCH_START = 0 TIME_STEP ...
- LSTM用于MNIST手写数字图片分类
按照惯例,先放代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data #载入数据集 ...
- 用Keras搭建神经网络 简单模版(二)——Classifier分类(手写数字识别)
# -*- coding: utf-8 -*- import numpy as np np.random.seed(1337) #for reproducibility再现性 from keras.d ...
- 基于TensorFlow的MNIST手写数字识别-初级
一:MNIST数据集 下载地址 MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件 分别是test set images,test set labels,training se ...
随机推荐
- Modbus库开发笔记之六:Modbus RTU Master开发
这一节我们来封装最后一种应用(Modbus RTU Master应用),RTU主站的开发与TCP客户端的开发是一致的.同样的我们也不是做具体的应用,而是实现RTU主站的基本功能.我们将RTU主站的功能 ...
- Confluence 6 Cron 表达式
一个 cron 表达式是以 6-7 时间字段来定义一个计划任务是如何按照时间被执行的.每一个字段中的数据库而已为数字或者是一些特定的字符串来进行表达.每一个字段是使用空格或者 tab 进行分隔的. 下 ...
- Confluence 6 使用 WebDAV 客户端来对页面进行操作
下面的部分告诉你如何在不同的系统中来设置原生的 WebDAV 客户端,这个客户端通常显示在你操作系统的文件浏览器中,例如,Windows 的 Windows Explorer 或者 Linux 的 K ...
- Confluence 6 启用和禁用 Office 连接器
如果你希望限制访问 Office 连接器的所有组件或者部分组件,你可以禁用整个插件也可以禁用插件中的某个模块. 希望启用或禁用 Office 连接器模块: 进入 > 基本配置(General ...
- LeetCode(68):文本左右对齐
Hard! 题目描述: 给定一个单词数组和一个长度 maxWidth,重新排版单词,使其成为每行恰好有 maxWidth 个字符,且左右两端对齐的文本. 你应该使用“贪心算法”来放置给定的单词:也就是 ...
- checkbox 选中的id拼接长字符串
需求描述:为了做一个批量操作,需要获取到checkbox选中的项的id,并且把选中的id拼接成字符串. 解决思路:先获取到checkbox选中项,然后拼接.(这tm不废话么),问题的关键就是获取che ...
- laravel 里面结合关联查询 的when()用法
Laravel 5.6 里面的when用法: $name = $request->get('name'); //活动标题 $start_time = $request->get('star ...
- HTML&javaSkcript&CSS&jQuery&ajax-Css
CSS 1 .eg <head> <style> body{ background-color:#d0e4fe;} h1{ color:orange; text-alin:ce ...
- ftp的自动部署以及添加虚拟账户的脚本
#!/bin/bash #本脚本为自动化安装vsftp,使用虚拟用户认证登录ftp上传下载文件 echo =============================================== ...
- JMeter 中跨线程组 变量值传递的方法
关于jmeter中跨线程组 变量值传递的方法 找了好久,终于找到方法了,赶紧整理下来. 1.在线程组1 中使用__setProperty函数设置jmeter属性值(此值为全局变量值), ...