Task

Sample Baseline模型介绍

class Classifier(nn.Module):
def __init__(self, d_model=80, n_spks=600, dropout=0.1):
super().__init__()
# Project the dimension of features from that of input into d_model.
self.prenet = nn.Linear(40, d_model) # transformer
self.encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, dim_feedforward=256, nhead=2
)
self.encoder = self.encoder_layer self.pred_layer = nn.Sequential(
nn.Linear(d_model, d_model),
nn.ReLU(),
nn.Linear(d_model, n_spks),
) def forward(self, mels):
"""
args:
mels: (batch size, length, 40)
return:
out: (batch size, n_spks)
"""
# out: (batch size, length, d_model)
out = self.prenet(mels)
# out: (length, batch size, d_model)
out = out.permute(1, 0, 2)
# The encoder layer expect features in the shape of (length, batch size, d_model).
out = self.encoder(out)
# out: (batch size, length, d_model)
out = out.transpose(0, 1)
# mean pooling
stats = out.mean(dim=1)
# out: (batch, n_spks)
out = self.pred_layer(stats)
return out

模型开始对特征进行了升维以增强表示能力,随后通过transformer的encoder对数据进一步编码(未使用decoder),到这一步就包含了原来没有包含的注意力信息,以英文Sequence为例,如果原来的Sequence中每个单词是独立编码的是没有任何关联的,那么经过这一步之后,每一个单词的编码都是由其他单词编码的叠加而成。最后通过pred_layer进行预测(当然在此之前进行了一个mean pooling,这个下面会讲)。

在模型的前向传播时,模型基本是安装前面定义的各层进行计算的,我们注意到在给encoder的输入时,维度的顺序为(length,batch_size, d_model),而不是(batch_size, length, d_model),实际上这是为了并行计算?

下图是batch_first时所对应的存储顺序

下图时length_first时所对应的存储顺序

在其他时序模型中,由于需要按序输入,因此直接拿到一个sequence没啥用,不如直接得到一批batch中的所有sequence中的第一个语音序列或单词,但是在transformer中应该不需要这样吧?

另外一个小细节是进行mean pooling

stats = out.mean(dim=1)

这一步是不必可少的,不然没法输入pred_layer,这里做mean的意思是把每个sequence的所有frame通过平均合并为一个frame,如下图所示



维度由batch_size\(\times\)length\(\times\)d_model变成了batch_size\(\times\)d_model

Medium Baseline

对于medium baseline,只需要调节视频中提示的地方进行修改即可,我的得分如下:



相关参数如下

d_model=120

4个encoder_layer的nhead=4

Strong Baseline

对于strong baseline,需要引入conformer架构,我的得分如下:

而conformer的引入需要注意以下几点:

引入(当然你也可以使用pip进行单独安装)

from torchaudio.models.conformer import Conformer

由于torchaudio中实现的conformer默认是batch_first,因此在代码中我们需要去掉下面两行

out = out.permute(1, 0, 2)
out = out.transpose(0, 1)

Boss Baseline

引入self-attention pooling和additive margin softmax后,准确率下降了。

李宏毅2022机器学习HW4 Speaker Identification下的更多相关文章

  1. 李宏毅老师机器学习课程笔记_ML Lecture 3-1: Gradient Descent

    引言: 这个系列的笔记是台大李宏毅老师机器学习的课程笔记 视频链接(bilibili):李宏毅机器学习(2017) 另外已经有有心的同学做了速记并更新在github上:李宏毅机器学习笔记(LeeML- ...

  2. Python 机器学习实战 —— 监督学习(下)

    前言 近年来AI人工智能成为社会发展趋势,在IT行业引起一波热潮,有关机器学习.深度学习.神经网络等文章多不胜数.从智能家居.自动驾驶.无人机.智能机器人到人造卫星.安防军备,无论是国家级军事设备还是 ...

  3. 李宏毅老师机器学习课程笔记_ML Lecture 2: Where does the error come from?

    引言: 最近开始学习"机器学习",早就听说祖国宝岛的李宏毅老师的大名,一直没有时间看他的系列课程.今天听了一课,感觉非常棒,通俗易懂,而又能够抓住重点,中间还能加上一些很有趣的例子 ...

  4. 李宏毅老师机器学习课程笔记_ML Lecture 1: ML Lecture 1: Regression - Demo

    引言: 最近开始学习"机器学习",早就听说祖国宝岛的李宏毅老师的大名,一直没有时间看他的系列课程.今天听了一课,感觉非常棒,通俗易懂,而又能够抓住重点,中间还能加上一些很有趣的例子 ...

  5. 李宏毅老师机器学习课程笔记_ML Lecture 1: 回归案例研究

    引言: 最近开始学习"机器学习",早就听说祖国宝岛的李宏毅老师的大名,一直没有时间看他的系列课程.今天听了一课,感觉非常棒,通俗易懂,而又能够抓住重点,中间还能加上一些很有趣的例子 ...

  6. 李宏毅老师机器学习课程笔记_ML Lecture 0-2: Why we need to learn machine learning?

    引言: 最近开始学习"机器学习",早就听说祖国宝岛的李宏毅老师的大名,一直没有时间看他的系列课程.今天听了一课,感觉非常棒,通俗易懂,而又能够抓住重点,中间还能加上一些很有趣的例子 ...

  7. 李宏毅老师机器学习课程笔记_ML Lecture 0-1: Introduction of Machine Learning

    引言: 最近开始学习"机器学习",早就听说祖国宝岛的李宏毅老师的大名,一直没有时间看他的系列课程.今天听了一课,感觉非常棒,通俗易懂,而又能够抓住重点,中间还能加上一些很有趣的例子 ...

  8. 机器学习之神经网络模型-下(Neural Networks: Representation)

    3. Model Representation I 1 神经网络是在模仿大脑中的神经元或者神经网络时发明的.因此,要解释如何表示模型假设,我们不妨先来看单个神经元在大脑中是什么样的. 我们的大脑中充满 ...

  9. Andrew Ng机器学习课程笔记--week5(下)

    Neural Networks: Learning 内容较多,故分成上下两篇文章. 一.内容概要 Cost Function and Backpropagation Cost Function Bac ...

  10. 机器学习模型从windows下 spring上传到预发布会导致模型不可加载

    1.通过上传到redis,程序通过redis拉取模型,解决问题. 2.问题原因初步思考为windows下模型文件上传到 linux导致,待继续跟进查找.

随机推荐

  1. 【团队效率提升】Python-PyWebIO介绍

    作者:京东零售 关键 Q&A快速了解PyWebIO Q:首先,什么是PyWebIO? A:PyWebIO提供了一系列命令式的交互函数,能够让咱们用只用Python就可以编写 Web 应用, 不 ...

  2. Mac 版的 Quicker CirMenu

    之前在Windows上用过一款圆盘菜单工具Quicker, 感觉非常方便, 换成Macos后,一直没有找到类似应用. 最近终于发现,一款好用的快捷键收集,触发工具CirMenu. 其核心功能是可以根据 ...

  3. SP5464 CT - Counting triangles 题解

    题目翻译 题意 有一个网格,左上角是 \((0,0)\),右上角是 \((x,y)\).求这个网格中一共有多少个等腰直角三角形. 输入 第一行给定一个 \(c\),表示有 \(c\) 组数据. 后面 ...

  4. 小白学k8s(1)二进制部署k8s

    二进制部署k8s 前言 准备工作 关闭防火墙 关闭 swap 分区 关闭 SELinux 更新系统时间 秘钥免密码 设置主机名称 服务器角色 安装etcd 创建证书 生成证书 部署Etcd 在Node ...

  5. 一键式文本纠错工具,整合了BERT、ERNIE等多种模型,让您立即享受纠错的便利和效果

    pycorrector一键式文本纠错工具,整合了BERT.MacBERT.ELECTRA.ERNIE等多种模型,让您立即享受纠错的便利和效果 pycorrector: 中文文本纠错工具.支持中文音似. ...

  6. Git 简单实用教程

    相关链接: 码云(gitee)配置SSH密钥 码云gitee创建仓库并用git上传文件 git 上传错误This oplation equires one of the flowi vrsionsot ...

  7. Linux 统计Web服务日志命令

    本人在Linux运维中收集的一些通用的统计,Apache/Nginx服务器日志的命令组合. Apache日志统计 # 列出当天访问次数最多的IP命令 [root@lyshark.cnblogs.com ...

  8. 【Matlab】蒙特卡罗法模拟圆周率+对应解析的GIF生成【超详细的注释和解释】

    文章目录 前言 模拟思路 GIF模拟动图的生成 GIF动图生成的基本思路 单张静态图的生成 GIF的生成 尾声 前言 因为博主最近要准备数学建模大赛了,在学习matlab和python之余,博主也会继 ...

  9. P8512 [Ynoi Easy Round 2021] TEST_152 题解

    题目链接:[Ynoi Easy Round 2021] TEST_152 题目比较抽象,翻译一下.就是有 \(n\) 个操作,每个操作为 \((l_i,r_i,v_i)\) 表示把长为 \(m\) 序 ...

  10. 一个关于用netty的小错误反思

    一个关于用netty的小认知 在使用netty时,观看了黑马的netty网课,没想就直接用他的依赖了 依赖如下 <dependency> <groupId>io.netty&l ...