第一章

1、什么是占位符和变量?

无论是占位符还是变量,都是tensor,tensor是tensorflow计算的节点。

占位符和变量是不同类型的tensor。占位符的值由用户自行传递,不依赖于其他tensor,通常用来存储样本数据和标签。

tf.Tensor类是核心类,占位符(tf.placeholder)和变量(tf.Variable)都可以看作特殊的tensor。

可以参阅http://www.tensorfly.cn/tfdoc/how_tos/variables.html

2、什么是会话?变量和占位符在会话中如何传递?

会话是一个核心概念,tensor是图计算的节点,会话是对这些节点进行计算的上下文。

变量是计算过程中可以改变的值的tensor,变量的值会被保存下来。在对变量进行操作前必须进行变量初始化,即在会话中保存变量的初始值。

训练时,每次提取一部分数据进行训练,把他们放入对应的占位符中,在会话中,不需要计算占位符的值,而是直接把占位符的值传递给会话。

会话中,变量的值会被保存下来,占位符的值不会被保存,每次可以给占位符传递不同的值。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# x是一个占位符,表示待识别的图片
# 形状是[None, 784],None表示这一维的大小可以是任意的
x = tf.placeholder(tf.float32, [None, 784])
# 变量参数用tf.Variable
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x, W) + b)
# y_是一个占位符,表示实际的图像标签,独热表示
y_ = tf.placeholder(tf.float32, [None, 10]) # 交叉熵
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y)))
# 梯度下降,学习率是0.01
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) # 创建session,只有在session中才能运行优化步骤train_step
sess = tf.InteractiveSession()
# 运行之前必须要初始化所有变量,分配内存
tf.global_variables_initializer().run()
print('start training...') for _ in range(1000):
# batch_xs: (100, 784), batch_ys: (100, 10)
batch_xs, batch_ys = mnist.train.next_batch(100)
# sess中运行train_step,运行时要使用feed_dict传入对应占位符的值
sess.run(train_step, feed_dict={x: batch_xs, y_:batch_ys})

3、计算图流程(画出思维导图

# 独热表示的y_ 需要通过sess.run(y_)才能获取此tensor的值
print(tf.argmax(y, 1))
# output: Tensor("ArgMax:0", shape=(?,), dtype=int64)
print(tf.argmax(y_, 1))
# output: Tensor("ArgMax_1:0", shape=(?,), dtype=int64) # tf.equal 比较是否相等,输出true和false
# tf.argmax(y,1), tf.argmax(y_,1), 取出数组中最大值的下标,可以用独热表示以及模型输出转换为数字标签
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
# tf.cast 将比较值转换为float32型的变量,true转换为1,false转换为0
# tf.reduce_mean 计算数组中的所有元素的平均值,得到模型的预测准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 使用全体测试样本预测,mnist.test.images, mnist.test.labels
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) # 只有输入了x,y_,通过sess.run才可以计算出correct_prediction,accuracy

4、扩展阅读

第二章

tensorflow的数据读取原理

画出思维导图

Deep learning with PyTorch: A 60 minute blitz _note(1) Tensors的更多相关文章

  1. DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ | TENSORS

    Tensor是一种特殊的数据结构,非常类似于数组和矩阵.在PyTorch中,我们使用tensor编码模型的输入和输出,以及模型的参数. Tensor类似于Numpy的数组,除了tensor可以在GPU ...

  2. DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ | TORCH.AUTOGRAD

    torch.autograd 是PyTorch的自动微分引擎,用以推动神经网络训练.在本节,你将会对autograd如何帮助神经网络训练的概念有所理解. 背景 神经网络(NNs)是在输入数据上执行的嵌 ...

  3. DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ | NEURAL NETWORKS

    神经网络可以使用 torch.nn包构建. 现在你已经对autograd有所了解,nn依赖 autograd 定义模型并对其求微分.nn.Module 包括层,和一个返回 output 的方法 - f ...

  4. DEEP LEARNING WITH PYTORCH: A 60 MINUTE BLITZ | TRAINING A CLASSIFIER

    你已经知道怎样定义神经网络,计算损失和更新网络权重.现在你可能会想, 那么,数据呢? 通常,当你需要解决有关图像.文本或音频数据的问题,你可以使用python标准库加载数据并转换为numpy arra ...

  5. Summary on deep learning framework --- PyTorch

    Summary on deep learning framework --- PyTorch  Updated on 2018-07-22 21:25:42  import osos.environ[ ...

  6. Neural Network Programming - Deep Learning with PyTorch with deeplizard.

    PyTorch Prerequisites - Syllabus for Neural Network Programming Series PyTorch先决条件 - 神经网络编程系列教学大纲 每个 ...

  7. Neural Network Programming - Deep Learning with PyTorch - YouTube

    百度云链接: 链接:https://pan.baidu.com/s/1xU-CxXGCvV6o5Sksryj3fA 提取码:gawn

  8. (zhuan) Where can I start with Deep Learning?

    Where can I start with Deep Learning? By Rotek Song, Deep Reinforcement Learning/Robotics/Computer V ...

  9. rlpyt(Deep Reinforcement Learning in PyTorch)

    rlpyt: A Research Code Base for Deep Reinforcement Learning in PyTorch Github:https://github.com/ast ...

随机推荐

  1. Codeforces Round #505 D. Recovering BST(区间DP)

    首先膜一发网上的题解.大佬们tql. 给你n个单调递增的数字,问是否能够把这些数字重新构成一棵二叉搜索树(BST),且所有的父亲结点和叶子结点之间的gcd > 1? 这个题场上是想暴力试试的.结 ...

  2. Persona5

    65536K   Persona5 is a famous video game. In the game, you are going to build relationship with your ...

  3. 百度地图的API接口----多地址查询和经纬度

    最近看了百度地图的API的接口,正想自己做点小东西,主要是多地址查询和经纬度坐标跟踪, 下面的代码直接另存为html就可以了,目前测试Chrome和360浏览器可以正常使用. <!DOCTYPE ...

  4. 【Codeforces Round #476 (Div. 2) [Thanks, Telegram!] C】Greedy Arkady

    [链接] 我是链接,点我呀:) [题意] 在这里输入题意 [题解] 枚举那个人收到了几次糖i. 最好的情况显然是其他人都只收到i-1次糖. 然后这个人刚好多收了一次糖 也即 (i-1)kx + x & ...

  5. dubbo基础文档

    随着互联网的发展,网站应用的规模不断扩大,常规的垂直应用架构已无法应对,分布式服务架构以及流动计算架构势在必行,亟需一个治理系统确保架构有条不紊的演进. 单一应用架构 当网站流量很小时,只需一个应用, ...

  6. 鼠标在窗口中的坐标转换到 canvas 中的坐标

        鼠标在窗口中的坐标转换到 canvas 中的坐标 由于需要用到isPointInPath函数,所以必须得将鼠标在窗口中的坐标位置转换到canvas画布中的坐标,今天发现网上这种非常常见的写法其 ...

  7. Jeddict研究过程中的总结

    一.与作者交流的总结 说来也是惭愧,没有太多的经验,先给大家贴两张图,看看大家能不能发现问题: 在最开始的时候,都处于Gaurav Gupta让我给材料的过程,因为我不是缺这个就是缺那个,根本说不清楚 ...

  8. Hibernate框架简述(转)

    转自:http://www.cnblogs.com/eflylab/archive/2007/01/09/615338.html Hibernate的核心组件在基于MVC设计模式的JAVA WEB应用 ...

  9. Linux Shell系列教程之(七)Shell输出

    本文是Linux Shell系列教程的第(七)篇,更多shell教程请看:Linux Shell系列教程 与其他语言一样,Shell中也有输出操作,而且在实际应用中也是非常重要的,今天就为大家介绍下S ...

  10. JSON之解析

    JSON之解析通过TouchJSON\SBJSON\JSONKit\NSJSONSerialization JSON(JavaScript Object Notation) 是一种轻量级的数据交换格式 ...