pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合

  • torch.utils.data.Dataset:所有继承他的子类都应该重写  __len()__  , __getitem()__ 这两个方法

    • __len()__ :返回数据集中数据的数量
    • __getitem()__ :返回支持下标索引方式获取的一个数据
  • torch.utils.data.DataLoader:对数据集进行包装,可以设置batch_size、是否shuffle....

第一步

  自定义的 Dataset 都需要继承 torch.utils.data.Dataset 类,并且重写它的两个成员方法:

  • __len()__:读取数据,返回数据和标签
  • __getitem()__:返回数据集的长度
from torch.utils.data import Dataset

class AudioDataset(Dataset):
def __init__(self, ...):
"""类的初始化"""
pass def __getitem__(self, item):
"""每次怎么读数据,返回数据和标签"""
return data, label def __len__(self):
"""返回整个数据集的长度"""
return total

注意事项:Dataset只负责数据的抽象,一次调用getiitem只返回一个样本

案例:

  文件目录结构

  • p225

    • ***.wav
    • ***.wav
    • ***.wav
    • ...
  • dataset.py

目的:读取p225文件夹中的音频数据

 1 class AudioDataset(Dataset):
2 def __init__(self, data_folder, sr=16000, dimension=8192):
3 self.data_folder = data_folder
4 self.sr = sr
5 self.dim = dimension
6
7 # 获取音频名列表
8 self.wav_list = []
9 for root, dirnames, filenames in os.walk(data_folder):
10 for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
11 self.wav_list.append(os.path.join(root, filename))
12
13 def __getitem__(self, item):
14 # 读取一个音频文件,返回每个音频数据
15 filename = self.wav_list[item]
16 wb_wav, _ = librosa.load(filename, sr=self.sr)
17
18 # 取 帧
19 if len(wb_wav) >= self.dim:
20 max_audio_start = len(wb_wav) - self.dim
21 audio_start = np.random.randint(0, max_audio_start)
22 wb_wav = wb_wav[audio_start: audio_start + self.dim]
23 else:
24 wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")
25
26 return wb_wav, filename
27
28 def __len__(self):
29 # 音频文件的总数
30 return len(self.wav_list)

注意事项:19-24行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,

第二步

  实例化 Dataset 对象

Dataset= AudioDataset("./p225", sr=16000)

如果要通过batch读取数据的可直接跳到第三步,如果你想一个一个读取数据的可以看我接下来的操作

# 实例化AudioDataset对象
train_set = AudioDataset("./p225", sr=16000) for i, data in enumerate(train_set):
wb_wav, filname = data
print(i, wb_wav.shape, filname) if i == 3:
break
# 0 (8192,) ./p225\p225_001.wav
# 1 (8192,) ./p225\p225_002.wav
# 2 (8192,) ./p225\p225_003.wav
# 3 (8192,) ./p225\p225_004.wav

第三步

  如果想要通过batch读取数据,需要使用DataLoader进行包装

为何要使用DataLoader?

  1. 深度学习的输入是mini_batch形式
  2. 样本加载时候可能需要随机打乱顺序,shuffle操作
  3. 样本加载需要采用多线程

  pytorch提供的 DataLoader 封装了上述的功能,这样使用起来更方便。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)

参数

  • dataset:加载的数据集(Dataset对象)
  • batch_size:每个批次要加载多少个样本(默认值:1)
  • shuffle:每个epoch是否将数据打乱
  • sampler:定义从数据集中抽取样本的策略。如果指定,则不能指定洗牌。
  • batch_sampler:类似于sampler,但每次返回一批索引。与batch_size、shuffle、sampler和drop_last相互排斥。
  • num_workers:使用多进程加载的进程数,0代表不使用多线程
  • collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认拼接方式
  • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
  • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

返回:数据加载器

案例:

# 实例化AudioDataset对象
train_set = AudioDataset("./p225", sr=16000)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True) for (i, data) in enumerate(train_loader):
wav_data, wav_name = data
print(wav_data.shape) # torch.Size([8, 8192])
print(i, wav_name)
# ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
# './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')

我们来吃几个栗子消化一下:

栗子1

  这个例子就是本文一直举例的,栗子1只是合并了一下而已

  文件目录结构

  • p225

    • ***.wav
    • ***.wav
    • ***.wav
    • ...
  • dataset.py

目的:读取p225文件夹中的音频数据

 1 import fnmatch
2 import os
3 import librosa
4 import numpy as np
5 from torch.utils.data import Dataset
6 from torch.utils.data import DataLoader
7
8
9 class Aduio_DataLoader(Dataset):
10 def __init__(self, data_folder, sr=16000, dimension=8192):
11 self.data_folder = data_folder
12 self.sr = sr
13 self.dim = dimension
14
15 # 获取音频名列表
16 self.wav_list = []
17 for root, dirnames, filenames in os.walk(data_folder):
18 for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
19 self.wav_list.append(os.path.join(root, filename))
20
21 def __getitem__(self, item):
22 # 读取一个音频文件,返回每个音频数据
23 filename = self.wav_list[item]
24 print(filename)
25 wb_wav, _ = librosa.load(filename, sr=self.sr)
26
27 # 取 帧
28 if len(wb_wav) >= self.dim:
29 max_audio_start = len(wb_wav) - self.dim
30 audio_start = np.random.randint(0, max_audio_start)
31 wb_wav = wb_wav[audio_start: audio_start + self.dim]
32 else:
33 wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")
34
35 return wb_wav, filename
36
37 def __len__(self):
38 # 音频文件的总数
39 return len(self.wav_list)
40
41
42 train_set = Aduio_DataLoader("./p225", sr=16000)
43 train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
44
45
46 for (i, data) in enumerate(train_loader):
47 wav_data, wav_name = data
48 print(wav_data.shape) # torch.Size([8, 8192])
49 print(i, wav_name)
50 # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
51 # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')

注意事项

  1. 27-33行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
  2. 48行:我们在__getitem__中并没有将numpy数组转换为tensor格式,可是第48行显示数据是tensor格式的。这里需要引起注意

栗子2

  相比于案例1,案例二才是重点,因为我们不可能每次只从一音频文件中读取一帧,然后读取另一个音频文件,通常情况下,一段音频有很多帧,我们需要的是按顺序的读取一个batch_size的音频帧,先读取第一个音频文件,如果满足一个batch,则不用读取第二个batch,如果不足一个batch则读取第二个音频文件,来补充。

  我给出一个建议,先按顺序读取每个音频文件,以窗长8192、帧移4096对语音进行分帧,然后拼接。得到(帧数,帧长,1)(frame_num, frame_len, 1)的数组保存到h5中。然后用上面讲到的 torch.utils.data.Dataset 和 torch.utils.data.DataLoader 读取数据。

具体实现代码:

  第一步:创建一个H5_generation脚本用来将数据转换为h5格式文件:

  第二步:通过Dataset从h5格式文件中读取数据

import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import h5py def load_h5(h5_path):
# load training data
with h5py.File(h5_path, 'r') as hf:
print('List of arrays in input file:', hf.keys())
X = np.array(hf.get('data'), dtype=np.float32)
Y = np.array(hf.get('label'), dtype=np.float32)
return X, Y class AudioDataset(Dataset):
"""数据加载器"""
def __init__(self, data_folder):
self.data_folder = data_folder
self.X, self.Y = load_h5(data_folder) # (3392, 8192, 1) def __getitem__(self, item):
# 返回一个音频数据
X = self.X[item]
Y = self.Y[item] return X, Y def __len__(self):
return len(self.X) train_set = AudioDataset("./speaker225_resample_train.h5")
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True) for (i, wav_data) in enumerate(train_loader):
X, Y = wav_data
print(i, X.shape)
# 0 torch.Size([64, 8192, 1])
# 1 torch.Size([64, 8192, 1])
# ...

我尝试在__init__中生成h5文件,但是会导致内存爆炸,就很奇怪,因此我只好分开了,

参考

pytorch学习(四)—自定义数据集(讲的比较详细)

pytorch加载语音类自定义数据集的更多相关文章

  1. pytorch 加载mnist数据集报错not gzip file

    利用pytorch加载mnist数据集的代码如下 import torchvision import torchvision.transforms as transforms from torch.u ...

  2. 解决Eclipse中“诡异”的错误:找不到或无法加载主类

    记录下来遇到的(问题,解决方法),是更有效的解决问题的方式.(原谅我领悟的太晚与懒,从此用更有意义的方法,做一个更有意义的人) 因为遇到了多次,参考同一个方法,原文连接:https://blog.cs ...

  3. JVM如何加载一个类的过程,双亲委派模型中有哪些方法

    1.类加载过程:加载.验证.准备.解析.初始化   加载   在加载阶段,虚拟机主要完成三件事: 1.通过一个类的全限定名来获取定义此类的二进制字节流. 2.将这个字节流所代表的静态存储结构转化为方法 ...

  4. 找不到或无法加载主类 ide 正常执行,但是打包jar后报错 maven 引入本地包

    错误: 找不到或无法加载主类 com.myali.TTSmy 问题原因: ide中编译能找到相关包,但是,打包成jar时,本地的jar引入失败 maven将系统用到的包从线上maven仓库下载到本地的 ...

  5. javac 不是内部或外部命令 和 错误 找不到或无法加载主类 的解决方法

    使用package语句与import语句. 实验要求:按实验要求使用package语句,并用import语句使用Java平台提供的包中的类以及自定义包中的类.掌握一些重要的操作步骤. 代码: 模板1: ...

  6. 使用Huggingface在矩池云快速加载预训练模型和数据集

    作为NLP领域的著名框架,Huggingface(HF)为社区提供了众多好用的预训练模型和数据集.本文介绍了如何在矩池云使用Huggingface快速加载预训练模型和数据集. 1.环境 HF支持Pyt ...

  7. eclipse 下找不到或无法加载主类的解决办法

    有时候 Eclipse 会发神经,好端端的 project 就这么编译不了了,连 Hello World 都会报“找不到或无法加载主类”的错误,我已经遇到好几次了,以前是懒得深究就直接重建projec ...

  8. java HelloWorld 提示“错误: 找不到或无法加载主类 HelloWorld“解决方案

    在检查环境变量等前提工作准确无误后,注意要配好CLASSPATH,仍然报“错误: 找不到或无法加载主类 HelloWorld“. 本人工程目录:mygs-maven/src/main/java/hel ...

  9. maven project中,在main方法上右键Run as Java Application时,提示错误:找不到或无法加载主类XXX.XXXX.XXX

    新建了一个maven project项目,经过一大堆的修改操作之后,突然发现在main方法上右键运行时,竟然提示:错误:找不到或无法加载主类xxx.xxx.xxx可能原因1.eclipse出问题了,在 ...

随机推荐

  1. 087 01 Android 零基础入门 02 Java面向对象 02 Java封装 01 封装的实现 01 封装的概念和特点

    087 01 Android 零基础入门 02 Java面向对象 02 Java封装 01 封装的实现 01 封装的概念和特点 本文知识点:封装的概念和特点 说明:因为时间紧张,本人写博客过程中只是对 ...

  2. 010 01 Android 零基础入门 01 Java基础语法 02 Java常量与变量 04 变量的三个元素的详细介绍之二——变量类型——即Java中的数据类型

    010 01 Android 零基础入门 01 Java基础语法 02 Java常量与变量 04 变量的三个元素的详细介绍之二--变量类型--即Java中的数据类型 Java中变量的三要素 变量名 变 ...

  3. OpenCV中threshold函数的使用

    转自:https://blog.csdn.net/u012566751/article/details/77046445 一篇很好的介绍threshold文章: 图像的二值化就是将图像上的像素点的灰度 ...

  4. Linux就该这么学28期——Day02 2.1-2.3

    本文记录必须掌握的Linux命令,部分内容引用自https://www.linuxprobe.com/basic-learning-02.html 工作中可使用https://www.linuxcoo ...

  5. ubuntu 18.04 搭建flask服务器(大合集,个人实操)

    ubuntu 18.04 搭建flask服务器(大合集) Ubuntu python flask 服务器 本次使用的Ubuntu版本为:Ubuntu 18.04.5 LTS (GNU/Linux 4. ...

  6. JS-YAML -YAML 1.2 JavaScript解析器/编写器

    下载 JS-YAML -YAML 1.2 JavaScript解析器/编写器JS-YAML -YAML 1.2 JavaScript解析器/编写器 在线演示 这是YAML的实现,YAML是一种对人友好 ...

  7. ASP。使用依赖注入的asp.net Core 2.0用户角色库动态菜单管理

    下载source code - 2.2 MB 介绍 在开始这篇文章之前,请阅读我的前一篇文章: 开始使用ASP.NET Core 2.0身份和角色管理 在上一篇文章中,我们详细讨论了如何使用ASP.N ...

  8. Oracle - ascii为0的陷阱

    一.概述 ascii0是个空字符,如果将这个字符插入到oracle数据库中会是什么现象,是null吗? 二.正式实验 创建一张测试表 create table test(id int, name va ...

  9. JavaScript常用对象介绍

    目录 对象(object) 对象的创建方式 点语法 括号表示法 内置对象 Array 数组创建方式 检测数组 转换方法 分割字符串 栈方法 队列方法 重排序方法 操作方法 位置方法 迭代方法 Stri ...

  10. python中的对文件的读写

    简单的实例 open函数获取文件,w是写权限,可以对文件进行io操作 file=open('C:/Users/Administrator/Desktop/yes.txt','w') file.writ ...