最近几年,随着AlphaGo的崛起,深度学习开始出现在各个领域,比如无人车、图像识别、物体检测、推荐系统、语音识别、聊天问答等等。因此具备深度学习的知识并能应用实践,已经成为很多开发者包括博主本人的下一个目标了。

目前最流行的框架莫过于Tensorflow了,但是只要接触过它的人,就知道它使用起来是多么让人恐惧。Tensorflow对我们来说,仿佛是一门高深的Deep Learning学习语言,需要具备很深的机器学习和深度学习功底,才能玩得转。

Keras正是在这种背景下应运而生的,它是一个对开发者很友好的框架,底层可以基于TensorFlow和Theano,使用起来仿佛是在搭积木。只要不停的添加已有的“层”,就可以实现各种复杂的深度网络模型。

因此,开发者需要熟悉的不过是两点:如何搭建积木?都有什么积木可以用?

安装

安装的步骤直接按照官方文档来就行了,我笔记本的环境已经杂乱不堪,没有办法一步一步记录安装配置了。主要是安装python3.6,然后各种pip install就行了。

参考文档:http://keras-cn.readthedocs.io/en/latest/for_beginners/keras_linux/

基础概念

在使用Keras前,首先要了解Keras里面关于模型如何创建。在上面可爱的小盆友的图片中,想要把积木罗列在一起,需要一个中心的木棍。那么Sequential就可以看做是这个木棍。

剩下的工作就是add不同的层就行了:

model = Sequential()
model.add(Dense(32, input_shape=(784,)))
model.add(Activation('relu'))

建立好model后,相当于我们定义好了逻辑模型。此时就需要编译模型,生成对应的代码:

model.compile(optimizer='rmsprop',
loss='binary_crossentropy',
metrics=['accuracy'])

其中optimizer是参数优化的方法,loss是损失函数的定义,metrics是衡量模型效果的指标。

最后再灌入数据进行训练即可:

model.fit(data, labels, epochs=10, batch_size=32)

完整的例子

代码已经上传到github:https://github.com/xinghalo/keras-examples/blob/master/keras-cn/mnist/mnist_mlp.py

很多人hello world跑不通是因为网络问题,不能下载到对应的数据集。我这里把数据集也上传到对应的目录下了,修改对应的path即可。

from __future__ import print_function

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import RMSprop batch_size = 128
num_classes = 10
epochs = 3 # the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data("/Users/xingoo/PycharmProjects/keras-examples/keras-cn/mnist/mnist.npz") x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples') # convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes) model = Sequential()
model.add(Dense(512, activation='relu', input_shape=(784,)))
model.add(Dropout(0.2))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(num_classes, activation='softmax')) model.summary() model.compile(loss='categorical_crossentropy',
optimizer=RMSprop(),
metrics=['accuracy']) history = model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
verbose=1,
validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

运行效果

Using TensorFlow backend.
/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6
return f(*args, **kwds)
60000 train samples
10000 test samples
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 512) 401920
_________________________________________________________________
dropout_1 (Dropout) (None, 512) 0
_________________________________________________________________
dense_2 (Dense) (None, 512) 262656
_________________________________________________________________
dropout_2 (Dropout) (None, 512) 0
_________________________________________________________________
dense_3 (Dense) (None, 10) 5130
=================================================================
Total params: 669,706
Trainable params: 669,706
Non-trainable params: 0
_________________________________________________________________
2018-05-25 17:15:22.294036: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
Train on 60000 samples, validate on 10000 samples
Epoch 1/3
128/60000 [..............................] - ETA: 166s - loss: 2.4237 - acc: 0.0703
640/60000 [..............................] - ETA: 38s - loss: 1.8688 - acc: 0.3812
1152/60000 [..............................] - ETA: 23s - loss: 1.5497 - acc: 0.5087
1664/60000 [..............................] - ETA: 18s - loss: 1.3466 - acc: 0.5655
2176/60000 [>.............................] - ETA: 15s - loss: 1.1902 - acc: 0.6167
2688/60000 [>.............................] - ETA: 13s - loss: 1.0736 - acc: 0.6536
3200/60000 [>.............................] - ETA: 12s - loss: 0.9968 - acc: 0.6778
3712/60000 [>.............................] - ETA: 11s - loss: 0.9323 - acc: 0.7002
4096/60000 [=>............................] - ETA: 11s - loss: 0.8971 - acc: 0.7109
...
51328/60000 [========================>.....] - ETA: 0s - loss: 0.0733 - acc: 0.9775
51840/60000 [========================>.....] - ETA: 0s - loss: 0.0733 - acc: 0.9774
52352/60000 [=========================>....] - ETA: 0s - loss: 0.0735 - acc: 0.9774
52864/60000 [=========================>....] - ETA: 0s - loss: 0.0733 - acc: 0.9775
53376/60000 [=========================>....] - ETA: 0s - loss: 0.0736 - acc: 0.9774
53888/60000 [=========================>....] - ETA: 0s - loss: 0.0734 - acc: 0.9775
54400/60000 [==========================>...] - ETA: 0s - loss: 0.0736 - acc: 0.9774
54912/60000 [==========================>...] - ETA: 0s - loss: 0.0740 - acc: 0.9773
55424/60000 [==========================>...] - ETA: 0s - loss: 0.0744 - acc: 0.9772
55936/60000 [==========================>...] - ETA: 0s - loss: 0.0746 - acc: 0.9771
56448/60000 [===========================>..] - ETA: 0s - loss: 0.0749 - acc: 0.9771
56960/60000 [===========================>..] - ETA: 0s - loss: 0.0751 - acc: 0.9772
57472/60000 [===========================>..] - ETA: 0s - loss: 0.0756 - acc: 0.9772
57984/60000 [===========================>..] - ETA: 0s - loss: 0.0754 - acc: 0.9772
58496/60000 [============================>.] - ETA: 0s - loss: 0.0750 - acc: 0.9773
59008/60000 [============================>.] - ETA: 0s - loss: 0.0750 - acc: 0.9773
59520/60000 [============================>.] - ETA: 0s - loss: 0.0749 - acc: 0.9774
60000/60000 [==============================] - 7s - loss: 0.0749 - acc: 0.9774 - val_loss: 0.0819 - val_acc: 0.9768
Test loss: 0.0819479118524
Test accuracy: 0.9768

参考

  1. Keras中文官方文档:http://keras-cn.readthedocs.io/en/latest/getting_started/sequential_model/
  2. Keras github examples:https://github.com/keras-team/keras/blob/master/examples/mnist_mlp.py
  3. 神经网络(一):概念:https://blog.csdn.net/xierhacker/article/details/51771428
  4. 神经网络(二):感知机:https://blog.csdn.net/xierhacker/article/details/51816484
  5. 深度学习笔记二:多层感知机(MLP)与神经网络结构:https://blog.csdn.net/xierhacker/article/details/53282038
  6. 多层感知机:Multi-Layer Perceptron:https://blog.csdn.net/xholes/article/details/78461164

Keras学习笔记——Hello Keras的更多相关文章

  1. 官网实例详解-目录和实例简介-keras学习笔记四

    官网实例详解-目录和实例简介-keras学习笔记四 2018-06-11 10:36:18 wyx100 阅读数 4193更多 分类专栏: 人工智能 python 深度学习 keras   版权声明: ...

  2. Keras学习笔记(完结)

    使用Keras中文文档学习 基本概念 Keras的核心数据结构是模型,也就是一种组织网络层的方式,最主要的是序贯模型(Sequential).创建好一个模型后就可以用add()向里面添加层.模型搭建完 ...

  3. Keras学习笔记

    Keras基于Tensorflow和Theano.作为一个更高级的框架,用其编写网络更加方便.具体流程为根据设想的网络结构,使用函数式模型API逐层构建网络即可,每一层的结构都是一个函数,上一层的输出 ...

  4. Keras学习笔记1--基本入门

    """ 1.30s上手keras """ #keras的核心数据结构是“模型”,模型是一种组织网络层的方式,keras 的主要模型是Sequ ...

  5. keras 学习笔记:从头开始构建网络处理 mnist

    全文参考 < 基于 python 的深度学习实战> import numpy as np from keras.datasets import mnist from keras.model ...

  6. Keras学习笔记二:保存本地模型和调用本地模型

    使用深度学习模型时当然希望可以保存下训练好的模型,需要的时候直接调用,不再重新训练 一.保存模型到本地 以mnist数据集下的AutoEncoder 去噪为例.添加: file_path=" ...

  7. Keras学习笔记。

    1. keras.layers.Dense (Fully Connected Neural NetWork),所实现的运算是output = activation(dot(input, kernel) ...

  8. keras学习笔记2

    1.keras的sequential模型需要知道输入数据的shape,因此,sequential的第一层需要接受一个关于输入数据shape的参数,后面的各个层则可以自动的推导出中间数据的shape,因 ...

  9. keras 学习笔记(二) ——— data_generator

    data_generator 每次输出一个batch,基于keras.utils.Sequence Base object for fitting to a sequence of data, suc ...

随机推荐

  1. 使用Narrator读取RichTextBlock内容

    先测试基本的RichTextBlock,看能否读取. 测试RichTextBlock中哪些子控件是可以被读取的. 结论:只有Hyperlink能Tab到,能被读取. 问题:RichTextBlock在 ...

  2. 树状数组训练题2:SuperBrother打鼹鼠(vijos1512)

    先给题目链接:打鼹鼠 这道题怎么写? 很明显是树状数组. 而且,很明显是二维树状数组. 如果你没学过二维的树状数组,那么戳开这里:二维树状数组 看完以后,你就会知道怎么做了. 没有什么好解释的,几乎就 ...

  3. mybatis拦截器案例之获取结果集总条数

    最近做的项目前端是外包出去的,所以在做查询分页的时候比较麻烦 我们需要先吧结果集的条数返回给前端,然后由前端根据页面情况(当前页码,每页显示条数)将所需参数传到后端. 由于在项目搭建的时候,是没有考虑 ...

  4. R入门(一)

    简单的算术操作和向量运算 向量赋值:函数c( ),参数可以是一个或多个数,也可以是向量 赋值符号‘<-’ 向量运算:exp(),log(),sin(),tan(),sqrt(),max(),mi ...

  5. hibernate添加数据报错:Could not execute JDBC batch update

    报错如下图所示: 报错原因:在配置文件或注解里设置了字段关联,但数据却没有关联. 解决方法:我的错误是向一个多对多的关联表里插入数据,由于表中一个字段的数据是从另一张表里get到的,通过调试发现,从以 ...

  6. 实习番外篇:解决C语言使用Makefile无法实现更好的持续集成问题

    工作中遇见的一个问题,提供项目源代码的情况下,希望对项目进行持续集成,达到一个C项目增量编译的效果.原本第一天是想通过模拟Makefile执行步骤来实现整个过程的,但是事实上发现整个Makefile显 ...

  7. MFCC

    在语音识别研究领域,音频特征的选择至关重要.在这里介绍一种非常成功的音频特征——Mel Frequency Cepstrum Coefficient(MFCC),中文名字为梅尔频率倒谱系数.MFCC特 ...

  8. 源项目 -> fork -> 本地 (如何把源项目的代码合并到本地然后推送给fork)

    git remote -v git remote add 别名 地址 git fetch 别名 git merge  别名/分支 第一步:命令行进入到本地.git 所在的目录,查看remote 信息 ...

  9. Raft协议学习笔记

    目录 目录 1 1. 前言 1 2. 名词 1 3. 什么是分布式一致性? 3 4. Raft选举 3 4.1. 什么是Leader选举? 3 4.2. 选举的实现 4 4.3. Term和Lease ...

  10. Visual C++实现局域网IP多播

    //////////////////////////////////////////////////////////////////////////////////////////////////// ...