李宏毅2022机器学习HW4 Speaker Identification下
Task
Sample Baseline模型介绍
class Classifier(nn.Module):
def __init__(self, d_model=80, n_spks=600, dropout=0.1):
super().__init__()
# Project the dimension of features from that of input into d_model.
self.prenet = nn.Linear(40, d_model)
# transformer
self.encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, dim_feedforward=256, nhead=2
)
self.encoder = self.encoder_layer
self.pred_layer = nn.Sequential(
nn.Linear(d_model, d_model),
nn.ReLU(),
nn.Linear(d_model, n_spks),
)
def forward(self, mels):
"""
args:
mels: (batch size, length, 40)
return:
out: (batch size, n_spks)
"""
# out: (batch size, length, d_model)
out = self.prenet(mels)
# out: (length, batch size, d_model)
out = out.permute(1, 0, 2)
# The encoder layer expect features in the shape of (length, batch size, d_model).
out = self.encoder(out)
# out: (batch size, length, d_model)
out = out.transpose(0, 1)
# mean pooling
stats = out.mean(dim=1)
# out: (batch, n_spks)
out = self.pred_layer(stats)
return out
模型开始对特征进行了升维以增强表示能力,随后通过transformer的encoder对数据进一步编码(未使用decoder),到这一步就包含了原来没有包含的注意力信息,以英文Sequence为例,如果原来的Sequence中每个单词是独立编码的是没有任何关联的,那么经过这一步之后,每一个单词的编码都是由其他单词编码的叠加而成。最后通过pred_layer进行预测(当然在此之前进行了一个mean pooling,这个下面会讲)。
在模型的前向传播时,模型基本是安装前面定义的各层进行计算的,我们注意到在给encoder的输入时,维度的顺序为(length,batch_size, d_model),而不是(batch_size, length, d_model),实际上这是为了并行计算?
下图是batch_first时所对应的存储顺序

下图时length_first时所对应的存储顺序

在其他时序模型中,由于需要按序输入,因此直接拿到一个sequence没啥用,不如直接得到一批batch中的所有sequence中的第一个语音序列或单词,但是在transformer中应该不需要这样吧?
另外一个小细节是进行mean pooling
stats = out.mean(dim=1)
这一步是不必可少的,不然没法输入pred_layer,这里做mean的意思是把每个sequence的所有frame通过平均合并为一个frame,如下图所示

维度由batch_size\(\times\)length\(\times\)d_model变成了batch_size\(\times\)d_model
Medium Baseline
对于medium baseline,只需要调节视频中提示的地方进行修改即可,我的得分如下:

相关参数如下
d_model=120
4个encoder_layer的nhead=4
Strong Baseline
对于strong baseline,需要引入conformer架构,我的得分如下:

而conformer的引入需要注意以下几点:
引入(当然你也可以使用pip进行单独安装)
from torchaudio.models.conformer import Conformer
由于torchaudio中实现的conformer默认是batch_first,因此在代码中我们需要去掉下面两行
out = out.permute(1, 0, 2)
out = out.transpose(0, 1)
Boss Baseline
引入self-attention pooling和additive margin softmax后,准确率下降了。

李宏毅2022机器学习HW4 Speaker Identification下的更多相关文章
- 李宏毅老师机器学习课程笔记_ML Lecture 3-1: Gradient Descent
引言: 这个系列的笔记是台大李宏毅老师机器学习的课程笔记 视频链接(bilibili):李宏毅机器学习(2017) 另外已经有有心的同学做了速记并更新在github上:李宏毅机器学习笔记(LeeML- ...
- Python 机器学习实战 —— 监督学习(下)
前言 近年来AI人工智能成为社会发展趋势,在IT行业引起一波热潮,有关机器学习.深度学习.神经网络等文章多不胜数.从智能家居.自动驾驶.无人机.智能机器人到人造卫星.安防军备,无论是国家级军事设备还是 ...
- 李宏毅老师机器学习课程笔记_ML Lecture 2: Where does the error come from?
引言: 最近开始学习"机器学习",早就听说祖国宝岛的李宏毅老师的大名,一直没有时间看他的系列课程.今天听了一课,感觉非常棒,通俗易懂,而又能够抓住重点,中间还能加上一些很有趣的例子 ...
- 李宏毅老师机器学习课程笔记_ML Lecture 1: ML Lecture 1: Regression - Demo
引言: 最近开始学习"机器学习",早就听说祖国宝岛的李宏毅老师的大名,一直没有时间看他的系列课程.今天听了一课,感觉非常棒,通俗易懂,而又能够抓住重点,中间还能加上一些很有趣的例子 ...
- 李宏毅老师机器学习课程笔记_ML Lecture 1: 回归案例研究
引言: 最近开始学习"机器学习",早就听说祖国宝岛的李宏毅老师的大名,一直没有时间看他的系列课程.今天听了一课,感觉非常棒,通俗易懂,而又能够抓住重点,中间还能加上一些很有趣的例子 ...
- 李宏毅老师机器学习课程笔记_ML Lecture 0-2: Why we need to learn machine learning?
引言: 最近开始学习"机器学习",早就听说祖国宝岛的李宏毅老师的大名,一直没有时间看他的系列课程.今天听了一课,感觉非常棒,通俗易懂,而又能够抓住重点,中间还能加上一些很有趣的例子 ...
- 李宏毅老师机器学习课程笔记_ML Lecture 0-1: Introduction of Machine Learning
引言: 最近开始学习"机器学习",早就听说祖国宝岛的李宏毅老师的大名,一直没有时间看他的系列课程.今天听了一课,感觉非常棒,通俗易懂,而又能够抓住重点,中间还能加上一些很有趣的例子 ...
- 机器学习之神经网络模型-下(Neural Networks: Representation)
3. Model Representation I 1 神经网络是在模仿大脑中的神经元或者神经网络时发明的.因此,要解释如何表示模型假设,我们不妨先来看单个神经元在大脑中是什么样的. 我们的大脑中充满 ...
- Andrew Ng机器学习课程笔记--week5(下)
Neural Networks: Learning 内容较多,故分成上下两篇文章. 一.内容概要 Cost Function and Backpropagation Cost Function Bac ...
- 机器学习模型从windows下 spring上传到预发布会导致模型不可加载
1.通过上传到redis,程序通过redis拉取模型,解决问题. 2.问题原因初步思考为windows下模型文件上传到 linux导致,待继续跟进查找.
随机推荐
- 【团队效率提升】Python-PyWebIO介绍
作者:京东零售 关键 Q&A快速了解PyWebIO Q:首先,什么是PyWebIO? A:PyWebIO提供了一系列命令式的交互函数,能够让咱们用只用Python就可以编写 Web 应用, 不 ...
- Mac 版的 Quicker CirMenu
之前在Windows上用过一款圆盘菜单工具Quicker, 感觉非常方便, 换成Macos后,一直没有找到类似应用. 最近终于发现,一款好用的快捷键收集,触发工具CirMenu. 其核心功能是可以根据 ...
- SP5464 CT - Counting triangles 题解
题目翻译 题意 有一个网格,左上角是 \((0,0)\),右上角是 \((x,y)\).求这个网格中一共有多少个等腰直角三角形. 输入 第一行给定一个 \(c\),表示有 \(c\) 组数据. 后面 ...
- 小白学k8s(1)二进制部署k8s
二进制部署k8s 前言 准备工作 关闭防火墙 关闭 swap 分区 关闭 SELinux 更新系统时间 秘钥免密码 设置主机名称 服务器角色 安装etcd 创建证书 生成证书 部署Etcd 在Node ...
- 一键式文本纠错工具,整合了BERT、ERNIE等多种模型,让您立即享受纠错的便利和效果
pycorrector一键式文本纠错工具,整合了BERT.MacBERT.ELECTRA.ERNIE等多种模型,让您立即享受纠错的便利和效果 pycorrector: 中文文本纠错工具.支持中文音似. ...
- Git 简单实用教程
相关链接: 码云(gitee)配置SSH密钥 码云gitee创建仓库并用git上传文件 git 上传错误This oplation equires one of the flowi vrsionsot ...
- Linux 统计Web服务日志命令
本人在Linux运维中收集的一些通用的统计,Apache/Nginx服务器日志的命令组合. Apache日志统计 # 列出当天访问次数最多的IP命令 [root@lyshark.cnblogs.com ...
- 【Matlab】蒙特卡罗法模拟圆周率+对应解析的GIF生成【超详细的注释和解释】
文章目录 前言 模拟思路 GIF模拟动图的生成 GIF动图生成的基本思路 单张静态图的生成 GIF的生成 尾声 前言 因为博主最近要准备数学建模大赛了,在学习matlab和python之余,博主也会继 ...
- P8512 [Ynoi Easy Round 2021] TEST_152 题解
题目链接:[Ynoi Easy Round 2021] TEST_152 题目比较抽象,翻译一下.就是有 \(n\) 个操作,每个操作为 \((l_i,r_i,v_i)\) 表示把长为 \(m\) 序 ...
- 一个关于用netty的小错误反思
一个关于用netty的小认知 在使用netty时,观看了黑马的netty网课,没想就直接用他的依赖了 依赖如下 <dependency> <groupId>io.netty&l ...