该项目github地址

基于keras的中文语音识别

  • 该项目实现了GRU-CTC中文语音识别,所有代码都在gru_ctc_am.py中,包括:

    • 音频文件特征提取
    • 文本数据处理
    • 数据格式处理
    • 构建模型
    • 模型训练及解码
  • 之外还包括将aishell数据处理为thchs30数据格式,合并数据进行训练。代码及数据放在gen_aishell_data中。

默认数据集为thchs30,参考gen_aishell_data中的数据及代码,也可以使用aishell的数据进行训练。

音频文件特征提取

# -----------------------------------------------------------------------------------------------------
'''
&usage: [audio]对音频文件进行处理,包括生成总的文件列表、特征提取等
'''
# -----------------------------------------------------------------------------------------------------
# 生成音频列表
def genwavlist(wavpath):
wavfiles = {}
fileids = []
for (dirpath, dirnames, filenames) in os.walk(wavpath):
for filename in filenames:
if filename.endswith('.wav'):
filepath = os.sep.join([dirpath, filename])
fileid = filename.strip('.wav')
wavfiles[fileid] = filepath
fileids.append(fileid)
return wavfiles,fileids # 对音频文件提取mfcc特征
def compute_mfcc(file):
fs, audio = wav.read(file)
mfcc_feat = mfcc(audio, samplerate=fs, numcep=26)
mfcc_feat = mfcc_feat[::3]
mfcc_feat = np.transpose(mfcc_feat)
mfcc_feat = pad_sequences(mfcc_feat, maxlen=500, dtype='float', padding='post', truncating='post').T
return mfcc_feat

文本数据处理

# -----------------------------------------------------------------------------------------------------
'''
&usage: [text]对文本标注文件进行处理,包括生成拼音到数字的映射,以及将拼音标注转化为数字的标注转化
'''
# -----------------------------------------------------------------------------------------------------
# 利用训练数据生成词典
def gendict(textfile_path):
dicts = []
textfile = open(textfile_path,'r+')
for content in textfile.readlines():
content = content.strip('\n')
content = content.split(' ',1)[1]
content = content.split(' ')
dicts += (word for word in content)
counter = Counter(dicts)
words = sorted(counter)
wordsize = len(words)
word2num = dict(zip(words, range(wordsize)))
num2word = dict(zip(range(wordsize), words))
return word2num, num2word #1176个音素 # 文本转化为数字
def text2num(textfile_path):
lexcion,num2word = gendict(textfile_path)
word2num = lambda word:lexcion.get(word, 0)
textfile = open(textfile_path, 'r+')
content_dict = {}
for content in textfile.readlines():
content = content.strip('\n')
cont_id = content.split(' ',1)[0]
content = content.split(' ',1)[1]
content = content.split(' ')
content = list(map(word2num,content))
add_num = list(np.zeros(50-len(content)))
content = content + add_num
content_dict[cont_id] = content
return content_dict,lexcion

数据格式处理

# -----------------------------------------------------------------------------------------------------
'''
&usage: [data]数据生成器构造,用于训练的数据生成,包括输入特征及标注的生成,以及将数据转化为特定格式
'''
# -----------------------------------------------------------------------------------------------------
# 将数据格式整理为能够被网络所接受的格式,被data_generator调用
def get_batch(x, y, train=False, max_pred_len=50, input_length=500):
X = np.expand_dims(x, axis=3)
X = x # for model2
# labels = np.ones((y.shape[0], max_pred_len)) * -1 # 3 # , dtype=np.uint8
labels = y
input_length = np.ones([x.shape[0], 1]) * ( input_length - 2 )
# label_length = np.ones([y.shape[0], 1])
label_length = np.sum(labels > 0, axis=1)
label_length = np.expand_dims(label_length,1)
inputs = {'the_input': X,
'the_labels': labels,
'input_length': input_length,
'label_length': label_length,
}
outputs = {'ctc': np.zeros([x.shape[0]])} # dummy data for dummy loss function
return (inputs, outputs) # 数据生成器,默认音频为thchs30\train,默认标注为thchs30\train.syllable,被模型训练方法fit_generator调用
def data_generate(wavpath = 'E:\\Data\\data_thchs30\\train', textfile = 'E:\\Data\\thchs30\\train.syllable.txt', bath_size=4):
wavdict,fileids = genwavlist(wavpath)
#print(wavdict)
content_dict,lexcion = text2num(textfile)
genloop = len(fileids)//bath_size
print("all loop :", genloop)
while True:
feats = []
labels = []
# 随机选择某个音频文件作为训练数据
i = random.randint(0,genloop-1)
for x in range(bath_size):
num = i * bath_size + x
fileid = fileids[num]
# 提取音频文件的特征
mfcc_feat = compute_mfcc(wavdict[fileid])
feats.append(mfcc_feat)
# 提取标注对应的label值
labels.append(content_dict[fileid])
# 将数据格式修改为get_batch可以处理的格式
feats = np.array(feats)
labels = np.array(labels)
# 调用get_batch将数据处理为训练所需的格式
inputs, outputs = get_batch(feats, labels)
yield inputs, outputs

构建模型

# -----------------------------------------------------------------------------------------------------
'''
&usage: [net model]构件网络结构,用于最终的训练和识别
'''
# -----------------------------------------------------------------------------------------------------
# 被creatModel调用,用作ctc损失的计算
def ctc_lambda(args):
labels, y_pred, input_length, label_length = args
y_pred = y_pred[:, :, :]
return K.ctc_batch_cost(labels, y_pred, input_length, label_length) # 构建网络结构,用于模型的训练和识别
def creatModel():
input_data = Input(name='the_input', shape=(500, 26))
layer_h1 = Dense(512, activation="relu", use_bias=True, kernel_initializer='he_normal')(input_data)
#layer_h1 = Dropout(0.3)(layer_h1)
layer_h2 = Dense(512, activation="relu", use_bias=True, kernel_initializer='he_normal')(layer_h1)
layer_h3_1 = GRU(512, return_sequences=True, kernel_initializer='he_normal', dropout=0.3)(layer_h2)
layer_h3_2 = GRU(512, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', dropout=0.3)(layer_h2)
layer_h3 = add([layer_h3_1, layer_h3_2])
layer_h4 = Dense(512, activation="relu", use_bias=True, kernel_initializer='he_normal')(layer_h3)
#layer_h4 = Dropout(0.3)(layer_h4)
layer_h5 = Dense(1177, activation="relu", use_bias=True, kernel_initializer='he_normal')(layer_h4)
output = Activation('softmax', name='Activation0')(layer_h5)
model_data = Model(inputs=input_data, outputs=output)
#ctc
labels = Input(name='the_labels', shape=[50], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')
loss_out = Lambda(ctc_lambda, output_shape=(1,), name='ctc')([labels, output, input_length, label_length])
model = Model(inputs=[input_data, labels, input_length, label_length], outputs=loss_out)
model.summary()
ada_d = Adadelta(lr=0.01, rho=0.95, epsilon=1e-06)
#model=multi_gpu_model(model,gpus=2)
model.compile(loss={'ctc': lambda y_true, output: output}, optimizer=ada_d)
#test_func = K.function([input_data], [output])
print("model compiled successful!")
return model, model_data

模型训练及解码

# -----------------------------------------------------------------------------------------------------
'''
&usage: 模型的解码,用于将数字信息映射为拼音
'''
# -----------------------------------------------------------------------------------------------------
# 对model预测出的softmax的矩阵,使用ctc的准则解码,然后通过字典num2word转为文字
def decode_ctc(num_result, num2word):
result = num_result[:, :, :]
in_len = np.zeros((1), dtype = np.int32)
in_len[0] = 50;
r = K.ctc_decode(result, in_len, greedy = True, beam_width=1, top_paths=1)
r1 = K.get_value(r[0][0])
r1 = r1[0]
text = []
for i in r1:
text.append(num2word[i])
return r1, text # -----------------------------------------------------------------------------------------------------
'''
&usage: 模型的训练
'''
# -----------------------------------------------------------------------------------------------------
# 训练模型
def train():
# 准备训练所需数据
yielddatas = data_generate()
# 导入模型结构,训练模型,保存模型参数
model, model_data = creatModel()
model.fit_generator(yielddatas, steps_per_epoch=2000, epochs=1)
model.save_weights('model.mdl')
model_data.save_weights('model_data.mdl') # -----------------------------------------------------------------------------------------------------
'''
&usage: 模型的测试,看识别结果是否正确
'''
# -----------------------------------------------------------------------------------------------------
# 测试模型
def test():
# 准备测试数据,以及生成字典
word2num, num2word = gendict('E:\\Data\\thchs30\\train.syllable.txt')
yielddatas = data_generate(bath_size=1)
# 载入训练好的模型,并进行识别
model, model_data = creatModel()
model_data.load_weights('model_data.mdl')
result = model_data.predict_generator(yielddatas, steps=1)
# 将数字结果转化为文本结果
result, text = decode_ctc(result, num2word)
print('数字结果: ', result)
print('文本结果:', text)

aishell数据转化

将aishell中的汉字标注转化为拼音标注,利用该数据与thchs30数据训练同样的网络结构。

该模型作为一个练手小项目。

没有使用语言模型,直接简单建模。

我的github: https://github.com/audier

GRU-CTC中文语音识别的更多相关文章

  1. 基于深度学习的中文语音识别系统框架(pluse)

    目录 声学模型 GRU-CTC DFCNN DFSMN 语言模型 n-gram CBHG 数据集 本文搭建一个完整的中文语音识别系统,包括声学模型和语言模型,能够将输入的音频信号识别为汉字. 声学模型 ...

  2. python使用vosk进行中文语音识别

    操作系统:Windows10 Python版本:3.9.2 vosk是一个离线开源语音识别工具,它可以识别16种语言,包括中文. 这里记录下使用vosk进行中文识别的过程,以便后续查阅. vosk地址 ...

  3. pyttsx的中文语音识别问题及探究之路

    最近在学习pyttsx时,发现中文阅读一直都识别错误,从发音来看应该是字符编码问题,但搜索之后并未发现解决方案.自己一路摸索解决,虽说最终的原因非常可笑,大牛们可能也是一眼就能洞穿,但也值得记录一下. ...

  4. Unity中使用百度中文语音识别功能

    下面是API类 Asr.cs using System; using System.Collections; using System.Collections.Generic; using Unity ...

  5. 深度学习实战篇-基于RNN的中文分词探索

    深度学习实战篇-基于RNN的中文分词探索 近年来,深度学习在人工智能的多个领域取得了显著成绩.微软使用的152层深度神经网络在ImageNet的比赛上斩获多项第一,同时在图像识别中超过了人类的识别水平 ...

  6. [DeeplearningAI笔记]序列模型3.9-3.10语音辨识/CTC损失函数/触发字检测

    5.3序列模型与注意力机制 觉得有用的话,欢迎一起讨论相互学习~Follow Me 3.9语音辨识 Speech recognition 问题描述 对于音频片段(audio clip)x ,y生成文本 ...

  7. Python实现各类验证码识别

    项目地址: https://github.com/kerlomz/captcha_trainer 编译版下载地址: https://github.com/kerlomz/captcha_trainer ...

  8. TensorFlow练习13: 制作一个简单的聊天机器人

    现在很多卖货公司都使用聊天机器人充当客服人员,许多科技巨头也纷纷推出各自的聊天助手,如苹果Siri.Google Now.Amazon Alexa.微软小冰等等.前不久有一个视频比较了Google N ...

  9. linux install Openvino

    recommend centos7 github Openvino tooltiks 1. download openvino addational installation for ncs2 ncs ...

随机推荐

  1. C++笔记005:用面向过程和面向对象方法求解圆形面积

    原创笔记,转载请注明出处! 点击[关注],关注也是一种美德~ 结束了第一个hello world程序后,我们来用面向过程和面向对象两个方法来求解圆的面积这个问题,以能够更清晰的体会面向对象和面向过程. ...

  2. eclipse中误删tomcat后,文件都报错,恢复server时无法选择tomcat7.0解决办法

    创建Tomcat v7.0 Server 不能进行下一步. 解决方法: 1.退出 eclipse 2.到[工程目录下]/.metadata/.plugins/org.eclipse.core.runt ...

  3. 竞赛题解 - [CF 1080D]Olya and magical square

    Olya and magical square - 竞赛题解 借鉴了一下神犇tly的博客QwQ(还是打一下广告) 终于弄懂了 Codeforces 传送门 『题目』(直接上翻译了) 给一个边长为 \( ...

  4. SSM整合时初始化出现异常

    java.lang.NoClassDefFoundError: org/aspectj/weaver/reflect/ReflectionWorld$ReflectionWorldException  ...

  5. shell习题第1题:每日一文件

    [题目要求] 请按照这样的日期格式(xxxx-xx-xx)每日生成一个文件 例如生成的文件为2019-04-25.log,并且把磁盘使用情况写入到这个文件中 不用考虑cron,仅仅写脚本即可 [核心要 ...

  6. day 18 类与类之间的关系

    类与类之间的关系     在我们的世界中事物和事物之间总会有一些联系.    在面向对象中,类和类之间也可以产生相关的关系 1.依赖关系     执行某个动作的时候. 需要xxx来帮助你完成这个操作, ...

  7. centos7关闭图形界面启动系统

    手动敲那么多不累么?仅2条命令(好) 1,命令模式systemctl set-default multi-user.target 2,图形模式systemctl set-default graphic ...

  8. ggnetwork

    ggnetwork ggnetwork PeRl 简介 ggnetwork是根据ggplot2的语法,开发的用于网络图可视化的包.虽然igraph是优秀的network处理包,但是在可视化方面依然是弱 ...

  9. Python 入门(一)

    IDE 个人推荐  Pycharm : 比较好用,虽然没有中文,但是练练英语也不错,毕竟大同小异 基础语法 行与缩进 python最具特色的就是使用缩进来表示代码块,不需要使用大括号 {} . 缩进的 ...

  10. 使用JAX-WS(JWS)发布WebService(一)

    JAX-WS概述: 通过Main发布一个简单WebService: JAX-WS(Java API for XML Web Services)规范是一组XML web services的JAVA AP ...