pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合

  • torch.utils.data.Dataset:所有继承他的子类都应该重写  __len()__  , __getitem()__ 这两个方法

    • __len()__ :返回数据集中数据的数量
    • __getitem()__ :返回支持下标索引方式获取的一个数据
  • torch.utils.data.DataLoader:对数据集进行包装,可以设置batch_size、是否shuffle....

第一步

  自定义的 Dataset 都需要继承 torch.utils.data.Dataset 类,并且重写它的两个成员方法:

  • __len()__:读取数据,返回数据和标签
  • __getitem()__:返回数据集的长度
from torch.utils.data import Dataset

class AudioDataset(Dataset):
def __init__(self, ...):
"""类的初始化"""
pass def __getitem__(self, item):
"""每次怎么读数据,返回数据和标签"""
return data, label def __len__(self):
"""返回整个数据集的长度"""
return total

注意事项:Dataset只负责数据的抽象,一次调用getiitem只返回一个样本

案例:

  文件目录结构

  • p225

    • ***.wav
    • ***.wav
    • ***.wav
    • ...
  • dataset.py

目的:读取p225文件夹中的音频数据

 1 class AudioDataset(Dataset):
2 def __init__(self, data_folder, sr=16000, dimension=8192):
3 self.data_folder = data_folder
4 self.sr = sr
5 self.dim = dimension
6
7 # 获取音频名列表
8 self.wav_list = []
9 for root, dirnames, filenames in os.walk(data_folder):
10 for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
11 self.wav_list.append(os.path.join(root, filename))
12
13 def __getitem__(self, item):
14 # 读取一个音频文件,返回每个音频数据
15 filename = self.wav_list[item]
16 wb_wav, _ = librosa.load(filename, sr=self.sr)
17
18 # 取 帧
19 if len(wb_wav) >= self.dim:
20 max_audio_start = len(wb_wav) - self.dim
21 audio_start = np.random.randint(0, max_audio_start)
22 wb_wav = wb_wav[audio_start: audio_start + self.dim]
23 else:
24 wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")
25
26 return wb_wav, filename
27
28 def __len__(self):
29 # 音频文件的总数
30 return len(self.wav_list)

注意事项:19-24行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,

第二步

  实例化 Dataset 对象

Dataset= AudioDataset("./p225", sr=16000)

如果要通过batch读取数据的可直接跳到第三步,如果你想一个一个读取数据的可以看我接下来的操作

# 实例化AudioDataset对象
train_set = AudioDataset("./p225", sr=16000) for i, data in enumerate(train_set):
wb_wav, filname = data
print(i, wb_wav.shape, filname) if i == 3:
break
# 0 (8192,) ./p225\p225_001.wav
# 1 (8192,) ./p225\p225_002.wav
# 2 (8192,) ./p225\p225_003.wav
# 3 (8192,) ./p225\p225_004.wav

第三步

  如果想要通过batch读取数据,需要使用DataLoader进行包装

为何要使用DataLoader?

  1. 深度学习的输入是mini_batch形式
  2. 样本加载时候可能需要随机打乱顺序,shuffle操作
  3. 样本加载需要采用多线程

  pytorch提供的 DataLoader 封装了上述的功能,这样使用起来更方便。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)

参数

  • dataset:加载的数据集(Dataset对象)
  • batch_size:每个批次要加载多少个样本(默认值:1)
  • shuffle:每个epoch是否将数据打乱
  • sampler:定义从数据集中抽取样本的策略。如果指定,则不能指定洗牌。
  • batch_sampler:类似于sampler,但每次返回一批索引。与batch_size、shuffle、sampler和drop_last相互排斥。
  • num_workers:使用多进程加载的进程数,0代表不使用多线程
  • collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认拼接方式
  • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
  • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

返回:数据加载器

案例:

# 实例化AudioDataset对象
train_set = AudioDataset("./p225", sr=16000)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True) for (i, data) in enumerate(train_loader):
wav_data, wav_name = data
print(wav_data.shape) # torch.Size([8, 8192])
print(i, wav_name)
# ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
# './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')

我们来吃几个栗子消化一下:

栗子1

  这个例子就是本文一直举例的,栗子1只是合并了一下而已

  文件目录结构

  • p225

    • ***.wav
    • ***.wav
    • ***.wav
    • ...
  • dataset.py

目的:读取p225文件夹中的音频数据

 1 import fnmatch
2 import os
3 import librosa
4 import numpy as np
5 from torch.utils.data import Dataset
6 from torch.utils.data import DataLoader
7
8
9 class Aduio_DataLoader(Dataset):
10 def __init__(self, data_folder, sr=16000, dimension=8192):
11 self.data_folder = data_folder
12 self.sr = sr
13 self.dim = dimension
14
15 # 获取音频名列表
16 self.wav_list = []
17 for root, dirnames, filenames in os.walk(data_folder):
18 for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
19 self.wav_list.append(os.path.join(root, filename))
20
21 def __getitem__(self, item):
22 # 读取一个音频文件,返回每个音频数据
23 filename = self.wav_list[item]
24 print(filename)
25 wb_wav, _ = librosa.load(filename, sr=self.sr)
26
27 # 取 帧
28 if len(wb_wav) >= self.dim:
29 max_audio_start = len(wb_wav) - self.dim
30 audio_start = np.random.randint(0, max_audio_start)
31 wb_wav = wb_wav[audio_start: audio_start + self.dim]
32 else:
33 wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")
34
35 return wb_wav, filename
36
37 def __len__(self):
38 # 音频文件的总数
39 return len(self.wav_list)
40
41
42 train_set = Aduio_DataLoader("./p225", sr=16000)
43 train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
44
45
46 for (i, data) in enumerate(train_loader):
47 wav_data, wav_name = data
48 print(wav_data.shape) # torch.Size([8, 8192])
49 print(i, wav_name)
50 # ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
51 # './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')

注意事项

  1. 27-33行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
  2. 48行:我们在__getitem__中并没有将numpy数组转换为tensor格式,可是第48行显示数据是tensor格式的。这里需要引起注意

栗子2

  相比于案例1,案例二才是重点,因为我们不可能每次只从一音频文件中读取一帧,然后读取另一个音频文件,通常情况下,一段音频有很多帧,我们需要的是按顺序的读取一个batch_size的音频帧,先读取第一个音频文件,如果满足一个batch,则不用读取第二个batch,如果不足一个batch则读取第二个音频文件,来补充。

  我给出一个建议,先按顺序读取每个音频文件,以窗长8192、帧移4096对语音进行分帧,然后拼接。得到(帧数,帧长,1)(frame_num, frame_len, 1)的数组保存到h5中。然后用上面讲到的 torch.utils.data.Dataset 和 torch.utils.data.DataLoader 读取数据。

具体实现代码:

  第一步:创建一个H5_generation脚本用来将数据转换为h5格式文件:

  第二步:通过Dataset从h5格式文件中读取数据

import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import h5py def load_h5(h5_path):
# load training data
with h5py.File(h5_path, 'r') as hf:
print('List of arrays in input file:', hf.keys())
X = np.array(hf.get('data'), dtype=np.float32)
Y = np.array(hf.get('label'), dtype=np.float32)
return X, Y class AudioDataset(Dataset):
"""数据加载器"""
def __init__(self, data_folder):
self.data_folder = data_folder
self.X, self.Y = load_h5(data_folder) # (3392, 8192, 1) def __getitem__(self, item):
# 返回一个音频数据
X = self.X[item]
Y = self.Y[item] return X, Y def __len__(self):
return len(self.X) train_set = AudioDataset("./speaker225_resample_train.h5")
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True) for (i, wav_data) in enumerate(train_loader):
X, Y = wav_data
print(i, X.shape)
# 0 torch.Size([64, 8192, 1])
# 1 torch.Size([64, 8192, 1])
# ...

我尝试在__init__中生成h5文件,但是会导致内存爆炸,就很奇怪,因此我只好分开了,

参考

pytorch学习(四)—自定义数据集(讲的比较详细)

pytorch加载语音类自定义数据集的更多相关文章

  1. pytorch 加载mnist数据集报错not gzip file

    利用pytorch加载mnist数据集的代码如下 import torchvision import torchvision.transforms as transforms from torch.u ...

  2. 解决Eclipse中“诡异”的错误:找不到或无法加载主类

    记录下来遇到的(问题,解决方法),是更有效的解决问题的方式.(原谅我领悟的太晚与懒,从此用更有意义的方法,做一个更有意义的人) 因为遇到了多次,参考同一个方法,原文连接:https://blog.cs ...

  3. JVM如何加载一个类的过程,双亲委派模型中有哪些方法

    1.类加载过程:加载.验证.准备.解析.初始化   加载   在加载阶段,虚拟机主要完成三件事: 1.通过一个类的全限定名来获取定义此类的二进制字节流. 2.将这个字节流所代表的静态存储结构转化为方法 ...

  4. 找不到或无法加载主类 ide 正常执行,但是打包jar后报错 maven 引入本地包

    错误: 找不到或无法加载主类 com.myali.TTSmy 问题原因: ide中编译能找到相关包,但是,打包成jar时,本地的jar引入失败 maven将系统用到的包从线上maven仓库下载到本地的 ...

  5. javac 不是内部或外部命令 和 错误 找不到或无法加载主类 的解决方法

    使用package语句与import语句. 实验要求:按实验要求使用package语句,并用import语句使用Java平台提供的包中的类以及自定义包中的类.掌握一些重要的操作步骤. 代码: 模板1: ...

  6. 使用Huggingface在矩池云快速加载预训练模型和数据集

    作为NLP领域的著名框架,Huggingface(HF)为社区提供了众多好用的预训练模型和数据集.本文介绍了如何在矩池云使用Huggingface快速加载预训练模型和数据集. 1.环境 HF支持Pyt ...

  7. eclipse 下找不到或无法加载主类的解决办法

    有时候 Eclipse 会发神经,好端端的 project 就这么编译不了了,连 Hello World 都会报“找不到或无法加载主类”的错误,我已经遇到好几次了,以前是懒得深究就直接重建projec ...

  8. java HelloWorld 提示“错误: 找不到或无法加载主类 HelloWorld“解决方案

    在检查环境变量等前提工作准确无误后,注意要配好CLASSPATH,仍然报“错误: 找不到或无法加载主类 HelloWorld“. 本人工程目录:mygs-maven/src/main/java/hel ...

  9. maven project中,在main方法上右键Run as Java Application时,提示错误:找不到或无法加载主类XXX.XXXX.XXX

    新建了一个maven project项目,经过一大堆的修改操作之后,突然发现在main方法上右键运行时,竟然提示:错误:找不到或无法加载主类xxx.xxx.xxx可能原因1.eclipse出问题了,在 ...

随机推荐

  1. 003 01 Android 零基础入门 01 Java基础语法 01 Java初识 03 Java程序的执行流程

    003 01 Android 零基础入门 01 Java基础语法 01 Java初识 03 Java程序的执行流程 Java程序长啥样? 首先编写一个Java程序 记事本编写程序 打开记事本 1.wi ...

  2. VS调试时查看动态数组的全部元素

    转载:https://blog.csdn.net/sinat_36219858/article/details/80720527

  3. Arduino 中 EEprom 写入读取清除

    转自:https://www.arduino.cn/thread-1157-1-1.html EEPROM (Electrically Erasable Programmable Read-Only ...

  4. JavaFX ComboBox的选中事项

    参考1:https://blog.csdn.net/mexel310/article/details/37909205 参考2:https://blog.csdn.net/maosijunzi/art ...

  5. 详解Class加载过程

    1.Class文件内容格式 2.一个class文件是被加载到内存的过程是怎样的? loading 把一个class文件装到内存里,class文件是一个二进制,一个个的字节 linking Verifi ...

  6. 2020年了,IT外企还香吗?

    本来是刚发了<世上有不用加班的程序员吗?>,有朋友问到IT外企不加班福利好什么的,就回复了几句. 老王观点: 现在IT外企已经不值得羡慕了,08.09年那会,ibm,惠普还是香饽饽,当时人 ...

  7. Java 集合看这一篇就够了

    大家好,这里是<齐姐聊数据结构>系列之大集合. 话不多说,直接上图: Java 集合,也称作容器,主要是由两大接口 (Interface) 派生出来的: Collection 和 Map ...

  8. go-zero 如何应对海量定时/延迟任务?

    一个系统中存在着大量的调度任务,同时调度任务存在时间的滞后性,而大量的调度任务如果每一个都使用自己的调度器来管理任务的生命周期的话,浪费cpu的资源而且很低效. 本文来介绍 go-zero 中 延迟操 ...

  9. C语言实现表达式求值,支持+、-、*、/四则运算,并且支持多级括号,自定义了栈的操作。

    以下是代码的实现使用gcc已经成功运行了,下面是效果图 #include <stdio.h> #include <stdlib.h> #define OPT_ADD 43 /* ...

  10. python面试题-python相关

    1. __new__.__init__区别,如何实现单例模式,有什么优点 __new__是一个静态方法,__init__是一个实例方法 __new__返回一个创建的实例,__init__什么都不返回 ...