mnist手写数字问题初体验
上一篇我们提到了回归问题中的梯度下降算法,而且我们知道线性模型只能解决简单的线性回归问题,对于高维图片,线性模型不能完成这样复杂的分类任务。那么是不是线性模型在离散值预测或图像分类问题中就没有用武之地了呢?
本篇我们就套用regression中的部分机制来处理classification中的问题。
在这里首先介绍一下激活函数。
所谓激活函数,实际上就是引入非线性因子,将线性模型去线性化,增强模型的表达能力。ReLU激活函数是我要介绍的第一个激活函数,其定义式为φ(z)=max{0,z},图像表示如下:
简单的说relu就是一个取最大值的函数,在负区间取值为0,正区间取值不变,这种操作被称为单侧抑制(输出为0时代表神经元不会被激活)。单侧抑制的特点就是同一时间只会有一部分神经元被激活(结合函数图像可以看出),也就使得神经元具有了稀疏激活性。加入relu激活函数的神经元被称作整流线性单元,它与线性单元非常相似,唯一的区别就是在一半定义域上输出为0。整流线性单元易于优化,当其处于激活状态时(输出不为0),它的一阶导数能够保持一个较大值(等于1),并且处处一致,它的二阶导数几乎处处为0,这样的好处就是避免了梯度下降时的梯度消失问题(可参考前一篇回归问题的随笔)。
简单介绍了激活函数,那么是不是将激活函数引入我们的线性模型out=X@w+b就能使其解决复杂的图像分类问题了呢?
很显然不是的,虽然加了激活函数,但是我们可以看到模型变为out=relu(X@w+b) 依然还是太简单。那么怎么办呢?
我们可以联系一下零件加工的流程,从原料到成品,零件的加工经历了多个工序,期间每一道工序都是由前一道工序为基础,这时候,原料就相当于神经网络的输入,成品零件就相当于神经网络的输出,他们中间并不是也不能一步到位,而是经过若干“隐藏”的工序一步一步的生成产品。我们的模型同样可以借助于这种思想。即给数据处理多添加几道所谓的“工序”,我们称之为“隐藏层”,因为我们关心的只有模型的输入和输出,隐藏层的数据是我们不可见的(当然也可以在运行过程中打印出来方便调试),下面我们就利用这样的思想来解决mnist手写数字分类问题。
我们使用的是mnist数据集,也是深度学习的基础入门数据集。它一共有70k张不同的手写数字图片,其中60k用来训练模型,10k用来评估模型,且所有图片均为28*28的灰度图。我们首先设计一个稍微复杂的模型
h1=relu(X@w1+b1)
h2=relu(h1@w2+b2)
out=relu(h2@w3+b3)
其中X为输入,out为输出,h1、h2均为隐藏层,且除输入层外每一层的输入均是前一层的输出。
首先我们将输入的28*28*1的图片扁平化,即将每张图片转化成784维的向量(28*28=784),这样的好处是可以以矩阵的形式同时喂入多张图片(每一行向量为一张灰度图的信息),提高效率。对于输出out,我们令其输出一个10维的向量,代表10个数字的概率。模型可以用以下公式概括:
out=relu { relu { relu[ X@w1+b1 ] @w2+b2 }@w3+b3 }
pred=argmax(out)
loss=MSE(out,label) (均方误差损失函数即loss=∑(label-out)2)
minimize loss→[w1',b1',w2',b2',w3',b3']
参数调整完成后,可以对新的输入x进行运算从而得到对应的输出
代码如下:
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, datasets # 屏蔽通知和警告信息,减少用处不大的问题输出
os.environ['TF_CPP_MIN_LOG_LEVEL']='' (x, y), (x_val, y_val) = datasets.mnist.load_data()
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
print(x.shape, y.shape)
train_dataset = tf.data.Dataset.from_tensor_slices((x, y))
train_dataset = train_dataset.batch(200) # 搭建网络结构
model = keras.Sequential([
layers.Dense(512, activation='relu'),
layers.Dense(256, activation='relu'),
layers.Dense(10)]) # 初始化优化器为梯度下降优化器
optimizer = optimizers.SGD(learning_rate=0.001) def train_epoch(epoch): # Step4.循环迭代
for step, (x, y) in enumerate(train_dataset): with tf.GradientTape() as tape:
# 将输入数据压平 [b, 28, 28] => [b, 784]
x = tf.reshape(x, (-1, 28*28))
# Step1. 计算输出
# 输入域数据经过神经网络降维 [b, 784] => [b, 10]
out = model(x)
# Step2. 计算损失
loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0] # Step3. 优化更新参数 w1, w2, w3, b1, b2, b3
grads = tape.gradient(loss, model.trainable_variables)
# w' = w - lr * grad
optimizer.apply_gradients(zip(grads, model.trainable_variables)) if step % 100 == 0:
print(epoch, step, 'loss:', loss.numpy()) def train(): for epoch in range(30): train_epoch(epoch) if __name__ == '__main__':
train()
运行结果如下:
可以看到损失从初始的1.65降到0.25,在这里我们先只对mnist进行一个初步探索,测试一下模型的表现,后续会通过一些更好的优化方法来不断改良我们的模型。
mnist手写数字问题初体验的更多相关文章
- MindSpore手写数字识别初体验,深度学习也没那么神秘嘛
摘要:想了解深度学习却又无从下手,不如从手写数字识别模型训练开始吧! 深度学习作为机器学习分支之一,应用日益广泛.语音识别.自动机器翻译.即时视觉翻译.刷脸支付.人脸考勤--不知不觉,深度学习已经渗入 ...
- mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)
前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- 深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- 第三节,CNN案例-mnist手写数字识别
卷积:神经网络不再是对每个像素做处理,而是对一小块区域的处理,这种做法加强了图像信息的连续性,使得神经网络看到的是一个图像,而非一个点,同时也加深了神经网络对图像的理解,卷积神经网络有一个批量过滤器, ...
- mnist 手写数字识别
mnist 手写数字识别三大步骤 1.定义分类模型2.训练模型3.评价模型 import tensorflow as tfimport input_datamnist = input_data.rea ...
- 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型
持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...
- Tensorflow可视化MNIST手写数字训练
简述] 我们在学习编程语言时,往往第一个程序就是打印“Hello World”,那么对于人工智能学习系统平台来说,他的“Hello World”小程序就是MNIST手写数字训练了.MNIST是一个手写 ...
随机推荐
- MySQL学习笔记——基础与进阶篇
目录 一.###MySQL登录和退出 二.###MySQL常用命令 三.###MySQL语法规范 四.###基础查询 五.###条件查询 六.###排序查询 七.###常见函数的学习 八.###分组查 ...
- spyder学习记录---如何调试
调试技巧: 当我们想单步执行某段代码(但是不进入调用的函数)时,点击运行当前行. 当我们想进入某个函数内部进行调试,在函数调用处点击进入函数或方法内运行. 当我们不想看函数内部的运行过程时,点击跳出函 ...
- SpringBoot学习笔记 文件访问映射
通过SpringBoot可以把磁盘内所有的文件都访问到 有一张图片存放在 E://images/acti/123.jpg import org.springframework.context.anno ...
- .Net Core中IOC容器的使用
打代码之前先说一下几个概念,那就是什么是IOC.DI.DIP 虽然网上讲这些的已经有很多了,我这里还是要再赘述一下 IOC容器就是一个工厂,负责创建对象的 IOC控制反转:只是把上端对下端的依赖,换成 ...
- 3.【Spring Cloud Alibaba】声明式HTTP客户端-Feign
使用Feign实现远程HTTP调用 什么是Feign Feign是Netflix开源的声明式HTTP客户端 GitHub地址:https://github.com/openfeign/feign 实现 ...
- js能力测评——查找元素的位置
查找元素的位置 题目描述: 找出元素 item 在给定数组 arr 中的位置 输出描述: 如果数组中存在 item,则返回元素在数组中的位置,否则返回 -1 示例1 输入 [ 1, 2, 3, 4 ] ...
- pikachu-越权漏洞(Over Permission)
一.越权漏洞概述 1.1 概述 由于没有用户权限进行严格的判断,导致低权限的账户(例如普通用户)可以去完成高权限账户(例如管理员账户)范围内的操作. 1.2 越权漏洞的分类 (1)平行越权 ...
- android 华为、魅族手机无法打印 Log 日志的问题
最近使用魅族真机测试 App 时,发现 LogCat 不显示项目工程中通过Log.d()和Log.v()打印的 debug 和 verbose 级别的日志,甚是奇怪,通过 debug 模式断点调试也没 ...
- Python股票量化 选股操作不好用 完结
这几日,写了一些python的代码,打算来选择股票的, 那么这个思路和开始的一篇文章类似,你会不会被贾跃亭坑?,所以基本的思路也是这样的,举一个简单的例子,就是通过连续几年的ROE数据,和其他的一些财 ...
- Django3的安装以及web项目的创建
cmd 直接输入:pip install -i https://pypi.douban.com/simple django 2.检测是否安装成功:用到的命令:import django ,检测版本 ...