紧接上篇Tensorflow学习教程------tfrecords数据格式生成与读取,本篇将数据读取、建立网络以及模型训练整理成一个小样例,完整代码如下。

#coding:utf-8
import tensorflow as tf
import os
def read_and_decode(filename):
#根据文件名生成一个队列
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
}) img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [227, 227, 3])
img = (tf.cast(img, tf.float32) * (1. / 255) - 0.5)*2
label = tf.cast(features['label'], tf.int32)
print img,label
return img, label def get_batch(image, label, batch_size,crop_size):
#数据扩充变换
distorted_image = tf.random_crop(image, [crop_size, crop_size, 3])#随机裁剪
distorted_image = tf.image.random_flip_up_down(distorted_image)#上下随机翻转
distorted_image = tf.image.random_brightness(distorted_image,max_delta=63)#亮度变化
distorted_image = tf.image.random_contrast(distorted_image,lower=0.2, upper=1.8)#对比度变化 #生成batch
#shuffle_batch的参数:capacity用于定义shuttle的范围,如果是对整个训练数据集,获取batch,那么capacity就应该够大
#保证数据打的足够乱
images, label_batch = tf.train.shuffle_batch([distorted_image, label],batch_size=batch_size,
num_threads=1,capacity=2000,min_after_dequeue=1000) return images, label_batch class network(object):
#构造函数初始化 卷积层 全连接层
def __init__(self):
with tf.variable_scope("weights"):
self.weights={ 'conv1':tf.get_variable('conv1',[4,4,3,20],initializer=tf.contrib.layers.xavier_initializer_conv2d()), 'conv2':tf.get_variable('conv2',[3,3,20,40],initializer=tf.contrib.layers.xavier_initializer_conv2d()), 'conv3':tf.get_variable('conv3',[3,3,40,60],initializer=tf.contrib.layers.xavier_initializer_conv2d()), 'fc1':tf.get_variable('fc1',[3*3*60,120],initializer=tf.contrib.layers.xavier_initializer()),
'fc2':tf.get_variable('fc2',[120,2],initializer=tf.contrib.layers.xavier_initializer()), }
with tf.variable_scope("biases"):
self.biases={
'conv1':tf.get_variable('conv1',[20,],initializer=tf.constant_initializer(value=0.0, dtype=tf.float32)),
'conv2':tf.get_variable('conv2',[40,],initializer=tf.constant_initializer(value=0.0, dtype=tf.float32)),
'conv3':tf.get_variable('conv3',[60,],initializer=tf.constant_initializer(value=0.0, dtype=tf.float32)), 'fc1':tf.get_variable('fc1',[120,],initializer=tf.constant_initializer(value=0.0, dtype=tf.float32)),
'fc2':tf.get_variable('fc2',[2,],initializer=tf.constant_initializer(value=0.0, dtype=tf.float32)), } def buildnet(self,images):
#向量转为矩阵
images = tf.reshape(images, shape=[-1, 39,39, 3])# [batch, in_height, in_width, in_channels]
images=(tf.cast(images,tf.float32)/255.-0.5)*2#归一化处理 #第一层
conv1=tf.nn.bias_add(tf.nn.conv2d(images, self.weights['conv1'], strides=[1, 1, 1, 1], padding='SAME'),
self.biases['conv1'])
relu1= tf.nn.relu(conv1)
pool1=tf.nn.max_pool(relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID') #第二层
conv2=tf.nn.bias_add(tf.nn.conv2d(pool1, self.weights['conv2'], strides=[1, 1, 1, 1], padding='VALID'),
self.biases['conv2'])
relu2= tf.nn.relu(conv2)
pool2=tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID') # 第三层
conv3=tf.nn.bias_add(tf.nn.conv2d(pool2, self.weights['conv3'], strides=[1, 1, 1, 1], padding='VALID'),
self.biases['conv3'])
relu3= tf.nn.relu(conv3)
pool3=tf.nn.max_pool(relu3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID') # 全连接层1,先把特征图转为向量
flatten = tf.reshape(pool3, [-1, self.weights['fc1'].get_shape().as_list()[0]])
drop1=tf.nn.dropout(flatten,0.5)
fc1=tf.matmul(drop1, self.weights['fc1'])+self.biases['fc1']
fc_relu1=tf.nn.relu(fc1)
fc2=tf.matmul(fc_relu1, self.weights['fc2'])+self.biases['fc2']
return fc2 #计算softmax交叉熵损失函数
def softmax_loss(self,predicts,labels):
predicts=tf.nn.softmax(predicts)
labels=tf.one_hot(labels,self.weights['fc2'].get_shape().as_list()[1])
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = predicts, labels =labels))
self.cost= loss
return self.cost
#梯度下降
def optimer(self,loss,lr=0.01):
train_optimizer = tf.train.GradientDescentOptimizer(lr).minimize(loss) return train_optimizer def train():
image,label=read_and_decode("./train.tfrecords")
batch_image,batch_label=get_batch(image,label,batch_size=30,crop_size=39)
#建立网络,训练所用
net=network()
inf=net.buildnet(batch_image)
loss=net.softmax_loss(inf,batch_label) #计算loss
opti=net.optimer(loss) #梯度下降 init=tf.global_variables_initializer()
with tf.Session() as session:
with tf.device("/gpu:0"):
session.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
max_iter=1000
iter=0
if os.path.exists(os.path.join("model",'model.ckpt')) is True:
tf.train.Saver(max_to_keep=None).restore(session, os.path.join("model",'model.ckpt'))
while iter<max_iter:
loss_np,_,label_np,image_np,inf_np=session.run([loss,opti,batch_image,batch_label,inf])
if iter%50==0:
print 'trainloss:',loss_np
iter+=1
coord.request_stop()#queue需要关闭,否则报错
coord.join(threads)
if __name__ == '__main__':
train()

结果如下:

Total memory: 10.91GiB
Free memory: 10.16GiB
2018-02-02 10:13:24.462286: I tensorflow/core/common_runtime/gpu/gpu_device.cc:961] DMA: 0
2018-02-02 10:13:24.462294: I tensorflow/core/common_runtime/gpu/gpu_device.cc:971] 0: Y
2018-02-02 10:13:24.462303: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1030] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:03:00.0)
trainloss: 0.745739
trainloss: 0.330364
trainloss: 0.317668
trainloss: 0.314964
trainloss: 0.314613
trainloss: 0.314483
trainloss: 0.314132
trainloss: 0.313661

Tensorflow学习教程------读取数据、建立网络、训练模型,小巧而完整的代码示例的更多相关文章

  1. Tensorflow学习教程------过拟合

    Tensorflow学习教程------过拟合   回归:过拟合情况 / 分类过拟合 防止过拟合的方法有三种: 1 增加数据集 2 添加正则项 3 Dropout,意思就是训练的时候隐层神经元每次随机 ...

  2. ASP.NET MVC 5 学习教程:数据迁移之添加字段

    原文 ASP.NET MVC 5 学习教程:数据迁移之添加字段 起飞网 ASP.NET MVC 5 学习教程目录: 添加控制器 添加视图 修改视图和布局页 控制器传递数据给视图 添加模型 创建连接字符 ...

  3. Tensorflow学习教程------代价函数

    Tensorflow学习教程------代价函数   二次代价函数(quadratic cost): 其中,C表示代价函数,x表示样本,y表示实际值,a表示输出值,n表示样本的总数.为简单起见,使用一 ...

  4. TensorFlow queue多线程读取数据

    一.tensorflow读取机制图解 我们必须要把数据先读入后才能进行计算,假设读入用时0.1s,计算用时0.9s,那么就意味着每过1s,GPU都会有0.1s无事可做,这就大大降低了运算的效率. 解决 ...

  5. Tensorflow学习教程------lenet多标签分类

    本文在上篇的基础上利用lenet进行多标签分类.五个分类标准,每个标准分两类.实际来说,本文所介绍的多标签分类属于多任务学习中的联合训练,具体代码如下. #coding:utf-8 import te ...

  6. Tensorflow机器学习入门——读取数据

    TensorFlow 中可以通过三种方式读取数据: 一.通过feed_dict传递数据: input1 = tf.placeholder(tf.float32) input2 = tf.placeho ...

  7. Tensorflow学习教程------实现lenet并且进行二分类

    #coding:utf-8 import tensorflow as tf import os def read_and_decode(filename): #根据文件名生成一个队列 filename ...

  8. TensorFlow从0到1之TensorFlow csv文件读取数据(14)

    大多数人了解 Pandas 及其在处理大数据文件方面的实用性.TensorFlow 提供了读取这种文件的方法. 前面章节中,介绍了如何在 TensorFlow 中读取文件,本节将重点介绍如何从 CSV ...

  9. tensorflow 学习教程

    tensorflow 学习手册 tensorflow 学习手册1:https://cloud.tencent.com/developer/section/1475687 tensorflow 学习手册 ...

随机推荐

  1. Stuts2与SpringMVC

    Struts2:一个基于MVC设计模式的Web应用框架,本质上相当于一个servlet.以WebWork为核心,采用拦截器的机制处理用户的请求(Filter). 轻量级的MVC框架.低侵入性,与业务代 ...

  2. 了解OOM

    1)什么是OOM? OOM,全称“Out Of Memory”,翻译成中文就是“内存用完了”,来源于java.lang.OutOfMemoryError.看下关于的官方说明: Thrown when ...

  3. 关于SOA的架构设计案例分析

    SOA体系架构及相关技术,主要应用在企业应用集成领域,它能够以服务的方式共享和复用企业现有应用资产,保护用户IT投资,并能够以服务的方式构建新的业务流程,对企业流程进行灵活重构和优化,增强业务的敏捷性 ...

  4. Day 4:集合——迭代器与List接口

    Collection-迭代方法 1.toArray()  返回Object类型数据,接收也需要Object对象! Object[] toArray(); Collection c = new Arra ...

  5. CodeForces - 748D Santa Claus and a Palindrome (贪心+构造)

    题意:给定k个长度为n的字符串,每个字符串有一个魅力值ai,在k个字符串中选取字符串组成回文串,使得组成的回文串魅力值最大. 分析: 1.若某字符串不是回文串a,但有与之对称的串b,将串a和串b所有的 ...

  6. CodeForces - 131C The World is a Theatre(组合数)

    题意:已知有n个男生,m个女生.现在要选t个人,要求有至少4个男生,至少1个女生,求有多少种选法. 分析: 1.展开,将分子中的m!与分母中n!相约,即可推出函数C. #pragma comment( ...

  7. zabbix监控tcp/nginx/memcache连接数自定义监控shell

    #!/bin/bashtcp_status_fun(){ TCP_STAT=$1 #netstat -n | awk '/^tcp/ {++state[$NF]} END {for(key in st ...

  8. For循环的几个练习

    1.括号里面只能放加或减,如果要使等式成立,括号里面应该放什么运算符12()34()56()78()9 = 59 2.蓝球弹起的高度篮球从10米高的地方落下,每次弹起的高度是原来的0.3倍,问弹跳10 ...

  9. zTree & ckeditor &ValidateCode.jar 使用个人心得总结

    zTree:依靠 jQuery 实现的多功能 “树插件”. 使用时只需要将下载的压缩包接用,复制里边的css 和 js 到指定目录即可. 如图所示: 在zTree的官网可以找到各种类型树的示例: 地址 ...

  10. SQL基础教程(第2版)第4章 数据更新:4-4 事务

    ●事务是需要在同一个处理单元中执行的一系列更新处理的集合. ● 事务处理的终止指令包括COMMIT(提交处理)和ROLLBACK(取消处理)两种. ● DBMS的事务具有原子性(Atomicity). ...