算法步骤:

1. 给定训练样本,x_data和y_data

2. 定义两个占位符分别接收输入x和输出y

3. 中间层操作实际为:权值w与输入x矩阵相乘,加上偏差b后,得到中间层输出

4. 使用tanh函数激活后传给输出层

5. 输出层操作实际为:权值w与中间层结果矩阵相乘,加上偏差b后,得到输出层输出

6. 使用tanh函数激活后得到最终结果

7. 利用y的预测值,与实际的y求出它们间的平均方差,即损失值

8. 最后使用梯度下降法进行训练,使loss最小化

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt # 生成一组离散点
x_data = np.linspace(-0.5, 0.5, 200)[:, np.newaxis]
noise = np.random.normal(0, 0.02, x_data.shape)
y_data = np.square(x_data) + noise # 定义两个占位符
x = tf.placeholder(tf.float32, [None, 1])
y= tf.placeholder(tf.float32, [None, 1]) # 中间层操作
Weights_L1 = tf.Variable(tf.random_normal([1,10]))
biases_L1 = tf.Variable(tf.zeros([1, 10]))
L1 = tf.nn.tanh(tf.matmul(x, Weights_L1)+biases_L1) # 输出层操作
Weights_L2 = tf.Variable(tf.random_normal([10,1]))
biases_L2 = tf.Variable(tf.zeros([1,1]))
prediction = tf.nn.tanh(tf.matmul(L1,Weights_L2)+biases_L2) # 计算损失率
loss = tf.reduce_mean(tf.square(y-prediction))
# 使用梯度下降法训练
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) with tf.Session() as sess:
# 初始化所有变量
sess.run(tf.global_variables_initializer())
for i in range(1000):
sess.run(train_step, feed_dict={x:x_data, y:y_data})
prediction_value = sess.run(prediction, feed_dict={x:x_data, y:y_data}) # 画图展示预测结果
plt.figure()
plt.scatter(x_data, y_data)
plt.plot(x_data, prediction_value,'r-*', lw=5)
plt.show()

总结:

1. tensorflow中训练模型前,必须需先初始化变量,否则会报错

2. 激活函数除了tanh(),还有rule(),Sigmoid(),据说Leaky ReLU 、 PReLU 或者 Maxout效果更佳

3. 梯度下降法中的学习率需小心设置,避免出现过的死亡神经元

如tf.train.GradientDescentOptimizer(0.1).minimize(loss)中0.1为学习率

tensorflow--非线性回归的更多相关文章

  1. tensorflow非线性回归(03-1)

    这个程序为简单的三层结构组成:输入层.中间层.输出层 要理清各层间变量个数 import numpy as np import matplotlib.pyplot as plt import tens ...

  2. tensorflow版helloworld---拟合线性函数的k和b(02-4)

    给不明白深度学习能干什么的同学,感受下深度学习的power import tensorflow as tf import numpy as np #使用numpy生成100个随机点 x_data=np ...

  3. tensorflow中的Fetch、Feed(02-3)

    import tensorflow as tf #Fetch概念 在session中同时运行多个op input1=tf.constant(3.0) #constant()是常量不用进行init初始化 ...

  4. tensorflow变量的使用(02-2)

    import tensorflow as tf x=tf.Variable([1,2]) a=tf.constant([3,3]) sub=tf.subtract(x,a) #增加一个减法op add ...

  5. tensorflow中的图(02-1)

    由于tensorflow版本迭代较快且不同版本的接口会有差距,我这里使用的是1.14.0的版本 安装指定版本的方法:pip install tensorflow==1.14.0      如果你之前安 ...

  6. tensorflow简介、目录

    目前工作为nlp相关的分类及数据治理,之前也使用tensorflow写过一些简单分类的代码,感受到深度学习确实用处较大,想更加系统和全面的学习下tensorflow的相关知识,于是我默默的打开了b站: ...

  7. MNIST手写数字分类simple版(03-2)

    simple版本nn模型 训练手写数字处理 MNIST_data数据   百度网盘链接:https://pan.baidu.com/s/19lhmrts-vz0-w5wv2A97gg 提取码:cgnx ...

  8. TensorFlow(三):非线性回归

    import tensorflow as tf import numpy as np import matplotlib.pyplot as plt # 非线性回归 # 使用numpy生成200个随机 ...

  9. tensorflow 使用 4 非线性回归

    # 输入一个 x 会计算出 y 值 y 是预测值,如果与 真的 y 值(y_data)接近就成功了 import tensorflow as tf import numpy as np # py 的画 ...

  10. Tensorflow学习教程------非线性回归

    自己搭建神经网络求解非线性回归系数 代码 #coding:utf-8 import tensorflow as tf import numpy as np import matplotlib.pypl ...

随机推荐

  1. 深入js系列-类型(隐式强制转换)

    隐式强制转换 在其可控的情况下,减少冗余,让代码更简洁,很多地方都进行了隐式转换,比如常见的三目表达式.if().for().while.逻辑运算符 || &&,适当通过语言机制,抽象 ...

  2. Luogu P3228 HNOI2013 数列 组合数学

    题面 看了题解的推导发现其实并不复杂,但是如果你想要用多项式或者组合数求解的话,就GG了 其实如果把式子列出来的话,不需要怎么推导就能算出来,关键是要想到这个巧妙的式子. 设\(b_i=a_{i+1} ...

  3. nuxtjs如何通过路由meta信息控制路由查看权限

    我们知道NUXTJS可以通过路由中间件进行路由鉴权,中间件允许您定义一个自定义函数运行在一个页面或一组页面渲染之前. 但是我在实际使用过程中发现,中间件只有在路由跳转到路由中时才会进入,而在强制刷新网 ...

  4. shell编程题(五)

    打印root可以使用可执行文件数. echo "root's bins: $(find ./ -type f | xargs ls -l | sed '/-..x/p' | wc -l)&q ...

  5. SqlServer事务语法及使用方法(转)

    原博:http://blog.csdn.net/xiaouncle/article/details/52891563 事务是关于原子性的.原子性的概念是指可以把一些事情当做一个不可分割的单元来看待.从 ...

  6. cocos:C++ 导出到lua, genbindings.py修改

    cocos:C++ 导出到lua, genbindings.py修改 1. 准备 把tools目录下的cocos2dx_extension.ini, genbindings.py, userconf. ...

  7. 【Gamma阶段】第七次Scrum Meeting

    冰多多团队-Gamma阶段第七次Scrum会议 工作情况 团队成员 已完成任务 待完成任务 卓培锦 编辑器风格切换(添加夜间模式) UI界面手势切换 牛雅哲 语音输入shell应用:基于pytorch ...

  8. CentOS安装SonarQube7.9.1

    1.准备 SonarQube版本:sonarqube-7.9.1.zip,官网地址:https://www.sonarqube.org/downloads/ jdk版本:jdk-11.0.4_linu ...

  9. select列表遍历和触发事件

    1.以下两种都是jquery获取select列表被选中的value.var strText=$("#select_id").find("option:selected&q ...

  10. CI框架从哪里看起?CI框架怎么开始学习,CI的初始设置

    很多朋友不知道CI框架从哪里开始学起,想学一个新的框架其实并不难.只要你认真研究,自习摸索都很简单! 概述和基本配置参数 配置CI: application/config/config.php:14配 ...