利用Tensorflow实现神经网络模型
首先看一下神经网络模型,一个比较简单的两层神经。

代码如下:
# 定义参数
n_hidden_1 = 256 #第一层神经元
n_hidden_2 = 128 #第二层神经元
n_input = 784 #输入大小,28*28的一个灰度图,彩图没有什么意义
n_classes = 10 #结果是要得到一个几分类的任务 # 输入和输出
x = tf.placeholder("float", [None, n_input])
y = tf.placeholder("float", [None, n_classes]) # 权重和偏置参数
stddev = 0.1
weights = {
'w1': tf.Variable(tf.random_normal([n_input, n_hidden_1], stddev=stddev)),
'w2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2], stddev=stddev)),
'out': tf.Variable(tf.random_normal([n_hidden_2, n_classes], stddev=stddev))
}
biases = {
'b1': tf.Variable(tf.random_normal([n_hidden_1])),
'b2': tf.Variable(tf.random_normal([n_hidden_2])),
'out': tf.Variable(tf.random_normal([n_classes]))
}
print ("NETWORK READY") def multilayer_perceptron(_X, _weights, _biases):
#第1层神经网络 = tf.nn.激活函数(tf.加上偏置量(tf.矩阵相乘(输入Data, 权重W1), 偏置参数b1))
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(_X, _weights['w1']), _biases['b1']))
#第2层的格式与第1层一样,第2层的输入是第1层的输出。
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, _weights['w2']), _biases['b2']))
#返回预测值
return (tf.matmul(layer_2, _weights['out']) + _biases['out']) # 预测
pred = multilayer_perceptron(x, weights, biases) # 计算损失函数和优化
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
optm = tf.train.GradientDescentOptimizer(learning_rate=0.001).minimize(cost)
corr = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accr = tf.reduce_mean(tf.cast(corr, "float")) # 初始化
init = tf.global_variables_initializer()
print ("FUNCTIONS READY") # 训练
training_epochs = 20
batch_size = 100
display_step = 4
# LAUNCH THE GRAPH
sess = tf.Session()
sess.run(init)
# 优化器
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size)
# 迭代训练
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
feeds = {x: batch_xs, y: batch_ys}
sess.run(optm, feed_dict=feeds)
avg_cost += sess.run(cost, feed_dict=feeds)
avg_cost = avg_cost / total_batch
# 打印结果
if (epoch+1) % display_step == 0:
print ("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))
feeds = {x: batch_xs, y: batch_ys}
train_acc = sess.run(accr, feed_dict=feeds)
print ("TRAIN ACCURACY: %.3f" % (train_acc))
feeds = {x: mnist.test.images, y: mnist.test.labels}
test_acc = sess.run(accr, feed_dict=feeds)
print ("TEST ACCURACY: %.3f" % (test_acc))
print ("OPTIMIZATION FINISHED")
利用Tensorflow实现神经网络模型的更多相关文章
- 通过TensorFlow训练神经网络模型
神经网络模型的训练过程其实质上就是神经网络参数的设置过程 在神经网络优化算法中最常用的方法是反向传播算法,下图是反向传播算法流程图: 从上图可知,反向传播算法实现了一个迭代的过程,在每次迭代的开始,先 ...
- kaggle赛题Digit Recognizer:利用TensorFlow搭建神经网络(附上K邻近算法模型预测)
一.前言 kaggle上有传统的手写数字识别mnist的赛题,通过分类算法,将图片数据进行识别.mnist数据集里面,包含了42000张手写数字0到9的图片,每张图片为28*28=784的像素,所以整 ...
- 利用Tensorflow实现卷积神经网络模型
首先看一下卷积神经网络模型,如下图: 卷积神经网络(CNN)由输入层.卷积层.激活函数.池化层.全连接层组成,即INPUT-CONV-RELU-POOL-FC池化层:为了减少运算量和数据维度而设置的一 ...
- 手写数字识别 ----卷积神经网络模型官方案例注释(基于Tensorflow,Python)
# 手写数字识别 ----卷积神经网络模型 import os import tensorflow as tf #部分注释来源于 # http://www.cnblogs.com/rgvb178/p/ ...
- 【TensorFlow/简单网络】MNIST数据集-softmax、全连接神经网络,卷积神经网络模型
初学tensorflow,参考了以下几篇博客: soft模型 tensorflow构建全连接神经网络 tensorflow构建卷积神经网络 tensorflow构建卷积神经网络 tensorflow构 ...
- Tensorflow 对上一节神经网络模型的优化
本节涉及的知识点: 1.在程序中查看变量的取值 2.张量 3.用张量重新组织输入数据 4.简化的神经网络模型 5.标量.多维数组 6.在TensorFlow中查看和设定张量的形态 7.用softmax ...
- Keras结合Keras后端搭建个性化神经网络模型(不用原生Tensorflow)
Keras是基于Tensorflow等底层张量处理库的高级API库.它帮我们实现了一系列经典的神经网络层(全连接层.卷积层.循环层等),以及简洁的迭代模型的接口,让我们能在模型层面写代码,从而不用仔细 ...
- tensorflow 神经网络模型概览;熟悉Eager 模式;
典型神经网络模型:(图片来源:https://github.com/madalinabuzau/tensorflow-eager-tutorials) 保持更新,更多内容请关注 cnblogs.com ...
- 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec
人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...
随机推荐
- ThinkPHP框架 3.2.2 获取系统常量信息 连接数据库 命名空间的理解
获取系统常量信息 随便一个方法里加上这句话都能获取到系统常量信息!! var_dump(get_defined_constants()); <?php namespace Admin\Contr ...
- npm构建保存 code ELIFECYCLE解决办法
参考文档https://blog.csdn.net/gh254172840/article/details/78871573 使用npm构建报错 解决办法,进入工作目录 rm -rf node_mod ...
- 关于tomcat服务器
如果遇到jsp代码反复运行不成功,并且不报错 而且代码也重复检查过,正确无误了 那么 就不要把精力放在代码上了 有可能是服务器的问题 重启下服务器试试 ……不要问我尽经历过什么
- [No0000150]VSVisualStudio提示图标,信号图标的含义
其右侧的图标表示这是一个接口类型__interface(或者是结构体类型) 其右侧图标表示这是一个类类型 其右侧图标表示这是一个.cpp文件(貌似还可以是.hpp等文件) 其右侧图标表示这是一个枚举类 ...
- 设置shell脚本静默方式输入密码方法
stty命令是一个终端处理工具.我们可以通过它来实现静默方式输入密码,脚本如下 #!/bin/sh echo –e “enter password:” stty –echo ...
- AngularJs 常用指令标签
1.ng-app:告诉Angular他应该管理页面的那一部分,可以放在html元素上也可以放在div等标签上 例:<html ng-app="problem"> 2.n ...
- c# http get post转义HttpUtility.UrlEncode
//该数据如果要http get.post提交,需要经过转义,否则该数据中含& ''等字符会导致意外错误.需要转义.这里用HttpUtility.UrlEncode来转义.接收方无需反解析 s ...
- Javascript面向对象编程(三):非构造函数的继承(对象的深拷贝与浅拷贝)
Javascript面向对象编程(三):非构造函数的继承 作者: 阮一峰 日期: 2010年5月24日 这个系列的第一部分介绍了"封装",第二部分介绍了使用构造函数实现&quo ...
- 洛谷P1966 火柴排队 贪心+离散化+逆序对(待补充QAQ
正解: 贪心+离散化+逆序对 解题报告: 链接在这儿呢quq 这题其实主要难在想方法吧我觉得?学长提点了下说用贪心之后就大概明白了,感觉没有很难 但是离散化这里还是挺有趣的,因为并不是能很熟练地掌握离 ...
- mysql 查询优化杂谈
一.把某些判断移动到应用层 我们需要在一张表里面删除某种类型的数据,大概的表结构类似这样: CREATE TABLE t ( id INT, tp ENUM ("t1", &quo ...