莫烦theano学习自修第七天【回归结果可视化】
1.代码实现
from __future__ import print_function
import theano
import theano.tensor as T
import numpy as np
import matplotlib.pyplot as plt
class Layer(object):
def __init__(self, inputs, in_size, out_size, activation_function=None):
self.W = theano.shared(np.random.normal(0, 1, (in_size, out_size)))
self.b = theano.shared(np.zeros((out_size, )) + 0.1)
self.Wx_plus_b = T.dot(inputs, self.W) + self.b
self.activation_function = activation_function
if activation_function is None:
self.outputs = self.Wx_plus_b
else:
self.outputs = self.activation_function(self.Wx_plus_b)
# Make up some fake data
x_data = np.linspace(-1, 1, 300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise # y = x^2 - 0.5
# show the fake data
#plt.scatter(x_data, y_data)
plt.show()
# determine the inputs dtype
x = T.dmatrix("x")
y = T.dmatrix("y")
# add layers
l1 = Layer(x, 1, 10, T.nnet.relu)
l2 = Layer(l1.outputs, 10, 1, None)
# compute the cost
cost = T.mean(T.square(l2.outputs - y))
# compute the gradients
gW1, gb1, gW2, gb2 = T.grad(cost, [l1.W, l1.b, l2.W, l2.b])
# apply gradient descent
learning_rate = 0.05
train = theano.function(
inputs=[x, y],
outputs=[cost],
updates=[(l1.W, l1.W - learning_rate * gW1),
(l1.b, l1.b - learning_rate * gb1),
(l2.W, l2.W - learning_rate * gW2),
(l2.b, l2.b - learning_rate * gb2)])
# prediction
predict = theano.function(inputs=[x], outputs=l2.outputs)
# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()
for i in range(1000):
# training
err = train(x_data, y_data)
if i % 50 == 0:
# to visualize the result and improvement
try:
ax.lines.remove(lines[0])
except Exception:
pass
prediction_value = predict(x_data)
# plot the prediction
lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
plt.pause(.5)
结果:

莫烦theano学习自修第七天【回归结果可视化】的更多相关文章
- 莫烦theano学习自修第九天【过拟合问题与正规化】
如下图所示(回归的过拟合问题):如果机器学习得到的回归为下图中的直线则是比较好的结果,但是如果进一步控制减少误差,导致机器学习到了下图中的曲线,则100%正确的学习了训练数据,看似较好,但是如果换成另 ...
- 莫烦theano学习自修第十天【保存神经网络及加载神经网络】
1. 为何保存神经网络 保存神经网络指的是保存神经网络的权重W及偏置b,权重W,和偏置b本身是一个列表,将这两个列表的值写到列表或者字典的数据结构中,使用pickle的数据结构将列表或者字典写入到文件 ...
- 莫烦theano学习自修第八天【分类问题】
1. 代码实现 from __future__ import print_function import numpy as np import theano import theano.tensor ...
- 莫烦theano学习自修第六天【回归】
1. 代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy a ...
- 莫烦theano学习自修第五天【定义神经层】
1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...
- 莫烦theano学习自修第三天【共享变量】
1. 代码实现 #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T i ...
- 莫烦theano学习自修第二天【激励函数】
1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...
- 莫烦theano学习自修第一天【常量和矩阵的运算】
1. 代码实现如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ # 导入numpy模块,因为numpy是常用的计算模块 import numpy as ...
- 莫烦sklearn学习自修第七天【交叉验证】
1. 什么是交叉验证 所谓交叉验证指的是将样本分为两组,一组为训练样本,一组为测试样本:对于哪些数据分为训练样本,哪些数据分为测试样本,进行多次拆分,每次将整个样本进行不同的拆分,对这些不同的拆分每个 ...
随机推荐
- [教程]教你如何制作彩色的3D打印Groot
http://mc.dfrobot.com.cn/forum.php?mod=viewthread&tid=24916 准备工作: <ignore_js_op> 3D打印高精度G ...
- Linux中查看你的用户是否为root用户
可以使用sudo -l命令: user@fafsf:/opt/user$ sudo -l [sudo] password for user: //这里是要输入你的密码 Sorry, user user ...
- Linux 实例如何开启 MySQL 慢查询功能
运行 MySQL 时,查询速度比较慢的语句对数据库的影响非常大,这些慢语句大多是写的不够合理或者大数据环境下多表并发查询造成的.MySQL 自带慢查询功能,能记录查询时间超过参数 long_query ...
- oracle SQL 执行进度
SELECT SS.USERNAME, SS.SID, SS.SERIAL#, SS.MACHINE, SS.PROGRAM, SL.OPNAME, SL.TARGET, SL.START_TIME, ...
- C语言之建立线性表
#include<stdio.h> #include<stdlib.h> #define MaxSize 60 #define ElemType int typedef str ...
- Spark笔记-gz压缩存储到HDFS【转】
参考:http://blog.csdn.net/u010454030/article/details/69291663 mergedRDD.saveAsTextFile(outputPath, cla ...
- 从零开始搭建django前后端分离项目 系列四(实战之实时进度)
本项目实现了任务执行的实时进度查询 实现方式 前端websocket + 后端websocket + 后端redis订阅/发布 实现原理 任务执行后,假设用变量num标记任务执行的进度,然后将num发 ...
- 【C# 复习总结】类、继承和接口
1 类 定义新的数据类型以及这些新的数据类型进行相互操作的方法 定义方式: class Cat { } class Cat:object { } C#中所有的类都是默认由object类派生来的,显示指 ...
- [C#]SQL Server Express LocalDb(SqlLocalDb)的一些体会
真觉得自己的知识面还是比较窄,在此之前,居然还不知道SqlLocalDb. SqlLocalDb是啥?其实就是简化SQL Server的本地数据库,可以这样子说,SQL Server既可以作为远程,也 ...
- 教你使用HTML5原生对话框元素,轻松创建模态框组件
HTML 5.2草案加入了新的dialog元素.但是是一种实验技术. 以前,如果我们想要构建任何形式的模式对话框或对话框,我们需要有一个背景,一个关闭按钮,将事件绑定在对话框中的方式安排我们的标记,找 ...