import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def add_layer(inputs,in_size,out_size,activation_function=None):
Weights = tf.Variable(tf.random_normal([in_size,out_size]))
biases = tf.Variable(tf.zeros([1,out_size]) + 0.1)
Wx_plus_b = tf.matmul(inputs,Weights) + biases

if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
return outputs

#data
x_data = np.linspace(-1,1,300)[:,np.newaxis]
noise = np.random.normal(0,0.05,x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

xs = tf.placeholder(tf.float32,[None,1])
ys = tf.placeholder(tf.float32,[None,1])

l1 = add_layer(xs,1,10,activation_function=tf.nn.relu)
prediction = add_layer(l1,10,1,activation_function=None)

loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

init = tf.initialize_all_variables()

sess = tf.Session()
sess.run(init)

fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data,y_data)
plt.ion()
plt.show()
for i in range(1000):
sess.run(train_step,feed_dict={xs:x_data,ys:y_data})
if i%50 == 0:
# print(sess.run(loss,feed_dict={xs:x_data,ys:y_data}))
try:
ax.lines.remove(lines[0])
except Exception:
pass
prediction_value = sess.run(prediction,feed_dict={xs:x_data,ys:y_data})
lines = ax.plot(x_data,prediction_value,'-r',lw=5)
plt.pause(0.1)

莫烦tensorflow(5)-训练二次函数模型并用matplotlib可视化的更多相关文章

  1. 莫烦tensorflow(1)-训练线性函数模型

    import tensorflow as tfimport numpy as np #create datax_data = np.random.rand(100).astype(np.float32 ...

  2. 莫烦PyTorch学习笔记(五)——模型的存取

    import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed() ...

  3. 莫烦tensorflow(9)-Save&Restore

    import tensorflow as tfimport numpy as np ##save to file#rember to define the same dtype and shape w ...

  4. 莫烦tensorflow(8)-CNN

    import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data#number 1 to 10 dat ...

  5. 莫烦tensorflow(7)-mnist

    import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data#number 1 to 10 dat ...

  6. 莫烦tensorflow(6)-tensorboard

    import tensorflow as tfimport numpy as np def add_layer(inputs,in_size,out_size,n_layer,activation_f ...

  7. 莫烦tensorflow(4)-placeholder

    import tensorflow as tf input1 = tf.placeholder(tf.float32)input2 = tf.placeholder(tf.float32) outpu ...

  8. 莫烦tensorflow(3)-Variable

    import tensorflow as tf state = tf.Variable(0,name='counter') one = tf.constant(1) new_value = tf.ad ...

  9. 莫烦tensorflow(2)-Session

    import os os.environ['TF_CPP_MIN_LOG_LEVEL']='2' import tensorflow as tfmatrix1 = tf.constant([[3,3] ...

随机推荐

  1. MySQL插入命令_INSERT INTO

    MySQL允许将一个或多个元组插入已存在的table中. 格式:INSERT INTO  表名 (属性名1,属性名2,属性名3) VALUES (value1,value2,value3);     ...

  2. 使用python玩跳一跳亲测使用步骤详解

    玩微信跳一跳,测测python跳一跳,顺便蹭一蹭热度: 参考博文 使用python玩跳一跳超详细使用教程 WIN10系统,安卓用户请直入此: python辅助作者github账号为:wangshub. ...

  3. Java基础学习-注释的概述和分类

    /*     注释:用于解释说明程序的文字          分类:             单行://             多行:/**/       作用:解释说明程序,提高程序的阅读性 */ ...

  4. python第一阶段总结(2)

    python3第一阶段的总结 首先申明一下,本人是看网络课程“老男孩”过来写博客的,想把自己学到的东西分享一下.同时给老男孩打个广告,其教学水平真的挺好的.仅据我个人多年的学习评价. 好,接下来是我对 ...

  5. Hi3519v101 SDK安装及升级

    安装SDK 1.解压tgz压缩包 将 Hi3519V101_SDK_Vx.x.x.x.tgz 压缩包放入共享文件夹中,并解压到Linux内部如 /sdk 目录下,因为在共享目录中编译容易出现各种错误. ...

  6. 【Git】Git使用记录: 撤回已经commit到本地的提交记录

    话不多说直接上步骤: git bash直接干到你的code. 直接敲命令: git reset --soft HEAD~1 搞定 就是这么简单粗暴. 如有顾虑请自行找个案例测试即可. 参考资料 htt ...

  7. 在 jupyter 中添加菜单和自动完成功能

    在 jupyter 中添加菜单和自动完成功能 参考文档http://www.360doc.com/content/17/1103/14/1489589_700569828.shtmlhttp://to ...

  8. CSS基础学习(二) 之 width min-width max-width属性

    width 1. 设置元素内容区(content area)的宽度. 2. 如果box-szing属性设置为border-box,width表示border area的宽度,如下图 min-width ...

  9. burpsuit 无法导入证书,抓取https的解决办法

    想用burpsuit中转https流量,需要安装证书: 确保浏览器能访问http 后,访问:http://burp/ 点击右上角下载证书. 然后导入,这些网上都有方法. 但如果你试了后: ①提示导入失 ...

  10. 『Python』socket网络编程

    Python3网络编程 '''无论是str2bytes或者是bytes2str其编码方式都是utf-8 str( ,encoding='utf-8') bytes( ,encoding='utf-8' ...