一、使用urllib下载cifar-10数据集,并读取再存为图片(TensorFlow v1.14.0)

 # -*- coding:utf-8 -*-
__author__ = 'Leo.Z' import sys
import os # 给定url下载文件
def download_from_url(url, dir=''):
_file_name = url.split('/')[-1]
_file_path = os.path.join(dir, _file_name) # 打印下载进度
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' %
(_file_name, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush() # 如果不存在dir,则创建文件夹
if not os.path.exists(dir):
print("Dir is not exsit,Create it..")
os.makedirs(dir) if not os.path.exists(_file_path):
print("Start downloading..")
# 开始下载文件
import urllib
urllib.request.urlretrieve(url, _file_path, _progress)
else:
print("File already exists..") return _file_path # 使用tarfile解压缩
def extract(filepath, dest_dir):
if os.path.exists(filepath) and not os.path.exists(dest_dir):
import tarfile
tarfile.open(filepath, 'r:gz').extractall(dest_dir) if __name__ == '__main__':
FILE_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
FILE_DIR = 'cifar10_dir/' loaded_file_path = download_from_url(FILE_URL, FILE_DIR)
extract(loaded_file_path)

 按BATCH_SIZE读取二进制文件中的图片数据,并存放为jpg:

# -*- coding:utf-8 -*-
__author__ = 'Leo.Z' # Tensorflow Version:1.14.0 import os import tensorflow as tf
from PIL import Image BATCH_SIZE = 128 def read_cifar10(filenames):
label_bytes = 1
height = 32
width = 32
depth = 3
image_bytes = height * width * depth record_bytes = label_bytes + image_bytes # lamda函数体
# def load_transform(x):
# # Convert these examples to dense labels and processed images.
# per_record = tf.reshape(tf.decode_raw(x, tf.uint8), [record_bytes])
# return per_record # tf v1.14.0版本的FixedLengthRecordDataset(filename_list,bin_data_len)
datasets = tf.data.FixedLengthRecordDataset(filenames=filenames, record_bytes=record_bytes)
# 是否打乱数据
# datasets.shuffle()
# 重复几轮epoches
datasets = datasets.shuffle(buffer_size=BATCH_SIZE).repeat(2).batch(BATCH_SIZE) # 使用map,也可使用lamda(注意,后面使用迭代器的时候这里转换为uint8没用,后面还得转一次,否则会报错)
# datasets.map(load_transform)
# datasets.map(lamda x : tf.reshape(tf.decode_raw(x, tf.uint8), [record_bytes])) # 创建一起迭代器tf v1.14.0
iter = tf.compat.v1.data.make_one_shot_iterator(datasets)
# 获取下一条数据(label+image的二进制数据1+32*32*3长度的bytes)
rec = iter.get_next()
# 这里转uint8才生效,在map中转貌似有问题?
rec = tf.decode_raw(rec, tf.uint8) label = tf.cast(tf.slice(rec, [0, 0], [BATCH_SIZE, label_bytes]), tf.int32) # 从第二个字节开始获取图片二进制数据大小为32*32*3
depth_major = tf.reshape(
tf.slice(rec, [0, label_bytes], [BATCH_SIZE, image_bytes]),
[BATCH_SIZE, depth, height, width])
# 将维度变换顺序,变为[H,W,C]
image = tf.transpose(depth_major, [0, 2, 3, 1]) # 返回获取到的label和image组成的元组
return (label, image) def get_data_from_files(data_dir):
# filenames一共5个,从data_batch_1.bin到data_batch_5.bin
# 读入的都是训练图像
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in range(1, 6)]
# 判断文件是否存在
for f in filenames:
if not tf.io.gfile.exists(f):
raise ValueError('Failed to find file: ' + f) # 获取一张图片数据的数据,格式为(label,image)
data_tuple = read_cifar10(filenames)
return data_tuple if __name__ == "__main__": # 获取label和type的对应关系
label_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
name_list = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
label_map = dict(zip(label_list, name_list)) with tf.compat.v1.Session() as sess:
batch_data = get_data_from_files('cifar10_dir/cifar-10-batches-bin')
# 在之前的旧版本中,因为使用了filename_queue,所以要使用start_queue_runners进行数据填充
# 1.14.0由于没有使用filename_queue所以不需要
# threads = tf.train.start_queue_runners(sess=sess) sess.run(tf.compat.v1.global_variables_initializer())
# 创建一个文件夹用于存放图片
if not os.path.exists('cifar10_dir/raw'):
os.mkdir('cifar10_dir/raw') # 存放30张,以index-typename.jpg命名,例如1-frog.jpg
for i in range(30):
# 获取一个batch的数据,BATCH_SIZE
# batch_data中包含一个batch的image和label
batch_data_tuple = sess.run(batch_data)
# 打印(128, 1)
print(batch_data_tuple[0].shape)
# 打印(128, 32, 32, 3)
print(batch_data_tuple[1].shape) # 每个batch存放第一张图片作为实验
Image.fromarray(batch_data_tuple[1][0]).save("cifar10_dir/raw/{index}-{type}.jpg".format(
index=i, type=label_map[batch_data_tuple[0][0][0]]))

简要代码流程图:

[深度学习] 各种下载深度学习数据集方法(In python)的更多相关文章

  1. ReLeQ:一种自动强化学习的神经网络深度量化方法

    ReLeQ:一种自动强化学习的神经网络深度量化方法     ReLeQ:一种自动强化学习的神经网络深度量化方法ReLeQ: An Automatic Reinforcement Learning Ap ...

  2. 腾讯优图&港科大提出一种基于深度学习的非光流 HDR 成像方法

    目前最好的高动态范围(HDR)成像方法通常是先利用光流将输入图像对齐,随后再合成 HDR 图像.然而由于输入图像存在遮挡和较大运动,这种方法生成的图像仍然有很多缺陷.最近,腾讯优图和香港科技大学的研究 ...

  3. 【深度学习系列】迁移学习Transfer Learning

    在前面的文章中,我们通常是拿到一个任务,譬如图像分类.识别等,搜集好数据后就开始直接用模型进行训练,但是现实情况中,由于设备的局限性.时间的紧迫性等导致我们无法从头开始训练,迭代一两百万次来收敛模型, ...

  4. OpenGL学习脚印:深度測试(depth testing)

    写在前面 上一节我们使用AssImp载入了3d模型,效果已经令人激动了.可是绘制效率和场景真实感还存在不足,接下来我们还是要保持耐心,继续学习一些高级主题,等学完后面的高级主题,我们再次来改进我们载入 ...

  5. OpenCV 学习笔记 04 深度估计与分割——GrabCut算法与分水岭算法

    1 使用普通摄像头进行深度估计 1.1 深度估计原理 这里会用到几何学中的极几何(Epipolar Geometry),它属于立体视觉(stereo vision)几何学,立体视觉是计算机视觉的一个分 ...

  6. (zhuan) 深度学习全网最全学习资料汇总之模型介绍篇

    This blog from : http://weibo.com/ttarticle/p/show?id=2309351000224077630868614681&u=5070353058& ...

  7. 深度强化学习day01初探强化学习

    深度强化学习 基本概念 强化学习 强化学习(Reinforcement Learning)是机器学习的一个重要的分支,主要用来解决连续决策的问题.强化学习可以在复杂的.不确定的环境中学习如何实现我们设 ...

  8. 小菜学习设计模式(三)—工厂方法(Factory Method)模式

    前言 设计模式目录: 小菜学习设计模式(一)—模板方法(Template)模式 小菜学习设计模式(二)—单例(Singleton)模式 小菜学习设计模式(三)—工厂方法(Factory Method) ...

  9. VC++/MFC(VC6)开发技术精品学习资料下载汇总

    工欲善其事,必先利其器,VC开发MFC Windows程序,Visual C++或Visual Studio是必须的,恩,这里都给你总结好了,拿去吧:VC/MFC开发必备Visual C++.Visu ...

随机推荐

  1. 菜鸟系列Fabric——Fabric 基本概念(1)

    Fabric 基本概念 1.区块链介绍 区块链之所以引来关注是因为比特币开源项目,尤其是比特币价值的飙升,让大家开始关注数字货币以及相关技术.那么区块链究竟是什么? 1.1 区块链定义 狭义上,区块链 ...

  2. JavaScript中好用的对象数组去重

    对象数组去重 Demo数据如下: var items= [{ "specItems": [{ "id": "966480614728069122&qu ...

  3. C++练习 | 铁轨问题

    #include <iostream> #include <cmath> #include <cstring> #include <string> #i ...

  4. python_0基础开始_day04

    第四节 一.列表 list 数据类型之一,存储大量的,不同类型的数据 列表中只要用逗号隔开的就是一个元素 有序可变的. 1.1列表的索引 列表和字符串一样也拥有索引,但是列表可以修改: lst = [ ...

  5. SpringBoot中使用Websocket进行消息推送

    WebsocketConfig.java @Configuration public class WebSocketConfig { @Bean public ServerEndpointExport ...

  6. java限流工具类

    代码 import com.google.common.util.concurrent.RateLimiter; import java.util.concurrent.ConcurrentHashM ...

  7. Zabbix 系统概述与部署

    Zabbix是一个非常强大的监控系统,是企业级的软件,来监控IT基础设施的可用性和性能.它是一个能够快速搭建起来的开源的监控系统,Zabbix能监视各种网络参数,保证服务器系统的安全运营,并提供灵活的 ...

  8. 模块之re模块 正则表达式

    正则表达式,正则表达式在处理字符串上有先天的优势,尤其大数量的字符串.先来记一个网站,此网站功能就是关于正则表达式方面的应用http://tool.chinaz.com/regex/ 单纯的正则表达式 ...

  9. ftp服务器上传下载共享文件

    1 windows下搭建ftp服务器 https://blog.csdn.net/qq_34610293/article/details/79210539 搭建好之后浏览器输入 ftp://ip就可以 ...

  10. 108、如何使用 Secret? (Swarm15)

    参考https://www.cnblogs.com/CloudMan6/p/8068057.html   我们经常要想容器传递敏感信息,最常见的就是密码.比如:   docker run -e MYS ...