莫烦theano学习自修第八天【分类问题】
1. 代码实现
from __future__ import print_function
import numpy as np
import theano
import theano.tensor as T
def compute_accuracy(y_target, y_predict):
correct_prediction = np.equal(y_predict, y_target)
accuracy = np.sum(correct_prediction)/len(correct_prediction)
return accuracy
rng = np.random
N = 400 # training sample size
feats = 784 # number of input variables
# generate a dataset: D = (input_values, target_class)
D = (rng.randn(N, feats), rng.randint(size=N, low=0, high=2))
# Declare Theano symbolic variables
x = T.dmatrix("x")
y = T.dvector("y")
# initialize the weights and biases
W = theano.shared(rng.randn(feats), name="w")
b = theano.shared(0., name="b")
# Construct Theano expression graph
p_1 = T.nnet.sigmoid(T.dot(x, W) + b) # Logistic Probability that target = 1 (activation function)
prediction = p_1 > 0.5 # The prediction thresholded
xent = -y * T.log(p_1) - (1-y) * T.log(1-p_1) # Cross-entropy loss function
# or
# xent = T.nnet.binary_crossentropy(p_1, y) # this is provided by theano
cost = xent.mean() + 0.01 * (W ** 2).sum()# The cost to minimize (l2 regularization)
gW, gb = T.grad(cost, [W, b]) # Compute the gradient of the cost
# Compile
learning_rate = 0.1
train = theano.function(
inputs=[x, y],
outputs=[prediction, xent.mean()],
updates=((W, W - learning_rate * gW), (b, b - learning_rate * gb)))
predict = theano.function(inputs=[x], outputs=prediction)
# Training
for i in range(500):
pred, err = train(D[0], D[1])
if i % 50 == 0:
print('cost:', err)
print("accuracy:", compute_accuracy(D[1], predict(D[0])))
print("target values for D:")
print(D[1])
print("prediction on D:")
print(predict(D[0]))
莫烦theano学习自修第八天【分类问题】的更多相关文章
- 莫烦theano学习自修第九天【过拟合问题与正规化】
如下图所示(回归的过拟合问题):如果机器学习得到的回归为下图中的直线则是比较好的结果,但是如果进一步控制减少误差,导致机器学习到了下图中的曲线,则100%正确的学习了训练数据,看似较好,但是如果换成另 ...
- 莫烦sklearn学习自修第八天【过拟合问题】
1. 什么是过拟合问题 所谓过拟合问题指的是使用训练样本进行训练时100%正确分类或规划,当使用测试样本时则不能正确分类和规划 2. 代码实战(模拟过拟合问题) from __future__ imp ...
- 莫烦theano学习自修第十天【保存神经网络及加载神经网络】
1. 为何保存神经网络 保存神经网络指的是保存神经网络的权重W及偏置b,权重W,和偏置b本身是一个列表,将这两个列表的值写到列表或者字典的数据结构中,使用pickle的数据结构将列表或者字典写入到文件 ...
- 莫烦theano学习自修第七天【回归结果可视化】
1.代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy as ...
- 莫烦theano学习自修第六天【回归】
1. 代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy a ...
- 莫烦theano学习自修第五天【定义神经层】
1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...
- 莫烦theano学习自修第四天【激励函数】
1. 定义 激励函数通常用于隐藏层,是将特征值进行过滤或者激活的算法 2.常见的激励函数 1. sigmoid (1)sigmoid() (2)ultra_fast_sigmoid() (3)hard ...
- 莫烦theano学习自修第三天【共享变量】
1. 代码实现 #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T i ...
- 莫烦theano学习自修第二天【激励函数】
1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...
随机推荐
- maven-resources-plugin插件关于占位符不生效问题
插件版本: <plugin> <artifactId>maven-resources-plugin</artifactId> <version>3.0. ...
- php 依赖注入的实现
当A类需要依赖于B类,也就是说需要在A类中实例化B类的对象来使用时候,如果B类中的功能发生改变,也会导致A类中使用B类的地方也要跟着修改,导致A类与B类高耦合.这个时候解决方式是,A类应该去依赖B类的 ...
- 4、原生jdbc链接数据库常用资源名
原生jdbc链接数据库要素:#MySql:String url="jdbc:mysql://localhost:3306/数据库名";String name="root& ...
- VS2015P配置opencv340
1添加系统环境变量 F:\dongdong\0tool\navidia_cuda_opencv\opencv\build\x64\vc14\bin 注销重启 2 工程配置 选择好工程 x64 包含目 ...
- JDK动态代理(3)--------动态代理具体实现
写个HelloWorld 接口 package com.spring.aop.proxy; public interface HelloWorld { public void sayHello(); ...
- OpenCV3计算机视觉Python语言实现笔记(三)
一.使用OpenCV处理图像 1.不同颜色空间的转换 OpenCV中有数百种关于在不同色彩空间之间转换的方法.当前,在计算机视觉中有三种常用的色彩空间:灰度.BGR以及HSV(Hue, Saturat ...
- 【P1941】 飞扬的小鸟
题目描述 游戏界面是一个长为 nn,高为 mm 的二维平面,其中有 kk 个管道(忽略管道的宽度). 小鸟始终在游戏界面内移动.小鸟从游戏界面最左边任意整数高度位置出发,到达游戏界面最右边时,游戏完成 ...
- JavaScript原生秒表、计时器
可以开始.暂停.清除. 效果图: 下面贴代码: <!DOCTYPE html> <html lang="en"> <head> <meta ...
- 性能调优3:硬盘IO性能
数据库系统严重依赖服务器的资源:CPU,内存和硬盘IO,通常情况下,内存是数据的读写性能最高的存储介质,但是,内存的价格昂贵,这使得系统能够配置的内存容量受到限制,不能大规模用于数据存储:并且内存是易 ...
- WIFI智能配网 - SmartConfig
要开始IoT项目的第一步是什么?当然不是硬件,而是硬件与硬件的连接!即使有各种各样的通信协议没有好的连接方式绝对不行.那外设上没有的屏幕,没有键盘怎末输入密码怎末选择网络?对,这就是WIFI模块最重要 ...