Keras入门(六)模型训练实时可视化
  在北京做某个项目的时候,客户要求能够对数据进行训练、预测,同时能导出模型,还有在页面上显示训练的进度。前面的几个要求都不难实现,但在页面上显示训练进度当时笔者并没有实现。
  本文将会分享如何在Keras中将模型训练的过程实时可视化。
  幸运的是,已经有人帮我们做好了这件事,这个项目名叫hualos,Github的访问网址为:https://github.com/fchollet/hualos, 作者为François Chollet和Eder Santana,前面的作者就是Keras的创造者,同时也是书籍《Deep Learning with Python》的作者。
  大神的工作大大地方便了我们的使用。调用该项目仅需要三行代码,示例如下:
from keras import callbacks
remote = callbacks.RemoteMonitor(root='http://localhost:9000')
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, validation_data=(X_test, Y_test), callbacks=[remote])
  该项目使用Python2写的,用到的第三方模块为Flask, gevent,其中Flask为网页端框架,gevent用于并发。用到的JavaScript的第三方模块为D3.js和C3.js。该项目使用起来非常方便,只需要切换至hualos项目所在文件夹,然后python api.py即可。
  下面将介绍其使用方法,我们的项目结构如下:

其中hualos可以从Github上直接clone下来,笔者对代码和HTML网页稍作了修改,便于自己使用。model_train.py为Keras模型训练脚本,iris.csv为著名的鸢尾花数据集。
  model_train.py中利用Keras搭建了简单的DNN模型对鸢尾花数据集进行训练及预测,该模型的介绍已经在文章Keras入门(一)搭建深度神经网络(DNN)解决多分类问题中给出,其完整代码如下:
# 导入模块
import numpy as np
import keras as K
import tensorflow as tf
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer
from keras import callbacks
# 读取CSV数据集,并拆分为训练集和测试集
# 该函数的传入参数为CSV_FILE_PATH: csv文件路径
def load_data(CSV_FILE_PATH):
    IRIS = pd.read_csv(CSV_FILE_PATH)
    target_var = 'class'  # 目标变量
    # 数据集的特征
    features = list(IRIS.columns)
    features.remove(target_var)
    # 目标变量的类别
    Class = IRIS[target_var].unique()
    # 目标变量的类别字典
    Class_dict = dict(zip(Class, range(len(Class))))
    # 增加一列target, 将目标变量进行编码
    IRIS['target'] = IRIS[target_var].apply(lambda x: Class_dict[x])
    # 对目标变量进行0-1编码(One-hot Encoding)
    lb = LabelBinarizer()
    lb.fit(list(Class_dict.values()))
    transformed_labels = lb.transform(IRIS['target'])
    y_bin_labels = []  # 对多分类进行0-1编码的变量
    for i in range(transformed_labels.shape[1]):
        y_bin_labels.append('y' + str(i))
        IRIS['y' + str(i)] = transformed_labels[:, i]
    # 将数据集分为训练集和测试集
    train_x, test_x, train_y, test_y = train_test_split(IRIS[features], IRIS[y_bin_labels], \
                                                        train_size=0.7, test_size=0.3, random_state=0)
    return train_x, test_x, train_y, test_y, Class_dict
if __name__ == '__main__':
    # 0. 开始
    print("\nIris dataset using Keras")
    np.random.seed(4)
    tf.set_random_seed(13)
    # 1. 读取CSV数据集
    print("Loading Iris data into memory")
    CSV_FILE_PATH = 'iris.csv'
    train_x, test_x, train_y, test_y, Class_dict = load_data(CSV_FILE_PATH)
    # 2. 定义模型
    init = K.initializers.glorot_uniform(seed=1)
    simple_adam = K.optimizers.Adam()
    model = K.models.Sequential()
    model.add(K.layers.Dense(units=5, input_dim=4, kernel_initializer=init, activation='relu'))
    model.add(K.layers.Dense(units=6, kernel_initializer=init, activation='relu'))
    model.add(K.layers.Dense(units=3, kernel_initializer=init, activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer=simple_adam, metrics=['accuracy'])
    # 3. 训练模型
    b_size = 1
    max_epochs = 100
    print("Starting training ")
    remote = callbacks.RemoteMonitor(root='http://localhost:9000')
    h = model.fit(train_x, train_y, validation_data=(test_x, test_y), batch_size=b_size, epochs=max_epochs,
                  shuffle=True, verbose=1, callbacks=[remote])
    print("Training finished \n")
    # 4. 评估模型
    eval = model.evaluate(test_x, test_y, verbose=0)
    print("Evaluation on test data: loss = %0.6f accuracy = %0.2f%% \n" \
          % (eval[0], eval[1] * 100) )
    # 5. 使用模型进行预测
    np.set_printoptions(precision=4)
    unknown = np.array([[6.1, 3.1, 5.1, 1.1]], dtype=np.float32)
    predicted = model.predict(unknown)
    print("Using model to predict species for features: ")
    print(unknown)
    print("\nPredicted softmax vector is: ")
    print(predicted)
    species_dict = {v:k for k,v in Class_dict.items()}
    print("\nPredicted species is: ")
    print(species_dict[np.argmax(predicted)])
  我们切换至hualos文件夹,运行python api.py,然后再用Python3运行model_train.py文件,在浏览器中输入网址:http://localhost:9000,即可看到在网页中显示的模型训练的实施可视化的结果,图像如下:

因为这里无法给出视频,需要观看视频的读者可以移步网址:https://mp.weixin.qq.com/s?__biz=MzU2NTYyMDk5MQ==&mid=2247484522&idx=1&sn=dab46a55945baf2411e30bd109cee76f&chksm=fcb9bdfacbce34ec02f3e958988e9b400676d29f88c1efad5ce01fb1f4ce2f5f96ccf0e4af66&token=1377830530&lang=zh_CN#rd 。
  本项目的Github地址为:https://github.com/percent4/keras_train_visualization 。
  本期分享到此结束,感谢大家阅读~
Keras入门(六)模型训练实时可视化的更多相关文章
- Keras入门(二)模型的保存、读取及加载
		
本文将会介绍如何利用Keras来实现模型的保存.读取以及加载. 本文使用的模型为解决IRIS数据集的多分类问题而设计的深度神经网络(DNN)模型,模型的结构示意图如下: 具体的模型参数可以参考文章 ...
 - 入门项目数字手写体识别:使用Keras完成CNN模型搭建(重要)
		
摘要: 本文是通过Keras实现深度学习入门项目——数字手写体识别,整个流程介绍比较详细,适合初学者上手实践. 对于图像分类任务而言,卷积神经网络(CNN)是目前最优的网络结构,没有之一.在面部识别. ...
 - Keras入门(四)之利用CNN模型轻松破解网站验证码
		
项目简介 在之前的文章keras入门(三)搭建CNN模型破解网站验证码中,笔者介绍介绍了如何用Keras来搭建CNN模型来破解网站的验证码,其中验证码含有字母和数字. 让我们一起回顾一下那篇文 ...
 - tensorflow笔记:模型的保存与训练过程可视化
		
tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代码分析 ...
 - Keras(六)Autoencoder 自编码 原理及实例 Save&reload 模型的保存和提取
		
Autoencoder 自编码 压缩与解压 原来有时神经网络要接受大量的输入信息, 比如输入信息是高清图片时, 输入信息量可能达到上千万, 让神经网络直接从上千万个信息源中学习是一件很吃力的工作. 所 ...
 - keras入门(三)搭建CNN模型破解网站验证码
		
项目介绍 在文章CNN大战验证码中,我们利用TensorFlow搭建了简单的CNN模型来破解某个网站的验证码.验证码如下: 在本文中,我们将会用Keras来搭建一个稍微复杂的CNN模型来破解以上的 ...
 - Keras实践:模型可视化
		
Keras实践:模型可视化 安装Graphviz 官方网址为:http://www.graphviz.org/.我使用的是mac系统,所以我分享一下我使用时遇到的坑. Mac安装时在终端中执行: br ...
 - tensorflow:模型的保存和训练过程可视化
		
在使用tf来训练模型的时候,难免会出现中断的情况.这时候自然就希望能够将辛辛苦苦得到的中间参数保留下来,不然下次又要重新开始. 保存模型的方法: #之前是各种构建模型graph的操作(矩阵相乘,sig ...
 - Keras入门(一)搭建深度神经网络(DNN)解决多分类问题
		
Keras介绍 Keras是一个开源的高层神经网络API,由纯Python编写而成,其后端可以基于Tensorflow.Theano.MXNet以及CNTK.Keras 为支持快速实验而生,能够把 ...
 
随机推荐
- wannafly 27 D 巧妙求取约数
			
链接:https://www.nowcoder.com/acm/contest/215/D来源:牛客网 题目描述 “我不知道你在说什么,因为我只是个pupil.”--绿魔法师 一个空的可重集合S. n ...
 - mysql 向字段添加数据或者删除数据
			
UPDATE table SET cids = CONCAT(cids , ',12') where id=id //向字段添加数据 //因为要用逗号分隔 所以在在前面加了一个逗号 UPDATE ta ...
 - Java入门 - 高级教程 - 03.泛型
			
原文地址:http://www.work100.net/training/java-generic.html 更多教程:光束云 - 免费课程 泛型 序号 文内章节 视频 1 概述 2 泛型方法 3 泛 ...
 - 「 神器 」在线PDF文件管理工具和图片编辑神器
			
每天进步一丢丢,连接梦与想 在线PDF文件管理工具 完全免费的PDF文件在线管理工具,其功能包括:合并PDF文件.拆分PDF文件.压缩PDF文件.Office文件转换为PDF文件.PDF文件转换为JP ...
 - 基于bootstrap和knockoutjs使用 mvc 查询
			
这是我摘抄的码 http://pan.baidu.com/s/1nvKWdsd
 - Scala 学习(7)之「trait (1) 」
			
作为接口使用 在 triat 中可以定义抽象方法,就与抽象类中的抽象方法一样,只要不给出方法的具体实现即可 类可以使用 extends 关键字继承 trait,注意,这里不是 implement,而是 ...
 - Redis(八):zset/zadd/zrange/zrembyscore 命令源码解析
			
前面几篇文章,我们完全领略了redis的string,hash,list,set数据类型的实现方法,相信对redis已经不再神秘. 本篇我们将介绍redis的最后一种数据类型: zset 的相关实现. ...
 - .NET Core微服务二:Ocelot API网关
			
.NET Core微服务一:Consul服务中心 .NET Core微服务二:Ocelot API网关 .NET Core微服务三:polly熔断与降级 本文的项目代码,在文章结尾处可以下载. 本文使 ...
 - 移动端ui框架
			
https://blog.csdn.net/Robin_star_/article/details/81810197
 - model form
			
ModelForm 能允许我们通过一个 Model 直接创建一个和该模型的字段一一对应的表单,大大方便了表单操作. 下面来看一个例子. 首先我们有这样的 model: from django.db i ...