import tensorflow as tf
from tensorflow import keras
from keras import Sequential,datasets, layers, optimizers, metrics def preprocess(x, y):
"""数据处理函数"""
x = tf.cast(x, dtype=tf.float32) / 255.
y = tf.cast(y, dtype=tf.int32)
return x, y # 加载数据
(x, y), (x_test, y_test) = datasets.fashion_mnist.load_data()
print(x.shape, y.shape) # 处理train数据
batch_size = 128
db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(10000).batch(batch_size) # 处理test数据
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.map(preprocess).batch(batch_size) # # 生成train数据的迭代器
db_iter = iter(db)
sample = next(db_iter)
print(f'batch: {sample[0].shape,sample[1].shape}') # 设计网络结构
model = Sequential([
layers.Dense(256, activation=tf.nn.relu), # [b,784] --> [b,256]
layers.Dense(128, activation=tf.nn.relu), # [b,256] --> [b,128]
layers.Dense(64, activation=tf.nn.relu), # [b,128] --> [b,64]
layers.Dense(32, activation=tf.nn.relu), # [b,64] --> [b,32]
layers.Dense(10) # [b,32] --> [b,10], 330=32*10+10
]) model.build(input_shape=[None, 28 * 28])
model.summary() # 调试
# w = w - lr*grad
optimizer = optimizers.Adam(lr=1e-3) # 优化器,加快训练速度 def main():
"""主运行函数"""
for epoch in range(10):
for step, (x, y) in enumerate(db):
# x:[b,28,28] --> [b,784]
# y:[b]
x = tf.reshape(x, [-1, 28 * 28])
with tf.GradientTape() as tape:
# [b,784] --> [b,10]
logits = model(x)
y_onehot = tf.one_hot(y, depth=10)
# [b]
loss_mse = tf.reduce_mean(tf.losses.MSE(y_onehot, logits))
loss_ce = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True))
grads = tape.gradient(loss_ce, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
if step % 100 == 0:
print(epoch, step, f'loss: {float(loss_ce),float(loss_mse)}') # test
total_correct = 0
total_num = 0
for x, y in db_test:
# x:[b,28,28] --> [b,784]
# y:[b]
x = tf.reshape(x, [-1, 28 * 28])
# [b,10]
logits = model(x)
# logits --> prob [b,10]
prob = tf.nn.softmax(logits, axis=1)
# [b,10] --> [b], int32
pred = tf.argmax(prob, axis=1)
pred = tf.cast(pred, dtype=tf.int32)
# pred:[b]
# y:[b]
# correct: [b], True: equal; False: not equal
correct = tf.equal(pred, y)
correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))
total_correct += int(correct)
total_num += x.shape[0]
acc = total_correct / total_num
print(epoch, f'test acc: {acc}') if __name__ == '__main__':
main()

吴裕雄--天生自然TensorFlow2教程:手写数字问题实战的更多相关文章

  1. 吴裕雄--天生自然TensorFlow2教程:前向传播(张量)- 实战

    手写数字识别流程 MNIST手写数字集7000*10张图片 60k张图片训练,10k张图片测试 每张图片是28*28,如果是彩色图片是28*28*3-255表示图片的灰度值,0表示纯白,255表示纯黑 ...

  2. 吴裕雄--天生自然TensorFlow2教程:函数优化实战

    import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D def himme ...

  3. 吴裕雄--天生自然TensorFlow2教程:反向传播算法

  4. 吴裕雄--天生自然TensorFlow2教程:链式法则

    import tensorflow as tf x = tf.constant(1.) w1 = tf.constant(2.) b1 = tf.constant(1.) w2 = tf.consta ...

  5. 吴裕雄--天生自然TensorFlow2教程:多输出感知机及其梯度

    import tensorflow as tf x = tf.random.normal([2, 4]) w = tf.random.normal([4, 3]) b = tf.zeros([3]) ...

  6. 吴裕雄--天生自然TensorFlow2教程:单输出感知机及其梯度

    import tensorflow as tf x = tf.random.normal([1, 3]) w = tf.ones([3, 1]) b = tf.ones([1]) y = tf.con ...

  7. 吴裕雄--天生自然TensorFlow2教程:损失函数及其梯度

    import tensorflow as tf x = tf.random.normal([2, 4]) w = tf.random.normal([4, 3]) b = tf.zeros([3]) ...

  8. 吴裕雄--天生自然TensorFlow2教程:激活函数及其梯度

    import tensorflow as tf a = tf.linspace(-10., 10., 10) a with tf.GradientTape() as tape: tape.watch( ...

  9. 吴裕雄--天生自然TensorFlow2教程:梯度下降简介

    import tensorflow as tf w = tf.constant(1.) x = tf.constant(2.) y = x * w with tf.GradientTape() as ...

随机推荐

  1. CAD

    文件另存为——Autocad.doc.SaveAs   一.前言 使用pyautocad编辑好cad图纸后,往往涉及到一个保存的问题,但是官方文档并未提及,所以只能自己来了,测试了好久,终于是找到了保 ...

  2. 在网页中插入背景音乐代码(html)

    有两种 分别用<bgsound>和<embed></embed>标签,当用<embed>插入背景音乐时可以设置宽度和高度为0,隐藏播放器. 二者的参数如 ...

  3. HTML学习(10)图像

    HTML图像标签<img>,没有闭合标签 <img src="" alt="" width="" height=" ...

  4. chrome firefox浏览器屏蔽百度热搜

    我是原文 操作 点击拦截元素,然后选择页面元素,添加

  5. @RequestMapping(value = {"list", ""})

    https://www.cnblogs.com/tongs/p/7486478.html   @RequestMapping是请求路径的注解 里面写两个value就是,路径可以是这两个, 第二个空,是 ...

  6. bugku flag在index里

    原题内容: http://120.24.86.145:8005/post/ Mark一下这道题,前前后后弄了两个多小时,翻了一下别的博主的wp感觉还是讲的太粗了,这里总结下自己的理解: 首先打开这道题 ...

  7. 安卓按键:读取txt开头出现未知字符的问题

    很多时候 我们读取txt 用traceprint输出后 最头上会莫名其妙多出一个问号 但是你用问号匹配他 却匹配不到  就是1个未知字符  这个到底是什么 怎么避免出现这个东西呢 这个主要是txt文件 ...

  8. 查询数据操作:distinct

    1.作用:distinct 去除重复记录.重复记录,指的是字段值,都相同的记录,而不是部分字段值相同的记录 与之相对的是all,表示所有.在MySQL中默认就是all. 2.例子: select ch ...

  9. socket udp编程的一些积累的记录

    接了个小活,要求写udp的客户端,循环接收服务端的固定的指令并显示数据 我设计的逻辑是,用户在界面输入框输入服务器ip.端口,随后udp连接,开启线程循环接收,接收指令,解析成数据,存在结构体的lis ...

  10. (原创)Windows下编译的Shell脚本不能再Linux中运行的解决办法

    一.原理 Windows编译的文件和Linux编译的文件格式不太一样,导致在Linux运行Shell脚本的时候会提示:/bin/bash^M: bad interpreter: 没有那个文件或目录. ...