作者|Praneet Bomma
编译|VK
来源|https://towardsdatascience.com/visualising-lstm-activations-in-keras-b50206da96ff

你是否想知道LSTM层学到了什么?有没有想过是否有可能看到每个单元如何对最终输出做出贡献。我很好奇,试图将其可视化。在满足我好奇的神经元的同时,我偶然发现了Andrej Karpathy的博客,名为“循环神经网络的不合理有效性”。如果你想获得更深入的解释,建议你浏览他的博客。

在本文中,我们不仅将在Keras中构建文本生成模型,还将可视化生成文本时某些单元格正在查看的内容。就像CNN一样,它学习图像的一般特征,例如水平和垂直边缘,线条,斑块等。类似,在“文本生成”中,LSTM则学习特征(例如空格,大写字母,标点符号等)。 LSTM层学习每个单元中的特征。

我们将使用Lewis Carroll的《爱丽丝梦游仙境》一书作为训练数据。该模型体系结构将是一个简单的模型体系结构,在其末尾具有两个LSTM和Dropout层以及一个Dense层。

你可以在此处下载训练数据和训练好的模型权重

https://github.com/Praneet9/Visualising-LSTM-Activations

这就是我们激活单个单元格的样子。

让我们深入研究代码。

步骤1:导入所需的库

import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Dropout, CuDNNLSTM
from keras.callbacks import ModelCheckpoint
from keras.utils import np_utils
import re # 可视化库
from IPython.display import HTML as html_print
from IPython.display import display
import keras.backend as K

注意:我使用CuDNN-LSTM代替LSTM,因为它的训练速度提高了15倍。CuDNN-LSTM由CuDNN支持,只能在GPU上运行。

步骤2:读取训练资料并进行预处理

使用正则表达式,我们将使用单个空格删除多个空格。该char_to_int和int_to_char只是数字字符和字符数的映射。

# 读取数据
filename = "wonderland.txt"
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = re.sub(r'[ ]+', ' ', raw_text) # 创建字符到整数的映射
chars = sorted(list(set(raw_text)))
char_to_int = dict((c, i) for i, c in enumerate(chars))
int_to_char = dict((i, c) for i, c in enumerate(chars)) n_chars = len(raw_text)
n_vocab = len(chars)

步骤3:准备训练资料

准备我们的数据很重要,每个输入都是一个字符序列,而输出是后面的字符。

seq_length = 100
dataX = []
dataY = [] for i in range(0, n_chars - seq_length, 1):
seq_in = raw_text[i:i + seq_length]
seq_out = raw_text[i + seq_length]
dataX.append([char_to_int[char] for char in seq_in])
dataY.append(char_to_int[seq_out]) n_patterns = len(dataX)
print("Total Patterns: ", n_patterns) X = np.reshape(dataX, (n_patterns, seq_length, 1)) # 标准化
X = X / float(n_vocab) # one-hot编码
y = np_utils.to_categorical(dataY) filepath="weights-improvement-{epoch:02d}-{loss:.4f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')
callbacks_list = [checkpoint]

步骤4:构建模型架构

# 定义 LSTM 模型
model = Sequential() model.add(CuDNNLSTM(512, input_shape=(X.shape[1], X.shape[2]), return_sequences=True))
model.add(Dropout(0.5)) model.add(CuDNNLSTM(512))
model.add(Dropout(0.5)) model.add(Dense(y.shape[1], activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) model.summary()

步骤5:训练模型

model.fit(X, y, epochs=300, batch_size=2048, callbacks=callbacks_list)

使用Google Colab训练模型时,我无法一口气训练模型300个epoch。我必须通过缩减权重数量并再次加载它们来进行3天的训练,每天100个epoch

如果你拥有强大的GPU,则可以一次性训练300个epoch的模型。如果你不这样做,我建议你使用Colab,因为它是免费的。

你可以使用下面的代码加载模型,并从最后一点开始训练。

from keras.models import load_model

filename = "weights-improvement-303-0.2749_wonderland.hdf5"
model = load_model(filename)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) # 用相同的数据训练模型
model.fit(X, y, epochs=300, batch_size=2048, callbacks=callbacks_list)

现在到文章最重要的部分-可视化LSTM激活。我们将需要一些功能来实际使这些可视化变得可理解。

步骤6:后端功能以获取中间层输出

正如我们在上面的步骤4中看到的那样,第一层和第三层是LSTM层。我们的目标是可视化第二LSTM层(即整个体系结构中的第三层)的输出。

Keras Backend帮助我们创建一个函数,该函数接受输入并为我们提供来自中间层的输出。我们可以使用它来创建我们自己的管道功能。这里attn_func将返回大小为512的隐藏状态向量。这将是具有512个单位的LSTM层的激活。我们可以可视化这些单元激活中的每一个,以了解它们试图解释的内容。为此,我们必须将其转换为可以表示其重要性的范围的数值。

#第三层是输出形状为LSTM层(Batch_Size, 512)
lstm = model.layers[2] #从中间层获取输出以可视化激活
attn_func = K.function(inputs = [model.get_input_at(0), K.learning_phase()],
outputs = [lstm.output]
)

步骤7:辅助功能

这些助手功能将帮助我们使用每个激活值来可视化字符序列。我们正在通过sigmoid功能传递激活,因为我们需要一个可以表示其对整个输出重要性的规模值。get_clr功能有助于获得给定值的适当颜色。

#获取html元素
def cstr(s, color='black'):
if s == ' ':
return "<text style=color:#000;padding-left:10px;background-color:{}> </text>".format(color, s)
else:
return "<text style=color:#000;background-color:{}>{} </text>".format(color, s) # 输出html
def print_color(t):
display(html_print(''.join([cstr(ti, color=ci) for ti,ci in t]))) #选择合适的颜色
def get_clr(value):
colors = ['#85c2e1', '#89c4e2', '#95cae5', '#99cce6', '#a1d0e8'
'#b2d9ec', '#baddee', '#c2e1f0', '#eff7fb', '#f9e8e8',
'#f9e8e8', '#f9d4d4', '#f9bdbd', '#f8a8a8', '#f68f8f',
'#f47676', '#f45f5f', '#f34343', '#f33b3b', '#f42e2e']
value = int((value * 100) / 5)
return colors[value] # sigmoid函数
def sigmoid(x):
z = 1/(1 + np.exp(-x))
return z

下图显示了如何用各自的颜色表示每个值。

步骤8:获取预测

get_predictions函数随机选择一个输入种子序列,并获得该种子序列的预测序列。visualize函数将预测序列,序列中每个字符的S形值以及要可视化的单元格编号作为输入。根据输出的值,将以适当的背景色打印字符。

将Sigmoid应用于图层输出后,值在0到1的范围内。数字越接近1,它的重要性就越高。如果该数字接近于0,则意味着不会以任何主要方式对最终预测做出贡献。这些单元格的重要性由颜色表示,其中蓝色表示较低的重要性,红色表示较高的重要性。

def visualize(output_values, result_list, cell_no):
print("\nCell Number:", cell_no, "\n")
text_colours = []
for i in range(len(output_values)):
text = (result_list[i], get_clr(output_values[i][cell_no]))
text_colours.append(text)
print_color(text_colours) # 从随机序列中获得预测
def get_predictions(data):
start = np.random.randint(0, len(data)-1)
pattern = data[start]
result_list, output_values = [], []
print("Seed:")
print("\"" + ''.join([int_to_char[value] for value in pattern]) + "\"")
print("\nGenerated:") for i in range(1000):
#为预测下一个字符而重塑输入数组
x = np.reshape(pattern, (1, len(pattern), 1))
x = x / float(n_vocab) # 预测
prediction = model.predict(x, verbose=0) # LSTM激活函数
output = attn_func([x])[0][0]
output = sigmoid(output)
output_values.append(output) # 预测字符
index = np.argmax(prediction)
result = int_to_char[index] # 为下一个字符准备输入
seq_in = [int_to_char[value] for value in pattern]
pattern.append(index)
pattern = pattern[1:len(pattern)] # 保存生成的字符
result_list.append(result)
return output_values, result_list

步骤9:可视化激活

超过90%的单元未显示任何可理解的模式。我手动可视化了所有512个单元,并注意到其中的三个(189、435、463)显示了一些可以理解的模式。


output_values, result_list = get_predictions(dataX) for cell_no in [189, 435, 463]:
visualize(output_values, result_list, cell_no)

单元格189将激活引号内的文本,如下所示。这表示单元格在预测时要查找的内容。如下所示,这个单元格对引号之间的文本贡献很大。

引用句中的几个单词后激活了单元格435。

对于每个单词中的第一个字符,将激活单元格463。

通过更多的训练或更多的数据可以进一步改善结果。这恰恰证明了深度学习毕竟不是一个完整的黑匣子。

你可以在我的Github个人资料中得到整个代码。

https://github.com/Praneet9/Visualising-LSTM-Activations

欢迎关注磐创博客资源汇总站:
http://docs.panchuang.net/

欢迎关注PyTorch官方中文教程站:
http://pytorch.panchuang.net/

OpenCV中文官方文档:
http://woshicver.com/

在Keras中可视化LSTM的更多相关文章

  1. Keras中使用LSTM层时设置的units参数是什么

    https://www.zhihu.com/question/64470274 http://colah.github.io/posts/2015-08-Understanding-LSTMs/ ht ...

  2. Python中利用LSTM模型进行时间序列预测分析

    时间序列模型 时间序列预测分析就是利用过去一段时间内某事件时间的特征来预测未来一段时间内该事件的特征.这是一类相对比较复杂的预测建模问题,和回归分析模型的预测不同,时间序列模型是依赖于事件发生的先后顺 ...

  3. 通过keras例子理解LSTM 循环神经网络(RNN)

    博文的翻译和实践: Understanding Stateful LSTM Recurrent Neural Networks in Python with Keras 正文 一个强大而流行的循环神经 ...

  4. Python机器学习笔记:深入学习Keras中Sequential模型及方法

    Sequential 序贯模型 序贯模型是函数式模型的简略版,为最简单的线性.从头到尾的结构顺序,不分叉,是多个网络层的线性堆叠. Keras实现了很多层,包括core核心层,Convolution卷 ...

  5. keras 文本分类 LSTM

    首先,对需要导入的库进行导入,读入数据后,用jieba来进行中文分词 # encoding: utf-8 #载入接下来分析用的库 import pandas as pd import numpy as ...

  6. keras实例学习-双向LSTM进行imdb情感分类

    源码:https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py 及keras中文文档 1. ...

  7. keras: 在构建LSTM模型时,使用变长序列的方法

    众所周知,LSTM的一大优势就是其能够处理变长序列.而在使用keras搭建模型时,如果直接使用LSTM层作为网络输入的第一层,需要指定输入的大小.如果需要使用变长序列,那么,只需要在LSTM层前加一个 ...

  8. keras中seq2seq实现

    这里只是简单的一个例子 输入序列 目标序列 [13, 28, 18, 7, 9, 5] [18, 28, 13] [29, 44, 38, 15, 26, 22] [38, 44, 29] [27, ...

  9. keras中VGG19预训练模型的使用

    keras提供了VGG19在ImageNet上的预训练权重模型文件,其他可用的模型还有VGG16.Xception.ResNet50.InceptionV3 4个. VGG19在keras中的定义: ...

随机推荐

  1. Microsoft Translator:消除面对面交流的语言障碍

    ​ Translator:消除面对面交流的语言障碍" title="Microsoft Translator:消除面对面交流的语言障碍"> ​ James Simm ...

  2. 【原创】面试官问我G1回收器怎么知道你是什么时候的垃圾?

    这是why技术的第36篇原创文章 上面的图片是我上周末在家拍的.以后的文章里面我的第一张配图都用自己随手拍下的照片吧.分享生活,分享技术,哈哈. 阳台上的花开了,成都的春天快来了,疫情也应该快要过去了 ...

  3. WebAPI-处理架构

    带着问题去思考,大家好! 问题1:HTTP请求和返回相应的HTTP响应信息之间发生了什么? 1:首先是最底层,托管层,位于WebAPI和底层HTTP栈之间 2:其次是 消息处理程序管道层,这里比如日志 ...

  4. Cisco模拟器的基本使用

    获取帮助查找命令 只需输入一个'?'便可得到详细的帮助信息,如果想获取c开头的命名,那么直接输入'c?'即可. 在各个模式下切换的方法 给如图所示路由器接口配置IP地址 第一步:安装HWIC-2T(串 ...

  5. 微信Android自动播放视频(可交互,设置层级,无控制条,非X5)ffmpeg,jsmpeg.js,.ts视频

    原料: ffmpeg : http://ffmpeg.zeranoe.com/builds/  win64 https://evermeet.cx/ffmpeg/   mac OS X 64 jsmp ...

  6. HTML、CSS笔记

    盒模型 在CSS中,使用标准盒模型描述这些矩形盒子中的每一个.这个模型描述了元素所占空间的内容.每个盒子有四个边:外边距边, 边框边, 内填充边 与 内容边. 在标准模式下,一个块的总宽度= widt ...

  7. Linux学习4-部署LAMP项目

    前言 LAMP——linux  Apache  Mysql  PHP 今天我们来学习如何在Linux部署Discuz论坛 准备工作 1.一台linux服务器,没有购买服务器的小伙伴也可以使用虚拟机,操 ...

  8. openwrt 外挂usb 网卡 RTL8188CU 及添加 RT5572 kernel支持

    RT5572 原来叫 Ralink雷凌 现在被 MTK 收购了,淘宝上买的很便宜50块邮,2.4 5G 双频.在 win10 上插了试试,果然是支持 5G.这上面写着 飞荣 是什么牌子,有知道的和我说 ...

  9. 使用ajax提交登录

    引入jquery-1.10.2.js或者jquery-1.10.2.min.js 页面 <h3>后台系统登录</h3> <form name="MyForm&q ...

  10. zabbix图表出现中文乱码

    搭建完成Zabbix监控服务器之后,切换到中文语言,图表展示出现乱码,如图所示 按照网上流传的上传windows下的字体的方法,还是不行,最后发现是PHP编译时的问题: php在编译时开启了-enab ...