1. 代码实现

from __future__ import print_function
import numpy as np
import theano
import theano.tensor as T

def compute_accuracy(y_target, y_predict):
    correct_prediction = np.equal(y_predict, y_target)
    accuracy = np.sum(correct_prediction)/len(correct_prediction)
    return accuracy

rng = np.random

N = 400                                   # training sample size
feats = 784                               # number of input variables

# generate a dataset: D = (input_values, target_class)
D = (rng.randn(N, feats), rng.randint(size=N, low=0, high=2))

# Declare Theano symbolic variables
x = T.dmatrix("x")
y = T.dvector("y")

# initialize the weights and biases
W = theano.shared(rng.randn(feats), name="w")
b = theano.shared(0., name="b")

# Construct Theano expression graph
p_1 = T.nnet.sigmoid(T.dot(x, W) + b)   # Logistic Probability that target = 1 (activation function)
prediction = p_1 > 0.5                    # The prediction thresholded
xent = -y * T.log(p_1) - (1-y) * T.log(1-p_1) # Cross-entropy loss function
# or
# xent = T.nnet.binary_crossentropy(p_1, y) # this is provided by theano
cost = xent.mean() + 0.01 * (W ** 2).sum()# The cost to minimize (l2 regularization)
gW, gb = T.grad(cost, [W, b])             # Compute the gradient of the cost

# Compile
learning_rate = 0.1
train = theano.function(
          inputs=[x, y],
          outputs=[prediction, xent.mean()],
          updates=((W, W - learning_rate * gW), (b, b - learning_rate * gb)))
predict = theano.function(inputs=[x], outputs=prediction)

# Training
for i in range(500):
    pred, err = train(D[0], D[1])
    if i % 50 == 0:
        print('cost:', err)
        print("accuracy:", compute_accuracy(D[1], predict(D[0])))

print("target values for D:")
print(D[1])
print("prediction on D:")
print(predict(D[0]))

莫烦theano学习自修第八天【分类问题】的更多相关文章

  1. 莫烦theano学习自修第九天【过拟合问题与正规化】

    如下图所示(回归的过拟合问题):如果机器学习得到的回归为下图中的直线则是比较好的结果,但是如果进一步控制减少误差,导致机器学习到了下图中的曲线,则100%正确的学习了训练数据,看似较好,但是如果换成另 ...

  2. 莫烦sklearn学习自修第八天【过拟合问题】

    1. 什么是过拟合问题 所谓过拟合问题指的是使用训练样本进行训练时100%正确分类或规划,当使用测试样本时则不能正确分类和规划 2. 代码实战(模拟过拟合问题) from __future__ imp ...

  3. 莫烦theano学习自修第十天【保存神经网络及加载神经网络】

    1. 为何保存神经网络 保存神经网络指的是保存神经网络的权重W及偏置b,权重W,和偏置b本身是一个列表,将这两个列表的值写到列表或者字典的数据结构中,使用pickle的数据结构将列表或者字典写入到文件 ...

  4. 莫烦theano学习自修第七天【回归结果可视化】

    1.代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy as ...

  5. 莫烦theano学习自修第六天【回归】

    1. 代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy a ...

  6. 莫烦theano学习自修第五天【定义神经层】

    1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...

  7. 莫烦theano学习自修第四天【激励函数】

    1. 定义 激励函数通常用于隐藏层,是将特征值进行过滤或者激活的算法 2.常见的激励函数 1. sigmoid (1)sigmoid() (2)ultra_fast_sigmoid() (3)hard ...

  8. 莫烦theano学习自修第三天【共享变量】

    1. 代码实现 #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T i ...

  9. 莫烦theano学习自修第二天【激励函数】

    1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...

随机推荐

  1. 【转】使用ffmpeg转码的MP4文件需要加载完了才能播放的解决办法

    1.前一段时间做了一个ffmpeg转码MP4的项目,但是转出来的MP4部署在网站上需要把整个视频加载完成才能播放,到处找资料,最后找到解决方案记录于此备忘. FFMpeg转码由此得到的mp4文件中, ...

  2. SQLAlchemy中的自引用

    SQLALCHEMY采用adjacency list pattern来表示类的自引用. 例如,对于类Node自引用: class Node(Base): __tablename__='node' id ...

  3. docker-1-环境安装及例子实践

    1.安装go 先新建一个Go的工作空间文件夹,文件夹路径建议放在$HOME下: userdeMacBook-Pro:~ user$ cd $HOME userdeMacBook-Pro:~ user$ ...

  4. GFF高仿QQ客户端及服务器

    一.GFF简介 GFF是仿QQ界面,通信基于SAEA.MessageSocket.SAEA.Http.SAEA.MVC实现包含客户端和服务器的程序,源码完全公开,项目源码地址:https://gith ...

  5. Python_装饰器习题_31

    # 1.编写装饰器,为多个函数加上认证的功能(用户的账号密码来源于文件), # 要求登录成功一次,后续的函数都无需再输入用户名和密码 FLAG = False def login(func): def ...

  6. Unix / Linux 线程的实质

    线程与进程的比较 概述: 进程是具有一定独立功能的程序关于某个数据集合上的一次运行活动,进程是系统进行资源分配和调度的一个独立单位. 线程是进程的一个实体,是CPU调度和分派的基本单位,它是比进程更小 ...

  7. 1060E Sergey and Subway(思维题,dfs)

    题意:给出一颗树,现在,给哪些距离为2的点对,加上一条边,问所有点对的距离和 题解:如果没有加入新的边,距离和就会等于每条边的贡献,由于是树,我们用点来代表点上面的边,对于每条边,它的贡献将是(子树大 ...

  8. 1. FPGA内部的逻辑资源

    CLB(包括LUT.加法器.寄存器.MUX(多路选择器)) 时钟网络资源(全局时钟网络,区域时钟网络,IO时钟网络),理解时钟网络的本质和意义 时钟处理单元(PLL,DCM),理解时钟网络资源和时钟处 ...

  9. AtCoder Beginner Contest 122 D - We Like AGC (DP)

    D - We Like AGC Time Limit: 2 sec / Memory Limit: 1024 MB Score : 400400 points Problem Statement Yo ...

  10. Day15 Python基础之logging模块(十三)

    参考源:http://www.cnblogs.com/yuanchenqi/articles/5732581.html logging模块 (****重点***) 一 (简单应用) import lo ...