要确保已经明白神经网络和卷积神经网络的原理.如果不明白,先学习参考资料1.tensorflow中有很多api,可以分成2大类.1类是比较低层的api(tf.train),叫TensorFlow Core.还有1种相对高层的api(tf.contrib.learn),是建立在TensorFlow Core基础上的,这种api码农用着很方便.

环境

python 3.5.3

tensorflow 1.0.0

Tensors

TensorFlow中的基本数据是tensor. tensor可以直观地理解为把numpy中的数组又包了一层.tensor的runk表示tensor是几维的.比如

[1. ,2., 3.] # runnk为1的tensor,它的shape是[3]
[[1., 2., 3.], [4., 5., 6.]] # runnk为2的tensor,它的shape是[2, 3]
[[[1., 2., 3.]], [[7., 8., 9.]]] # runnk为3的tensor,它的shape是[2, 1, 3]

helloworld级使用

使用tensorflow编程有2个步骤.第1是建立computational graph,第2是运行computational graph.computational graph中的每个结点都有0个或多个tensor作为输入.有一种结点本身是个常量,这种结点没有输入,有固定的输出(即它本身).下面是2个结点:

node1 = tf.constant(3.0, tf.float32)
node2 = tf.constant(4.0) # 默认类型就是tf.float32
print(node1, node2)

运行这个代码结果是

Tensor("Const:0", shape=(), dtype=float32) Tensor("Const_1:0", shape=(), dtype=float32)

注意输出中没有具体的值3.0,4.0.这可以理解成建立computational graph.在通过Session运行computational graph的时候才会把值填到结点中.比如

node1 = tf.constant(3.0, tf.float32)
node2 = tf.constant(4.0) # 默认类型就是tf.float32
sess = tf.Session()
print(sess.run([node1, node2]))

稍微复杂一点的例子.

node1 = tf.constant(3.0, tf.float32)
node2 = tf.constant(4.0) # 默认类型就是tf.float32
node3 = tf.add(node1, node2)
print("node3: ", node3)
sess = tf.Session()
print("sess.run(node3): ",sess.run(node3))

placeholder

placeholder有什么作用?placeholder可以用来先定义一种操作,执行的时候再具体赋值.比如python函数的定义

def add(a, b)
return a + b

a,b都没有具体的值,调用的时候才赋值.不严谨但直观地可以把a,b理解为placeholder.下面看tensorflow的placeholder.

a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
adder_node = a + b
print(sess.run(adder_node, {a: 3, b: 4.5})) # 输出为7.5
print(sess.run(adder_node, {a: [1,3], b: [2, 4]})) # 输出为[ 3. 7.]

再看一个例子

import tensorflow as tf
a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
adder_node = a + b
add_and_triple = adder_node * 3
sess = tf.Session()
print(sess.run(add_and_triple, {a: 3, b: 4.5})) # 输出为22.5

Variable

可以简单地认为在训练的各个参数即为Variable.看下面的例子.

import tensorflow as tf
W = tf.Variable([.3], tf.float32)
b = tf.Variable([-.3], tf.float32)
x = tf.placeholder(tf.float32)
linear_model = W * x + b
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
print(sess.run(linear_model, {x: [1, 2, 3, 4]}))

输出为W * x + b的值.下面看怎么使用损失函数.

import tensorflow as tf

b = tf.Variable([-.3], tf.float32)
x = tf.placeholder(tf.float32)
linear_model = W * x + b
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
y = tf.placeholder(tf.float32)
squared_deltas = tf.square(linear_model - y)
loss = tf.reduce_sum(squared_deltas)
print(sess.run(loss, {x: [1, 2, 3, 4], y: [0, -1, -2, -3]})) # 损失函数是23.66

修改w,b的值再看下损失函数.

import tensorflow as tf
W = tf.Variable([.3], tf.float32)
b = tf.Variable([-.3], tf.float32)
x = tf.placeholder(tf.float32)
linear_model = W * x + b
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init) y = tf.placeholder(tf.float32)
squared_deltas = tf.square(linear_model - y)
loss = tf.reduce_sum(squared_deltas) fixW = tf.assign(W, [-1.])
fixb = tf.assign(b, [1.])
sess.run([fixW, fixb])
print(sess.run(loss, {x: [1, 2, 3, 4], y:[0, -1, -2, -3]})) # 损失函数是0

训练方法

现在要解决如下问题:

已知向量x=(1, 2, 3, 4),向量y=(0, -1, -2, -3),w,b是标量.求w,b使y=wx+b

import tensorflow as tf

W = tf.Variable([.3], tf.float32)
b = tf.Variable([-.3], tf.float32)
x = tf.placeholder(tf.float32)
linear_model = W * x + b
y = tf.placeholder(tf.float32)
loss = tf.reduce_sum(tf.square(linear_model - y))
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
x_train = [1,2,3,4]
y_train = [0,-1,-2,-3]
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
for i in range(1000):
sess.run(train, {x: x_train, y: y_train})
# print(sess.run(loss, {x: x_train, y: y_train}))输出loss curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x: x_train, y: y_train})
print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))

上面代码用梯度下降方法求出w,b.现在来验证下wx+b和y相差多少.

import tensorflow as tf

W = tf.constant([-0.9999969], tf.float32)
b = tf.constant([0.99999082], tf.float32)
x = tf.placeholder(tf.float32)
linear_model = W * x + b
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
print(sess.run(linear_model, {x: [1, 2, 3, 4]}))

输出为

[ -6.07967377e-06  -1.00000298e+00  -1.99999988e+00  -2.99999666e+00]

已经相当接近(0, -1, -2, -3).

上面是用TensorFlow Core中的方法训练.也可以用较高层的api(tf.contrib.learn). 因为这种方法过于抽象,反而会分散初学都注意力.以后再补上.

问题

  • 用较高层的api(tf.contrib.learn)训练

参考资料

1 http://cs231n.github.io/

2 tensorflow官方教程

初次接触tensorflow的更多相关文章

  1. tensorflow初次接触记录,我用python写的tensorflow第一个模型

    tensorflow初次接触记录,我用python写的tensorflow第一个模型 刚用python写的tensorflow机器学习代码,训练60000张手写文字图片,多层神经网络学习拟合17000 ...

  2. 初次接触json...

    这两天发现很多网站显示图片版块都用了瀑布流模式布局的:随着页面滚动条向下滚动,这种布局还会不断加载数据并附加至当前尾部.身为一个菜鸟级的程序员,而且以后可能会经常与网站打交道,我觉得我还是很有必要去尝 ...

  3. 初次接触GWT,知识点总括

    初次接触GWT,知识点概括 前言 本人最近开始研究 GWT(Google Web Toolkit) ,现将个人的一点心得贴出来,希望对刚开始接触 GWT的程序员们有所帮助,也欢迎讨论,共同进步. 先说 ...

  4. [Docker]初次接触

    Docker 初次接触 近期看了不少docker介绍性文章,也听了不少公开课,于是今天去官网逛了逛,发现了一个交互式的小教程于是决定跟着学习下. 仅仅是把认为重点的知识记录下来,不是非常系统的学习和笔 ...

  5. 初次接触:DirectDraw

    第六章 初次接触:DirectDraw 本章,你将初次接触DirectX中最重要的组件:DirectDraw.DirectDraw可能是DirectX中最强大的技术,因为其贯穿着2D图形绘制同时其帧缓 ...

  6. 初次接触scrapy框架

    初次接触这个框架,先订个小目标,抓取QQ首页,然后存入记事本. 安装框架(http://scrapy-chs.readthedocs.io/zh_CN/0.24/intro/install.html) ...

  7. javaweb中的乱码问题(初次接触时写)

    javaweb中的乱码问题 在初次接触javaweb中就遇到了乱码问题,下面是我遇到这些问题的解决办法 1. 页面乱码(jsp) 1. 在页面最前方加上 <%@ page language=&q ...

  8. 初次接触Java

    今天初次接触Eclipse,学着用他来建立java工程,话不多说,来看看今天的成果! 熟悉自己手中的开发工具,热热身 刚上手别慌,有问题找度娘 刚刚拿到这个软件的安装包我是一脸懵逼的,因为是从官网下载 ...

  9. -1.记libgdx初次接触

    学习一门技术最难的是开发环境变量配置和工具配置,以下为我初次接触libgdx时遇到的问题 几个难点记录下 gradle 直接用下到本地,然后放到d盘,链接到就行(gradle-wrapper.prop ...

随机推荐

  1. springboot封装JsonUtil,CookieUtil工具类

    springboot封装JsonUtil,CookieUtil工具类 yls 2019-9-23 JsonUtil public class JsonUtil { private static Obj ...

  2. 网站搭建-IIS Windows系统搭建网站 (不小心看到自己的密码 - 怎么找回网站记住的密码)

    上一期说到IIS可以用自己喜欢的网站来直接玩,然后得得瑟瑟将自己的博客园账号首页拿过去玩(今天第一天水博客园). 然后自己访问啊,访问啊,然后就一直点啊点的,当然,其实后面的链接都是跳转到博客园里面去 ...

  3. PHP laravel+thrift+swoole打造微服务框架

    Laravel作为最受欢迎的php web框架一直广受广大互联网公司的喜爱. 笔者也参与过一些由laravel开发的项目.虽然laravel的性能广受诟病但是业界也有一些比较好的解决方案,比如堆机器, ...

  4. nyoj 277-车牌号 (map, pair, iterator)

    277-车牌号 内存限制:64MB 时间限制:3000ms 特判: No 通过数:9 提交数:13 难度:1 题目描述: 茵茵很喜欢研究车牌号码,从车牌号码上可以看出号码注册的早晚,据研究发现,车牌号 ...

  5. 0MQ宗旨

    先来看<Implementing distributed applications with 0MQ and some other bad guys...>.用0MQ去实现分布应用,或者用 ...

  6. vux组件的样式变量的使用

    使用x-header,查看文档发现有个样式变量,可以改变x-header的样式 这玩意怎么用呢? 1.在项目中创建一个.less样式文件,例如我这里是创建一个src/style/vux_theme.l ...

  7. 程序员用于机器学习编程的Python 数据处理库 pandas 入门教程

    入门介绍 pandas适合于许多不同类型的数据,包括: · 具有异构类型列的表格数据,例如SQL表格或Excel数据 · 有序和无序(不一定是固定频率)时间序列数据. · 具有行列标签的任意矩阵数据( ...

  8. scala学习系列二

    一 scala语言开发注意事项: 1 Scala程序的执行入口是main()函数 2 Scala语言严格区分大小写. 3 Scala方法由一条条语句构成,每个语句后不需要分号(Scala语言会在每行后 ...

  9. Spring中,多个service发生嵌套,事务是怎么样的?

    前言 最近在项目中发现了一则报错:"org.springframework.transaction.UnexpectedRollbackException: Transaction roll ...

  10. CTF比赛时准备的一些shell命令

    防御策略: sudo service apache2 start :set fileformat=unix1.写脚本关闭大部分服务,除了ssh       2.改root密码,禁用除了root之外的所 ...