Tensorflow简单CNN实现详解

觉得有用的话,欢迎一起讨论相互学习~

"""转换图像数据格式时需要将它们的颜色空间变为灰度空间,将图像尺寸修改为同一尺寸,并将标签依附于每幅图像"""
import tensorflow as tf sess = tf.Session()
import glob image_filenames = glob.glob("./imagenet-dogs/n02*/*.jpg") # 访问imagenet-dogs文件夹中所有n02开头的子文件夹中所有的jpg文件 # image_filenames[0:2] 此语句表示image_filenames文件中的从第0个编号到第2个编号的值
# ['./imagenet-dogs\\n02085620-Chihuahua\\n02085620_10074.jpg',
# './imagenet-dogs\\n02085620-Chihuahua\\n02085620_10131.jpg']
# 此时image_filenames中保存的全部是类似于以上形式的值
# 注意书上的解释和这个输出和此处的输出与有很大的不同,原因是书本是用linux系统,
# 所以是以"/"对文件名进行分隔符的操作而此处不是windows下使用"\\"对文件名进行操作. from itertools import groupby
from collections import defaultdict training_dataset = defaultdict(list)
testing_dataset = defaultdict(list) # Split up the filename into its breed and corresponding filename. The breed is found by taking the directory name
# 将文件名分解为品种和对应的文件名,品种对应于文件夹名称
image_filename_with_breed = map(lambda filename: (filename.split("/")[1].split("\\")[1], filename), image_filenames)
# 表示定义一个匿名函数lambda传入参数为filename,对filename以"/"为分隔符,然后取第二个值,并且返回filename.split("/")[1]和filename
# 并且以image_filenames作为参数
# ('n02086646-Blenheim_spaniel', './imagenet-dogs\\n02086646-Blenheim_spaniel\\n02086646_3739.jpg') # Group each image by the breed which is the 0th element in the tuple returned above
# 依据品种(上述返回的元组的第0个分量对元素进行分组)
for dog_breed, breed_images in groupby(image_filename_with_breed, lambda x: x[0]):
# Enumerate each breed's image and send ~20% of the images to a testing set
# 美剧每个品种的图像,并将大致20%的图像划入测试集
# 此函数返回的dog_breed即是image_filename_with_breed[0]也就是文件夹的名字即是狗的类别
# breed_images则是一个迭代器是根据狗的类别进行分类的
for i, breed_image in enumerate(breed_images):
# breed_images此时是根据狗的种类进行分类的迭代器
# 返回的i表示品种的代表编号
# 返回的breed_image表示这个标号的种类下狗的图片
if i%5 == 0:
testing_dataset[dog_breed].append(breed_image[1])
else:
training_dataset[dog_breed].append(breed_image[1])
# 表示其中五分之一加入测试集其余进入训练集
# 并且以狗的类别名称进行区分,向同一类型中添加图片
# Check that each breed includes at least 18% of the images for testing
breed_training_count = len(training_dataset[dog_breed])
breed_testing_count = len(testing_dataset[dog_breed])
# 现在,每个字典就按照下列格式包含了所有的Chihuahua图像
# training_dataset["n02085620-Chihuahua"] = ['./imagenet-dogs\\n02085620-Chihuahua\\n02085620_10131.jpg', ...] def write_records_file(dataset, record_location):
"""
Fill a TFRecords file with the images found in `dataset` and include their category.
用dataset中的图像填充一个TFRecord文件,并将其类别包含进来
Parameters
参数
----------
dataset : dict(list)
Dictionary with each key being a label for the list of image filenames of its value.
这个字典的键对应于其值中文件名列表对应的标签
record_location : str
Location to store the TFRecord output.
存储TFRecord输出的路径
"""
writer = None # Enumerating the dataset because the current index is used to breakup the files if they get over 100
# images to avoid a slowdown in writing.
# 枚举dataset,因为当前索引用于对文件进行划分,每个100幅图像,训练样本的信息就被写入到一个新的TFRecord文件中,以加快写操作的速度
current_index = 0
for breed, images_filenames in dataset.items():
# print(breed) n02085620-Chihuahua...
# print(image_filenames) ['./imagenet-dogs\\n02085620-Chihuahua\\n02085620_10074.jpg', ...]
for image_filename in images_filenames:
if current_index%100 == 0: # 如果记录了100个文件的话,write就关闭
if writer:
writer.close()
# 否则开始记录write文件
# record_Location表示当前的目录
# current_index初始值为0,随着文件记录逐渐增加
record_filename = "{record_location}-{current_index}.tfrecords".format(
record_location=record_location,
current_index=current_index)
# format是格式化字符串操作,通过format(){}函数将文件名保存到record_filename中 writer = tf.python_io.TFRecordWriter(record_filename)
current_index += 1 image_file = tf.read_file(image_filename) # In ImageNet dogs, there are a few images which TensorFlow doesn't recognize as JPEGs. This
# try/catch will ignore those images.
# 在ImageNet的狗的图像中,有少量无法被Tensorflow识别为JPEG的图像,利用try/catch可将这些图像忽略
try:
image = tf.image.decode_jpeg(image_file)
except:
print(image_filename)
continue # Converting to grayscale saves processing and memory but isn't required.
# 将其转化为灰度图片的类型,虽然这并不是必需的,但是可以减少计算量和内存占用,
grayscale_image = tf.image.rgb_to_grayscale(image)
resized_image = tf.image.resize_images(grayscale_image, (250, 151)) # 并将图片修改为长250宽151的图片类型 # tf.cast is used here because the resized images are floats but haven't been converted into
# image floats where an RGB value is between [0,1).
# 这里之所以使用tf.cast,是因为 尺寸更改后的图像的数据类型是浮点数,但是RGB值尚未转换到[0,1)的区间之内
image_bytes = sess.run(tf.cast(resized_image, tf.uint8)).tobytes() # Instead of using the label as a string, it'd be more efficient to turn it into either an
# integer index or a one-hot encoded rank one tensor.
# https://en.wikipedia.org/wiki/One-hot
# 将标签按照字符串存储较为高效,推荐的做法是将其转换成整数索引或独热编码的秩1张量
image_label = breed.encode("utf-8") example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_label])),
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))
})) writer.write(example.SerializeToString()) # 将其序列化为二进制字符串
writer.close() # 如果你已经运行过一次此程序成功生成了所有Tf_records文件,下次运行时可以将以下两句话注释掉,避免再次运行时浪费时间消耗资源。
write_records_file(testing_dataset, "./output/testing-images/testing-image")
write_records_file(training_dataset, "./output/training-images/training-image")
filename_queue = tf.train.string_input_producer(tf.train.match_filenames_once("./output/training-images/*.tfrecords")) # 生成文件名队列
reader = tf.TFRecordReader()
_, serialized = reader.read(filename_queue)
# 通过阅读器读取value值并将其保存为serialized # 模板化的代码,将label和image分开
features = tf.parse_single_example(
serialized,
features={
'label': tf.FixedLenFeature([], tf.string),
'image': tf.FixedLenFeature([], tf.string),
}) record_image = tf.decode_raw(features['image'], tf.uint8)
# tf.decode_raw()函数将字符串的字节重新解释为一个数字的向量 # Changing the image into this shape helps train and visualize the output by converting it to
# be organized like an image.
# 修改图像的形状有助于训练和输出的可视化
image = tf.reshape(record_image, [250, 151, 1]) label = tf.cast(features['label'], tf.string) min_after_dequeue = 10 # 当一次出列操作完成后,队列中元素的最小数量,往往用于定义元素的混合级别.
batch_size = 3 # 批处理大小
capacity = min_after_dequeue + 3*batch_size # 批处理容量
image_batch, label_batch = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue, num_threads=4)
# 通过随机打乱的方式创建数据批次 # Converting the images to a float of [0,1) to match the expected input to convolution2d
# 将图像转换为灰度值位于[0, 1)的浮点类型,以与convlution2d期望的输入匹配
float_image_batch = tf.image.convert_image_dtype(image_batch, tf.float32) # 第一个卷积层 conv2d_layer_one = tf.contrib.layers.conv2d(
float_image_batch,
num_outputs=32, # 生成的滤波器的数量
kernel_size=(5, 5), # 滤波器的高度和宽度
activation_fn=tf.nn.relu,
weights_initializer=tf.random_normal_initializer, # 设置weight的值是正态分布的随机值
stride=(2, 2), # 对image_batch和imput_channels的跨度值
trainable=True)
# shape(3, 125, 76,32)
# 3表示批处理数据量是3,
# 125和76表示经过卷积操作后的宽和高,这和滤波器的大小还有步长有关系 # 第一个混合/池化层,输出降采样 pool_layer_one = tf.nn.max_pool(conv2d_layer_one,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME')
# shape(3, 63,38,32)
# 混合层ksize,1表示选取一个批处理数据,2表示在宽的维度取2个单位,2表示在高的取两个单位,1表示选取一个滤波器也就数选择一个通道进行操作.
# strides步长表示其分别在四个维度上的跨度 # Note, the first and last dimension of the convolution output hasn't changed but the
# middle two dimensions have.
# 注意卷积输出的第一个维度和最后一个维度没有发生变化,但是中间的两个维度发生了变化 # 第二个卷积层 conv2d_layer_two = tf.contrib.layers.conv2d(
pool_layer_one,
num_outputs=64, # 更多输出通道意味着滤波器数量的增加
kernel_size=(5, 5),
activation_fn=tf.nn.relu,
weights_initializer=tf.random_normal_initializer,
stride=(1, 1),
trainable=True)
# shape(3, 63,38,64) # 第二个混合/池化层,输出降采样 pool_layer_two = tf.nn.max_pool(conv2d_layer_two,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding='SAME')
# shape(3, 32, 19,64) # 光栅化层 # 由于后面要使用softmax,因此全连接层需要修改为二阶张量,张量的第1维用于区分每幅图像,第二维用于对们每个输入张量的秩1张量
flattened_layer_two = tf.reshape(
pool_layer_two,
[
batch_size, # image_batch中的每幅图像
-1 # 输入的其他所有维度
])
# 例如,如果此时一批次有三个数据的时候,则每一行就是一个数据行,然后每一列就是这个图片的数据,
# 这里的-1参数将最后一个池化层调整为一个巨大的秩1张量 # 全连接层1
hidden_layer_three = tf.contrib.layers.fully_connected(
flattened_layer_two,
512,
weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
activation_fn=tf.nn.relu
) # 对一些神经元进行dropout操作.每个神经元以0.1的概率决定是否放电
hidden_layer_three = tf.nn.dropout(hidden_layer_three, 0.1) # The output of this are all the connections between the previous layers and the 120 different dog breeds
# available to train on.
# 输出是前面的层与训练中可用的120个不同品种的狗的品种的全连接
# 全连接层2
final_fully_connected = tf.contrib.layers.fully_connected(
hidden_layer_three,
120, # ImageNet Dogs 数据集中狗的品种数
weights_initializer=tf.truncated_normal_initializer(stddev=0.1)
) """
由于每个标签都是字符串类型,tf.nn.softmax无法直接使用这些字符串,所以需要将这些字符创转换为独一无二的数字,
这些操作都应该在数据预处理阶段进行
"""
import glob # Find every directory name in the imagenet-dogs directory (n02085620-Chihuahua, ...)
# 找到位于imagenet-dogs路径下的所有文件目录名
labels = list(map(lambda c: c.split("/")[-1].split("\\")[1], glob.glob("./imagenet-dogs/*"))) # Match every label from label_batch and return the index where they exist in the list of classes
# 匹配每个来自label_batch的标签并返回它们在类别列表的索引
# 将label_batch作为参数l传入到匿名函数中tf.map_fn函数总体来讲和python中map函数相似,map_fn主要是将定义的函数运用到后面集合中每个元素中
train_labels = tf.map_fn(lambda l: tf.where(tf.equal(labels, l))[0][0], label_batch, dtype=tf.int64) # setup-only-ignore
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=final_fully_connected, labels=train_labels)) global_step = tf.Variable(0) # 相当于global_step,是一个全局变量,在训练完一个批次后自动增加1 # 学习率使用退化学习率的方法
# 设置初始学习率为0.01,
learning_rate = tf.train.exponential_decay(learning_rate=0.01, global_step=global_step, decay_steps=120,
decay_rate=0.95, staircase=True) optimizer = tf.train.AdamOptimizer(learning_rate, 0.9).minimize(loss, global_step=global_step) # 主程序
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op) coord = tf.train.Coordinator()
# 线程控制管理器
threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 训练
training_steps = 1000
for step in range(training_steps):
sess.run(optimizer) if step % 10 == 0:
print("loss:", sess.run(loss)) train_prediction = tf.nn.softmax(final_fully_connected)
# setup-only-ignore
filename_queue.close(cancel_pending_enqueues=True)
coord.request_stop()
coord.join(threads)
sess.close()

新手指南,函数参考

Glob

glob是python自己带的一个文件操作相关模块,用它可以查找符合自己目的的文件,就类似于Windows下的文件搜索,支持通配符操作,*,?,[]这三个通配符,

代表0个或多个字符,?代表一个字符,[]匹配指定范围内的字符,如[0-9]匹配数字。它的主要方法就是glob,该方法返回所有匹配的文件路径列表,

该方法需要一个参数用来指定匹配的路径字符串(本字符串可以为绝对路径也可以为相对路径),其返回的文件名只包括当前目录里的文件名,不包括子文件夹里的文件。

glob.glob(r'c:*.txt')

我这里就是获得C盘下的所有txt文件

glob.glob(r'E:\pic**.jpg')

获得指定目录下的所有jpg文件

使用相对路径:

glob.glob(r'../
.py')

group

groupby(iterable[, keyfunc])

返回:按照keyfunc函数对序列每个元素执行后的结果分组(每个分组是一个迭代器), 返回这些分组的迭代器

例子:

from itertools import *
a = ['aa', 'ab', 'abc', 'bcd', 'abcde']
for i, k in groupby(a, len):#按照字符串的长度对a的每个元素进行分组
for m in k:
print m,
print i
输出:
aa ab 2
abc bcd 3
abcde 5

defaultdict对象

class collections.defaultdict([default_factory[, ...]])

返回一个新的类似字典的对象。defaultdict是内置dict类的子类。它覆盖一个方法,并添加一个可写的实例变量。其余的功能与dict类相同,这里就不再记录。

第一个参数提供default_factory属性的初始值;它默认为None。所有剩余的参数都视为与传递给dict构造函数的参数相同,包括关键字参数。

defaultdict对象除了支持标准的dict操作,还支持以下方法:

missing ( key )

如果default_factory属性为None,则以key作为参数引发KeyError异常。

如果default_factory不为None,则不带参数调用它以用来给key提供默认值,此值将插入到字典中用于key,并返回。

如果调用default_factory引发异常,则该异常会保持原样传播。

当未找到请求的key时,此方法由dict类的__getitem__()方法调用;getitem()将返回或引发它返回或引发的。

请注意,除了__getitem__()之外的任何操作,都不会调用__missing__()。这意味着get()会像正常的字典一样返回None作为默认值,而不是使用default_factory。

defaultdict对象支持以下实例变量:

default_factory

此属性由__missing__()方法使用;如果构造函数的第一个参数存在,则初始化为它,如果不存在,则初始化为None。

defaultdict示例

使用list作为default_factory,可以很容易地将一系列键值对分组为一个列表字典:
>>>
>>> s = [('yellow', 1), ('blue', 2), ('yellow', 3), ('blue', 4), ('red', 1)]
>>> d = defaultdict(list)
>>> for k, v in s:
... d[k].append(v)
...
>>> sorted(d.items())
[('blue', [2, 4]), ('red', [1]), ('yellow', [1, 3])]

当每个键第一次遇到时,它不在映射中;因此使用返回空list的default_factory函数自动创建一个条目。然后,list.append()操作将值附加到新列表。当再次遇到这个键时,查找正常继续(返回该键的列表),并且list.append()操作向列表中添加另一个值。这种技术比使用等效的dict.setdefault()技术更简单和更快:

>>>
>>> d = {}
>>> for k, v in s:
... d.setdefault(k, []).append(v)
...
>>> sorted(d.items())
[('blue', [2, 4]), ('red', [1]), ('yellow', [1, 3])]

将default_factory设置为int可使defaultdict用于计数(如其他语言的bag或multiset):

>>>
>>> s = 'mississippi'
>>> d = defaultdict(int)
>>> for k in s:
... d[k] += 1
...
>>> sorted(d.items())
[('i', 4), ('m', 1), ('p', 2), ('s', 4)]

当一个字母第一次遇到时,映射中缺少该字母,因此default_factory函数调用int()以提供默认计数零。增量操作然后建立每个字母的计数。

始终返回零的函数int()只是常量函数的特殊情况。创建常量函数的更快和更灵活的方法是使用lambda函数,它可以提供任何常量值(不只是零):

>>>
>>> def constant_factory(value):
... return lambda: value
>>> d = defaultdict(constant_factory('<missing>'))
>>> d.update(name='John', action='ran')
>>> '%(name)s %(action)s to %(object)s' % d
'John ran to <missing>'

将default_factory设置为set可使defaultdict有助于构建集合字典:

>>>
>>> s = [('red', 1), ('blue', 2), ('red', 3), ('blue', 4), ('red', 1), ('blue', 4)]
>>> d = defaultdict(set)
>>> for k, v in s:
... d[k].add(v)
...
>>> sorted(d.items())
[('blue', {2, 4}), ('red', {1, 3})]

Lambda

lambda的一般形式是关键字lambda后面跟一个或多个参数,紧跟一个冒号,以后是一个表达式。lambda是一个表达式而不是一个语句。它能够出现在python语法不允许def出现的地方。作为表达式,lambda返回一个值(即一个新的函数)。lambda用来编写简单的函数,而def用来处理更强大的任务。

  1. f = lambda x,y,z : x+y+z
2. print f(1,2,3)
3.
4. g = lambda x,y=2,z=3 : x+y+z
5. print g(1,z=4,y=5)
输出
1. 6
2. 10

Split

Python中有split()和os.path.split()两个函数,具体作用如下:

split():拆分字符串。通过指定分隔符对字符串进行切片,并返回分割后的字符串列表(list)

os.path.split():按照路径将文件名和路径分割开

一、函数说明

1、split()函数

语法:str.split(str="",num=string.count(str))[n]

参数说明:

str:表示为分隔符,默认为空格,但是不能为空('')。若字符串中没有分隔符,则把整个字符串作为列表的一个元素

num:表示分割次数。如果存在参数num,则仅分隔成 num+1 个子字符串,并且每一个子字符串可以赋给新的变量

[n]:表示选取第n个分片

注意:当使用空格作为分隔符时,对于中间为空的项会自动忽略

2、os.path.split()函数

语法:os.path.split('PATH')

参数说明:

1.PATH指一个文件的全路径作为参数:

2.如果给出的是一个目录和文件名,则输出路径和文件名

3.如果给出的是一个目录名,则输出路径和为空文件名

二、分离字符串

string = "www.gziscas.com.cn"

1.以'.'为分隔符

print(string.split('.'))

['www', 'gziscas', 'com', 'cn']

2.分割两次

print(string.split('.',2))

['www', 'gziscas', 'com.cn']

3.分割两次,并取序列为1的项

print(string.split('.',2)[1])

gziscas

4.分割两次,并把分割后的三个部分保存到三个文件

u1, u2, u3 =string.split('.',2)

print(u1)—— www

print(u2)—— gziscas

print(u3) ——com.cn

三、分离文件名和路径

import os

print(os.path.split('/dodo/soft/python/'))

('/dodo/soft/python', '')

print(os.path.split('/dodo/soft/python'))

('/dodo/soft', 'python')

四、实例

str="hello boy<[www.baidu.com]>byebye"

print(str.split("[")[1].split("]")[0])

www.baidu.com

filter

filter()函数包括两个参数,分别是function和list。该函数根据function参数返回的结果是否为真来过滤list参数中的项,最后返回一个新列表,如下例所示

>>>a=[1,2,3,4,5,6,7]
>>>b=filter(lambda x:x>5, a)
>>>print b
>>>[6,7]
如果filter参数值为None,就使用identity()函数,list参数中所有为假的元素都将被删除。如下所示:
>>>a=[0,1,2,3,4,5,6,7]
b=filter(None, a)
>>>print b
>>>[1,2,3,4,5,6,7]

map

map()的两个参数一个是函数名,另一个是列表或元组。

>>>map(lambda x:x+3, a) #这里的a同上
>>>[3,4,5,6,7,8,9,10]
#另一个例子
>>>a=[1,2,3]
>>>b=[4,5,6]
>>>map(lambda x,y:x+y, a,b)
>>>[5,7,9]

reduce

reduce 函数按照指定规则递归求值
>>>reduce(lambda x,y:x*y, [1,2,3,4,5])
>>>120
>>>reduce(lambda x,y:x*y, [1,2,3], 10)
>>>60 # ((1*2)*3)*10

enumerate

enumerate(iteration, start):返回一个枚举的对象。迭代器(iteration)必须是另外一个可以支持的迭代对象。初始值默认为零,也就是你如果不输入start那就代表从零开始。迭代器的输入可以是列表,字符串,集合等,因为这些都是可迭代的对象。返回一个对象,如果你用列表的形式表现出来的话那就是一个列表,列表的每个元素是一个元组,元祖有两个元素,第一个元素代表编号,也就是第几个元素的意思,第二个元素就是迭代器的对应的元素,这是在默认start为零的情况下。如果不为零,那就是列表的第一个元组的第一个元素就是start的值,后面的依次累加,第二个元素还是一样的意思。

例如:


1. str1 = 'lplplp' # string
2. list1 = [1, 5, 6] # list
3. tuple1 = (5, 8, 4, 2) # 元组
4. set1 = {'kl', 'lk'} # 集合
5.
6. print list(enumerate(str1))
7. print list(enumerate(str1, start=2))
8.
9. print list(enumerate(list1))
10. print list(enumerate(list1, start=2))
11.
12. print list(enumerate(tuple1))
13. print list(enumerate(tuple1, start=2))
14.
15. print list(enumerate(set1))
16. print list(enumerate(set1, start=2)) 输出:
[(0, 'l'), (1, 'p'), (2, 'l'), (3, 'p'), (4, 'l'), (5, 'p')]
[(2, 'l'), (3, 'p'), (4, 'l'), (5, 'p'), (6, 'l'), (7, 'p')]
[(0, 1), (1, 5), (2, 6)]
[(2, 1), (3, 5), (4, 6)]
[(0, 5), (1, 8), (2, 4), (3, 2)]
[(2, 5), (3, 8), (4, 4), (5, 2)]
[(0, 'lk'), (1, 'kl')]
[(2, 'lk'), (3, 'kl')]

tf.train.shuffle_batch

tf.train.shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue, num_threads=1, seed=None, enqueue_many=False, shapes=None, name=None)

Creates batches by randomly shuffling tensors.

通过随机打乱张量的顺序创建批次.

简单来说就是读取一个文件并且加载一个张量中的batch_size行

This function adds the following to the current Graph:

这个函数将以下内容加入到现有的图中.

A shuffling queue into which tensors from tensor_list are enqueued.

一个由传入张量组成的随机乱序队列

A dequeue_many operation to create batches from the queue.

从张量队列中取出张量的出队操作

A QueueRunner to QUEUE_RUNNER collection, to enqueue the tensors

from tensor_list.

一个队列运行器管理出队操作.

If enqueue_many is False, tensor_list is assumed to represent a

single example. An input tensor with shape [x, y, z] will be output

as a tensor with shape [batch_size, x, y, z].

If enqueue_many is True, tensor_list is assumed to represent a

batch of examples, where the first dimension is indexed by example,

and all members of tensor_list should have the same size in the

first dimension. If an input tensor has shape [*, x, y, z], the

output will have shape [batch_size, x, y, z].

'enqueue_many’主要是设置tensor中的数据是否能重复,如果想要实现同一个样本多次出现可以将其设置为:“True”,如果只想要其出现一次,也就是保持数据的唯一性,这时候我们将其设置为默认值:"False"

The capacity argument controls the how long the prefetching is allowed to grow the queues.

容量控制了预抓取操作对于增加队列长度操作的长度.

For example:

Creates batches of 32 images and 32 labels.

image_batch, label_batch = tf.train.shuffle_batch(

[single_image, single_label],

batch_size=32,

num_threads=4,

capacity=50000,

min_after_dequeue=10000)

Args:

tensor_list: The list of tensors to enqueue.

入队的张量列表

batch_size: The new batch size pulled from the queue.

表示进行一次批处理的tensors数量.

capacity: An integer. The maximum number of elements in the queue.

容量:一个整数,队列中的最大的元素数.

这个参数一定要比min_after_dequeue参数的值大,并且决定了我们可以进行预处理操作元素的最大值.

推荐其值为:

capacity=(min_after_dequeue+(num_threads+a small safety margin∗batchsize)

min_after_dequeue: Minimum number elements in the queue after a

dequeue(出列), used to ensure a level of mixing of elements.

当一次出列操作完成后,队列中元素的最小数量,往往用于定义元素的混合级别.

定义了随机取样的缓冲区大小,此参数越大表示更大级别的混合但是会导致启动更加缓慢,并且会占用更多的内存

num_threads: The number of threads enqueuing tensor_list.

设置num_threads的值大于1,使用多个线程在tensor_list中读取文件,这样保证了同一时刻只在一个文件中进行读取操作(但是读取速度依然优于单线程),而不是之前的同时读取多个文件,这种方案的优点是:

避免了两个不同的线程从同一文件中读取用一个样本

避免了过多的磁盘操作

seed: Seed for the random shuffling within the queue.

打乱tensor队列的随机数种子

enqueue_many: Whether each tensor in tensor_list is a single example.

定义tensor_list中的tensor是否冗余.

shapes: (Optional) The shapes for each example. Defaults to the

inferred shapes for tensor_list.

用于改变读取tensor的形状,默认情况下和直接读取的tensor的形状一致.

name: (Optional) A name for the operations.

Returns:

A list of tensors with the same number and types as tensor_list.

默认返回一个和读取tensor_list数据和类型一个tensor列表.

tf.where

tf.where(condition, x=None, y=None, name=None)

功能:若x,y都为None,返回condition值为True的坐标;

若x,y都不为None,返回condition值为True的坐标在x内的值,condition值为False的坐标在y内的值

输入:condition:bool类型的tensor

a = tf.constant([True, False, False, True])
x = tf.constant([1, 2, 3, 4])
y = tf.constant([5, 6, 7, 8])
z = tf.where(a)
z2 = tf.where(a, x, y) sess = tf.Session()
print(sess.run(z))
print(sess.run(z2))
sess.close() # z==>[[0]
# [3]]
# z2==>[ 1 6 7 4]

标签的匹配

train_labels = tf.map_fn(lambda l: tf.where(tf.equal(labels, l)
)[0][0], label_batch, dtype=tf.int64)

注解:

label_batch是一个[batch_size,1]的张量,labels储存有所有的图片标签的信息,是一个[pictures_num,1]的张量。

很明显label_batch的行数比picture_num小得多,这时候如果我们直接使用tf.equal函数会出现维度不匹配的问题,使用map_fn主要是将定义的函数运用到后面集合中每个元素中。这里的l其实是label_batch标签张量中的一个秩相同的单个张量。

tf.equal(labels,l)会得到一个[Flase,True,Flase,True,False,False,False]的张量,tf.where会找到此布尔值数组的第一个为True的索引。由于函数返回的是一个二维数组,所以使用[0][0]提取出该值。

Example

"""主要测试tf.where的使用"""
import tensorflow as tf
import numpy as np a = np.array([[5]])
a1 = np.array([[1], [2], [3]])
b = np.array([[1], [7], [8], [4], [5], [2], [3], [2], [3]])
# 对于[n,1]shape张量匹配必须使用map_fn函数,否则会出shape函数维度不匹配的错误
c1 = tf.map_fn(lambda l: tf.where(tf.equal(b, l))[0][0], a1, dtype=tf.int64)
c = tf.where(tf.equal(a, b))[0][0] # c = tf.where(tf.equal(a1, b))[0][0] 这个语句就会出现下面维度不匹配的错误。
# Dimensions must be equal, but are 3 and 7 for 'Equal' (op: 'Equal') with input shapes: [3,1], [7,1].
sess = tf.Session()
print(sess.run(c))
print(sess.run(c1)) # 4
# [0 5 6]

环境

tensorflow 1.2.1 CPU版本

python 3.5.0

windows 10(特别注意,linux系统和windows系统对于文件名表示的区别)

参考资料

面向机器智能的Tensorflow实践

Tensorflow简单CNN实现的更多相关文章

  1. TensorFlow简单介绍和在centos上的安装

    ##tensorflow简单介绍: TensorFlow™ is an open source software library for numerical computation using dat ...

  2. TersorflowTutorial_MNIST数据集上简单CNN实现

    MNIST数据集上简单CNN实现 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 Tensorflow机器学习实战指南 源代码请点击下方链接欢迎加星 Tesorflow实现基于MNI ...

  3. 【TensorFlow/简单网络】MNIST数据集-softmax、全连接神经网络,卷积神经网络模型

    初学tensorflow,参考了以下几篇博客: soft模型 tensorflow构建全连接神经网络 tensorflow构建卷积神经网络 tensorflow构建卷积神经网络 tensorflow构 ...

  4. FaceRank-人脸打分基于 TensorFlow 的 CNN 模型

    FaceRank-人脸打分基于 TensorFlow 的 CNN 模型 隐私 因为隐私问题,训练图片集并不提供,稍微可能会放一些卡通图片. 数据集 130张 128*128 张网络图片,图片名: 1- ...

  5. TensorFlow简单线性回归

    TensorFlow简单线性回归 将针对波士顿房价数据集的房间数量(RM)采用简单线性回归,目标是预测在最后一列(MEDV)给出的房价. 波士顿房价数据集可从http://lib.stat.cmu.e ...

  6. tensorflow实现一个神经网络简单CNN网络

    本例子用到了minst数据库,通过训练CNN网络,实现手写数字的预测. 首先先把数据集读取到程序中(MNIST数据集大约12MB,如果没在文件夹中找到就会自动下载): mnist = input_da ...

  7. Tensorflow的CNN教程解析

    之前的博客我们已经对RNN模型有了个粗略的了解.作为一个时序性模型,RNN的强大不需要我在这里重复了.今天,让我们来看看除了RNN外另一个特殊的,同时也是广为人知的强大的神经网络模型,即CNN模型.今 ...

  8. [DL学习笔记]从人工神经网络到卷积神经网络_3_使用tensorflow搭建CNN来分类not_MNIST数据(有一些问题)

    3:用tensorflow搭个神经网络出来 为什么用tensorflow呢,应为谷歌是亲爹啊,虽然有些人说caffe更适合图像啊mxnet效率更高等等,但爸爸就是爸爸,Android都能那么火,一个道 ...

  9. TensorflowTutorial_二维数据构造简单CNN

    使用二维数据构造简单卷积神经网络 觉得有用的话,欢迎一起讨论相互学习~Follow Me 图像和一些时序数据集都可以用二维数据的形式表现,我们此次使用随机分布的二位数据构造一个简单的CNN-网络卷积- ...

随机推荐

  1. python抢火车票的脚本

    起因: 想着那么多人,抢不到火车票.感觉到一丝感慨 所以有了抢火车票这个脚本. 0x01 思路:自动打开浏览器,自动输入账号密码 知道查看.自动预定. 0x02 要用到的模块 splinter模块: ...

  2. python网络数据采集(伴奏曲)

    这里是前章,我们做一下预备.之前太多事情没能写博客~..             (此博客只适合python3x,python2x请自行更改代码) 首先你要有bs4模块 windows下安装:pip3 ...

  3. IntelliJ IDEA 17和Maven构建javaWeb项目

    前言 电脑又断电了,眼看着写好的东西就没有了,这是第二次犯这个错误了.很难受呀!还是回到正题吧,我们来使用IDEA和Maven构建一个JavaWeb项目 软件环境: IDEA:2017.2.1 JDK ...

  4. 2017"百度之星"程序设计大赛 - 复赛1003&&HDU 6146 Pokémon GO【数学,递推,dp】

    Pokémon GO Time Limit: 3000/1500 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others)Total ...

  5. 【Java学习笔记之二】java标识符命名规范

    什么是标识符 就是程序员在定义java程序时,自定义的一些名字.标识符可以应用在类名.变量.函数名.包名上. 标识符必须遵循以下规则 标识符由26个英文字符大小写(a~zA~Z).数字(0~9).下划 ...

  6. COGS 1299. bplusa【听说比a+b还要水的大水题???】

    1299. bplusa ☆   输入文件:bplusa.in   输出文件:bplusa.out   评测插件 时间限制:1 s   内存限制:128 MB [题目描述] 输入一个整数n,将其拆为两 ...

  7. POI实现大数据EXCLE导入导出,解决内存溢出问题

    使用POI能够导出大数据保证内存不溢出的一个重要原因是SXSSFWorkbook生成的EXCEL为2007版本,修改EXCEL2007文件后缀为ZIP打开可以看到,每一个Sheet都是一个xml文件, ...

  8. SpringMVC框架学习笔记(1)——HelloWorld

    搭建SpringMVC框架 1.添加jar包 jsp-api.jar servlet-api.jar jstl.jar commons-logging-1.1.1.jar spring-beans-4 ...

  9. window下部署Solr

    主要步骤如下: 1.下载solr-4.7.2.zip;下载地址:http://archive.apache.org/dist/lucene/java/ 2.解压缩solr-4.7.2.zip,解压后目 ...

  10. 01 mysql的安装(windows)

    在安装mysql之前,一般是先下载mysql,推荐大家去Oracle的官网下载,而且尽量使用免安装的版本(即压缩版,解压之后就可以使用的版本,不是.exe的安装版本),因为安装版的mysql在安装过程 ...