mnist手写数字问题初体验
上一篇我们提到了回归问题中的梯度下降算法,而且我们知道线性模型只能解决简单的线性回归问题,对于高维图片,线性模型不能完成这样复杂的分类任务。那么是不是线性模型在离散值预测或图像分类问题中就没有用武之地了呢?
本篇我们就套用regression中的部分机制来处理classification中的问题。
在这里首先介绍一下激活函数。
所谓激活函数,实际上就是引入非线性因子,将线性模型去线性化,增强模型的表达能力。ReLU激活函数是我要介绍的第一个激活函数,其定义式为φ(z)=max{0,z},图像表示如下:
简单的说relu就是一个取最大值的函数,在负区间取值为0,正区间取值不变,这种操作被称为单侧抑制(输出为0时代表神经元不会被激活)。单侧抑制的特点就是同一时间只会有一部分神经元被激活(结合函数图像可以看出),也就使得神经元具有了稀疏激活性。加入relu激活函数的神经元被称作整流线性单元,它与线性单元非常相似,唯一的区别就是在一半定义域上输出为0。整流线性单元易于优化,当其处于激活状态时(输出不为0),它的一阶导数能够保持一个较大值(等于1),并且处处一致,它的二阶导数几乎处处为0,这样的好处就是避免了梯度下降时的梯度消失问题(可参考前一篇回归问题的随笔)。
简单介绍了激活函数,那么是不是将激活函数引入我们的线性模型out=X@w+b就能使其解决复杂的图像分类问题了呢?
很显然不是的,虽然加了激活函数,但是我们可以看到模型变为out=relu(X@w+b) 依然还是太简单。那么怎么办呢?
我们可以联系一下零件加工的流程,从原料到成品,零件的加工经历了多个工序,期间每一道工序都是由前一道工序为基础,这时候,原料就相当于神经网络的输入,成品零件就相当于神经网络的输出,他们中间并不是也不能一步到位,而是经过若干“隐藏”的工序一步一步的生成产品。我们的模型同样可以借助于这种思想。即给数据处理多添加几道所谓的“工序”,我们称之为“隐藏层”,因为我们关心的只有模型的输入和输出,隐藏层的数据是我们不可见的(当然也可以在运行过程中打印出来方便调试),下面我们就利用这样的思想来解决mnist手写数字分类问题。
我们使用的是mnist数据集,也是深度学习的基础入门数据集。它一共有70k张不同的手写数字图片,其中60k用来训练模型,10k用来评估模型,且所有图片均为28*28的灰度图。我们首先设计一个稍微复杂的模型
h1=relu(X@w1+b1)
h2=relu(h1@w2+b2)
out=relu(h2@w3+b3)
其中X为输入,out为输出,h1、h2均为隐藏层,且除输入层外每一层的输入均是前一层的输出。
首先我们将输入的28*28*1的图片扁平化,即将每张图片转化成784维的向量(28*28=784),这样的好处是可以以矩阵的形式同时喂入多张图片(每一行向量为一张灰度图的信息),提高效率。对于输出out,我们令其输出一个10维的向量,代表10个数字的概率。模型可以用以下公式概括:
out=relu { relu { relu[ X@w1+b1 ] @w2+b2 }@w3+b3 }
pred=argmax(out)
loss=MSE(out,label) (均方误差损失函数即loss=∑(label-out)2)
minimize loss→[w1',b1',w2',b2',w3',b3']
参数调整完成后,可以对新的输入x进行运算从而得到对应的输出
代码如下:
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, datasets # 屏蔽通知和警告信息,减少用处不大的问题输出
os.environ['TF_CPP_MIN_LOG_LEVEL']='' (x, y), (x_val, y_val) = datasets.mnist.load_data()
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
print(x.shape, y.shape)
train_dataset = tf.data.Dataset.from_tensor_slices((x, y))
train_dataset = train_dataset.batch(200) # 搭建网络结构
model = keras.Sequential([
layers.Dense(512, activation='relu'),
layers.Dense(256, activation='relu'),
layers.Dense(10)]) # 初始化优化器为梯度下降优化器
optimizer = optimizers.SGD(learning_rate=0.001) def train_epoch(epoch): # Step4.循环迭代
for step, (x, y) in enumerate(train_dataset): with tf.GradientTape() as tape:
# 将输入数据压平 [b, 28, 28] => [b, 784]
x = tf.reshape(x, (-1, 28*28))
# Step1. 计算输出
# 输入域数据经过神经网络降维 [b, 784] => [b, 10]
out = model(x)
# Step2. 计算损失
loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0] # Step3. 优化更新参数 w1, w2, w3, b1, b2, b3
grads = tape.gradient(loss, model.trainable_variables)
# w' = w - lr * grad
optimizer.apply_gradients(zip(grads, model.trainable_variables)) if step % 100 == 0:
print(epoch, step, 'loss:', loss.numpy()) def train(): for epoch in range(30): train_epoch(epoch) if __name__ == '__main__':
train()
运行结果如下:
可以看到损失从初始的1.65降到0.25,在这里我们先只对mnist进行一个初步探索,测试一下模型的表现,后续会通过一些更好的优化方法来不断改良我们的模型。
mnist手写数字问题初体验的更多相关文章
- MindSpore手写数字识别初体验,深度学习也没那么神秘嘛
摘要:想了解深度学习却又无从下手,不如从手写数字识别模型训练开始吧! 深度学习作为机器学习分支之一,应用日益广泛.语音识别.自动机器翻译.即时视觉翻译.刷脸支付.人脸考勤--不知不觉,深度学习已经渗入 ...
- mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)
前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- 深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- 第三节,CNN案例-mnist手写数字识别
卷积:神经网络不再是对每个像素做处理,而是对一小块区域的处理,这种做法加强了图像信息的连续性,使得神经网络看到的是一个图像,而非一个点,同时也加深了神经网络对图像的理解,卷积神经网络有一个批量过滤器, ...
- mnist 手写数字识别
mnist 手写数字识别三大步骤 1.定义分类模型2.训练模型3.评价模型 import tensorflow as tfimport input_datamnist = input_data.rea ...
- 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型
持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...
- Tensorflow可视化MNIST手写数字训练
简述] 我们在学习编程语言时,往往第一个程序就是打印“Hello World”,那么对于人工智能学习系统平台来说,他的“Hello World”小程序就是MNIST手写数字训练了.MNIST是一个手写 ...
随机推荐
- 【人类观察所】"当代人"正经历的生活
一."即时满足"的互联网 "轻微烦躁,偶尔自燃,当代生活多数时刻的心情基调." 如果你出生于上个世纪,应该能明白木心的<从前慢>里的 「从前的日色变 ...
- opencv简单实用(cv2)
一.介绍 安装:pip install opencv-python OpenCV是一个基于BSD许可(开源)发行的跨平台计算机视觉库,可以运行在Linux.Windows.Android和Mac OS ...
- 网络设备 密码、用户级别 AAA授权 的管理
一.进入 特权模式 密码 设置访问网络设备特权模式口令 cisco>enable cisco#config terminal cisco(config)#enable password 密码 e ...
- Java8尽管很香,你想过升级到Java11吗?会踩那些坑?
目前最新JDK 11,Oracle会一直维护到2026年. Java11的新特性 1.更新支持到Unicode 10编码 Unicode 10(version 10.0 of the Unicode ...
- 使用PropTypes进行类型检查
原文地址 1.组件特殊属性——propTypes 对Component设置propTypes属性,可以为Component的props属性进行类型检查. import PropTypes from ' ...
- hive内置方法一览
引用 https://www.cnblogs.com/qingyunzong/p/8744593.html#_label0 官方文档 https://cwiki.apache.org/confluen ...
- 基于JavaSwing开发银行信用卡管理系统
开发环境: Windows操作系统开发工具: MyEclipse10/Eclipse+Jdk+Mysql数据库 运行效果图 源码及原文链接:https://javadao.xyz/forum.php? ...
- Openshift部署流程介绍
背景 Openshift是一个开源容器云平台,是一个基于主流的容器技术Docker和Kubernetes构建的云平台.Openshift底层以Docker作为容器引擎驱动,以Kubernetes 作为 ...
- RHEL7开机不能正常进入系统(图形化界面)
今天在重启RHEL7的虚拟机后一直无法正常开机,一直提示输入管理员密码,如下图所示: 输入密码后进入命令行模式,经排查出现此现象的问题是在挂载银盘的时候文件格式写错,在格式化硬盘的时候格式化的是xfs ...
- You are my great sunshine
"何为孤寂?" "清风,艳日,无笑意." "可否具体?" "左拥,右抱,无情欲." "可否再具体?" ...