前言

这篇博客以PTB数据集为例,详细讲解了如何将txt格式的数据集文件,转换为pytorch框架可以直接处理的tensor变量,并附上相应代码

@


1. PTB 数据集

PTB数据集含有三个txt文件,分别作为训练集(train),验证集(valid)和测试集(test);这三个txt文件分别包含42000,3000和3000句英文;

我们要将其转化为pytorch可处理的tensor类型数据集,需要以下几步:

  • 依次读取每一行的训练集文件(train.txt),为每一个读到的单词分配序号,构建词汇表

    • 出现频率低于min_ooc(通常默认为3)次的词汇,单词一率变为未知单词 < unk > ,分配序号1
    • < sos >为每句话的起始信号,分配序号2
    • < eos >为每句话的结束信号,分配序号3
    • 由于每句话长度不一样,而pytorch批处理数据,需要统一句子长度,因此长度较短的句子用 < pad > 填充,分配序号0
  • 统一句子长度为max_sentence_length(默认50)
    • 高于50个单词的句子,只保留前50个单词;
    • 低于50个单词的句子,用 < pad > 信号填充到50
  • 根据训练集构建的词汇表,将训练集,验证集和测试集都变成数字序号表示的句子,如 a cat not is dog变成 2 25 54 12 0 0
  • 构建三个数据集加载转换后的用数字序号表示的句子,并将其错位句子作为该句子的标签(target),例如, a cat not is dog变成 2 25 54 12 0 0, 那它对应的target就是 25 54 12 3 0 0 了 (2 3 分别为起始,末尾信号)
  • 将其转换为批处理的tensor变量

这样我们就能得到pytorch可以直接加载处理的tensor类型数据集了

2. 构建词汇表

我们先定义一个字典类型的变量,该字典类型变量,会将输入的句子里的新单词添加到字典中,并记录该单词的出现次数

引入库文件:

  1. import json
  2. import torch
  3. import numpy as np
  4. from nltk.tokenize import TweetTokenizer
  5. from collections import Counter, OrderedDict,defaultdict
  6. import io
  7. import os
  8. from torch.utils.data import DataLoader
  1. class OrderedCounter(Counter, OrderedDict): #这样定义的字典类型变量,会将输入的句子里的新单词添加到字典中,并记录该单词的出现次数
  2. """Counter that remembers the order elements are first encountered"""
  3. def __repr__(self):
  4. return '%s(%r)' % (self.__class__.__name__, OrderedDict(self))
  5. def __reduce__(self):
  6. return self.__class__, (OrderedDict(self),)

依次读入ptb.train.txt的每一句话,并对其进行分词,不区分大小写;

  • 分词:由于默认可以通过空格来分开每个单词,但专业的分词函数更好些
  1. def create_vocab(split):
  2. tokenizer = TweetTokenizer(preserve_case=False) #分词,不区分大小写
  3. w2c = OrderedCounter()
  4. w2i = dict()
  5. i2w = dict()
  6. special_tokens = ['<pad>', '<unk>', '<sos>', '<eos>']
  7. for st in special_tokens:
  8. i2w[len(w2i)] = st
  9. w2i[st] = len(w2i)
  10. with open(split, 'r') as file:
  11. for i, line in enumerate(file):
  12. words = tokenizer.tokenize(line)
  13. w2c.update(words) #这段程序将文件中出现过的所有单词加载到字典类型变量w2c中,并存储了他们出现的次数
  14. for w, c in w2c.items():
  15. if c > 3 and w not in special_tokens: #依次为出现次数大于3,且不是那4种特殊信号的单词分配序号
  16. i2w[len(w2i)] = w
  17. w2i[w] = len(w2i) #w2i格式为'cat':50这种,i2w为50:'cat'这种
  18. return w2i,i2w

实例化一下试试:

3. 将训练集,验证集和测试集根据词汇表转换为数字序号,并转换为tensor

  1. def create_data(split,w2i): #split为待转换的txt文件地址
  2. tokenizer = TweetTokenizer(preserve_case=False) #分词,不区分大小写
  3. data = defaultdict(dict)
  4. with open(split, 'r') as file: #读取该文件的每一行
  5. for i, line in enumerate(file):
  6. words = tokenizer.tokenize(line) #分词
  7. input = ['<sos>'] + words #输入的开头增加<sos>信号
  8. input = input[:50] #只保留前50个(起始信号<sos> + 文本的前49个单词)
  9. target = words[:50-1] #输入对应的target,也只保留50个(取文本的前49个单词+ 结束信号<eos>)
  10. target = target + ['<eos>']
  11. length = len(input)
  12. input.extend(['<pad>'] * (50-length)) #输入和target,不足50个的,用<pad>补足50个
  13. target.extend(['<pad>'] * (50-length))
  14. input = [w2i.get(w, w2i['<unk>']) for w in input]
  15. target = [w2i.get(w, w2i['<unk>']) for w in target]
  16. id = len(data) #id表示该数据集的第id句话
  17. inpu_t = torch.from_numpy(np.asarray(input)) #转换为tensor形式
  18. targe_t = torch.from_numpy(np.asarray(target))
  19. data[id]['input'] = inpu_t
  20. data[id]['target'] = targe_t
  21. data[id]['length'] = length
  22. return data

实例化一下试试:

3. 转换为批处理的tensor变量

  1. data_loader = DataLoader(
  2. dataset= data,
  3. batch_size= 64,#批处理大小
  4. shuffle=True #是否打乱排序
  5. )

实例化试试:

总结

这篇博客以PTB数据集为例,介绍了如何将txt形式的数据集转换为pytorch框架中可以使用的,批处理的tensor形式

参考项目:github上以PTB数据集训练的一个语言模型的项目

Pytorch加载txt格式的数据集文件(以PTB数据集为例)的更多相关文章

  1. Away3D 学习笔记(一): 加载3DS格式的模型文件

    加载外部的3DS文件分为两种: 1: 模型与贴图独立于程序的,也就是从外部的文件夹中读取 private function load3DSFile():Loader3D { loader = new ...

  2. pytorch 加载mnist数据集报错not gzip file

    利用pytorch加载mnist数据集的代码如下 import torchvision import torchvision.transforms as transforms from torch.u ...

  3. 神坑 Resources.Load 不能实时加载TXT文件

    Resources.Load(fileName) as TextAsset; 这句话并不能实时加载文本文件,对文本文件进行修改之后,若是没有刷新的话,加载的还是之前的文件: 要实时读取文本文件还是要以 ...

  4. hive 压缩全解读(hive表存储格式以及外部表直接加载压缩格式数据);HADOOP存储数据压缩方案对比(LZO,gz,ORC)

    数据做压缩和解压缩会增加CPU的开销,但可以最大程度的减少文件所需的磁盘空间和网络I/O的开销,所以最好对那些I/O密集型的作业使用数据压缩,cpu密集型,使用压缩反而会降低性能. 而hive中间结果 ...

  5. 为不同分辨率单独做样式文件,在页面头部用js判断分辨率后动态加载定义好的样式文件

    为不同分辨率单独做样式文件,在页面头部用js判断分辨率后动态加载定义好的样式文件.样式文件命名格式如:forms[_屏幕宽度].css,样式文件中只需重新定义文本框和下拉框的宽度即可. 在包含的头文件 ...

  6. 使用getJSON()方法异步加载JSON格式数据

    使用getJSON()方法异步加载JSON格式数据 使用getJSON()方法可以通过Ajax异步请求的方式,获取服务器中的数组,并对获取的数据进行解析,显示在页面中,它的调用格式为: jQuery. ...

  7. cesium模型加载-加载fbx格式模型

    整体思路: fbx格式→dae格式→gltf格式→cesium加载gltf格式模型 具体方法: 1. fbx格式→dae格式 工具:3dsMax, 3dsMax插件:OpenCOLLADA, 下载地址 ...

  8. Lab_1:练习4——分析bootloader加载ELF格式的OS的过程

    一.实验内容 通过阅读bootmain.c,了解bootloader如何加载ELF文件.通过分析源代码和通过qemu来运行并调试bootloader&OS, bootloader如何读取硬盘扇 ...

  9. Lab1:练习四——分析bootloader加载ELF格式的OS的过程

    练习四:分析bootloader加载ELF格式的OS的过程. 1.题目要求 通过阅读bootmain.c,了解bootloader如何加载ELF文件.通过分析源代码和通过qemu来运行并调试bootl ...

  10. 如何实现通过Leaflet加载dwg格式的CAD图

    前言 ​ 在前面介绍了通过openlayers加载dwg格式的CAD图并与互联网地图叠加,openlayers功能很全面,但同时也很庞大,入门比较难,适合于大中型项目中.而在中小型项目中,一般用开源的 ...

随机推荐

  1. 《Terraform 101 从入门到实践》 Functions函数

    <Terraform 101 从入门到实践>这本小册在南瓜慢说官方网站和GitHub两个地方同步更新,书中的示例代码也是放在GitHub上,方便大家参考查看. Terraform的函数 T ...

  2. 解决org.apache.ibatis.binding.BindingException: Invalid bound statement (not found)问题

    解决org.apache.ibatis.binding.BindingException: Invalid bound statement (not found)问题 需要检查的步骤: 1.是否map ...

  3. mysql-01数据库基本简介

    1.数据库的概念 DB:数据库(database):存储数据的"仓库".它保存了一系列有组织的数据. DBMS:数据库管理系统(Database Management System ...

  4. Python接口自动化测试(1)

    接口自动化测试三部曲:1.构造请求  2.判断结果  3.数据库查询 1.Python的第三方包:requests 简介:requests可以用来做接口测试.接口自动化测试.爬虫等 requests的 ...

  5. 五大数据类型 - 字符串 - 列表 list - 集合set - 有序集合 - 哈希 hashMap

    基础知识 redis默认有16个数据库:默认使用的是第0个. 可以使用select num切换 查看DB大小 DBSIZE 查看所有的key **keys ** 清空当前数据库 flushdb 清空全 ...

  6. Cobaltstrike —— shellcode分析(一)

    前言 搞iot搞久了,换个方向看看,改改口味. windows 常见结构体 在分析Cobaltstrike-shellcode之前我们得先了解一下windows下一些常见的结构体. X86 Threa ...

  7. mssql 常用sql 语句

    ----insert ----delete----update----select ----选择数据库进行操作select top 1 * from smzx2018.dbo.tbuseruse sm ...

  8. Socket.io + Knex 实现私聊聊天室

    前言 本文只介绍实现的核心代码,目的是记录和分享知识.若感兴趣可以往下看,在文章最后贴上了仓库地址.前端使用 Vite + Vue3:后端使用 Knex + Express. Room 的概念 私密 ...

  9. PostGIS之空间连接

    1. 概述 PostGIS 是PostgreSQL数据库一个空间数据库扩展,它添加了对地理对象的支持,允许在 SQL 中运行空间查询 PostGIS官网:About PostGIS | PostGIS ...

  10. LeetCode-23 合并K个升序链表

    来源:力扣(LeetCode)链接:https://leetcode-cn.com/problems/merge-k-sorted-lists 题目描述 给你一个链表数组,每个链表都已经按升序排列. ...