1. 代码实现

from __future__ import print_function
import numpy as np
import theano
import theano.tensor as T

def compute_accuracy(y_target, y_predict):
    correct_prediction = np.equal(y_predict, y_target)
    accuracy = np.sum(correct_prediction)/len(correct_prediction)
    return accuracy

rng = np.random

N = 400                                   # training sample size
feats = 784                               # number of input variables

# generate a dataset: D = (input_values, target_class)
D = (rng.randn(N, feats), rng.randint(size=N, low=0, high=2))

# Declare Theano symbolic variables
x = T.dmatrix("x")
y = T.dvector("y")

# initialize the weights and biases
W = theano.shared(rng.randn(feats), name="w")
b = theano.shared(0., name="b")

# Construct Theano expression graph
p_1 = T.nnet.sigmoid(T.dot(x, W) + b)   # Logistic Probability that target = 1 (activation function)
prediction = p_1 > 0.5                    # The prediction thresholded
xent = -y * T.log(p_1) - (1-y) * T.log(1-p_1) # Cross-entropy loss function
# or
# xent = T.nnet.binary_crossentropy(p_1, y) # this is provided by theano
cost = xent.mean() + 0.01 * (W ** 2).sum()# The cost to minimize (l2 regularization)
gW, gb = T.grad(cost, [W, b])             # Compute the gradient of the cost

# Compile
learning_rate = 0.1
train = theano.function(
          inputs=[x, y],
          outputs=[prediction, xent.mean()],
          updates=((W, W - learning_rate * gW), (b, b - learning_rate * gb)))
predict = theano.function(inputs=[x], outputs=prediction)

# Training
for i in range(500):
    pred, err = train(D[0], D[1])
    if i % 50 == 0:
        print('cost:', err)
        print("accuracy:", compute_accuracy(D[1], predict(D[0])))

print("target values for D:")
print(D[1])
print("prediction on D:")
print(predict(D[0]))

莫烦theano学习自修第八天【分类问题】的更多相关文章

  1. 莫烦theano学习自修第九天【过拟合问题与正规化】

    如下图所示(回归的过拟合问题):如果机器学习得到的回归为下图中的直线则是比较好的结果,但是如果进一步控制减少误差,导致机器学习到了下图中的曲线,则100%正确的学习了训练数据,看似较好,但是如果换成另 ...

  2. 莫烦sklearn学习自修第八天【过拟合问题】

    1. 什么是过拟合问题 所谓过拟合问题指的是使用训练样本进行训练时100%正确分类或规划,当使用测试样本时则不能正确分类和规划 2. 代码实战(模拟过拟合问题) from __future__ imp ...

  3. 莫烦theano学习自修第十天【保存神经网络及加载神经网络】

    1. 为何保存神经网络 保存神经网络指的是保存神经网络的权重W及偏置b,权重W,和偏置b本身是一个列表,将这两个列表的值写到列表或者字典的数据结构中,使用pickle的数据结构将列表或者字典写入到文件 ...

  4. 莫烦theano学习自修第七天【回归结果可视化】

    1.代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy as ...

  5. 莫烦theano学习自修第六天【回归】

    1. 代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy a ...

  6. 莫烦theano学习自修第五天【定义神经层】

    1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...

  7. 莫烦theano学习自修第四天【激励函数】

    1. 定义 激励函数通常用于隐藏层,是将特征值进行过滤或者激活的算法 2.常见的激励函数 1. sigmoid (1)sigmoid() (2)ultra_fast_sigmoid() (3)hard ...

  8. 莫烦theano学习自修第三天【共享变量】

    1. 代码实现 #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T i ...

  9. 莫烦theano学习自修第二天【激励函数】

    1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...

随机推荐

  1. flash设置里面:您的 Flash 设置会一直保留到您退出 Chrome 为止。

    疑问:flash设置里面:您的 Flash 设置会一直保留到您退出 Chrome 为止. 我记得以前的版本配置后就一直用啊,允许的就可以一直允许,现在这个sb版本退出后就恢复到默认,允许列表的网站就清 ...

  2. windows下简单的缓冲区溢出

    缓冲区溢出是什么? 当缓冲区边界限制不严格时,由于变量传入畸形数据或程序运行错误,导致缓冲区被“撑爆”,从而覆盖了相邻内存区域的数据 成功修改内存数据,可造成进程劫持,执行恶意代码,获取服务器控制权等 ...

  3. 如何让Node.js运行在浏览器端

    Node.js又称服务端JavaScript.今天我为了解决一个问题,通过搜索引擎找到了如何将Node.js转成浏览器端可以运行的javascript.尽管这种方式有其局限性,但是还是可以用的. 1. ...

  4. 机器学习三剑客之Numpy库基本操作

    NumPy是Python语言的一个扩充程序库.支持高级大量的维度数组与矩阵运算,此外也针对数组运算提供大量的数学函数库.Numpy内部解除了Python的PIL(全局解释器锁),运算效率极好,是大量机 ...

  5. 隐写工具Hydan的安装使用方法

    Hydan是可以在32位ELF二进制文件里隐藏信息的工具,主要原理是利用了i386指令中的冗余信息. 官网地址:http://www.crazyboy.com/hydan/ 但这个工具最后更新好像是在 ...

  6. Ubuntu Server 16.04修改IP、DNS、hosts

    本文记录下Ubuntu Server 16.04修改IP.DNS.hosts的方法 -------- 1. Ubuntu Server 16.04修改IP sudo vi /etc/network/i ...

  7. mysql及python交互

    mysql在之前写过一次,那时是我刚刚进入博客,今天介绍一下mysql的python交互,当然前面会把mysql基本概述一下. 目录: 一.命令脚本(mysql) 1.基本命令 2.数据库操作命令 3 ...

  8. python3 pip 安装Scrapy在win10 安装报错error: Microsoft Visual C++ 14.0 is required. Get it with "Microsoft Visual C++ Build Tools": http://landinghub.visualstudio.com/visual-cpp-build-tools

    问题描述 当前环境win10,python_3.6.1,64位. 在windows下,在dos中运行pip install Scrapy报错: building 'twisted.test.raise ...

  9. HDU - 1754 线段树-单点修改+询问区间最大值

    这个也是线段树的经验问题,待修改的,动态询问区间的最大值,只需要每次更新的时候,去把利用子节点的信息进行修改即可以. 注意更新的时候区间的选择,需要对区间进行二分. #include<iostr ...

  10. 利用tushare进行对兴业银行股价的爬取,并使用numpy进行分析

    import sysimport tushare as tsimport numpy as npdata=ts.get_h_data('601066')print(data)#读出兴业银行7列数据da ...