莫烦theano学习自修第八天【分类问题】
1. 代码实现
from __future__ import print_function
import numpy as np
import theano
import theano.tensor as T
def compute_accuracy(y_target, y_predict):
correct_prediction = np.equal(y_predict, y_target)
accuracy = np.sum(correct_prediction)/len(correct_prediction)
return accuracy
rng = np.random
N = 400 # training sample size
feats = 784 # number of input variables
# generate a dataset: D = (input_values, target_class)
D = (rng.randn(N, feats), rng.randint(size=N, low=0, high=2))
# Declare Theano symbolic variables
x = T.dmatrix("x")
y = T.dvector("y")
# initialize the weights and biases
W = theano.shared(rng.randn(feats), name="w")
b = theano.shared(0., name="b")
# Construct Theano expression graph
p_1 = T.nnet.sigmoid(T.dot(x, W) + b) # Logistic Probability that target = 1 (activation function)
prediction = p_1 > 0.5 # The prediction thresholded
xent = -y * T.log(p_1) - (1-y) * T.log(1-p_1) # Cross-entropy loss function
# or
# xent = T.nnet.binary_crossentropy(p_1, y) # this is provided by theano
cost = xent.mean() + 0.01 * (W ** 2).sum()# The cost to minimize (l2 regularization)
gW, gb = T.grad(cost, [W, b]) # Compute the gradient of the cost
# Compile
learning_rate = 0.1
train = theano.function(
inputs=[x, y],
outputs=[prediction, xent.mean()],
updates=((W, W - learning_rate * gW), (b, b - learning_rate * gb)))
predict = theano.function(inputs=[x], outputs=prediction)
# Training
for i in range(500):
pred, err = train(D[0], D[1])
if i % 50 == 0:
print('cost:', err)
print("accuracy:", compute_accuracy(D[1], predict(D[0])))
print("target values for D:")
print(D[1])
print("prediction on D:")
print(predict(D[0]))
莫烦theano学习自修第八天【分类问题】的更多相关文章
- 莫烦theano学习自修第九天【过拟合问题与正规化】
如下图所示(回归的过拟合问题):如果机器学习得到的回归为下图中的直线则是比较好的结果,但是如果进一步控制减少误差,导致机器学习到了下图中的曲线,则100%正确的学习了训练数据,看似较好,但是如果换成另 ...
- 莫烦sklearn学习自修第八天【过拟合问题】
1. 什么是过拟合问题 所谓过拟合问题指的是使用训练样本进行训练时100%正确分类或规划,当使用测试样本时则不能正确分类和规划 2. 代码实战(模拟过拟合问题) from __future__ imp ...
- 莫烦theano学习自修第十天【保存神经网络及加载神经网络】
1. 为何保存神经网络 保存神经网络指的是保存神经网络的权重W及偏置b,权重W,和偏置b本身是一个列表,将这两个列表的值写到列表或者字典的数据结构中,使用pickle的数据结构将列表或者字典写入到文件 ...
- 莫烦theano学习自修第七天【回归结果可视化】
1.代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy as ...
- 莫烦theano学习自修第六天【回归】
1. 代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy a ...
- 莫烦theano学习自修第五天【定义神经层】
1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...
- 莫烦theano学习自修第四天【激励函数】
1. 定义 激励函数通常用于隐藏层,是将特征值进行过滤或者激活的算法 2.常见的激励函数 1. sigmoid (1)sigmoid() (2)ultra_fast_sigmoid() (3)hard ...
- 莫烦theano学习自修第三天【共享变量】
1. 代码实现 #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T i ...
- 莫烦theano学习自修第二天【激励函数】
1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...
随机推荐
- flash设置里面:您的 Flash 设置会一直保留到您退出 Chrome 为止。
疑问:flash设置里面:您的 Flash 设置会一直保留到您退出 Chrome 为止. 我记得以前的版本配置后就一直用啊,允许的就可以一直允许,现在这个sb版本退出后就恢复到默认,允许列表的网站就清 ...
- windows下简单的缓冲区溢出
缓冲区溢出是什么? 当缓冲区边界限制不严格时,由于变量传入畸形数据或程序运行错误,导致缓冲区被“撑爆”,从而覆盖了相邻内存区域的数据 成功修改内存数据,可造成进程劫持,执行恶意代码,获取服务器控制权等 ...
- 如何让Node.js运行在浏览器端
Node.js又称服务端JavaScript.今天我为了解决一个问题,通过搜索引擎找到了如何将Node.js转成浏览器端可以运行的javascript.尽管这种方式有其局限性,但是还是可以用的. 1. ...
- 机器学习三剑客之Numpy库基本操作
NumPy是Python语言的一个扩充程序库.支持高级大量的维度数组与矩阵运算,此外也针对数组运算提供大量的数学函数库.Numpy内部解除了Python的PIL(全局解释器锁),运算效率极好,是大量机 ...
- 隐写工具Hydan的安装使用方法
Hydan是可以在32位ELF二进制文件里隐藏信息的工具,主要原理是利用了i386指令中的冗余信息. 官网地址:http://www.crazyboy.com/hydan/ 但这个工具最后更新好像是在 ...
- Ubuntu Server 16.04修改IP、DNS、hosts
本文记录下Ubuntu Server 16.04修改IP.DNS.hosts的方法 -------- 1. Ubuntu Server 16.04修改IP sudo vi /etc/network/i ...
- mysql及python交互
mysql在之前写过一次,那时是我刚刚进入博客,今天介绍一下mysql的python交互,当然前面会把mysql基本概述一下. 目录: 一.命令脚本(mysql) 1.基本命令 2.数据库操作命令 3 ...
- python3 pip 安装Scrapy在win10 安装报错error: Microsoft Visual C++ 14.0 is required. Get it with "Microsoft Visual C++ Build Tools": http://landinghub.visualstudio.com/visual-cpp-build-tools
问题描述 当前环境win10,python_3.6.1,64位. 在windows下,在dos中运行pip install Scrapy报错: building 'twisted.test.raise ...
- HDU - 1754 线段树-单点修改+询问区间最大值
这个也是线段树的经验问题,待修改的,动态询问区间的最大值,只需要每次更新的时候,去把利用子节点的信息进行修改即可以. 注意更新的时候区间的选择,需要对区间进行二分. #include<iostr ...
- 利用tushare进行对兴业银行股价的爬取,并使用numpy进行分析
import sysimport tushare as tsimport numpy as npdata=ts.get_h_data('601066')print(data)#读出兴业银行7列数据da ...