Tensorflow读取CIFAR-10数据集

觉得有用的话,欢迎一起讨论相互学习~

参考文献

Tensorflow官方文档

tf.transpose函数解析

tf.slice函数解析

CIFAR10/CIFAR100数据集介绍

tf.train.shuffle_batch函数解析

Python urllib urlretrieve函数解析

import os
import tarfile
import tensorflow as tf
from six.moves import urllib
from tensorflow.python.framework import ops ops.reset_default_graph() # 更改工作目录
abspath = os.path.abspath(__file__) # 获取当前文件绝对地址
# E:\GitHub\TF_Cookbook\08_Convolutional_Neural_Networks\03_CNN_CIFAR10\ostest.py
dname = os.path.dirname(abspath) # 获取文件所在文件夹地址
# E:\GitHub\TF_Cookbook\08_Convolutional_Neural_Networks\03_CNN_CIFAR10
os.chdir(dname) # 转换目录文件夹到上层 # Start a graph session
# 初始化Session
sess = tf.Session() # 设置模型超参数
batch_size = 128 # 批处理数量
data_dir = 'temp' # 数据目录
output_every = 50 # 输出训练loss值
generations = 20000 # 迭代次数
eval_every = 500 # 输出测试loss值
image_height = 32 # 图片高度
image_width = 32 # 图片宽度
crop_height = 24 # 裁剪后图片高度
crop_width = 24 # 裁剪后图片宽度
num_channels = 3 # 图片通道数
num_targets = 10 # 标签数
extract_folder = 'cifar-10-batches-bin' # 指数学习速率衰减参数
learning_rate = 0.1 # 学习率
lr_decay = 0.1 # 学习率衰减速度
num_gens_to_wait = 250. # 学习率更新周期 # 提取模型参数
image_vec_length = image_height*image_width*num_channels # 将图片转化成向量所需大小
record_length = 1 + image_vec_length # ( + 1 for the 0-9 label) # 读取数据
data_dir = 'temp'
if not os.path.exists(data_dir): # 当前目录下是否存在temp文件夹
os.makedirs(data_dir) # 如果当前文件目录下不存在这个文件夹,创建一个temp文件夹
# 设定CIFAR10下载路径
cifar10_url = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' # 检查这个文件是否存在,如果不存在下载这个文件
data_file = os.path.join(data_dir, 'cifar-10-binary.tar.gz')
# temp\cifar-10-binary.tar.gz
if os.path.isfile(data_file):
pass
else:
# 回调函数,当连接上服务器、以及相应的数据块传输完毕时会触发该回调,我们可以利用这个回调函数来显示当前的下载进度。
# block_num已经下载的数据块数目,block_size数据块大小,total_size下载文件总大小 def progress(block_num, block_size, total_size):
progress_info = [cifar10_url, float(block_num*block_size)/float(total_size)*100.0]
print('\r Downloading {} - {:.2f}%'.format(*progress_info), end="") # urlretrieve(url, filename=None, reporthook=None, data=None)
# 参数 finename 指定了保存本地路径(如果参数未指定,urllib会生成一个临时文件保存数据。)
# 参数 reporthook 是一个回调函数,当连接上服务器、以及相应的数据块传输完毕时会触发该回调,我们可以利用这个回调函数来显示当前的下载进度。
# 参数 data 指 post 到服务器的数据,该方法返回一个包含两个元素的(filename, headers)元组,filename 表示保存到本地的路径,header 表示服务器的响应头。
# 此处 url=cifar10_url,filename=data_file,reporthook=progress filepath, _ = urllib.request.urlretrieve(cifar10_url, data_file, progress)
# 解压文件
tarfile.open(filepath, 'r:gz').extractall(data_dir) # Define CIFAR reader
# 定义CIFAR读取器
def read_cifar_files(filename_queue, distort_images=True):
reader = tf.FixedLengthRecordReader(record_bytes=record_length)
# 返回固定长度的文件记录 record_length函数参数为一条图片信息即1+32*32*3
key, record_string = reader.read(filename_queue)
# 此处调用tf.FixedLengthRecordReader.read函数返回键值对
record_bytes = tf.decode_raw(record_string, tf.uint8)
# 读出来的原始文件是string类型,此处我们需要用decode_raw函数将String类型转换成uint8类型
image_label = tf.cast(tf.slice(record_bytes, [0], [1]), tf.int32)
# 见slice函数用法,取从0号索引开始的第一个元素。并将其转化为int32型数据。其中存储的是图片的标签 # 截取图像
image_extracted = tf.reshape(tf.slice(record_bytes, [1], [image_vec_length]),
[num_channels, image_height, image_width])
# 从1号索引开始提取图片信息。这和此数据集存储图片信息的格式相关。
# CIFAR-10数据集中
"""第一个字节是第一个图像的标签,它是一个0-9范围内的数字。接下来的3072个字节是图像像素的值。
前1024个字节是红色通道值,下1024个绿色,最后1024个蓝色。值以行优先顺序存储,因此前32个字节是图像第一行的红色通道值。
每个文件都包含10000个这样的3073字节的“行”图像,但没有任何分隔行的限制。因此每个文件应该完全是30730000字节长。""" # Reshape image
image_uint8image = tf.transpose(image_extracted, [1, 2, 0])
# 详见tf.transpose函数,将[channel,image_height,image_width]转化为[image_height,image_width,channel]的数据格式。
reshaped_image = tf.cast(image_uint8image, tf.float32)
# 将图片剪裁或填充至合适大小
final_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, crop_width, crop_height) if distort_images:
# 将图像水平随机翻转,改变亮度和对比度。
final_image = tf.image.random_flip_left_right(final_image)
final_image = tf.image.random_brightness(final_image, max_delta=63)
final_image = tf.image.random_contrast(final_image, lower=0.2, upper=1.8) # 对图片做标准化处理
"""Linearly scales `image` to have zero mean and unit norm.
This op computes `(x - mean) / adjusted_stddev`, where `mean` is the average
of all values in image, and `adjusted_stddev = max(stddev, 1.0/sqrt(image.NumElements()))`.
`stddev` is the standard deviation of all values in `image`.
It is capped away from zero to protect against division by 0 when handling uniform images."""
final_image = tf.image.per_image_standardization(final_image)
return (final_image, image_label) # Create a CIFAR image pipeline from reader
# 从阅读器中构造CIFAR图片管道
def input_pipeline(batch_size, train_logical=False):
# train_logical标志用于区分读取训练和测试数据集
if train_logical:
files = [os.path.join(data_dir, extract_folder, 'data_batch_{}.bin'.format(i)) for i in range(1, 6)]
# data_dir=tmp
# extract_folder=cifar-10-batches-bin
else:
files = [os.path.join(data_dir, extract_folder, 'test_batch.bin')]
filename_queue = tf.train.string_input_producer(files)
image, label = read_cifar_files(filename_queue)
print(train_logical, 'after read_cifar_files ops image', sess.run(tf.shape(image)))
print(train_logical, 'after read_cifar_files ops label', sess.run(tf.shape(label)))
# min_after_dequeue defines how big a buffer we will randomly sample
# from -- bigger means better shuffling but slower start up and more
# memory used.
# capacity must be larger than min_after_dequeue and the amount larger
# determines the maximum we will prefetch. Recommendation:
# min_after_dequeue + (num_threads + a small safety margin) * batch_size
min_after_dequeue = 5000
capacity = min_after_dequeue + 3*batch_size
# 批量读取图片数据
example_batch, label_batch = tf.train.shuffle_batch([image, label],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
print(train_logical, 'after shuffle_batch ops image', sess.run(tf.shape(image)))
print(train_logical, 'after shuffle_batch ops example_batch', sess.run(tf.shape(example_batch)))
print(train_logical, 'after shuffle_batch ops label', sess.run(tf.shape(label)))
print(train_logical, 'after shuffle_batch ops label_batch', sess.run(tf.shape(label_batch))) return (example_batch, label_batch) # 获取数据
print('Getting/Transforming Data.')
# 初始化数据管道获取训练数据和对应标签
images, targets = input_pipeline(batch_size, train_logical=True)
# 获取测试数据和对应标签
test_images, test_targets = input_pipeline(batch_size, train_logical=False) sess.close() # True after read_cifar_files ops image [24 24 3]
# True after read_cifar_files ops label [1]
# True after shuffle_batch ops image [24 24 3]
# True after shuffle_batch ops example_batch [128 24 24 3]
# True after shuffle_batch ops label [1]
# True after shuffle_batch ops label_batch [128 1]
# False after read_cifar_files ops image [24 24 3]
# False after read_cifar_files ops label [1]
# False after shuffle_batch ops image [24 24 3]
# False after shuffle_batch ops example_batch [128 24 24 3]
# False after shuffle_batch ops label [1]
# False after shuffle_batch ops label_batch [128 1]

利用Tensorflow读取二进制CIFAR-10数据集的更多相关文章

  1. 【翻译】TensorFlow卷积神经网络识别CIFAR 10Convolutional Neural Network (CNN)| CIFAR 10 TensorFlow

    原网址:https://data-flair.training/blogs/cnn-tensorflow-cifar-10/ by DataFlair Team · Published May 21, ...

  2. (第二章第三部分)TensorFlow框架之读取二进制数据

    系列博客链接: (第二章第一部分)TensorFlow框架之文件读取流程:https://www.cnblogs.com/kongweisi/p/11050302.html (第二章第二部分)Tens ...

  3. tensorflow读取本地MNIST数据集

    tensorflow读取本地MNIST数据集 数据放入文件夹(不要解压gz): >>> import tensorflow as tf >>> from tenso ...

  4. 【实践】如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统)

    如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统) 一.环境配置 1. Python3.7.x(注:我用的是3.7.3.安 ...

  5. 第十二节,TensorFlow读取数据的几种方法以及队列的使用

    TensorFlow程序读取数据一共有3种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow图的起 ...

  6. TensorFlow 制作自己的TFRecord数据集

    官网的mnist和cifar10数据之后,笔者尝试着制作自己的数据集,并保存,读入,显示. TensorFlow可以支持cifar10的数据格式, 也提供了标准的TFRecord 格式,而关于 ten ...

  7. 利用Tensorflow进行自然语言处理(NLP)系列之一Word2Vec

    同步笔者CSDN博客(https://blog.csdn.net/qq_37608890/article/details/81513882). 一.概述 本文将要讨论NLP的一个重要话题:Word2V ...

  8. Tensorflow读取文件到队列文件

    TensorFlow读取二进制文件数据到队列 2016-11-03 09:30:00      0个评论    来源:diligent_321的博客   收藏   我要投稿 TensorFlow是一种 ...

  9. Tensorflow学习笔记No.10

    多输出模型 使用函数式API构建多输出模型完成多标签分类任务. 数据集下载链接:https://pan.baidu.com/s/1JtKt7KCR2lEqAirjIXzvgg 提取码:2kbc 1.读 ...

随机推荐

  1. Python:字符串操作总结

    所有标准的序列操作(索引.分片.乘法.判断成员资格.求长度.取最小值最大值)对字符串同样适用,且字符串是不可变的. 一.字符串格式化 转换说明符 [注]: 这些项的顺序至关重要 (1)%字符:标记转换 ...

  2. HDU 5666 Segment 数论+大数

    题目链接: hdu:http://acm.hdu.edu.cn/showproblem.php?pid=5666 bc(中文):http://bestcoder.hdu.edu.cn/contests ...

  3. 缓存-System.Web.Caching.Cache

    实现 Web 应用程序的缓存. 每个应用程序域创建一个此类的实例,只要应用程序域将保持活动状态,保持有效. 有关此类的实例的信息,请通过Cache的属性HttpContext对象或Cache属性的Pa ...

  4. lintcode-392-打劫房屋

    392-打劫房屋 假设你是一个专业的窃贼,准备沿着一条街打劫房屋.每个房子都存放着特定金额的钱.你面临的唯一约束条件是:相邻的房子装着相互联系的防盗系统,且 当相邻的两个房子同一天被打劫时,该系统会自 ...

  5. 《高性能JavaScript》学习笔记——日更中

    ------------------2016-7-20更------------------ 最近在看<高性能JavaScript>一书,里面当中,有讲很多提高js性能的书,正在看的过程中 ...

  6. Windows下基于http的git服务器搭建-gitstack

    版权声明:若无来源注明,Techie亮博客文章均为原创. 转载请以链接形式标明本文标题和地址: 本文标题:Windows下基于http的git服务器搭建-gitstack     本文地址:http: ...

  7. 使用Log4在测试过程中打印执行日志 及配置log4j.properties!

    http://zengxiantao.iteye.com/blog/1881706 1.环境配置:到网上下载log4j-1.2.17.jar包!完后 添加到 项目的build path 中即可! 2. ...

  8. fsockopen 异步非阻塞式请求数据

    index.php <?php ini_set ( "max_execution_time", "0" ); // 要传递的数据 $form_data = ...

  9. floyd最短路

    floyd可以在O(n^3)的时间复杂度,O(n^2)的空间复杂度下求解正权图中任意两点间的最短路长度. 本质是动态规划. 定义f[k][i][j]表示从i出发,途中只允许经过编号小于等于k的点时的最 ...

  10. 【uoj#192】[UR #14]最强跳蚤 Hash

    题目描述 给定一棵 $n$ 个点的树,边有边权.求简单路径上的边的乘积为完全平方数的点对 $(x,y)\ ,\ x\ne y$ 的数目. 题解 Hash 一个数是完全平方数,当且仅当每个质因子出现次数 ...