上一篇我们提到了回归问题中的梯度下降算法,而且我们知道线性模型只能解决简单的线性回归问题,对于高维图片,线性模型不能完成这样复杂的分类任务。那么是不是线性模型在离散值预测或图像分类问题中就没有用武之地了呢?

本篇我们就套用regression中的部分机制来处理classification中的问题。

在这里首先介绍一下激活函数。

所谓激活函数,实际上就是引入非线性因子,将线性模型去线性化,增强模型的表达能力。ReLU激活函数是我要介绍的第一个激活函数,其定义式为φ(z)=max{0,z},图像表示如下:

简单的说relu就是一个取最大值的函数,在负区间取值为0,正区间取值不变,这种操作被称为单侧抑制(输出为0时代表神经元不会被激活)。单侧抑制的特点就是同一时间只会有一部分神经元被激活(结合函数图像可以看出),也就使得神经元具有了稀疏激活性。加入relu激活函数的神经元被称作整流线性单元,它与线性单元非常相似,唯一的区别就是在一半定义域上输出为0。整流线性单元易于优化,当其处于激活状态时(输出不为0),它的一阶导数能够保持一个较大值(等于1),并且处处一致,它的二阶导数几乎处处为0,这样的好处就是避免了梯度下降时的梯度消失问题(可参考前一篇回归问题的随笔)。

简单介绍了激活函数,那么是不是将激活函数引入我们的线性模型out=X@w+b就能使其解决复杂的图像分类问题了呢?

很显然不是的,虽然加了激活函数,但是我们可以看到模型变为out=relu(X@w+b) 依然还是太简单。那么怎么办呢?

我们可以联系一下零件加工的流程,从原料到成品,零件的加工经历了多个工序,期间每一道工序都是由前一道工序为基础,这时候,原料就相当于神经网络的输入,成品零件就相当于神经网络的输出,他们中间并不是也不能一步到位,而是经过若干“隐藏”的工序一步一步的生成产品。我们的模型同样可以借助于这种思想。即给数据处理多添加几道所谓的“工序”,我们称之为“隐藏层”,因为我们关心的只有模型的输入和输出,隐藏层的数据是我们不可见的(当然也可以在运行过程中打印出来方便调试),下面我们就利用这样的思想来解决mnist手写数字分类问题。

我们使用的是mnist数据集,也是深度学习的基础入门数据集。它一共有70k张不同的手写数字图片,其中60k用来训练模型,10k用来评估模型,且所有图片均为28*28的灰度图。我们首先设计一个稍微复杂的模型

h1=relu(X@w1+b1)

h2=relu(h1@w2+b2)

out=relu(h2@w3+b3)

其中X为输入,out为输出,h1、h2均为隐藏层,且除输入层外每一层的输入均是前一层的输出。

首先我们将输入的28*28*1的图片扁平化,即将每张图片转化成784维的向量(28*28=784),这样的好处是可以以矩阵的形式同时喂入多张图片(每一行向量为一张灰度图的信息),提高效率。对于输出out,我们令其输出一个10维的向量,代表10个数字的概率。模型可以用以下公式概括:

out=relu { relu { relu[ X@w1+b1 ] @w2+b2 }@w3+b3 }

pred=argmax(out)

loss=MSE(out,label) (均方误差损失函数即loss=∑(label-out)2

minimize loss→[w1',b1',w2',b2',w3',b3']

参数调整完成后,可以对新的输入x进行运算从而得到对应的输出

代码如下:

 import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, datasets # 屏蔽通知和警告信息,减少用处不大的问题输出
os.environ['TF_CPP_MIN_LOG_LEVEL']='' (x, y), (x_val, y_val) = datasets.mnist.load_data()
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
print(x.shape, y.shape)
train_dataset = tf.data.Dataset.from_tensor_slices((x, y))
train_dataset = train_dataset.batch(200) # 搭建网络结构
model = keras.Sequential([
layers.Dense(512, activation='relu'),
layers.Dense(256, activation='relu'),
layers.Dense(10)]) # 初始化优化器为梯度下降优化器
optimizer = optimizers.SGD(learning_rate=0.001) def train_epoch(epoch): # Step4.循环迭代
for step, (x, y) in enumerate(train_dataset): with tf.GradientTape() as tape:
# 将输入数据压平 [b, 28, 28] => [b, 784]
x = tf.reshape(x, (-1, 28*28))
# Step1. 计算输出
# 输入域数据经过神经网络降维 [b, 784] => [b, 10]
out = model(x)
# Step2. 计算损失
loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0] # Step3. 优化更新参数 w1, w2, w3, b1, b2, b3
grads = tape.gradient(loss, model.trainable_variables)
# w' = w - lr * grad
optimizer.apply_gradients(zip(grads, model.trainable_variables)) if step % 100 == 0:
print(epoch, step, 'loss:', loss.numpy()) def train(): for epoch in range(30): train_epoch(epoch) if __name__ == '__main__':
train()

运行结果如下:

   

可以看到损失从初始的1.65降到0.25,在这里我们先只对mnist进行一个初步探索,测试一下模型的表现,后续会通过一些更好的优化方法来不断改良我们的模型。

mnist手写数字问题初体验的更多相关文章

  1. MindSpore手写数字识别初体验,深度学习也没那么神秘嘛

    摘要:想了解深度学习却又无从下手,不如从手写数字识别模型训练开始吧! 深度学习作为机器学习分支之一,应用日益广泛.语音识别.自动机器翻译.即时视觉翻译.刷脸支付.人脸考勤--不知不觉,深度学习已经渗入 ...

  2. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

  3. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  4. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  5. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  6. 第三节,CNN案例-mnist手写数字识别

    卷积:神经网络不再是对每个像素做处理,而是对一小块区域的处理,这种做法加强了图像信息的连续性,使得神经网络看到的是一个图像,而非一个点,同时也加深了神经网络对图像的理解,卷积神经网络有一个批量过滤器, ...

  7. mnist 手写数字识别

    mnist 手写数字识别三大步骤 1.定义分类模型2.训练模型3.评价模型 import tensorflow as tfimport input_datamnist = input_data.rea ...

  8. 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

    持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...

  9. Tensorflow可视化MNIST手写数字训练

    简述] 我们在学习编程语言时,往往第一个程序就是打印“Hello World”,那么对于人工智能学习系统平台来说,他的“Hello World”小程序就是MNIST手写数字训练了.MNIST是一个手写 ...

随机推荐

  1. Android客户端OkHttp的使用以及tomcat服务器的解析客户端发过来的数据

    2020-02-15 21:25:42 ### android客户端客户向服务器发送json字符串或者以参数请求的方式发送数据 其中又分为post请求和get请求 1.activity.xml < ...

  2. 3、实战:OutOfMemoryError异常

    目的:第一,通过代码验证Java虚拟机规范中描述的各个运行时区域存储的内容:第二,工作中遇到实际的内存溢出异常时,能根据异常的信息快速判断是哪个区域的内存溢出,知道什么样的代码可能会导致这些区域内存溢 ...

  3. XXE漏洞复现步骤

    XXE漏洞复现步骤 0X00XXE注入定义 XXE注入,即XML External Entity,XML外部实体注入.通过 XML 实体,”SYSTEM”关键词导致 XML 解析器可以从本地文件或者远 ...

  4. 关于ThinkPHP在Nginx服务器下因PATH_INFO出错的解决方法

    参考:https://www.linuxidc.com/Linux/2011-11/46871.htm 这是一个ningx设置的问题,和TP无关.TP默认使用PATH_INFO来做CURD,而ngin ...

  5. sql 忘记密码 解决方法(window cmd命令解决)

    cd wamp\bin\mysql\mysql5.6.17\bin mysqld --skip-grant-tables

  6. Ikuai路由安装及简单配置 v1.0

    第一部分:创建虚拟机: 1.点击创建新的虚拟机   2.选择自定义模式创建(选择经典模式会更友好一些),然后点击下一步 3.下图内容不用管,直接点击下一步:   4.这里是选择安装系统路径.在这里我们 ...

  7. 在windows系统安装nginx

    1.下载Nginx,链接:http://nginx.org/en/download.html 2.解压放到自己的磁盘,双击击运行nginx.exe,会有命令框一闪而过,在浏览器上面输入localhos ...

  8. Django设置异步任务

    1.安装Django-celery 包:pip install django-celery==3.2.2 2.开启redis服务 需要使用redis做broker,所以在使用异步和定时任务时需要开启r ...

  9. C++ 解决列车重排问题

    问题节选自<<数据结构.算法与应用(C++语言描述)>>, 思路与代码为原创, 如有疏漏及问题欢迎指正 问题描述: 一辆列车有n节车厢, 车厢排列乱序(如: 284657139 ...

  10. JS对象的概念、声明方式等及js中的继承与封装

    对象的遍历 对象可以当做数组处理,使用for in var person={}; person.name="cyy"; person.age=25; person.infos=fu ...