import tensorflow as tf

INPUT_NODE = 784
OUTPUT_NODE = 10 IMAGE_SIZE = 28
NUM_CHANNELS = 1
NUM_LABELS = 10 CONV1_DEEP = 32
CONV1_SIZE = 5 CONV2_DEEP = 64
CONV2_SIZE = 5 FC_SIZE = 512 def inference(input_tensor, train, regularizer):
with tf.variable_scope('layer1-conv1'):
conv1_weights = tf.get_variable(
"weight", [CONV1_SIZE, CONV1_SIZE, NUM_CHANNELS, CONV1_DEEP],
initializer=tf.truncated_normal_initializer(stddev=0.1))
conv1_biases = tf.get_variable("bias", [CONV1_DEEP], initializer=tf.constant_initializer(0.0))
conv1 = tf.nn.conv2d(input_tensor, conv1_weights, strides=[1, 1, 1, 1], padding='SAME')
relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases)) with tf.name_scope("layer2-pool1"):
pool1 = tf.nn.max_pool(relu1, ksize = [1,2,2,1],strides=[1,2,2,1],padding="SAME") with tf.variable_scope("layer3-conv2"):
conv2_weights = tf.get_variable(
"weight", [CONV2_SIZE, CONV2_SIZE, CONV1_DEEP, CONV2_DEEP],
initializer=tf.truncated_normal_initializer(stddev=0.1))
conv2_biases = tf.get_variable("bias", [CONV2_DEEP], initializer=tf.constant_initializer(0.0))
conv2 = tf.nn.conv2d(pool1, conv2_weights, strides=[1, 1, 1, 1], padding='SAME')
relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_biases)) with tf.name_scope("layer4-pool2"):
pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
pool_shape = pool2.get_shape().as_list()
nodes = pool_shape[1] * pool_shape[2] * pool_shape[3]
reshaped = tf.reshape(pool2, [pool_shape[0], nodes]) with tf.variable_scope('layer5-fc1'):
fc1_weights = tf.get_variable("weight", [nodes, FC_SIZE],
initializer=tf.truncated_normal_initializer(stddev=0.1))
if regularizer != None: tf.add_to_collection('losses', regularizer(fc1_weights))
fc1_biases = tf.get_variable("bias", [FC_SIZE], initializer=tf.constant_initializer(0.1)) fc1 = tf.nn.relu(tf.matmul(reshaped, fc1_weights) + fc1_biases)
if train: fc1 = tf.nn.dropout(fc1, 0.5) with tf.variable_scope('layer6-fc2'):
fc2_weights = tf.get_variable("weight", [FC_SIZE, NUM_LABELS],
initializer=tf.truncated_normal_initializer(stddev=0.1))
if regularizer != None: tf.add_to_collection('losses', regularizer(fc2_weights))
fc2_biases = tf.get_variable("bias", [NUM_LABELS], initializer=tf.constant_initializer(0.1))
logit = tf.matmul(fc1, fc2_weights) + fc2_biases return logit
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import LeNet5_infernece
import os
import numpy as np BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.01
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 6000
MOVING_AVERAGE_DECAY = 0.99 def train(mnist):
# 定义输出为4维矩阵的placeholder
x = tf.placeholder(tf.float32, [
BATCH_SIZE,
LeNet5_infernece.IMAGE_SIZE,
LeNet5_infernece.IMAGE_SIZE,
LeNet5_infernece.NUM_CHANNELS],
name='x-input')
y_ = tf.placeholder(tf.float32, [None, LeNet5_infernece.OUTPUT_NODE], name='y-input') regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
y = LeNet5_infernece.inference(x,False,regularizer)
global_step = tf.Variable(0, trainable=False) # 定义损失函数、学习率、滑动平均操作以及训练过程。
variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,
staircase=True) train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
with tf.control_dependencies([train_step, variables_averages_op]):
train_op = tf.no_op(name='train') # 初始化TensorFlow持久化类。
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
for i in range(TRAINING_STEPS):
xs, ys = mnist.train.next_batch(BATCH_SIZE) reshaped_xs = np.reshape(xs, (
BATCH_SIZE,
LeNet5_infernece.IMAGE_SIZE,
LeNet5_infernece.IMAGE_SIZE,
LeNet5_infernece.NUM_CHANNELS))
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: reshaped_xs, y_: ys}) if i % 1000 == 0:
print("After %d training step(s), loss on training batch is %g." % (step, loss_value)) def main(argv=None):
mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)
train(mnist) if __name__ == '__main__':
main()

吴裕雄--天生自然python Google深度学习框架:经典卷积神经网络模型的更多相关文章

  1. 吴裕雄--天生自然python Google深度学习框架:Tensorflow实现迁移学习

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  2. 吴裕雄--天生自然python Google深度学习框架:图像识别与卷积神经网络

  3. 吴裕雄--天生自然python Google深度学习框架:MNIST数字识别问题

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_NODE = 784 ...

  4. 吴裕雄--天生自然python Google深度学习框架:深度学习与深层神经网络

  5. 吴裕雄--天生自然python Google深度学习框架:TensorFlow实现神经网络

    http://playground.tensorflow.org/

  6. 吴裕雄--天生自然python Google深度学习框架:Tensorflow基础应用

    import tensorflow as tf a = tf.constant([1.0, 2.0], name="a") b = tf.constant([2.0, 3.0], ...

  7. 吴裕雄--天生自然python Google深度学习框架:人工智能、深度学习与机器学习相互关系介绍

  8. 吴裕雄--天生自然神经网络与深度学习实战Python+Keras+TensorFlow:Bellman函数、贪心算法与增强性学习网络开发实践

    !pip install gym import random import numpy as np import matplotlib.pyplot as plt from keras.layers ...

  9. 吴裕雄--天生自然神经网络与深度学习实战Python+Keras+TensorFlow:使用TensorFlow和Keras开发高级自然语言处理系统——LSTM网络原理以及使用LSTM实现人机问答系统

    !mkdir '/content/gdrive/My Drive/conversation' ''' 将文本句子分解成单词,并构建词库 ''' path = '/content/gdrive/My D ...

随机推荐

  1. EUI库 - 9 - 数据集合 - 列表

      List 和DataGroup的区别 1 选中一项 会触发 eui.ItemEvent.ITEM_TAP 事件, 2 有选中项的概念,可以设置 List 里的默认选中项    selectedIn ...

  2. OpenResty从入门到开发一个网关服务(使用etcd作为注册中心)

    简介 OpenResty(也称为 ngx_openresty)是一个全功能的 Web 应用服务器.它打包了标准的 Nginx 核心,很多的常用的第三方模块,以及它们的大多数依赖项. 通过揉和众多设计良 ...

  3. 201812-2 小明放学 Java

    思路: 红绿灯每种灯亮划分区间,在[0,r]区间内红灯亮,在(r,g+r]区间内绿灯亮,在(r+g,r+g+y]区间内黄灯亮,在划分好区间后只需要判断当小明到达红绿灯时是哪个灯在亮,就可以判断出通过红 ...

  4. 启动mysql遇到问题Can't connect to local MySQL server through socket '/tmp/mysql.sock' (2)

    在mysql的启动过程中有时会遇到下述错误 Can't connect to local MySQL server through socket '/tmp/mysql.sock' (2) 请问mys ...

  5. Linux误删所有内核,恢复内核的解决办法

    前言 我用df -h命令查看磁盘使用情况的时候发现,系统根目录空间已经比较小了,于是我就使用clean命令对系统内核进行清理,一不小心,就把所有的内核删除了,你很有可能也是我的这种经历,非常的崩溃.好 ...

  6. Fedora、SuSE、Redhat、Ubuntu、Centos

    想学Linux,但版本太多了,如Fedora.SuSE.Redhat.Ubuntu等,不免让人眼花缭乱,那么初学者该如何选择呢?也许很多人会不屑的说,Linux不就是个操作系统么.错!Linux不是一 ...

  7. JAVAEE 和项目开发(第一课:浏览器和服务器的交互模式和HTTP协议的概念和介绍)

    互联网的发展非常迅速,但是万变不离其宗.学习 web 开发,需要我们对互 联的交互机制有一定的了解.为了更好的理解并掌握 Servlet,在正式学习 Servlet之前需要对 web 开发中客户端和服 ...

  8. 使用maven打包问题

    项目打包:选择项目 右键->run as-> maven install . 项目中使用的是maven项目,将项目打包成war的时候有时候会出现 出现这种情况的时候解决步骤如下: 选择要打 ...

  9. HTML5 SVG应用(1)——loading效果

    先看一下效果: 链接 代码: <svg version="1.1" id="loader-1" xmlns="http://www.w3.org ...

  10. Windows下MariaDB数据搬家问题

    背景:公司买了一台服务器要将原来老服务器数据库数据导入新的服务器上,两台服务器环境如下 做好准备后开始实施,老数据库进行停库 找到MariaDB的安装目录下的整data目录进行拷贝,然后到新服务器进行 ...