1、心得: 在使用TensorFlow做非线性拟合的时候注意的一点就是输出层不能使用激活函数,这样就会把整个区间映射到激活函数的值域范围内无法收敛。

# coding:utf-8
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 准备需要拟合的数据点
x_data = np.arange(-2*np.pi,2*np.pi,0.1).reshape(-1,1)
y_data = np.sin(x_data).reshape(-1,1)*2 # 建立TensorFlow网络模型
x = tf.placeholder(tf.float32,[None,1])
y = tf.placeholder(tf.float32,[None,1]) # 定义权重
weights = {
'w1':tf.Variable(tf.random_normal([1,10],stddev=0.1)),
'w2':tf.Variable(tf.random_normal([10,20],stddev=0.1)),
'out':tf.Variable(tf.random_normal([20,1],stddev=0.1))
} biases = {
'b1':tf.Variable(tf.random_normal([10])),
'b2':tf.Variable(tf.random_normal([20])),
'out':tf.Variable(tf.random_normal([1]))
} # 定义模型
def deep_liner_model(_x,_weights,_biases):
y1 = tf.nn.tanh(tf.add(tf.matmul(_x,_weights['w1']),_biases['b1']))
y2 = tf.nn.tanh(tf.add(tf.matmul(y1,_weights['w2']),_biases['b2']))
# 在计算的时候最后一层别使用激活函数,会进行映射不收敛的。
out = tf.add(tf.matmul(y2,_weights['out']),_biases['out'])
return out y_pred = deep_liner_model(x,weights,biases) # 损失函数:使用欧式距离
# loss = tf.sqrt(tf.reduce_sum(tf.pow(y-y_pred,2)))
loss = tf.reduce_mean(tf.square(y-y_pred))
# 优化器:训练方法
optm = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss)
#optm = tf.train.AdadeltaOptimizer(learning_rate=0.01).minimize(loss)
# 准确率:R方评估
R2 = 1 - tf.reduce_sum(tf.pow(y-y_pred,2))/tf.reduce_sum(tf.pow(y-tf.reduce_mean(y_pred),2))
acc_score = tf.reduce_mean(tf.cast(R2,tf.float32)) # 万事俱备只欠训练了。 with tf.Session() as sess:
# 初始化全局变量
sess.run(tf.global_variables_initializer())
# 开始迭代首先使用一万次
for i in range(20000):
sess.run(optm,feed_dict={x:x_data,y:y_data}) if (i+1)%1000==0:
acc = sess.run(acc_score,feed_dict={x:x_data,y:y_data})
avg_loss = sess.run(loss,feed_dict={x:x_data,y:y_data})
print('epoch:%s loss:%s acc:%s'%(i+1,str(avg_loss),str(acc))) y_predict = sess.run(y_pred,feed_dict={x:x_data}) plt.figure('tensorflow',figsize=(12,6))
plt.scatter(x_data, y_data,label='sin(x)的值')
plt.plot(x_data,y_predict,'r',linewidth=1,label='tensorflow拟合值')
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置字体为SimHei显示中文
plt.rcParams['axes.unicode_minus'] = False # 设置正常显示符号
plt.title('tensorflow实现y=sin(x)拟合')
plt.xlabel('x-values',{'size':15})
plt.ylabel('y-values-sin(x)',{'size':15})
plt.legend(loc='upper right')
plt.show()

  

TensorFlow非线性拟合的更多相关文章

  1. Java 使用 Apache commons-math3 线性拟合、非线性拟合实例(带效果图)

    Java 使用 CommonsMath3 的线性和非线性拟合实例,带效果图 例子查看 GitHub Gitee 运行src/main/java/org/wfw/chart/Main.java 即可查看 ...

  2. tensorflow神经网络拟合非线性函数与操作指南

    本实验通过建立一个含有两个隐含层的BP神经网络,拟合具有二次函数非线性关系的方程,并通过可视化展现学习到的拟合曲线,同时随机给定输入值,输出预测值,最后给出一些关键的提示. 源代码如下: # -*- ...

  3. AI - TensorFlow - 过拟合(Overfitting)

    过拟合 过拟合(overfitting,过度学习,过度拟合): 过度准确地拟合了历史数据(精确的区分了所有的训练数据),而对新数据适应性较差,预测时会有很大误差. 过拟合是机器学习中常见的问题,解决方 ...

  4. 2层感知机(神经网络)实现非线性回归(非线性拟合)【pytorch】

    import torch import numpy import random from torch.autograd import Variable import torch.nn.function ...

  5. MATLAB实例:多元函数拟合(线性与非线性)

    MATLAB实例:多元函数拟合(线性与非线性) 作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/ 更多请看:随笔分类 - MATLAB作图 之前写过一篇博 ...

  6. tensorflow之分类学习

    写在前面的话 MNIST教程是tensorflow中文社区的第一课,例程即训练一个 手写数字识别 模型:http://www.tensorfly.cn/tfdoc/tutorials/mnist_be ...

  7. Matlab:拟合(2)

    非线性最小二乘拟合: 解法一:用命令lsqcurvefit function f = curvefun(x, tdata) f = x() + x()*exp() * tdata); %其中x() = ...

  8. matlab最小二乘法数据拟合函数详解

    定义: 最小二乘法(又称最小平方法)是一种数学优化技术.它通过最小化误差的平方和寻找数据的最佳函数匹配.利用最小二乘法可 以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小. ...

  9. scipy插值与拟合

    原文链接:https://zhuanlan.zhihu.com/p/28149195 1.最小二乘拟合 实例1 import numpy as np import matplotlib.pyplot ...

随机推荐

  1. JXM 监控tomcat 7(含代码

    1.在tomcat的server.xml中加入: <Listener className="org.apache.catalina.mbeans.JmxRemoteLifecycleL ...

  2. C#中的反射和扩展方法的运用

    前段时间做了一个练手的小项目,采用的是三层架构,也就是Models,IDAL,DAL,BLL 和 Web , 在DAL层中各个类中有一个方法比较常用,那就是 RowToClass ,顾名思义,也就是将 ...

  3. Go语言【第六篇】:Go循环语句

    Go语言循环语句 在不少实际问题中有许多具有规律性的重复操作,因此在程序中就需要重复执行某些语句,以下为大多数编程语言循环程序的流程如: Go语言提供了以下几种类型循环处理语句: 循环类型 描述 fo ...

  4. HTML5 Web SQL 数据库总结

    Web SQL 数据库 API 并不是 HTML5 规范的一部分,但是它是一个独立的规范,引入了一组使用 SQL 操作客户端数据库的 APIs. 如果你是一个 Web 后端程序员,应该很容易理解 SQ ...

  5. linux虚拟机磁盘扩展与分区大小调整

    有段时间觉得linux虚拟机上的磁盘不太够用,研究了下其磁盘扩展 1.linux虚拟机磁盘扩展 step1. 先关机在编辑虚拟机中,找到硬盘选项增加空间,进行扩展step2. 进入root fdisk ...

  6. JS详细图解作用域链与闭包

    JS详细图解作用域链与闭包 攻克闭包难题 初学JavaScript的时候,我在学习闭包上,走了很多弯路.而这次重新回过头来对基础知识进行梳理,要讲清楚闭包,也是一个非常大的挑战. 闭包有多重要?如果你 ...

  7. mysql 迁移 mariadb

    背景: mysql5.7数据库安装在windows环境中,数据需要迁移到CentOS7.4的mariadb5.5中.web应用是采用springboot2.x开发的,迁移数据完成后,还需要简单修改一些 ...

  8. CentOS scp远程拷贝

    scp(secure copy)是一个基于 SSH 协议在网络之间进行安全传输的命令, 其格式为“scp [参数] 本地文件 远程帐户@远程 IP 地址:远程目录”. 1.主要参数 -v 显示详细的连 ...

  9. BZOJ1486:[HNOI2009]最小圈——题解

    https://www.lydsy.com/JudgeOnline/problem.php?id=1486 https://www.luogu.org/problemnew/show/P3199 题面 ...

  10. BZOJ1010:[HNOI2008]玩具装箱——题解

    http://www.lydsy.com/JudgeOnline/problem.php?id=1010 P教授要去看奥运,但是他舍不下他的玩具,于是他决定把所有的玩具运到北京.他使用自己的压缩器进行 ...