# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import tensorflow.python.platform
import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_download(filename, work_directory):
"""Download the data from Yann's website, unless it's already here."""
if not os.path.exists(work_directory):
os.mkdir(work_directory)
filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath):
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
return filepath
def _read32(bytestream):
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
def extract_images(filename):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError(
'Invalid magic number %d in MNIST image file: %s' %
(magic, filename))
num_images = _read32(bytestream)
rows = _read32(bytestream)
cols = _read32(bytestream)
buf = bytestream.read(rows * cols * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8)
data = data.reshape(num_images, rows, cols, 1)
return data
def dense_to_one_hot(labels_dense, num_classes=10):
"""Convert class labels from scalars to one-hot vectors."""
num_labels = labels_dense.shape[0]
index_offset = numpy.arange(num_labels) * num_classes
labels_one_hot = numpy.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot
def extract_labels(filename, one_hot=False):
"""Extract the labels into a 1D uint8 numpy array [index]."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2049:
raise ValueError(
'Invalid magic number %d in MNIST label file: %s' %
(magic, filename))
num_items = _read32(bytestream)
buf = bytestream.read(num_items)
labels = numpy.frombuffer(buf, dtype=numpy.uint8)
if one_hot:
return dense_to_one_hot(labels)
return labels
class DataSet(object):
def __init__(self, images, labels, fake_data=False, one_hot=False,
dtype=tf.float32):
"""Construct a DataSet.
one_hot arg is used only if fake_data is true. `dtype` can be either
`uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
`[0, 1]`.
"""
dtype = tf.as_dtype(dtype).base_dtype
if dtype not in (tf.uint8, tf.float32):
raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
dtype)
if fake_data:
self._num_examples = 10000
self.one_hot = one_hot
else:
assert images.shape[0] == labels.shape[0], (
'images.shape: %s labels.shape: %s' % (images.shape,
labels.shape))
self._num_examples = images.shape[0]
# Convert shape from [num examples, rows, columns, depth]
# to [num examples, rows*columns] (assuming depth == 1)
assert images.shape[3] == 1
images = images.reshape(images.shape[0],
images.shape[1] * images.shape[2])
if dtype == tf.float32:
# Convert from [0, 255] -> [0.0, 1.0].
images = images.astype(numpy.float32)
images = numpy.multiply(images, 1.0 / 255.0)
self._images = images
self._labels = labels
self._epochs_completed = 0
self._index_in_epoch = 0
@property
def images(self):
return self._images
@property
def labels(self):
return self._labels
@property
def num_examples(self):
return self._num_examples
@property
def epochs_completed(self):
return self._epochs_completed
def next_batch(self, batch_size, fake_data=False):
"""Return the next `batch_size` examples from this data set."""
if fake_data:
fake_image = [1] * 784
if self.one_hot:
fake_label = [1] + [0] * 9
else:
fake_label = 0
return [fake_image for _ in xrange(batch_size)], [
fake_label for _ in xrange(batch_size)]
start = self._index_in_epoch
self._index_in_epoch += batch_size
if self._index_in_epoch > self._num_examples:
# Finished epoch
self._epochs_completed += 1
# Shuffle the data
perm = numpy.arange(self._num_examples)
numpy.random.shuffle(perm)
self._images = self._images[perm]
self._labels = self._labels[perm]
# Start next epoch
start = 0
self._index_in_epoch = batch_size
assert batch_size <= self._num_examples
end = self._index_in_epoch
return self._images[start:end], self._labels[start:end]
def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32):
class DataSets(object):
pass
data_sets = DataSets()
if fake_data:
def fake():
return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)
data_sets.train = fake()
data_sets.validation = fake()
data_sets.test = fake()
return data_sets
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
VALIDATION_SIZE = 5000
local_file = maybe_download(TRAIN_IMAGES, train_dir)
train_images = extract_images(local_file)
local_file = maybe_download(TRAIN_LABELS, train_dir)
train_labels = extract_labels(local_file, one_hot=one_hot)
local_file = maybe_download(TEST_IMAGES, train_dir)
test_images = extract_images(local_file)
local_file = maybe_download(TEST_LABELS, train_dir)
test_labels = extract_labels(local_file, one_hot=one_hot)
validation_images = train_images[:VALIDATION_SIZE]
validation_labels = train_labels[:VALIDATION_SIZE]
train_images = train_images[VALIDATION_SIZE:]
train_labels = train_labels[VALIDATION_SIZE:]
data_sets.train = DataSet(train_images, train_labels, dtype=dtype)
data_sets.validation = DataSet(validation_images, validation_labels,
dtype=dtype)
data_sets.test = DataSet(test_images, test_labels, dtype=dtype)
return data_sets

自动下载和安装 MNIST 到 TensorFlow 的 python 源码 (转)的更多相关文章

  1. MySQL安装(yum、二进制、源码)

    MySQL安装(yum.二进制.源码) 目录 1.1 yum安装... 2 1.2 二进制安装-mysql-5.7.17. 3 1.2.1 准备工作... 3 1.2.2 解压.移动.授权... 3 ...

  2. python 源码安装

    1)下载python源码包 http://mirrors.sohu.com/python/3.5.2/Python-3.5.2.tgz 2)安装相关依赖  yum install zlib-devel ...

  3. Linux CentOS 7 安装PostgreSQL 9.5.17 (源码编译)

    近日需要将PostgreSQL数据库从Windows中迁移到Linux中,Linux CentOS 7 安装PostgreSQL 9.5.17 安装过程 特此记录. 安装环境: 数据库:Postgre ...

  4. linux环境安装svn并进行多个源码库区分管理

    关于svn的文档有很多大部分已Windows为例子,因公司没有Windows服务器经过一天的曲折终于初步安装了解了svn.下面一些经验希望能帮助新手 本文采用的yum安装(简单快速没必要源码) 1.y ...

  5. 『TensorFlow』SSD源码学习_其一:论文及开源项目文档介绍

    一.论文介绍 读论文系列:Object Detection ECCV2016 SSD 一句话概括:SSD就是关于类别的多尺度RPN网络 基本思路: 基础网络后接多层feature map 多层feat ...

  6. Python源码剖析|百度网盘免费下载|Python新手入门|Python新手学习资料

    百度网盘免费下载:Python源码剖析|新手免费领取下载 提取码:g78z 目录  · · · · · · 第0章 Python源码剖析——编译Python0.1 Python总体架构0.2 Pyth ...

  7. 学习Tensorflow,使用源码安装

    PC上装好Ubuntu系统,我们一步一步来讲解如何使用源码安装tensorflow?(我的Ubuntu系统是15.10) 安装cuda 根据你的系统型号选择相应的cuda版本下载 https://de ...

  8. 在centos7上安装gcc、node.js(源码下载)

    一.在centos7中安装node.js https://www.cnblogs.com/lpbottle/p/7733397.html 1.从源码下载Nodejs cd /usr/local/src ...

  9. Linux学习(二十)软件安装与卸载(三)源码包安装

    一.概述 源码包安装的优点在于它自由程度比较高,可以指定目录与组件.再有,你要是能改源码也可以. 二.安装方法 步骤 1.从官网或者信任站点下载源码包 [root@localhost ~]# wget ...

随机推荐

  1. AWS S3

    Amazon Simple Storage Service (Amazon S3) Amazon S3 提供了一个简单 Web 服务接口,可用于随时在 Web 上的任何位置存储和检索任何数量的数据.此 ...

  2. Unity QualitySettings.antiAliasing 抗锯齿

    QualitySettings.antiAliasing 抗锯齿 Description 描述 Set The AA Filtering option. 设置AA过滤选项. The AntiAliaz ...

  3. Oracle 事务操作

    在看本文之前,请确保你已经了解了Oracle事务和锁的概念即其作用,不过不了解,请参考数据库事务的一致性和原子性浅析和Oracle TM锁和TX锁 1.提交事务 当执行使用commit语句可以提交事务 ...

  4. 利用C#结合net use命令破解域帐号密码

    背景 我的职业是程序猿,而所在的工作单位因各种原因,对上网帐号有严格控制,近期竟然把我们的上网帐号全部停用,作为程序猿,不能上网,就如同鱼儿没有水,煮饭没有米,必须想办法解决此问题.公司的局域网环境是 ...

  5. HDFS HA和Federaion

    1.HA HA即为High Availability,用于解决NameNode单点故障问题,该特性通过热备的方式为主NameNode提供一个备用者,一旦主NameNode出现故障,可以迅速切换至备Na ...

  6. html控件

    checkbox val = "<li class='layer'><label><input type='checkbox' checked name='la ...

  7. linux系统更改当前主机名

    问题描述:CentOS系统,默认的主机名为localhost.localdomain,刚开始安装的时候会提示修改,但是有时候会忽略,那安装好后怎么修改呢? 解决方法: 1.以根用户登录,输入hostn ...

  8. C#实体对象序列化成Json,格式化,并让字段的首字母小写

    解决办法有两种:第一种:使用对象的字段属性设置JsonProperty来实现(不推荐,因为需要手动的修改每个字段的属性) public class UserInfo { [JsonProperty(& ...

  9. C#语言-03.逻辑控制语句

    a. 逻辑控制语句: i. 条件语句:先对条件判断,然后根据判断结果执行不同的分支 . If 和 if-else:判断“布尔表达式的值”来决定执行那个代码块 a. 语法:if(布尔表达式){ b. 布 ...

  10. javaEE Design Patter(1)初步了解23种常用设计模式

    设计模式分为三大类: 创建型模式,共五种:工厂方法模式.抽象工厂模式.单例模式.建造者模式.原型模式. 结构型模式,共七种:适配器模式.装饰器模式.代理模式.外观模式.桥接模式.组合模式.享元模式. ...