tensorflow--非线性回归
算法步骤:
1. 给定训练样本,x_data和y_data
2. 定义两个占位符分别接收输入x和输出y
3. 中间层操作实际为:权值w与输入x矩阵相乘,加上偏差b后,得到中间层输出
4. 使用tanh函数激活后传给输出层
5. 输出层操作实际为:权值w与中间层结果矩阵相乘,加上偏差b后,得到输出层输出
6. 使用tanh函数激活后得到最终结果
7. 利用y的预测值,与实际的y求出它们间的平均方差,即损失值
8. 最后使用梯度下降法进行训练,使loss最小化
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt # 生成一组离散点
x_data = np.linspace(-0.5, 0.5, 200)[:, np.newaxis]
noise = np.random.normal(0, 0.02, x_data.shape)
y_data = np.square(x_data) + noise # 定义两个占位符
x = tf.placeholder(tf.float32, [None, 1])
y= tf.placeholder(tf.float32, [None, 1]) # 中间层操作
Weights_L1 = tf.Variable(tf.random_normal([1,10]))
biases_L1 = tf.Variable(tf.zeros([1, 10]))
L1 = tf.nn.tanh(tf.matmul(x, Weights_L1)+biases_L1) # 输出层操作
Weights_L2 = tf.Variable(tf.random_normal([10,1]))
biases_L2 = tf.Variable(tf.zeros([1,1]))
prediction = tf.nn.tanh(tf.matmul(L1,Weights_L2)+biases_L2) # 计算损失率
loss = tf.reduce_mean(tf.square(y-prediction))
# 使用梯度下降法训练
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) with tf.Session() as sess:
# 初始化所有变量
sess.run(tf.global_variables_initializer())
for i in range(1000):
sess.run(train_step, feed_dict={x:x_data, y:y_data})
prediction_value = sess.run(prediction, feed_dict={x:x_data, y:y_data}) # 画图展示预测结果
plt.figure()
plt.scatter(x_data, y_data)
plt.plot(x_data, prediction_value,'r-*', lw=5)
plt.show()
总结:
1. tensorflow中训练模型前,必须需先初始化变量,否则会报错
2. 激活函数除了tanh(),还有rule(),Sigmoid(),据说Leaky ReLU 、 PReLU 或者 Maxout效果更佳
3. 梯度下降法中的学习率需小心设置,避免出现过的死亡神经元
如tf.train.GradientDescentOptimizer(0.1).minimize(loss)中0.1为学习率
tensorflow--非线性回归的更多相关文章
- tensorflow非线性回归(03-1)
这个程序为简单的三层结构组成:输入层.中间层.输出层 要理清各层间变量个数 import numpy as np import matplotlib.pyplot as plt import tens ...
- tensorflow版helloworld---拟合线性函数的k和b(02-4)
给不明白深度学习能干什么的同学,感受下深度学习的power import tensorflow as tf import numpy as np #使用numpy生成100个随机点 x_data=np ...
- tensorflow中的Fetch、Feed(02-3)
import tensorflow as tf #Fetch概念 在session中同时运行多个op input1=tf.constant(3.0) #constant()是常量不用进行init初始化 ...
- tensorflow变量的使用(02-2)
import tensorflow as tf x=tf.Variable([1,2]) a=tf.constant([3,3]) sub=tf.subtract(x,a) #增加一个减法op add ...
- tensorflow中的图(02-1)
由于tensorflow版本迭代较快且不同版本的接口会有差距,我这里使用的是1.14.0的版本 安装指定版本的方法:pip install tensorflow==1.14.0 如果你之前安 ...
- tensorflow简介、目录
目前工作为nlp相关的分类及数据治理,之前也使用tensorflow写过一些简单分类的代码,感受到深度学习确实用处较大,想更加系统和全面的学习下tensorflow的相关知识,于是我默默的打开了b站: ...
- MNIST手写数字分类simple版(03-2)
simple版本nn模型 训练手写数字处理 MNIST_data数据 百度网盘链接:https://pan.baidu.com/s/19lhmrts-vz0-w5wv2A97gg 提取码:cgnx ...
- TensorFlow(三):非线性回归
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt # 非线性回归 # 使用numpy生成200个随机 ...
- tensorflow 使用 4 非线性回归
# 输入一个 x 会计算出 y 值 y 是预测值,如果与 真的 y 值(y_data)接近就成功了 import tensorflow as tf import numpy as np # py 的画 ...
- Tensorflow学习教程------非线性回归
自己搭建神经网络求解非线性回归系数 代码 #coding:utf-8 import tensorflow as tf import numpy as np import matplotlib.pypl ...
随机推荐
- 安装Ruby 2.3.0
安装最新的redis集群需要用到的rb脚本要ruby2.3.0版本,centos7默认的是2.0.0,很显然报错不行,所以安装呗,使用rvm安装,安装步骤如下: yum -y install curl ...
- NOIp初赛题目整理
NOIp初赛题目整理 这个 blog 用来整理扶苏准备第一轮 csp 时所做的与 csp 没 有 关 系 的历年 noip-J/S 初赛题目,记录了一些我从不知道的细碎知识点,还有一些憨憨题目,不定期 ...
- C程序获取命令行参数
命令行参数 命令行界面中,可执行文件可以在键入命令的同一行中获取参数用于具体的执行命令.无论是Python.Java还是C等等,这些语言都能够获取命令行参数(Command-line argument ...
- PHP重命名文件夹下的文件后缀名
PHP重命名文件夹下的文件后缀名<pre> public function zhuanhouzuiming(){ $lujings='upload/'; $filesnames = sca ...
- kafka topic查看删除
1,查看kafka topic列表,使用--list参数 >bin/kafka-topics.sh --zookeeper 127.0.0.1:2181 --list __consumer_of ...
- python(二)面向对象知识点
模块 别名 import my_module as xxx(别名) 先导入内置模块 再导入第三方模块 再导入自定义模块 from my_module(导入的文件) import *(变量) __all ...
- 025 Linux基础入门-----历史、简介、版本、安装
1.linux历史 Linux最初是由芬兰赫尔辛基大学学生Linus Torvalds由于自己不满意教学中使用的MINIX操作系统, 所以在1990年底由于个人爱好设计出了LINUX系统核心.后来发布 ...
- 【Python】处理Excel的库Xlwings
# # 引入库 import xlwings as xw import time # 打开Excel程序,默认设置:程序可见,只打开不新建工作薄 # app = xw.App(visible=True ...
- Java的多路分支代码,感觉有点意思
/** * @Author hty * @Date 2019-12-16 16:39 * @Version 1.0 */ import java.util.Random; // 比赛结果 enum O ...
- Java学习:static 关键字概述
static 关键字概述 一旦用了static关键字,那么这样的内容不再属于对象自己.而是属于类的,所以凡是本类的对象,都共享同一份. 如果没有static关键字,那么必须首先创建对象,然后通过对象才 ...