再来一个tensorflow的测试性能的代码
感觉这个比前一套,容易理解些~~
关于数据提前下载的问题:
https://www.jianshu.com/p/5116046733fe
如果使用keras的cifar10.load_data()函数,你会发现,代码会自动去下载 cifar-10-python.tar.gz 文件
实际上,通过查看cifar10.py和site-packages/keras/utils/data_utils.py的get_file函数,你会发现,代码将将下载后的文件存放在 ~./keras/datasets目录下,但是!!!!文件名却被改成了 cifar-10-batches-py.tar.gz
惊不惊喜,意不意外?所以如果要避免下载,已经有数据集了,应该:
cp cifar-10-python.tar.gz ~./keras/datasets/cifar-10-batches-py.tar.gz
完美解决问题!
作者:不爱吃饭的小孩怎么办
链接:https://www.jianshu.com/p/5116046733fe
来源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。
import timeit
import tensorflow as tf
import numpy as np
from tensorflow.keras.datasets.cifar10 import load_data
def model():
x = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
y = tf.placeholder(tf.float32, shape=[None, 10])
rate = tf.placeholder(tf.float32)
# convolutional layer 1
conv_1 = tf.layers.conv2d(x, 32, [3, 3], padding='SAME', activation=tf.nn.relu)
max_pool_1 = tf.layers.max_pooling2d(conv_1, [2, 2], strides=2, padding='SAME')
drop_1 = tf.layers.dropout(max_pool_1, rate=rate)
# convolutional layer 2
conv_2 = tf.layers.conv2d(drop_1, 64, [3, 3], padding="SAME", activation=tf.nn.relu)
max_pool_2 = tf.layers.max_pooling2d(conv_2, [2, 2], strides=2, padding="SAME")
drop_2 = tf.layers.dropout(max_pool_2, rate=rate)
# convolutional layers 3
conv_3 = tf.layers.conv2d(drop_2, 128, [3, 3], padding="SAME", activation=tf.nn.relu)
max_pool_3 = tf.layers.max_pooling2d(conv_3, [2, 2], strides=2, padding="SAME")
drop_3 = tf.layers.dropout(max_pool_3, rate=rate)
# fully connected layer 1
flat = tf.reshape(drop_3, shape=[-1, 4 * 4 * 128])
fc_1 = tf.layers.dense(flat, 80, activation=tf.nn.relu)
drop_4 = tf.layers.dropout(fc_1 , rate=rate)
# fully connected layer 2 or the output layers
fc_2 = tf.layers.dense(drop_4, 10)
output = tf.nn.relu(fc_2)
# accuracy
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(output, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# loss
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=output, labels=y))
# optimizer
optimizer = tf.train.AdamOptimizer(1e-4, beta1=0.9, beta2=0.999, epsilon=1e-8).minimize(loss)
return x, y, rate, accuracy, loss, optimizer
def one_hot_encoder(y):
ret = np.zeros(len(y) * 10)
ret = ret.reshape([-1, 10])
for i in range(len(y)):
ret[i][y[i]] = 1
return (ret)
def train(x_train, y_train, sess, x, y, rate, optimizer, accuracy, loss):
batch_size = 128
y_train_cls = one_hot_encoder(y_train)
start = end = 0
for i in range(int(len(x_train) / batch_size)):
if (i + 1) % 100 == 1:
start = timeit.default_timer()
batch_x = x_train[i * batch_size:(i + 1) * batch_size]
batch_y = y_train_cls[i * batch_size:(i + 1) * batch_size]
_, batch_loss, batch_accuracy = sess.run([optimizer, loss, accuracy], feed_dict={x:batch_x, y:batch_y, rate:0.4})
if (i + 1) % 100 == 0:
end = timeit.default_timer()
print("Time:", end-start, "s the loss is ", batch_loss, " and the accuracy is ", batch_accuracy * 100, "%")
def test(x_test, y_test, sess, x, y, rate, accuracy, loss):
batch_size = 64
y_test_cls = one_hot_encoder(y_test)
global_loss = 0
global_accuracy = 0
for t in range(int(len(x_test) / batch_size)):
batch_x = x_test[t * batch_size : (t + 1) * batch_size]
batch_y = y_test_cls[t * batch_size : (t + 1) * batch_size]
batch_loss, batch_accuracy = sess.run([loss, accuracy], feed_dict={x:batch_x, y:batch_y, rate:1})
global_loss += batch_loss
global_accuracy += batch_accuracy
global_loss = global_loss / (len(x_test) / batch_size)
global_accuracy = global_accuracy / (len(x_test) / batch_size)
print("In Test Time, loss is ", global_loss, ' and the accuracy is ', global_accuracy)
EPOCH = 100
(x_train, y_train), (x_test, y_test) = load_data()
print("There is ", len(x_train), " training images and ", len(x_test), " images")
x, y, rate, accuracy, loss, optimizer = model()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(EPOCH):
print("Train on epoch ", i ," start")
train(x_train, y_train, sess, x, y, rate, optimizer, accuracy, loss)
test(x_train, y_train, sess, x, y, rate, accuracy, loss)

再来一个tensorflow的测试性能的代码的更多相关文章
- TensorFlow CNN 测试CIFAR-10数据集
本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50738311 1 CIFAR-10 数 ...
- 有一个很大的整数list,需要求这个list中所有整数的和,写一个可以充分利用多核CPU的代码,来计算结果(转)
引用 前几天在网上看到一个淘宝的面试题:有一个很大的整数list,需要求这个list中所有整数的和,写一个可以充分利用多核CPU的代码,来计算结果.一:分析题目 从题中可以看到“很大的List”以及“ ...
- OpenCV:Mat元素访问方法、性能、代码复杂度以及安全性分析
欢迎转载,尊重原创,所以转载请注明出处: http://blog.csdn.net/bendanban/article/details/30527785 本文讲述了OpenCV中几种访问矩阵元素的方法 ...
- tensorflow笔记:多层LSTM代码分析
tensorflow笔记:多层LSTM代码分析 标签(空格分隔): tensorflow笔记 tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) ten ...
- 新书《编写可测试的JavaScript代码 》出版,感谢支持
本书介绍 JavaScript专业开发人员必须具备的一个技能是能够编写可测试的代码.不管是创建新应用程序,还是重写遗留代码,本书都将向你展示如何为客户端和服务器编写和维护可测试的JavaScript代 ...
- 编写可测试的JavaScript代码
<编写可测试的JavaScript代码>基本信息作者: [美] Mark Ethan Trostler 托斯勒 著 译者: 徐涛出版社:人民邮电出版社ISBN:9787115373373上 ...
- 20135202闫佳歆--week2 一个简单的时间片轮转多道程序内核代码及分析
一个简单的时间片轮转多道程序内核代码及分析 所用代码为课程配套git库中下载得到的. 一.进程的启动 /*出自mymain.c*/ /* start process 0 by task[0] */ p ...
- 【Head First Servlets and JSP】笔记6:什么是响应首部 & 快速搭建一个简单的测试环境
搭建简单的测试环境 什么是响应首部 最简单的响应首部——Content-Type 设置响应首部 请求重定向与响应首部 在浏览器中查看Response Headers 1.先快速搭建一个简单的测试环境, ...
- tensorflow笔记:多层CNN代码分析
tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代码分析 ...
随机推荐
- nginx入门系列之应用场景介绍
目录 HTTP服务器 反向代理服务器 作为一个虚拟主机下多个应用的反向代理 作为多个虚拟主机的反向代理 负载均衡器 简单轮训策略 最小连接数策略 客户端IP哈希策略 服务器权重策略 邮件代理服务器 官 ...
- python获取https并且写文件日志
# -*- coding: utf-8 -*- import os import os.path import shutil import chardet import urllib.request ...
- 【Spring Cloud学习之一】微服务架构
一.网站架构模式发展 单体应用-->SOA-->微服务 1.分布式项目与项目集群分布式项目:根据业务需求进行拆分成N个子系统,多个子系统相互协作才能完成业务流程子系统之间通讯使用RPC远程 ...
- replace的回调函数。
今天在看算法时,看到一些题目,感觉replace的回调函数好奇葩,$0 .$1什么的: JS的replace方法: str.replace(regexp|substr, newSubStr|funct ...
- MD5用户密码加密工具类 MD5Util
一般记录用户密码,我们都是通过MD5加密配置的形式.这里记录一下,MD5加密的工具类. package com.mms.utils; import java.security.MessageDiges ...
- Python之路【第十六篇】:Python并发编程|进程、线程
一.进程和线程 进程 假如有两个程序A和B,程序A在执行到一半的过程中,需要读取大量的数据输入(I/O操作), 而此时CPU只能静静地等待任务A读取完数据才能继续执行,这样就白白浪费了CPU资源. 是 ...
- Go语言【开发】加载JSON配置文件
JSON配置加载 辅助网址,JSON转结构体对应 http://json2struct.mervine.net/ 从JSON文件中加载配置到全局变量中 配置文件 config.json { &quo ...
- golang微服务框架go-micro 入门笔记2.3 micro工具之消息接收和发布
本章节阐述micro消息订阅和发布相关内容 阅读本文前你可能需要进行如下知识储备 golang分布式微服务框架go-micro 入门笔记1:搭建go-micro环境, golang微服务框架go-mi ...
- pytest_03_pycharm运行pytest (转:上海悠悠)
前言 上一篇pytest文档2-用例运行规则已经介绍了如何在cmd执行pytest用例,平常我们写代码在pycharm比较多 写完用例之后,需要调试看看,是不是能正常运行,如果每次跑去cmd执行,太麻 ...
- 记一次 WPS Pro 2019 设备和驱动器图标删除
1.图标预览 先看样式 2.软件不能关闭 百度和腾讯网盘都会创建,但是可以软件关闭,WPS以前也可以,现在新版作妖了 3.注册表删除 你做那我就删~Code:HKEY_CURRENT_USER\Sof ...