CTCLoss如何使用

什么是CTC

CTC全称为Connectionist Temporal Classification,中文翻译不好类似“联结主义按时间分类”。

CTCLoss是一类损失函数,用于计算模型输出\(y\)和标签\(label\)的损失。

\[loss=CTCLoss(y,label)
\]

神经网络在训练过程中,是让\(loss\)减少的过程。常用于图片文字识别OCR和语音识别项目,因为CTCLoss计算过程中不需要\(y\)和\(label\)对齐,这样做的好处就是大幅的减轻了数据对齐标注的工作量,极大的提高了效率。

架构介绍

本文主要是介绍CTCLoss,这里介绍模型架构是为了更好的理解CTCLoss函数在整体的做用。现有一段原始数据,它可以是一张带文字的图片或一段说话的音频。



如图所示原始的声音通过DFT(离散傅立叶变化)得到一张具有时频特性的特征图,将特征图通过网络\(\mathcal{N}_w\)后输出结果\(y\)(\(y\in\mathbb{R}^{K \times T}\),\(K\)维是在每一时间点预测的词的概率,\(T\)是时间维度)。

一个简单的例子

现在有一段语音,是一个人在拼写英文单词“CAT”,语音内容是“C”、“A”、“T”这三个字母。这个人读完这三个字母用了5s的时间。我们想通过语音识别这三个字母。

首先我们需要一个26个字母的词表,我们用序号1-26,分别来表示字母A-Z这26个字母,我们用序号0表示blankblank是用来区分那些不属于这26字母的部分。然后是假设这个模型每秒会给出一个识别字母表的概率分布,

音频持续了5s,因此有5列这样的概率分布。

\[y\in\mathbb{R}^{K \times T}\ \ \ \ \ (K=27,\ T=5)
\]

下表就是\(y\)的概率分布,每一列是当前时刻输入数据所对应的概率分布。

表1 每个时刻输出字符的概率分布
\(y_t^k\) t=1 t=2 t=3 t=4 t=5
k=0 (-) 0.031953 0.044296 0.038297 0.038320 0.027464
k=1 (A) 0.026221 0.030363 0.031878 0.027295 0.029824
k=2 (B) 0.040555 0.025838 0.023487 0.041529 0.028116
k=3 (C) 0.029333 0.045889 0.031872 0.023184 0.029338
k=4 (D) 0.023595 0.053792 0.022519 0.039882 0.025342
k=5 (E) 0.048014 0.028887 0.020526 0.041302 0.045833
k=6 (F) 0.028770 0.040735 0.045488 0.044244 0.032191
k=7 (G) 0.035127 0.032281 0.034032 0.051973 0.041613
k=8 (H) 0.044897 0.047910 0.049222 0.056956 0.048665
k=9 (I) 0.032323 0.044911 0.038994 0.046017 0.040002
k=10 (J) 0.047130 0.024608 0.034797 0.038146 0.041496
k=11 (K) 0.033491 0.049294 0.043909 0.053962 0.037901
k=12 (L) 0.044700 0.056019 0.046794 0.038094 0.027488
k=13 (M) 0.045632 0.034822 0.052229 0.021692 0.039653
k=14 (N) 0.035123 0.050406 0.019438 0.024067 0.056986
k=15 (O) 0.023015 0.037482 0.046163 0.050536 0.058191
k=16 (P) 0.031419 0.024302 0.035848 0.034614 0.031820
k=17 (Q) 0.034497 0.025424 0.052284 0.049642 0.029912
k=18 (R) 0.029572 0.031274 0.032931 0.026295 0.042725
k=19 (S) 0.027484 0.044015 0.031383 0.037050 0.046068
k=20 (T) 0.051330 0.047532 0.043297 0.040039 0.036849
k=21 (U) 0.034691 0.045869 0.024400 0.022020 0.029838
k=22 (V) 0.054835 0.028627 0.031971 0.039436 0.062661
k=23 (W) 0.033373 0.035513 0.047827 0.030642 0.026361
k=24 (X) 0.048700 0.022777 0.034515 0.022410 0.026991
k=25 (Y) 0.033561 0.023278 0.045237 0.034797 0.027990
k=26 (Z) 0.050657 0.023858 0.040665 0.025854 0.028682

上面的例子已经给出了网络\(\mathcal{N}_w\)输出\(y\)的描述,与这段音频所对应的标签\(label\),应该是‘C’、‘A’、‘T’这三个字母,将它转换成用字母表中序号表示

\[label=[3,1,20]
\]

CTC计算的推导

论文中CTCLoss的计算公式为

\[O^{ML}(S,\mathcal{N}_w)=-\sum_{x,z \in S}ln(p(z|x))
\]

那上面这个公式表示的含义是什么呢?

  • 符号\(S\)一个训练样本集合,它是总体分布的一个子集。

  • \(x,z \in S\),\(x\)是训练样本集合\(S\)中原始的数据经过网络\(\mathcal{N}_w\)后的输出,\(z\)是与\(x\)相对应的标签。

  • \(p(z|x)\)表示以\(\mathcal{N}_w\)的输出\(x\),将\(x\)恢复为标签\(z\)的概率,也就是\(z\)相对于\(x\)的条件概率。

  • 这样将样本集合\(S\)中每一条样本的\(p(z|x)\)相乘,就是样本\(S\)对于\(\mathcal{N}_w\)似然函数:

    \[L(S,\mathcal{N}_w)=\prod_{x,z \in S}{p(z|x)}
    \]
  • 我们通过训练调整网络\(\mathcal{N}_w\)的参数\(w\),使\(ln{(L(S,\mathcal{N}_w))}\)最大,这个过程就叫最大似然估计。

  • 为了方便计算,我们在等式两边取\(ln\),这就是对数似然函数。

    \[ln{(L(S,\mathcal{N}_w))}=\sum_{x,z \in S}{ln{(p(z|x))}}
    \]
  • 因为似然函数是越大表示结果越好,而损失函数是越小则表示结果越好所以需要一个负号

    \[O^{ML}(S,\mathcal{N}_w)=-ln{(S,\mathcal{N}_w)}=-\sum_{x,z \in S}ln(p(z|x))
    \]

总概率\(p(z|x)\)

CTCLoss中最关键的就是计算每一条样本\({\{x,z\}} \in S\)的条件概率\(p(z|x)\),\(z\)是目标标签与\(x\)是一一对应关系,\(l\)是任意标签只要是符合字母表规则的标签都是可以的,而\(z\)只是符合\(l\)规则中的一条。在训练的时候可以指定\(l=z\),但在公式推导时应该更严谨更泛化一些。因此\(p(z|x)\)可以用作\(p(l|x)\)替代,下面给出\(p(l|x)\)的计算公式

\[p(l|x)=\sum_{\pi \in \mathcal{B}^{-1}(l)}{p(\pi|x)}
\]

路径的含义

已知网络\(\mathcal{N}_w\)的输出\(x\in\mathbb{R}^{K \times T}\),它有\(T\)个时间点,并在每个时间点中有\(K\)种输出的可能,一共有\(K^ T\)条路径。在上面的例子中\(K=27,T=5\)所以一共就有\(27^5=14348907\)条可能的路径。仅仅\(T=5\)时,总路径条数已经相当的巨大了。

路径概率\(p(\pi|x)\)

表1已经给出于每个时刻所有的字母概率,由每个时刻选出的字母将组成一条路径,那么这条路径的概率就等于各个时刻选择字母的概率的乘积。

\[\begin{aligned}
p(\pi|x)&=\prod_{t=1}^{T}{y_{k=\pi^t}^t} \\
&=y_{k=\pi^1}^1\times y_{k=\pi^2}^2\times y_{k=\pi^3}^3\times...\times y_{k=\pi^T}^T
\end{aligned}\]

什么是\(\mathcal{B}\)变换

在上面提到的\(27^5\)条路径中\(\mathcal{B}\)变换就是将路径中所有的blank\((-)\),和相邻重复的元素删除,比如

\[\mathcal{B}(a − ab−) = \mathcal{B}(−aa − −abb) = aab
\]
\[\mathcal{B}(C − AT−) = \mathcal{B}(CC-AT) = CAT
\]

同理符号\(\mathcal{B}^{-1}(l)\)则是\(\mathcal{B}(\pi)\)的逆变换。表示所有满足\(\mathcal{B}(\pi)=l\)的路径

\(p(l|x)\)并不是计算所有路径的概率之和,而是计算所有满足\(\mathcal{B}(\pi)=l\)的路径概率之和。

一步一步手动计算CTCLoss

现在就根据上面提供的例子,一步一步手动计算CTCLoss

找出所有满足\(\mathcal{B}(\pi)=l\),\(l\)=“CAT”的路径

在上面给出的\(27^5\)条路径中给出的符合\(\mathcal{B}(\pi)=l\),\(l\)=“CAT”共有28条,

如表2所示

表2 所有满足条件的路径,共28条

t=1 t=2 t=3 t=4 t=5
\(\pi_{1}\) - - C A T
\(\pi_{2}\) - C - A T
\(\pi_{3}\) - C C A T
\(\pi_{4}\) - C A - T
\(\pi_{5}\) - C A A T
\(\pi_{6}\) - C A T -
\(\pi_{7}\) - C A T T
\(\pi_{8}\) C - - A T
\(\pi_{9}\) C - A - T
\(\pi_{10}\) C - A A T
\(\pi_{11}\) C - A T -
\(\pi_{12}\) C - A T T
\(\pi_{13}\) C C - A T
\(\pi_{14}\) C C C A T
\(\pi_{15}\) C C A - T
\(\pi_{16}\) C C A A T
\(\pi_{17}\) C C A T -
\(\pi_{18}\) C C A T T
\(\pi_{19}\) C A - - T
\(\pi_{20}\) C A - T -
\(\pi_{21}\) C A - T T
\(\pi_{22}\) C A A - T
\(\pi_{23}\) C A A A T
\(\pi_{24}\) C A A T -
\(\pi_{25}\) C A A T T
\(\pi_{26}\) C A T - -
\(\pi_{27}\) C A T T -
\(\pi_{28}\) C A T T T

计算每条路径的概率\(p(\pi|x)\)

路径\(\pi_1\)所对应的标签为"- - C A T",这段序列转换为字母表中的索引,

则路径\(\pi_1\)在每个时刻的取值如下

\[y_{k=\pi^1_1}^1=y_{0}^1=0.031953
\]
\[y_{k=\pi^2_1}^1=y_{0}^2=0.044296
\]
\[y_{k=\pi^3_1}^1=y_{3}^3=0.031872
\]
\[y_{k=\pi^4_1}^1=y_{1}^4=0.027295
\]
\[y_{k=\pi^5_1}^1=y_{20}^5=0.036849
\]

因此路径\(\pi_1的概率\)\(p(\pi_1|x)\)的计算如下

\[\begin{aligned}
p(\pi_1|x)&=\prod_{t=1}^{T}{y_{k=\pi_1^t}^t} \\
&=y_{k=\pi_1^1}^1\times y_{k=\pi_1^2}^2 \times y_{k=\pi_1^3}^3 \times ... \times y_{k=\pi_1^T}^T \\
&=y_{0}^1 \times y_{0}^2 \times y_{3}^3 \times y_{1}^4\times y_{20}^5 \\
&=0.031953 \times0.044296\times0.031872\times0.027295\times0.036849 \\
&=4.5373e^{-8}
\end{aligned}\]

同理可计算

\[p(\pi_1|x)=4.5374e^{-8},
p(\pi_2|x)=5.6482e^{-8},
p(\pi_3|x)=4.7006e^{-8},
p(\pi_4|x)=6.6003e^{-8}\]
\[p(\pi_5|x)=4.7014e^{-8},
p(\pi_6|x)=5.1401e^{-8},
p(\pi_7|x)=6.8965e^{-8},
p(\pi_8|x)=5.0050e^{-8}\]
\[p(\pi_9|x)=5.8487e^{-8},
p(\pi_{10}|x)=4.1660e^{-8},
p(\pi_{11}|x)=4.5547e^{-8},
p(\pi_{12}|x)=6.1111e^{-8}\]
\[p(\pi_{13}|x)=5.1850e^{-8},
p(\pi_{14}|x)=4.3151e^{-8},
p(\pi_{15}|x)=6.0590e^{-8},
p(\pi_{16}|x)=4.3158e^{-8}\]
\[p(\pi_{17}|x)=4.7185e^{-8},
p(\pi_{18}|x)=6.3309e^{-8},
p(\pi_{19}|x)=4.8163e^{-8},
p(\pi_{20}|x)=3.7508e^{-8}\]
\[p(\pi_{21}|x)=5.0324e^{-8},
p(\pi_{22}|x)=4.0090e^{-8},
p(\pi_{23}|x)=2.8556e^{-8},
p(\pi_{24}|x)=3.1220e^{-8}\]
\[p(\pi_{25}|x)=4.1889e^{-8},
p(\pi_{26}|x)=4.0583e^{-8},
p(\pi_{27}|x)=4.2404e^{-8},
p(\pi_{28}|x)=5.6894e^{-8}\]

计算总概率\(p(l|x)\)

\(p(l|x)\)是所有满足\(\mathcal{B}(\pi)=l\)的路径概率之和。

\[\begin{aligned}
p(l|x)&=\sum_{\pi \in \mathcal{B}^{-1}(l)}{p(\pi|x)} \\
&=p(\pi_1|x)+p(\pi_2|x)+p(\pi_1|x)+...+p(\pi_{28}|x) \\
&=4.5374e^{-8} + 5.6482e^{-8}+4.7006e^{-8}+...+5.6894e^{-8} \\
&=1.366e^{-6}
\end{aligned}\]

计算损失函数CTCLoss

由于例子中只给了1样本,所以下面的损失函数CTCLoss也就只有这一个样本的损失。

\[\begin{aligned}
O^{ML}(S,\mathcal{N}_w)&=-ln{(S,\mathcal{N}_w)} \\
&=-\sum_{x,z \in S}ln(p(z|x)) \\
&=-ln(p(z|x)) \\
&=-ln(1.366e^{-6}) \\
&=\ 13.5036

\end{aligned}\]

CTCLoss库函数的验证

网络\(\mathcal{N}_w\)输出\(y\_out\)的softmax处理

这里有一点需要解释一下,CTCLoss的输入\(ctc\_input\)与网络\(\mathcal{N}_w\)的输出\(y\_out\)之间的关系。

在网络\(\mathcal{N}_w\)输出的最后一级是没有softmax,所以\(y\_out\)在每一个时间点的的概率和都不为1,为了将概率分布归一化需要将\(y\)进行softmax计算。

\[y\_softmax=softmax(y\_out)
\]

同时CTCLoss中包含有大量的概率的乘法运算,需要将\(y\_softmax\)进行\(ln\)计算,

这样可以将乘法转换为加法计算,提升计算的速度。

\[ctc\_input=ln(y\_softmax)
\]

上面的例子,为了让文档更直观,已经默认

\[y=y\_softmax
\]

下表就是\(y\_out\),显然每一列之和不为1。

\(y\_out_t^k\) t=1 t=2 t=3 t=4 t=5
k=0 (-) 0.347713 0.755077 0.678652 0.585987 0.123084
k=1 (A) 0.149997 0.377396 0.495177 0.246735 0.205494
k=2 (B) 0.586092 0.216019 0.189710 0.666416 0.146515
k=3 (C) 0.262145 0.790407 0.495006 0.083483 0.189072
k=4 (D) 0.044454 0.949304 0.147608 0.625960 0.042652
k=5 (E) 0.754933 0.327565 0.054974 0.660945 0.635198
k=6 (F) 0.242785 0.671264 0.850713 0.729752 0.281867
k=7 (G) 0.442402 0.438645 0.560560 0.890752 0.538597
k=8 (H) 0.687796 0.833501 0.929609 0.982303 0.695163
k=9 (I) 0.359228 0.768854 0.696667 0.769029 0.499116
k=10 (J) 0.736340 0.167254 0.582791 0.581446 0.535801
k=11 (K) 0.394707 0.861980 0.815397 0.928313 0.445183
k=12 (L) 0.683416 0.989872 0.879014 0.580090 0.123932
k=13 (M) 0.704047 0.514423 0.988912 0.016983 0.490357
k=14 (N) 0.442305 0.884281 0.000522 0.120860 0.852998
k=15 (O) 0.019578 0.588026 0.865439 0.862711 0.873927
k=16 (P) 0.330858 0.154752 0.612566 0.484297 0.270294
k=17 (Q) 0.424309 0.199863 0.989950 0.844856 0.208461
k=18 (R) 0.270270 0.406955 0.527680 0.209405 0.564980
k=19 (S) 0.197054 0.748706 0.479523 0.552291 0.640312
k=20 (T) 0.821721 0.825584 0.801348 0.629883 0.417029
k=21 (U) 0.429921 0.789963 0.227843 0.031991 0.205976
k=22 (V) 0.887771 0.318524 0.498094 0.614713 0.947933
k=23 (W) 0.391183 0.534064 0.900852 0.362411 0.082071
k=24 (X) 0.769114 0.089951 0.574661 0.049533 0.105709
k=25 (Y) 0.396792 0.111706 0.845178 0.489570 0.142041
k=26 (Z) 0.808514 0.136293 0.738640 0.192510 0.166460

pytorch库函数验证

CTCLoss使用细节可以参考pytorch官方文档

import torch
import torch.nn as nn
import numpy as np y_softmax = np.array([
[[0.0319533345695271, 0.0262210133693412, 0.0405548727460100, 0.0293328834922530, 0.0235946021815836, 0.0480142162870594, 0.0287704618407728, 0.0351268637054168, 0.0448965052477630, 0.0323234212279283, 0.0471297269219778, 0.0334908192070999, 0.0447002788315031, 0.0456320948241136,
0.0351234600906292, 0.0230148922614546, 0.0314192811142228, 0.0344970346892286, 0.0295721871384341, 0.0274843752526059, 0.0513304969210734, 0.0346911732659917, 0.0548353372646645, 0.0333729892573427, 0.0486999624899632, 0.0335606882517763, 0.0506570275502634]],
[[0.0442961938109001, 0.0303627704208565, 0.0258378526020265, 0.0458891577161975, 0.0537920435977104, 0.0288868677848477, 0.0407349328912650, 0.0322806067098565, 0.0479099042067772, 0.0449106925711146, 0.0246080887866719, 0.0492939884049119, 0.0560191619281624, 0.0348218517081914,
0.0504056201105211, 0.0374815087428365, 0.0243023731122621, 0.0254237678526359, 0.0312736688595233, 0.0440148630768450, 0.0475321094768427, 0.0458687788283468, 0.0286268732637606, 0.0355125367928648, 0.0227774801386588, 0.0232784351056503, 0.0238578714997625]],
[[0.0382974368377362, 0.0318777312135849, 0.0234868589674224, 0.0318722744011979, 0.0225185381516373, 0.0205262552943881, 0.0454877627911883, 0.0340316234294017, 0.0492219436202117, 0.0389936131137926, 0.0347966678592871, 0.0439093761642613, 0.0467935124498177, 0.0522292290638150,
0.0194384495697102, 0.0461625681675025, 0.0358483354617907, 0.0522835019782284, 0.0329308772273817, 0.0313826141807340, 0.0432967801742709, 0.0243997674509821, 0.0319708630090250, 0.0478266566415420, 0.0345149265806327, 0.0452367066323343, 0.0406651295681235]],
[[0.0383195501689954, 0.0272951137973125, 0.0415288927451887, 0.0231838517718695, 0.0398823138441169, 0.0413022813256117, 0.0442442329310963, 0.0519730489462436, 0.0569558497142297, 0.0460166028890008, 0.0381459528257684, 0.0539623316283564, 0.0380942573161036, 0.0216922730554261,
0.0240667868142706, 0.0505358960731075, 0.0346143968556499, 0.0496415831760055, 0.0262949856792733, 0.0370498580320465, 0.0400391034751884, 0.0220202876462848, 0.0394362954874324, 0.0306423990773223, 0.0224099657044701, 0.0347974172676594, 0.0258544717519696]],
[[0.0274643882294982, 0.0298236175649923, 0.0281155092606543, 0.0293378537462717, 0.0253418924737544, 0.0458330002578632, 0.0321905618820226, 0.0416126048467898, 0.0486654573566434, 0.0400017201897758, 0.0414964341812715, 0.0379014590893513, 0.0274877024782956, 0.0396528862221281,
0.0569859416555112, 0.0581911831104043, 0.0318201830875284, 0.0299122412570334, 0.0427250763149338, 0.0460679863549903, 0.0368492548068844, 0.0298379764585031, 0.0626610201269008, 0.0263607892806820, 0.0269913345266294, 0.0279900073565483, 0.0286819178841385]]
]).astype("float32") labels = np.array([[3, 1, 20]]).astype("int32")
input_lengths = np.array([5]).astype("int64")
label_lengths = np.array([3]).astype("int64") ctc_input = torch.tensor(y_softmax).log()
labels = torch.tensor(labels)
input_lengths = torch.tensor(input_lengths)
label_lengths = torch.tensor(label_lengths) ctc_loss = nn.CTCLoss(reduction='none')
loss = ctc_loss(ctc_input, labels, input_lengths, label_lengths)
print('loss is {}'.format(loss))
loss is tensor([13.5036])

paddle库函数的使用

CTCLoss使用细节可以参考

paddle官方文档

由于paddle的CTCLoss库底层已经实现了log_softmax,所以它的输入可以直接为\(y\_out\)

import numpy as np
import paddle
import paddle.nn.functional as F y_out = np.array([
[[0.347712671277525, 0.149997253831683, 0.586092067231462, 0.262145317727807, 0.0444540922782385, 0.754933267231179, 0.242785357820962, 0.442402313001943, 0.687796085120107, 0.359228210401861, 0.736340074301202, 0.394707475278763, 0.683415866967978, 0.704047430334266,
0.442305413383371, 0.0195776235533187, 0.330857880214071, 0.424309496833137, 0.270270423432065, 0.197053798095456, 0.821721184961310, 0.429921409383266, 0.887770954256354, 0.391182995461163, 0.769114387388296, 0.396791517013617, 0.808514095887345]],
[[0.755077099007084, 0.377395544835103, 0.216018915961394, 0.790407217966913, 0.949303911849797, 0.327565434075205, 0.671264370451740, 0.438644982586956, 0.833500595588975, 0.768854252429615, 0.167253545494722, 0.861980478702072, 0.989872153631504, 0.514423456505704,
0.884281023126955, 0.588026055308498, 0.154752348656045, 0.199862822857452, 0.406954837138907, 0.748705718215691, 0.825583815786156, 0.789963029944531, 0.318524245398992, 0.534064127370726, 0.0899506787705811, 0.111705744193203, 0.136292548938299]],
[[0.678652304800188, 0.495177019089661, 0.189710406017580, 0.495005824990221, 0.147608221976689, 0.0549741469061882, 0.850712674289007, 0.560559527354885, 0.929608866756663, 0.696667200555228, 0.582790965175840, 0.815397211477421, 0.879013904597178, 0.988911616079589,
0.000522375356944771, 0.865438591013025, 0.612566469483999, 0.989950205708831, 0.527680069338442, 0.479523385210219, 0.801347605521952, 0.227842935706042, 0.498094291196390, 0.900852488532005, 0.574661219130188, 0.845178185054037, 0.738640291995402]],
[[0.585987035826476, 0.246734525985975, 0.666416217319468, 0.0834828136026227, 0.625959785171583, 0.660944557947342, 0.729751855317221, 0.890752116325322, 0.982303222883606, 0.769029085335896, 0.581446487875398, 0.928313062314188, 0.580090365758442, 0.0169829383372613,
0.120859571098558, 0.862710718699670, 0.484296511212103, 0.844855674576263, 0.209405084020935, 0.552291341538775, 0.629883385064421, 0.0319910157625669, 0.614713419117141, 0.362411462273053, 0.0495325790420612, 0.489569989177322, 0.192510396062075]],
[[0.123083747545945, 0.205494170907680, 0.146514910614890, 0.189072174472614, 0.0426524109111434, 0.635197916859882, 0.281866855880430, 0.538596678045340, 0.695163039444332, 0.499116013482590, 0.535801055751113, 0.445183165296042, 0.123932277598070, 0.490357293468018,
0.852998155340816, 0.873927405861733, 0.270294332292698, 0.208461358751314, 0.564979570738201, 0.640311825162758, 0.417028951642886, 0.205975515532243, 0.947933121293169, 0.0820712070977259, 0.105709426581721, 0.142041121903998, 0.166460440876421]]
]).astype("float32") labels = np.array([[3, 1, 20]]).astype("int32")
input_lengths = np.array([5]).astype("int64")
label_lengths = np.array([3]).astype("int64") y_out=paddle.to_tensor(y_out)
labels = paddle.to_tensor(labels)
input_lengths = paddle.to_tensor(input_lengths)
label_lengths = paddle.to_tensor(label_lengths) loss = paddle.nn.CTCLoss(blank=0, reduction='none')(y_out, labels,
input_lengths,
label_lengths)
print('loss is {}'.format(loss))
loss is Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
[13.50364304])

CTCLoss如何使用的更多相关文章

  1. 语音识别中的CTC算法的基本原理解释

    欢迎大家前往腾讯云+社区,获取更多腾讯海量技术实践干货哦~ 本文作者:罗冬日 目前主流的语音识别都大致分为特征提取,声学模型,语音模型几个部分.目前结合神经网络的端到端的声学模型训练方法主要CTC和基 ...

  2. 【OCR技术系列之八】端到端不定长文本识别CRNN代码实现

    CRNN是OCR领域非常经典且被广泛使用的识别算法,其理论基础可以参考我上一篇文章,本文将着重讲解CRNN代码实现过程以及识别效果. 数据处理 利用图像处理技术我们手工大批量生成文字图像,一共360万 ...

  3. CTC+pytorch编译配置warp-CTC

    CTC CTC可以生成一个损失函数,用于在序列数据上进行监督式学习,不需要对齐输入数据及标签,经常连接在一个RNN网络的末端,训练端到端的语音和文本识别系统.CTC论文地址:http://www.cs ...

  4. 服务器个人环境下pytorch0.4.1编译warp-ctc遇到的问题及解决方法

    一.关于warp-ctc CTC可以生成一个损失函数,用于在序列数据上进行监督式学习,不需要对齐输入数据及标签,经常连接在一个RNN网络的末端,训练端到端的语音或文本识别系统.CTC论文 CTC网络的 ...

  5. 从零和使用mxnet实现线性回归

    1.线性回归从零实现 from mxnet import ndarray as nd import matplotlib.pyplot as plt import numpy as np import ...

  6. Pytorch的19种损失函数

    基本用法 12 criterion = LossCriterion() loss = criterion(x, y) # 调用标准时也有参数 损失函数 L1范数损失:L1Loss 计算 output ...

  7. [PyTorch 学习笔记] 4.2 损失函数

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/loss_function_1.py https:// ...

  8. pytorch(16)损失函数(二)

    5和6是在数据回归中用的较多的损失函数 5. nn.L1Loss 功能:计算inputs与target之差的绝对值 代码: nn.L1Loss(reduction='mean') 公式: \[l_n ...

  9. PaddleOCR详解

    @ 目录 PaddleOCR简介 环境配置 PaddleOCR2.0的配置环境 Docker 数据集 文本检测 使用自己的数据集 文本识别 使用自己的数据集 字典 自定义字典 添加空格类别 文本角度分 ...

随机推荐

  1. .NET 程序读取当前目录避坑指南

    前些天有 AgileConfig 的用户反映,如果把 AgileConfig 部署成 Windows 服务程序会启动失败.我看了一下日志,发现根目录被定位到了 C:\Windows\System32 ...

  2. IOC简介 -Bean的作用域 创建对象

    创建对象 创建对象时默认使用无参构造器,无论对象在容器中后续是否被使用, 都会先实例化对象 . (婚介网站,里面人都是先存在的,到时直接牵手就行) 也可以使用以下方法,使用有参构造器来创建对象 根据参 ...

  3. 超耐心地毯式分析,来试试这道看似简单但暗藏玄机的Promise顺序执行题

    壹 ❀ 引 就在昨天,与朋友聊到JS基础时,她突然想起之前在面试时,遇到了一道难以理解的Promise执行顺序题.由于我之前专门写过手写promise的文章,对于部分原理也还算了解,出于兴趣我便要了这 ...

  4. CRM项目的整理---第一篇

    CRM:cunstomer relationship management  客户管理系统 1.项目的使用者:销售  班主任    讲师  助教 2.项目的需求分析 2.1.注册 2.2.登录 2.3 ...

  5. python使用vosk进行中文语音识别

    操作系统:Windows10 Python版本:3.9.2 vosk是一个离线开源语音识别工具,它可以识别16种语言,包括中文. 这里记录下使用vosk进行中文识别的过程,以便后续查阅. vosk地址 ...

  6. 04 Springboot 格式化LocalDateTime

    Springboot 格式化LocalDateTime 我们知道在springboot中有默认的json解析器,Spring Boot 中默认使用的 Json 解析技术框架是 jackson.我们点开 ...

  7. 手动搭建简易web框架与django框架简介

    目录 纯手写简易web框架 基于wsgiref模块 动静态网页 简单了解jinja2模块 框架请求流程 python主流web框架 django框架 简介 应用app 命令操作django pycha ...

  8. e2fsck-磁盘分区修复

    检查 ext2/ext3/ext4 类型文件系统. 语法 e2fsck [-panyrcdfvtDFV] [-b superblock] [-B blocksize] [-I inode_buffer ...

  9. 基础篇:java GC 总结,建议收藏

    垃圾标记算法 垃圾回收算法 major gc.mini gc.full gc.mixed gc 又是什么,怎么触发的 垃圾回收器的介绍 Safe Point 和 Safe Region 什么是 TLA ...

  10. 拭目以待!JNPF .NET将更新.NET 6技术,同时上线 3.4.1 版本

    2022年5月30日,福建引迈即将上线JNPF开发平台的.NET 6版本,在产品性能上做了深度优化,且极大的提升了工作效率,加强了对云服务的改进升级,全面提升用户的使用体验. JNPF是一个以PaaS ...