在TF1.8之后Keras被当作为一个内置API:tf.keras.

并且之前的下载语句会报错。

 mnist = input_data.read_data_sets('MNIST_data',one_hot=True)

下面给出Keras和TensorFlow两种方式的训练代码(附验证代码):

Keras:

import numpy as np
import matplotlib.pyplot as plt import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = x_train.reshape(60000,784)
x_test = x_test.reshape(10000,784) # 归一化
x_train = x_train/255
x_test = x_test/255 y_train = keras.utils.to_categorical(y_train,10)
y_test = keras.utils.to_categorical(y_test,10) model = Sequential()
model.add(Dense(512,activation='relu',input_shape=(784,)))
model.add(Dense(256,activation="relu"))
model.add(Dense(10,activation="softmax")) # 显示网络结构
model.summary() model.compile(optimizer=SGD(),loss='categorical_crossentropy',metrics=['accuracy'])
model.fit(x_train,y_train,batch_size=64,epochs=5,validation_data=(x_test,y_test)) score = model.evaluate(x_test,y_test) # 输出 loss 和 accuracy
print("loss",score[0])
print("accu",score[1]) # 输入样本生成输出预测。
predictions = model.predict([x_test]) # 预测
print(np.argmax(predictions[23])) # 查看图片 检查是否预测正确
plt.imshow(x_test[23])
plt.show()

TensorFlow:

代码来自TensorFlow官网

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
]) model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']) model.fit(x_train, y_train, epochs=5) model.evaluate(x_test, y_test) #验证代码同上 略

另附Keras与tf.keras的区别(https://www.zhihu.com/question/313111229/answer/606660552)

【TensorFlow 3】mnist数据集:与Keras对比的更多相关文章

  1. 一个简单的TensorFlow可视化MNIST数据集识别程序

    下面是TensorFlow可视化MNIST数据集识别程序,可视化内容是,TensorFlow计算图,表(loss, 直方图, 标准差(stddev)) # -*- coding: utf-8 -*- ...

  2. 深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归 1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)

    1. tf.matmul(X, w) # 进行点乘操作 参数说明:X,w都表示输入的数据, 2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False 参 ...

  3. TensorFlow 训练MNIST数据集(2)—— 多层神经网络

    在我的上一篇随笔中,采用了单层神经网络来对MNIST进行训练,在测试集中只有约90%的正确率.这次换一种神经网络(多层神经网络)来进行训练和测试. 1.获取MNIST数据 MNIST数据集只要一行代码 ...

  4. TensorFlow训练MNIST数据集(1) —— softmax 单层神经网络

    1.MNIST数据集简介 首先通过下面两行代码获取到TensorFlow内置的MNIST数据集: from tensorflow.examples.tutorials.mnist import inp ...

  5. 基于 tensorflow 的 mnist 数据集预测

    1. tensorflow 基本使用方法 2. mnist 数据集简介与预处理 3. 聚类算法模型 4. 使用卷积神经网络进行特征生成 5. 训练网络模型生成结果 how to install ten ...

  6. TensorFlow 下 mnist 数据集的操作及可视化

    from tensorflow.examples.tutorials.mnist import input_data 首先需要连网下载数据集: mnsit = input_data.read_data ...

  7. 《Hands-On Machine Learning with Scikit-Learn&TensorFlow》mnist数据集错误及解决方案

    最近在看这本书看到Chapter 3.Classification,是关于mnist数据集的分类,里面有个代码是 from sklearn.datasets import fetch_mldata m ...

  8. Tensorflow基础-mnist数据集

    MNIST数据集,每张图片包含28*28个像素,把一个数组展开成向量,长度为28*28=784,故数据集中mnist.train.images是一个形状为[60000,784]的张量,第一个维度数字用 ...

  9. 基于TensorFlow的MNIST数据集的实验

    一.MNIST实验内容 MNIST的实验比较简单,可以直接通过下面的程序加上程序上的部分注释就能很好的理解了,后面在完善具体的相关的数学理论知识,先记录在这里: 代码如下所示: import tens ...

  10. TensorFlow训练MNIST数据集(3) —— 卷积神经网络

    前面两篇随笔实现的单层神经网络 和多层神经网络, 在MNIST测试集上的正确率分别约为90%和96%.在换用多层神经网络后,正确率已有很大的提升.这次将采用卷积神经网络继续进行测试. 1.模型基本结构 ...

随机推荐

  1. hive -e和hive -f的区别(转)

    大家都知道,hive -f 后面指定的是一个文件,然后文件里面直接写sql,就可以运行hive的sql,hive -e 后面是直接用双引号拼接hivesql,然后就可以执行命令. 但是,有这么一个东西 ...

  2. python之数据分析pandas

    做数据分析的同学大部分入门都是从excel开始的,excel也是微软office系列评价最高的一种工具. 但当数据量超过百万行的时候,excel就无能无力了,python第三方包pandas极大的扩展 ...

  3. 【转载】BIO、NIO、AIO

    请看原文,排版更佳>转载请注明出处:http://blog.csdn.net/anxpp/article/details/51512200,谢谢! 本文会从传统的BIO到NIO再到AIO自浅至深 ...

  4. 汇编入门三-CPU工作原理

    本文为读书笔记,个人总结与摘抄自<汇编语言 第二版> 1.CPU从内存中读取数据,首先要获得存储单元的地址. 2.指明进行的操作,如存储或者读写 所以,CPU要进行操作总结为: 1.存储单 ...

  5. python算法与数据结构-数据结构中常用树的介绍(45)

    一.树的定义 树是一种非线性的数据结构,是由n(n >=0)个结点组成的有限集合.如果n==0,树为空树.如果n>0,树有一个特定的结点,根结点根结点只有直接后继,没有直接前驱.除根结点以 ...

  6. Python向FTP服务器上传文件

    上传 代码示例: #!/usr/bin/python # -*- coding:utf-8 -*- from ftplib import FTP ftp = FTP() # 打开调试级别2, 显示详细 ...

  7. Python开发【第九篇】: 并发编程

    内容概要 操作系统介绍 进程 线程 协程 二. 进程 python并发编程之多进程理论部分 在python程序中的进程操作 运行中的程序就是一个进程.所有的进程都是通过它的父进程来创建的.因此,运行起 ...

  8. C++类的完美单元测试方案——基于C++11扩展的friend语法

    版权相关声明:本文所述方案来自于<深入理解C++11—C++11新特性解析与应用>(Michael Wong著,机械工业出版社,2016.4重印)一书的学习. 项目管理中,C语言工程做单元 ...

  9. 线性模型之LDA和PCA推导

    线性模型之LDA和PCA 线性判别分析LDA LDA是一种无监督学习的降维技术. 思想:投影后类内方差最小,类间方差最大,即期望同类实例投影后的协方差尽可能小,异类实例的投影后的类中心距离尽量大. 二 ...

  10. redis 是如何做持久化的

    Redis 是一个键值对数据库服务器.基于内存存储数据,它常被用做缓存数据库,用来替代 memcached.官网:https://redis.io/ 什么是持久化? 持久化,指将数据存储到可永久保存的 ...