1. 代码实现

from __future__ import print_function
import numpy as np
import theano
import theano.tensor as T

def compute_accuracy(y_target, y_predict):
    correct_prediction = np.equal(y_predict, y_target)
    accuracy = np.sum(correct_prediction)/len(correct_prediction)
    return accuracy

rng = np.random

N = 400                                   # training sample size
feats = 784                               # number of input variables

# generate a dataset: D = (input_values, target_class)
D = (rng.randn(N, feats), rng.randint(size=N, low=0, high=2))

# Declare Theano symbolic variables
x = T.dmatrix("x")
y = T.dvector("y")

# initialize the weights and biases
W = theano.shared(rng.randn(feats), name="w")
b = theano.shared(0., name="b")

# Construct Theano expression graph
p_1 = T.nnet.sigmoid(T.dot(x, W) + b)   # Logistic Probability that target = 1 (activation function)
prediction = p_1 > 0.5                    # The prediction thresholded
xent = -y * T.log(p_1) - (1-y) * T.log(1-p_1) # Cross-entropy loss function
# or
# xent = T.nnet.binary_crossentropy(p_1, y) # this is provided by theano
cost = xent.mean() + 0.01 * (W ** 2).sum()# The cost to minimize (l2 regularization)
gW, gb = T.grad(cost, [W, b])             # Compute the gradient of the cost

# Compile
learning_rate = 0.1
train = theano.function(
          inputs=[x, y],
          outputs=[prediction, xent.mean()],
          updates=((W, W - learning_rate * gW), (b, b - learning_rate * gb)))
predict = theano.function(inputs=[x], outputs=prediction)

# Training
for i in range(500):
    pred, err = train(D[0], D[1])
    if i % 50 == 0:
        print('cost:', err)
        print("accuracy:", compute_accuracy(D[1], predict(D[0])))

print("target values for D:")
print(D[1])
print("prediction on D:")
print(predict(D[0]))

莫烦theano学习自修第八天【分类问题】的更多相关文章

  1. 莫烦theano学习自修第九天【过拟合问题与正规化】

    如下图所示(回归的过拟合问题):如果机器学习得到的回归为下图中的直线则是比较好的结果,但是如果进一步控制减少误差,导致机器学习到了下图中的曲线,则100%正确的学习了训练数据,看似较好,但是如果换成另 ...

  2. 莫烦sklearn学习自修第八天【过拟合问题】

    1. 什么是过拟合问题 所谓过拟合问题指的是使用训练样本进行训练时100%正确分类或规划,当使用测试样本时则不能正确分类和规划 2. 代码实战(模拟过拟合问题) from __future__ imp ...

  3. 莫烦theano学习自修第十天【保存神经网络及加载神经网络】

    1. 为何保存神经网络 保存神经网络指的是保存神经网络的权重W及偏置b,权重W,和偏置b本身是一个列表,将这两个列表的值写到列表或者字典的数据结构中,使用pickle的数据结构将列表或者字典写入到文件 ...

  4. 莫烦theano学习自修第七天【回归结果可视化】

    1.代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy as ...

  5. 莫烦theano学习自修第六天【回归】

    1. 代码实现 from __future__ import print_function import theano import theano.tensor as T import numpy a ...

  6. 莫烦theano学习自修第五天【定义神经层】

    1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...

  7. 莫烦theano学习自修第四天【激励函数】

    1. 定义 激励函数通常用于隐藏层,是将特征值进行过滤或者激活的算法 2.常见的激励函数 1. sigmoid (1)sigmoid() (2)ultra_fast_sigmoid() (3)hard ...

  8. 莫烦theano学习自修第三天【共享变量】

    1. 代码实现 #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T i ...

  9. 莫烦theano学习自修第二天【激励函数】

    1. 代码如下: #!/usr/bin/env python #! _*_ coding:UTF-8 _*_ import numpy as np import theano.tensor as T ...

随机推荐

  1. pytorch torch.Storage学习

    tensor分为头信息区(Tensor)和存储区(Storage) 信息区主要保存着tensor的形状(size).步长(stride).数据类型(type)等信息,而真正的数据则保存成连续数组,存储 ...

  2. 缓存表 内存表(将表keep到内存)

    缓存表 内存表(将表keep到内存) 一.引言:     有时候一些基础表需要非常的频繁访问,尤其是在一些循环中,对该表中的访问速度将变的非常重要.为了提高系统的处理性能,可以考虑将一些表及索引读取并 ...

  3. NOIP2018初赛游记

    NOIP2018初赛游记 (编辑中)

  4. 序列化与ArrayList 的elementData的修饰关键字transient

    transient用来表示一个域不是该对象序行化的一部分,当一个对象被序行化的时候,transient修饰的变量不会被序列化 ArrayList的动态数组elementData被transient  ...

  5. web安全:通俗易懂,以实例讲述破解网站的原理及如何进行防护!如何让网站变得更安全。

    本篇以我自己的网站为例来通俗易懂的讲述网站的常见漏洞,如何防止网站被入侵,如何让网站更安全. 要想足够安全,首先得知道其中的道理. 本文例子通俗易懂,主要讲述了 各种漏洞 的原理及防护,相比网上其它的 ...

  6. Java基础之数据比较Integer、Short、int、short

    基础很重要,基础很重要,基础很重要.重要的事情说三遍,. 今天聊一聊Java的数据比较,这个范围比较大,基础类型的比较.引用类型的比较. 前提: 1.Java和c#都提供自动装箱和自动拆箱操作,何为自 ...

  7. 从 0 到 1 实现 react - 9.onChange 事件以及受控组件

    该系列文章在实现 cpreact 的同时理顺 React 框架的核心内容 项目地址 从一个疑问点开始 接上一章 HOC 探索 抛出的问题 ---- react 中的 onChange 事件和原生 DO ...

  8. ZOJ - 1610 经典线段树染色问题

    这个是一个经典线段树染色问题,不过题目给的是左右左右坐标,即[0,3]包含0-1这一段 1-2这一段 2-3这一段,和传统的染色不太一样,不过其实也不用太着急. 我们把左边的坐标+1,即可,那么[0, ...

  9. JS 执行上下文

    先看个小例子 function fn(){ console.log(a);//undefined; var a = 1; } fn(); 为什么打印出来的是 undefined 呢? 执行上下文概念 ...

  10. hibernate添加数据时Exception in thread "main" org.hibernate.PropertyValueException: not-null property references a null or transient value: com.javakc.hibernate.test.entity.TestEntity.testName

    意思是,一个非null属性引用了一个null或瞬态值.就是在对应实体类配置文件hbm.xml中该属性配置了not-null="true",将其去掉即可.