tensorflow学习之搭建最简单的神经网络
这几天在B站看莫烦的视频,学习一波,给出视频地址:https://www.bilibili.com/video/av16001891/?p=22
先放出代码
#####搭建神经网络测试
def add_layer(inputs,in_size,out_size,activation_function=None):
Weights = tf.Variable(tf.random_normal([in_size, out_size],dtype=np.float32))
biases = tf.Variable(tf.zeros([1,out_size])+0.1)
Wx_plus_b = tf.matmul(inputs, Weights)+biases
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
return outputs x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0,0.05,x_data.shape)
y_data = np.square(x_data)-0.5+noise xs = tf.placeholder(tf.float32,[None,1])
ys = tf.placeholder(tf.float32,[None,1])
l1 = add_layer(xs,1,10,activation_function=tf.nn.relu) prediction = add_layer(l1,10,1,activation_function=None)
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for i in range(1000):
sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
if i% 50 ==0:
print(sess.run(loss,feed_dict={xs:x_data,ys:y_data}))
#####
首先,在add_layer函数中,参数有inputs,in_size,out_size,activation_function=None
其中inupts是输入,in_size是输入维度,out_size是输出维度, activation_function是激活函数,
Weights是权重,维度是(in_size*out_size);
bias是偏置,维度是(1*out_size);
Wx_plus_b的维度和out_size相同;
x_data = np.linspace(-1,1,300)[:, np.newaxis]这步操作,表示生成-1到1之间均匀分布的300个数,然后转换维度,变成(300,1);noise和y_data的维度均和
x_data相同;
xs = tf.placeholder(tf.float32,[None,1])和ys = tf.placeholder(tf.float32,[None,1])表示生成xs和ys变量的占位符,维度是(None,1),不知道有多少行,但只要1列;
l1 = add_layer(xs,1,10,activation_function=tf.nn.relu)表示xs是inputs,in_size是1,out_size是10,激活函数是relu;添加了一层神经网络
prediction = add_layer(l1,10,1,activation_function=None)表示输入是l1,in_size是10,out_size是1,没有激活函数
接下去是计算损失,loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
之后一步是用梯度下降来优化损失函数;
解释一下为什么不直接在add_layer函数中使用x_data:x_data是ndarray格式,Weights是Variable格式,不能直接相乘,所以要在session会话中用字典格式传入x_data和y_data, 也就是sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
tensorflow学习之搭建最简单的神经网络的更多相关文章
- TensorFlow学习笔记(六)循环神经网络
一.循环神经网络简介 循环神经网络的主要用途是处理和预测序列数据.循环神经网络刻画了一个序列当前的输出与之前信息的关系.从网络结构上,循环神经网络会记忆之前的信息,并利用之前的信息影响后面节点的输出. ...
- python日记:用pytorch搭建一个简单的神经网络
最近在学习pytorch框架,给大家分享一个最最最最基本的用pytorch搭建神经网络并且训练的方法.本人是第一次写这种分享文章,希望对初学pytorch的朋友有所帮助! 一.任务 首先说下我们要搭建 ...
- TensorFlow学习笔记13-循环、递归神经网络
循环神经网络(RNN) 卷积网络专门处理网格化的数据,而循环网络专门处理序列化的数据. 一般的神经网络结构为: 一般的神经网络结构的前提假设是:元素之间是相互独立的,输入.输出都是独立的. 现实世界中 ...
- TensorFlow学习笔记(二)深层神经网络
一.深度学习与深层神经网络 深层神经网络是实现“多层非线性变换”的一种方法. 深层神经网络有两个非常重要的特性:深层和非线性. 1.1线性模型的局限性 线性模型:y =wx+b 线性模型的最大特点就是 ...
- 深度学习环境搭建部署(DeepLearning 神经网络)
工作环境 系统:Ubuntu LTS 显卡:GPU NVIDIA驱动:410.93 CUDA:10.0 Python:.x CUDA以及NVIDIA驱动安装,详见https://www.cnblogs ...
- 『TensorFlow』读书笔记_简单卷积神经网络
如果你可视化CNN的各层级结构,你会发现里面的每一层神经元的激活态都对应了一种特定的信息,越是底层的,就越接近画面的纹理信息,如同物品的材质. 越是上层的,就越接近实际内容(能说出来是个什么东西的那些 ...
- tensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试
刚开始学习tf时,我们从简单的地方开始.卷积神经网络(CNN)是由简单的神经网络(NN)发展而来的,因此,我们的第一个例子,就从神经网络开始. 神经网络没有卷积功能,只有简单的三层:输入层,隐藏层和输 ...
- Tensorflow学习:(二)搭建神经网络
一.神经网络的实现过程 1.准备数据集,提取特征,作为输入喂给神经网络 2.搭建神经网络结构,从输入到输出 3.大量特征数据喂给 NN,迭代优化 NN 参数 4.使 ...
- 深度学习(TensorFlow)环境搭建:(三)Ubuntu16.04+CUDA8.0+cuDNN7+Anaconda4.4+Python3.6+TensorFlow1.3
紧接着上一篇的文章<深度学习(TensorFlow)环境搭建:(二)Ubuntu16.04+1080Ti显卡驱动>,这篇文章,主要讲解如何安装CUDA+CUDNN,不过前提是我们是已经把N ...
随机推荐
- Ubuntu安装DroidCamX网络摄像头
1.安装依赖项 sudo apt-get install gcc make linux-headers-`uname -r` 2.安装 cd /tmp/ bits=`getconf LONG_BIT` ...
- Flask【第9篇】:Flask-script组件
flask-script组件 Flask Script扩展提供向Flask插入外部脚本的功能,包括运行一个开发用的服务器,一个定制的Python shell,设置数据库的脚本,cronjobs,及其他 ...
- Python之常用模二(时间、序列号等等)
一.time模块 表示时间的三种方式: 时间戳:数字(计算机能认识的) 时间字符串:t='2012-12-12' 结构化时间:time.struct_time(tm_year=2017, tm_mon ...
- Linux系统中的硬件问题如何排查?(1)
在Linux系统中,对于硬件故障问题的排查可能是计算机管理领域最棘手的工作,即使是经验相当丰富的用户有时也会遇上自己搞不定的状况,本文分享一些实用的技巧与处理方法,希望有助于读者朋友理解.查明并最终搞 ...
- BZOJ 3193: [JLOI2013]地形生成 计数 + 组合 + 动态规划
第一问: 先不考虑山的高度有相同的:直接按照高度降序排序,试着将每一座山插入到前面山的缝隙中. 当然,这并不代表这些山的相对位置是固定的,因为后面高度更低的山是有机会插入进来的,所以就可以做到将所有情 ...
- Devexpress MVC Gridview 获取到增删改的所有行数据(JSON) 并使用SQL事物保存数据
//ModalChargeGridView Gridview的名字//Con_Shp_Chg 数据库表名//ConShpChgUID UID或者是标识列//gs_Language 语言(中英文)//l ...
- React 初试
小Demo, 后面会进行拓展的 import React from 'react'; import ReactDOM from 'react-dom'; class Welcome extends R ...
- 【BZOJ3931】[CQOI2015]网络吞吐量
Description 路由是指通过计算机网络把信息从源地址传输到目的地址的活动,也是计算机网络设计中的重点和难点.网络中实现路由转发的硬件设备称为路由器.为了使数据包最快的到达目的地,路由器需要选择 ...
- 织梦DedeCms技术资料
Dedecms调用文章发布时间的方法 11-20 样式 ([field:pubdate function='strftime("%m-%d",@me)'/]) May 15, 20 ...
- 【Python】学习笔记二:基本数据类型
变量 python的变量不需要提前声明,可以直接输入: >>> str = 'oliver' 此时,str已经被赋值字符串oliver,在赋值之前并没有提前定义与事先声明 打印值 & ...