1、心得: 在使用TensorFlow做非线性拟合的时候注意的一点就是输出层不能使用激活函数,这样就会把整个区间映射到激活函数的值域范围内无法收敛。

# coding:utf-8
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 准备需要拟合的数据点
x_data = np.arange(-2*np.pi,2*np.pi,0.1).reshape(-1,1)
y_data = np.sin(x_data).reshape(-1,1)*2 # 建立TensorFlow网络模型
x = tf.placeholder(tf.float32,[None,1])
y = tf.placeholder(tf.float32,[None,1]) # 定义权重
weights = {
'w1':tf.Variable(tf.random_normal([1,10],stddev=0.1)),
'w2':tf.Variable(tf.random_normal([10,20],stddev=0.1)),
'out':tf.Variable(tf.random_normal([20,1],stddev=0.1))
} biases = {
'b1':tf.Variable(tf.random_normal([10])),
'b2':tf.Variable(tf.random_normal([20])),
'out':tf.Variable(tf.random_normal([1]))
} # 定义模型
def deep_liner_model(_x,_weights,_biases):
y1 = tf.nn.tanh(tf.add(tf.matmul(_x,_weights['w1']),_biases['b1']))
y2 = tf.nn.tanh(tf.add(tf.matmul(y1,_weights['w2']),_biases['b2']))
# 在计算的时候最后一层别使用激活函数,会进行映射不收敛的。
out = tf.add(tf.matmul(y2,_weights['out']),_biases['out'])
return out y_pred = deep_liner_model(x,weights,biases) # 损失函数:使用欧式距离
# loss = tf.sqrt(tf.reduce_sum(tf.pow(y-y_pred,2)))
loss = tf.reduce_mean(tf.square(y-y_pred))
# 优化器:训练方法
optm = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss)
#optm = tf.train.AdadeltaOptimizer(learning_rate=0.01).minimize(loss)
# 准确率:R方评估
R2 = 1 - tf.reduce_sum(tf.pow(y-y_pred,2))/tf.reduce_sum(tf.pow(y-tf.reduce_mean(y_pred),2))
acc_score = tf.reduce_mean(tf.cast(R2,tf.float32)) # 万事俱备只欠训练了。 with tf.Session() as sess:
# 初始化全局变量
sess.run(tf.global_variables_initializer())
# 开始迭代首先使用一万次
for i in range(20000):
sess.run(optm,feed_dict={x:x_data,y:y_data}) if (i+1)%1000==0:
acc = sess.run(acc_score,feed_dict={x:x_data,y:y_data})
avg_loss = sess.run(loss,feed_dict={x:x_data,y:y_data})
print('epoch:%s loss:%s acc:%s'%(i+1,str(avg_loss),str(acc))) y_predict = sess.run(y_pred,feed_dict={x:x_data}) plt.figure('tensorflow',figsize=(12,6))
plt.scatter(x_data, y_data,label='sin(x)的值')
plt.plot(x_data,y_predict,'r',linewidth=1,label='tensorflow拟合值')
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置字体为SimHei显示中文
plt.rcParams['axes.unicode_minus'] = False # 设置正常显示符号
plt.title('tensorflow实现y=sin(x)拟合')
plt.xlabel('x-values',{'size':15})
plt.ylabel('y-values-sin(x)',{'size':15})
plt.legend(loc='upper right')
plt.show()

  

TensorFlow非线性拟合的更多相关文章

  1. Java 使用 Apache commons-math3 线性拟合、非线性拟合实例(带效果图)

    Java 使用 CommonsMath3 的线性和非线性拟合实例,带效果图 例子查看 GitHub Gitee 运行src/main/java/org/wfw/chart/Main.java 即可查看 ...

  2. tensorflow神经网络拟合非线性函数与操作指南

    本实验通过建立一个含有两个隐含层的BP神经网络,拟合具有二次函数非线性关系的方程,并通过可视化展现学习到的拟合曲线,同时随机给定输入值,输出预测值,最后给出一些关键的提示. 源代码如下: # -*- ...

  3. AI - TensorFlow - 过拟合(Overfitting)

    过拟合 过拟合(overfitting,过度学习,过度拟合): 过度准确地拟合了历史数据(精确的区分了所有的训练数据),而对新数据适应性较差,预测时会有很大误差. 过拟合是机器学习中常见的问题,解决方 ...

  4. 2层感知机(神经网络)实现非线性回归(非线性拟合)【pytorch】

    import torch import numpy import random from torch.autograd import Variable import torch.nn.function ...

  5. MATLAB实例:多元函数拟合(线性与非线性)

    MATLAB实例:多元函数拟合(线性与非线性) 作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/ 更多请看:随笔分类 - MATLAB作图 之前写过一篇博 ...

  6. tensorflow之分类学习

    写在前面的话 MNIST教程是tensorflow中文社区的第一课,例程即训练一个 手写数字识别 模型:http://www.tensorfly.cn/tfdoc/tutorials/mnist_be ...

  7. Matlab:拟合(2)

    非线性最小二乘拟合: 解法一:用命令lsqcurvefit function f = curvefun(x, tdata) f = x() + x()*exp() * tdata); %其中x() = ...

  8. matlab最小二乘法数据拟合函数详解

    定义: 最小二乘法(又称最小平方法)是一种数学优化技术.它通过最小化误差的平方和寻找数据的最佳函数匹配.利用最小二乘法可 以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小. ...

  9. scipy插值与拟合

    原文链接:https://zhuanlan.zhihu.com/p/28149195 1.最小二乘拟合 实例1 import numpy as np import matplotlib.pyplot ...

随机推荐

  1. iOS- 优化与封装 APP音效的播放

    1.关于音效 音效又称短音频,是一个声音文件,在应用程序中起到点缀效果,用于提升应用程序的整体用户体验.   我们手机里常见的APP几乎都少不了音效的点缀.   显示实现音效并不复杂,但对我们App很 ...

  2. 【Python】Python中*args 和**kwargs的用法

    好久没有学习Python了,应为工作的需要,再次拾起python,唤起记忆. 当函数的参数不确定时,可以使用*args 和**kwargs,*args 没有key值,**kwargs有key值. 还是 ...

  3. .Net MVC 实现长轮询

    什么是长轮询? 长轮询是“服务器推”技术实现方式的一种,可以将服务端发生的变化实时传送到客户端而无须客户端频繁的地刷新.发送请求. 长轮询原理? 客户端向服务器发送Ajax请求,服务器接收到请求后,保 ...

  4. 洛谷4578 & LOJ2520:[FJOI2018]所罗门王的宝藏——题解

    https://www.luogu.org/problemnew/show/P4578 https://loj.ac/problem/2520 有点水的. 先转换成图论模型,即每个绿宝石,横坐标向纵坐 ...

  5. NOIP2017 列队——平衡树

    平衡树蒟蒻,敲了半天. 其实思路很简单,就是把许多个人合并成一个区间.必要的时候再拆开.(是不是和这个题的动态开点线段树有异曲同工之妙?) 每次操作最多多出来6个点. 理论上时间复杂度是nlogn,空 ...

  6. 创建JavaScript的哈希表Hashtable

    Hashtable是最常用的数据结构之一,但在JavaScript里没有各种数据结构对象.但是我们可以利用动态语言的一些特性来实现一些常用的数据结构和操作,这样可以使一些复杂的代码逻辑更清晰,也更符合 ...

  7. snmp实用篇

    简单网络管理协议(SNMP)是 TCP/IP协议簇的一个应用层协议.在1988年被制定,并被Internet体系结构委员会(IAB)采纳作为一个短期的网络管理解决方案:由于 SNMP的简单性,在Int ...

  8. 监听scrollview

    http://blog.csdn.net/u012527802/article/details/47320009

  9. bzoj 1218 [HNOI2003]激光炸弹 二维前缀和

    [HNOI2003]激光炸弹 Time Limit: 10 Sec  Memory Limit: 162 MBSubmit: 3022  Solved: 1382[Submit][Status][Di ...

  10. lightoj 1245

    lightoj 1245 Harmonic Number (II) 题意:给定一个 n ,求 n/1 + n/2 + …… + n/n 的值(这里的 "/" 是计算机的整数除法,向 ...