from sys import path
path.append('/home/ustcjing/models/tutorials/image/cifar10/')
import cifar10,cifar10_input
import tensorflow as tf
import math
import numpy as np
import time max_steps=300
batch_size=128
data_dir='/tmp/cifar10_data/cifar-10-batches-bin' def variable_with_weight_loss(shape,stddev,w1):
var=tf.Variable(tf.truncated_normal(shape,stddev=stddev))
if w1 is not None:
weight_loss=tf.multiply(tf.nn.l2_loss(var),w1,name='weight_loss')
tf.add_to_collection('losses','weight_loss') return var cifar10.maybe_download_and_extract()
images_train,labels_train=cifar10_input.distorted_inputs(data_dir=data_dir,batch_size=batch_size)
images_test,labels_test=cifar10_input.inputs(eval_data=True,data_dir=data_dir,batch_size=batch_size) image_holder=tf.placeholder(tf.float32,[batch_size,24,24,3])
label_holder=tf.placeholder(tf.int32,[batch_size]) weight1=variable_with_weight_loss(shape=[5,5,3,64],stddev=5e-2,w1=0.0)
kernel1=tf.nn.conv2d(image_holder,weight1,[1,1,1,1],padding='SAME')
bias1=tf.Variable(tf.constant(0.0,shape=[64]))
conv1=tf.nn.relu(tf.nn.bias_add(kernel1,bias1))
pool1=tf.nn.max_pool(conv1,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME')
norm1=tf.nn.lrn(pool1,4,bias=1.0,alpha=0.001/9.0,beta=0.75) weight2=variable_with_weight_loss(shape=[5,5,64,64],stddev=5e-2,w1=0.0)
kernel2=tf.nn.conv2d(norm1,weight2,[1,1,1,1],padding='SAME')
bias2=tf.Variable(tf.constant(0.1,shape=[64]))
conv2=tf.nn.relu(tf.nn.bias_add(kernel2,bias2))
norm2=tf.nn.lrn(conv2,4,bias=1.0,alpha=0.001/9.0,beta=0.75)
pool2=tf.nn.max_pool(norm2,ksize=[1,3,3,1],strides=[1,2,2,1],padding='SAME') reshape=tf.reshape(pool2,[batch_size,-1])
dim=reshape.get_shape()[1].value
weight3=variable_with_weight_loss(shape=[dim,384],stddev=0.04,w1=0.004)
bias3=tf.variable(tf.constant(0.1,shape=[384]))
local3=tf.nn.relu(tf.matmul(reshape,weight3)+bias3) weight4=variable_with_weight_loss(shape=[384,192],stddev=0.04,w1=0.004)
bias4=tf.Variable9tf.constant(0.1,shape=[192])
local4=tf.nn.relu(tf.matmul(local3,weight4)+bias4) weight5=variable_with_weight_loss(shape=[192,10],stddev=1/192.0,w1=0.0)
bias5=tf.Variable(tf.constant(0.0,shape=[10]))
logits=tf.add(tf.matmul(local4,weight5),bias5) def loss(logits,labels):
labels=tf.cast(labels,tf.int64)
cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,labels=labels,name='cross_entropy_per_example')
cross_entropy_mean=tf.reduce_mean(cross_entropy,name='cross_entropy')
tf.add_to_collection('losses',cross_entropy_mean)
return tf.add_n(tf.get_collection('losses'),name='total_loss') loss=loss(logits,label_holder)
train_op=tf.train.AdamOptimizer(1e-3).minimize(loss)
top_k_op=tf.nn.in_top_k(logits,label_holder,1)
sess=tf.InteractiveSession()
tf.initialize_all_variables().run()
tf.train.start_queue_runners() for step in range(max_steps):
start_time=time.time()
image_batch,label_batch=sess.run([images_train,labels_train])
loss_value=sess.run([train_op,loss],feed_dict={image_holder:image_batch,label_holder:label_batch})
duration=time.time()-start_time
if step%10==0:
examples_per_sec=batch_size/duration
sec_per_batch=float(duration)
format_str=('step %d,loss=%.2f (%.1f examples/sec;%.3f sec/batch)')
print(format_str % (step,loss_value,examples_per_sec,sec_per_batch)) num_examples=1000
num_iter=int(math.ceil(num_examples / batch_size))
true_count=0;
total_sample_count=num_iter*batch_size
step=0
while step<num_iter:
image_batch,label_batch=sess.run([images_test,labels_test])
predictions=sess.run([top_k_op],feed_dict={image_holder:image_batch,label_holder:label_batch}) true_count+=np.sum(predictions)
step+=1 precision=true_count/total_sample_count
print('precision @ 1=%.3f' % precision)

CNN Advanced的更多相关文章

  1. Paper/ Overview | CNN(未完待续)

    目录 I. 基础知识 II. 早期尝试 1. Neocognitron, 1980 2. LeCun, 1989 A. 概况 B. Feature maps & Weight sharing ...

  2. [Tensorflow] Cookbook - CNN

    Convolutional Neural Networks (CNNs) are responsible for the major breakthroughs in image recognitio ...

  3. Tensorflow在CIFAR-10构建CNN

    使用Tensorflow在CIFAR-10二进制数据集上构建CNN 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 Tensorflow机器学习实战指南 利用Tensorflow读取 ...

  4. Paper | FFDNet: Toward a Fast and Flexible Solution for CNN based Image Denoising

    目录 故事背景 核心思想 FFDNet 网络设置 噪声水平图 对子图像的去噪 保证噪声水平图的有效性 如何盲处理 为啥不用短连接 裁剪像素范围 实验 关于噪声水平图的敏感性 盲处理 发表在2018 T ...

  5. Advanced search keywords

    Advanced search options Find what you're looking for in less time. Use the following symbols to quic ...

  6. 【翻译】借助 NeoCPU 在 CPU 上进行 CNN 模型推理优化

    本文翻译自 Yizhi Liu, Yao Wang, Ruofei Yu.. 的  "Optimizing CNN Model Inference on CPUs" 原文链接: h ...

  7. Deep learning:五十一(CNN的反向求导及练习)

    前言: CNN作为DL中最成功的模型之一,有必要对其更进一步研究它.虽然在前面的博文Stacked CNN简单介绍中有大概介绍过CNN的使用,不过那是有个前提的:CNN中的参数必须已提前学习好.而本文 ...

  8. 卷积神经网络(CNN)学习算法之----基于LeNet网络的中文验证码识别

    由于公司需要进行了中文验证码的图片识别开发,最近一段时间刚忙完上线,好不容易闲下来就继上篇<基于Windows10 x64+visual Studio2013+Python2.7.12环境下的C ...

  9. 如何用卷积神经网络CNN识别手写数字集?

    前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...

随机推荐

  1. 十四课 slam&gmapping

    gmapping 根据激光数据(或者深度数据模拟的激光数据)建立地图,在turtlebot里面应用的就是深度数据模拟的激光数据.如果没有激光雷达的话可以使用Kinect. SLAM 机器人在未知环境中 ...

  2. Cloud Design Patterns: Prescriptive Architecture Guidance for Cloud Applications

    January 2014 Containing twenty-four design patterns and ten related guidance topics, this guide arti ...

  3. jqGrid查询案例(实用)

    var ThisTime = getNowFormatDate(); //加载表格 function GetGrid() { var selectedRowIndex = 0; var $gridTa ...

  4. c#缓存介绍

    #缓存介绍(转) 本章导读 缓存主要是为了提高数据的读取速度.因为服务器和应用客户端之间存在着流量的瓶颈,所以读取大容量数据时,使用缓存来直接为客户端服务,可以减少客户端与服务器端的数据交互,从而大大 ...

  5. Spring MVC Hibernate MySQL Integration(集成) CRUD Example Tutorial【摘】

    Spring MVC Hibernate MySQL Integration(集成) CRUD Example Tutorial We learned how to integrate Spring ...

  6. LibreOJ 6000 搭配飞行员(最大流)

    题解:最基础的最大流,按照主飞行员与起点建边,副飞行员与终点建边,可以同坐的主副飞行员之间建边,值均为一,然后跑一边最大流就完美了! 代码如下: #include<queue> #incl ...

  7. MongoDB整理笔记のSharding分片

    这是一种将海量的数据水平扩展的数据库集群系统,数据分表存储在sharding 的各个节点上,使用者通过简单的配置就可以很方便地构建一个分布式MongoDB 集群.MongoDB 的数据分块称为 chu ...

  8. Log--事务日志

    由于日志是顺序写入,而修改数据分散在数据库各个页面,属于随机写入,而磁盘顺序写入速度远高于随机写入,因此主流数据库都采用预写日志的方式来确保数据完整性 1.日志记录的是数据的变化而不是引发数据的操作2 ...

  9. Android 学习笔记 文本文件的读写操作

    activity_main.xml <LinearLayout xmlns:android="http://schemas.android.com/apk/res/android&qu ...

  10. C#时常需要调用C++DLL

    在合作开发时,C#时常需要调用C++DLL,当传递参数时时常遇到问题,尤其是传递和返回字符串是,现总结一下,分享给大家: VC++中主要字符串类型为:LPSTR,LPCSTR, LPCTSTR, st ...