tensorflow中使用mnist数据集训练全连接神经网络

——学习曹健老师“人工智能实践:tensorflow笔记”的学习笔记, 感谢曹老师

前期准备:mnist数据集下载,并存入data目录:

文件列表:四个文件,分别为训练和测试集数据

Four files are available on 官网  http://yann.lecun.com/exdb/mnist/ :

train-images-idx3-ubyte.gz:  training set images (9912422 bytes)
train-labels-idx1-ubyte.gz
training set labels (28881 bytes)


t10k-images-idx3-ubyte.gz:  
test set images (1648877 bytes)


t10k-labels-idx1-ubyte.gz:  
test set labels (4542 bytes)

一、主要思路:

1、训练集输入数据X为28×28图像,和 Y_  onehot label

2、构建一个三层NN,input layer,one hidden layer,outputlayer

3、使用指数衰减学习率,交叉熵loss,移动平均loss构建NN

二、主要代码:

forward构建:

//mnist_forward.py

import tensorflow as tf
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500

def get_weight(shape, regularizer):
    w = tf.Variable(tf.truncated_normal(shape, stddev=0.1))
    if regularizer != None:
        tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
    return w

def get_bias(shape):
    b = tf.Variable(tf.zeros(shape))
    return b

def forward(x,regularizer):
    w1 = get_weight([INPUT_NODE, LAYER1_NODE], regularizer)
    b1 = get_bias([LAYER1_NODE])
    y1 = tf.nn.relu(tf.matmul(x, w1) + b1)

w2 = get_weight([LAYER1_NODE, OUTPUT_NODE], regularizer)
    b2 = get_bias([OUTPUT_NODE])
    y = tf.matmul(y1, w2) + b2
    return y

backward构建:

//mnist_backward.py

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os

BATCH_SIZE = 200
LEARNING_RATE_BASE = 0.1
LEARNING_RATE_DECAY = 0.99
REGULARIZER = 0.0001
STEPS = 50000
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "./model/"
MODEL_NAME = "mnist_model"

def backward(mnist):
    x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
    y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
    y = mnist_forward.forward(x, REGULARIZER)
    global_step = tf.Variable(0, trainable=False)

ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_,1))
    cem = tf.reduce_mean(ce)
    loss = cem + tf.add_n(tf.get_collection('losses'))

learning_rate = tf.train.exponential_decay(
            LEARNING_RATE_BASE,
            global_step,
            mnist.train.num_examples / BATCH_SIZE,
            LEARNING_RATE_DECAY,
            staircase = True
            )

train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step = global_step)
    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    ema_op = ema.apply(tf.trainable_variables())
    with tf.control_dependencies([train_step, ema_op]):
        train_op = tf.no_op(name='train')

saver = tf.train.Saver()
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
   sess.run(init_op)

for i in range(STEPS):
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
            if i % 1000 ==0:
                print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step = global_step)

def main():
    mnist = input_data.read_data_sets("./data/", one_hot = True)
    backward(mnist)

if __name__ == '__main__':
    main()

tensorflow中使用mnist数据集训练全连接神经网络-学习笔记的更多相关文章

  1. 【TensorFlow/简单网络】MNIST数据集-softmax、全连接神经网络,卷积神经网络模型

    初学tensorflow,参考了以下几篇博客: soft模型 tensorflow构建全连接神经网络 tensorflow构建卷积神经网络 tensorflow构建卷积神经网络 tensorflow构 ...

  2. 深度学习tensorflow实战笔记(1)全连接神经网络(FCN)训练自己的数据(从txt文件中读取)

    1.准备数据 把数据放进txt文件中(数据量大的话,就写一段程序自己把数据自动的写入txt文件中,任何语言都能实现),数据之间用逗号隔开,最后一列标注数据的标签(用于分类),比如0,1.每一行表示一个 ...

  3. TensorFlow之DNN(二):全连接神经网络的加速技巧(Xavier初始化、Adam、Batch Norm、学习率衰减与梯度截断)

    在上一篇博客<TensorFlow之DNN(一):构建“裸机版”全连接神经网络>中,我整理了一个用TensorFlow实现的简单全连接神经网络模型,没有运用加速技巧(小批量梯度下降不算哦) ...

  4. TensorFlow之DNN(一):构建“裸机版”全连接神经网络

    博客断更了一周,干啥去了?想做个聊天机器人出来,去看教程了,然后大受打击,哭着回来补TensorFlow和自然语言处理的基础了.本来如意算盘打得挺响,作为一个初学者,直接看项目(不是指MINIST手写 ...

  5. Tensorflow 多层全连接神经网络

    本节涉及: 身份证问题 单层网络的模型 多层全连接神经网络 激活函数 tanh 身份证问题新模型的代码实现 模型的优化 一.身份证问题 身份证号码是18位的数字[此处暂不考虑字母的情况],身份证倒数第 ...

  6. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  7. Caffe系列4——基于Caffe的MNIST数据集训练与测试(手把手教你使用Lenet识别手写字体)

    基于Caffe的MNIST数据集训练与测试 原创:转载请注明https://www.cnblogs.com/xiaoboge/p/10688926.html  摘要 在前面的博文中,我详细介绍了Caf ...

  8. tensorflow读取本地MNIST数据集

    tensorflow读取本地MNIST数据集 数据放入文件夹(不要解压gz): >>> import tensorflow as tf >>> from tenso ...

  9. Keras入门——(1)全连接神经网络FCN

    Anaconda安装Keras: conda install keras 安装完成: 在Jupyter Notebook中新建并执行代码: import keras from keras.datase ...

随机推荐

  1. ios宏定义字符串

    ios宏定义字符串 #define objcString(str) @""#str"" 使用效果: objcString(字符串)

  2. 由使用request-promise-native想到的异步处理方法

    由使用request-promise-native想到的异步处理方法 问题场景 因为js语言的特性,使用node开发程序的时候经常会遇到异步处理的问题.对于之前专长App开发的我来说,会纠结node中 ...

  3. ios开发遇到的问题

    运行后界面空白,Xcode跳转到APPDelegate.swift文件提示如下 第一种可能原因: 做输出口后在代码中重新命名了输出口 解决方法: 右键控件关闭输出口的连接,变回+号,将它重新连到代码的 ...

  4. javascript设计模式系列二-封装

    JavaScript封装: var Book = function (id, name, price) { this.id = id, this.name = name, this.price = p ...

  5. 1001. 温度转换 (Standard IO)

    1001. 温度转换 (Standard IO) 时间限制: 1000 ms  空间限制: 262144 KB  具体限制   题目描述 将输入的华氏温度转换为摄氏温度.由华氏温度F与摄氏温度C的转换 ...

  6. VS2017 编译 QT5.10.1 X64位 静态库 MT

    参考文章 https://blog.csdn.net/Devout_programming/article/details/78827112 准备工作* Supported compiler (Vis ...

  7. 如何做好一个优秀的web项目心得

    最近利用空余的时间(坐公交车看教程视频),想了很多自己做的做果项目的优缺点,重新了解了前后端分离,前端工程化等概念学习,思考如何打造好一个优秀的web前端项目. 前端准备篇 前端代码规范:制定前端开发 ...

  8. 10JavaScript作用域

    (作用域可访问变量的集合) 1.JavaScript 作用域 在 JavaScript 中, 对象和函数同样也是变量. 在 JavaScript 中, 作用域为可访问变量,对象,函数的集合. Java ...

  9. linux3.4.2内核之块设备驱动

    1. 基本概念: 扇区(Sectors):任何块设备硬件对数据处理的基本单位.通常,1个扇区的大小为512byte.(对设备而言) 块  (Blocks):由Linux制定对内核或文件系统等数据处理的 ...

  10. python教程(零)·前言

    本教程是作者根据自己学习python的经验写下的,一来是想将经验分享给对python同样感兴趣的小白(大神请忽略),二来是想借此加深本人对python的理解,温故而知新. 学习基础 本教程面向的读者, ...