感觉这个比前一套,容易理解些~~

关于数据提前下载的问题:

https://www.jianshu.com/p/5116046733fe

如果使用keras的cifar10.load_data()函数,你会发现,代码会自动去下载 cifar-10-python.tar.gz 文件
实际上,通过查看cifar10.py和site-packages/keras/utils/data_utils.py的get_file函数,你会发现,代码将将下载后的文件存放在 ~./keras/datasets目录下,但是!!!!文件名却被改成了 cifar-10-batches-py.tar.gz

惊不惊喜,意不意外?所以如果要避免下载,已经有数据集了,应该:
cp cifar-10-python.tar.gz ~./keras/datasets/cifar-10-batches-py.tar.gz

完美解决问题!

作者:不爱吃饭的小孩怎么办
链接:https://www.jianshu.com/p/5116046733fe
来源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。

import timeit
import tensorflow as tf
import numpy as np
from tensorflow.keras.datasets.cifar10 import load_data

def model():
    x = tf.placeholder(tf.float32, shape=[None, 32, 32, 3])
    y = tf.placeholder(tf.float32, shape=[None, 10])
    rate = tf.placeholder(tf.float32)
    # convolutional layer 1
    conv_1 = tf.layers.conv2d(x, 32, [3, 3], padding='SAME', activation=tf.nn.relu)
    max_pool_1 = tf.layers.max_pooling2d(conv_1, [2, 2], strides=2, padding='SAME')
    drop_1 = tf.layers.dropout(max_pool_1, rate=rate)
    # convolutional layer 2
    conv_2 = tf.layers.conv2d(drop_1, 64, [3, 3], padding="SAME", activation=tf.nn.relu)
    max_pool_2 = tf.layers.max_pooling2d(conv_2, [2, 2], strides=2, padding="SAME")
    drop_2 = tf.layers.dropout(max_pool_2, rate=rate)
    # convolutional layers 3
    conv_3 = tf.layers.conv2d(drop_2, 128, [3, 3], padding="SAME", activation=tf.nn.relu)
    max_pool_3 = tf.layers.max_pooling2d(conv_3, [2, 2], strides=2, padding="SAME")
    drop_3 = tf.layers.dropout(max_pool_3, rate=rate)
    # fully connected layer 1
    flat = tf.reshape(drop_3, shape=[-1, 4 * 4 * 128])
    fc_1 = tf.layers.dense(flat, 80, activation=tf.nn.relu)
    drop_4 = tf.layers.dropout(fc_1 , rate=rate)
    # fully connected layer 2 or the output layers
    fc_2 = tf.layers.dense(drop_4, 10)
    output = tf.nn.relu(fc_2)
    # accuracy
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(output, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    # loss
    loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=output, labels=y))
    # optimizer
    optimizer = tf.train.AdamOptimizer(1e-4, beta1=0.9, beta2=0.999, epsilon=1e-8).minimize(loss)
    return x, y, rate, accuracy, loss, optimizer

def one_hot_encoder(y):
    ret = np.zeros(len(y) * 10)
    ret = ret.reshape([-1, 10])
    for i in range(len(y)):
        ret[i][y[i]] = 1
    return (ret)

def train(x_train, y_train, sess, x, y, rate, optimizer, accuracy, loss):
    batch_size = 128
    y_train_cls = one_hot_encoder(y_train)
    start = end = 0
    for i in range(int(len(x_train) / batch_size)):
        if (i + 1) % 100 == 1:
            start = timeit.default_timer()
        batch_x = x_train[i * batch_size:(i + 1) * batch_size]
        batch_y = y_train_cls[i * batch_size:(i + 1) * batch_size]
        _, batch_loss, batch_accuracy = sess.run([optimizer, loss, accuracy], feed_dict={x:batch_x, y:batch_y, rate:0.4})
        if (i + 1) % 100 == 0:
            end = timeit.default_timer()
            print("Time:", end-start, "s the loss is ", batch_loss, " and the accuracy is ", batch_accuracy * 100, "%")

def test(x_test, y_test, sess, x, y, rate, accuracy, loss):
    batch_size = 64
    y_test_cls = one_hot_encoder(y_test)
    global_loss = 0
    global_accuracy = 0
    for t in range(int(len(x_test) / batch_size)):
        batch_x = x_test[t * batch_size : (t + 1) * batch_size]
        batch_y = y_test_cls[t * batch_size : (t + 1) * batch_size]
        batch_loss, batch_accuracy = sess.run([loss, accuracy], feed_dict={x:batch_x, y:batch_y, rate:1})
        global_loss += batch_loss
        global_accuracy += batch_accuracy
    global_loss = global_loss / (len(x_test) / batch_size)
    global_accuracy = global_accuracy / (len(x_test) / batch_size)
    print("In Test Time, loss is ", global_loss, ' and the accuracy is ', global_accuracy)

EPOCH = 100
(x_train, y_train), (x_test, y_test) = load_data()
print("There is ", len(x_train), " training images and ", len(x_test), " images")
x, y, rate, accuracy, loss, optimizer = model()
sess = tf.Session()
sess.run(tf.global_variables_initializer())

for i in range(EPOCH):
    print("Train on epoch ", i ," start")
    train(x_train, y_train, sess, x, y, rate, optimizer, accuracy, loss)
    test(x_train, y_train, sess, x, y, rate, accuracy, loss)

再来一个tensorflow的测试性能的代码的更多相关文章

  1. TensorFlow CNN 测试CIFAR-10数据集

    本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/50738311 1 CIFAR-10 数 ...

  2. 有一个很大的整数list,需要求这个list中所有整数的和,写一个可以充分利用多核CPU的代码,来计算结果(转)

    引用 前几天在网上看到一个淘宝的面试题:有一个很大的整数list,需要求这个list中所有整数的和,写一个可以充分利用多核CPU的代码,来计算结果.一:分析题目 从题中可以看到“很大的List”以及“ ...

  3. OpenCV:Mat元素访问方法、性能、代码复杂度以及安全性分析

    欢迎转载,尊重原创,所以转载请注明出处: http://blog.csdn.net/bendanban/article/details/30527785 本文讲述了OpenCV中几种访问矩阵元素的方法 ...

  4. tensorflow笔记:多层LSTM代码分析

    tensorflow笔记:多层LSTM代码分析 标签(空格分隔): tensorflow笔记 tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) ten ...

  5. 新书《编写可测试的JavaScript代码 》出版,感谢支持

    本书介绍 JavaScript专业开发人员必须具备的一个技能是能够编写可测试的代码.不管是创建新应用程序,还是重写遗留代码,本书都将向你展示如何为客户端和服务器编写和维护可测试的JavaScript代 ...

  6. 编写可测试的JavaScript代码

    <编写可测试的JavaScript代码>基本信息作者: [美] Mark Ethan Trostler 托斯勒 著 译者: 徐涛出版社:人民邮电出版社ISBN:9787115373373上 ...

  7. 20135202闫佳歆--week2 一个简单的时间片轮转多道程序内核代码及分析

    一个简单的时间片轮转多道程序内核代码及分析 所用代码为课程配套git库中下载得到的. 一.进程的启动 /*出自mymain.c*/ /* start process 0 by task[0] */ p ...

  8. 【Head First Servlets and JSP】笔记6:什么是响应首部 & 快速搭建一个简单的测试环境

    搭建简单的测试环境 什么是响应首部 最简单的响应首部——Content-Type 设置响应首部 请求重定向与响应首部 在浏览器中查看Response Headers 1.先快速搭建一个简单的测试环境, ...

  9. tensorflow笔记:多层CNN代码分析

    tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代码分析 ...

随机推荐

  1. Word 插入目录详细教程 -- 视频教程(6)

    >> 视频教程链接:B站,速度快,清晰 更多关于插入目录的方法,参看:Word插入目录系列 未完 ...... 点击访问原文(进入后根据右侧标签,快速定位到本文)

  2. php数组到json的转变

    今天做项目遇到个问题,一个接口,输出二维数组,前端说他要的数据格式是数组,而不是对象,就像上个数据一样,我当时就懵逼了,,,什么对象?我明明输出的是数组啊...然后我看了看我返回的json串,emmm ...

  3. python 之 面向对象(元类、__call__、单例模式)

    7.13 元类 元类:类的类就是元类,我们用class定义的类来产生我们自己的对象的,内置元类type是用来专门产生class定义的类 code=""" global x ...

  4. Django框架之第四篇(视图层)--HttpRequest对象、HttpResponse对象、JsonResponse、CBV和FBV、文件上传

    视图层 一.视图函数 一个视图函数,简称视图,是一个简单的python函数,它接收web请求并且会返回web响应.响应可以是一张网页的html,一个重定向,或者是一张图片...任何东西都可以.无论是什 ...

  5. SQL语句报错:Incorrect string value: '\xE9\x98\xBF\xE6\x96\xAF...'

    很明显是编码的问题.检查了一下$conn->query("set names utf8");已经加在代码里了.那莫非是数据库编码不是utf8? 看了一下 还真不是 于是右键要 ...

  6. 【LEETCODE】63、数组分类,hard级别,题目:85、4、84

    package y2019.Algorithm.array.hard; /** * @ProjectName: cutter-point * @Package: y2019.Algorithm.arr ...

  7. CF704D Captain America 上下界网络流

    传送门 现在相当于说每一个条件都有一个染成红色的盾牌的数量限制\([l,r]\),需要满足所有限制且染成红色的盾牌数量最小/最大. 注意到一个盾牌染成红色对于一行和一列都会产生影响.如果选中一个物品对 ...

  8. Java爬虫https网页内容报错SSLHandshakeException信任(忽略)所有SSL证书

    javax.net.ssl.SSLHandshakeException: sun.security.validator.ValidatorException: PKIX path building f ...

  9. 使用Powershell实现自动化安装/卸载程序

    最近需要制作软件安装包,需要附带VC运行时和.Net Framework的安装,但又不想让用户自己点下一步,所以就有了以下操作. 微软提供了一个程序叫msiexec.exe,位于C:\Windows\ ...

  10. .net Dapper 实践系列(4) ---数据查询(Layui+Ajax+Dapper+MySQL)

    写在前面 上一小节,总结了数据显示时,会出现的日期问题.以及如何处理格式化日期.这个小节,主要总结的是使用Dapper 中的QueryMultiple方法依次显示查询多表的数据. 实践步骤 1.在Bo ...