RBM:深度学习之Restricted Boltzmann Machine的BRBM学习+LR分类—Jason niu
from __future__ import print_function
print(__doc__) import numpy as np
import matplotlib.pyplot as plt from scipy.ndimage import convolve
from sklearn import linear_model, datasets, metrics
from sklearn.cross_validation import train_test_split
from sklearn.neural_network import BernoulliRBM
from sklearn.pipeline import Pipeline def nudge_dataset(X, Y):
direction_vectors = [
[[0, 1, 0],[0, 0, 0],[0, 0, 0]],
[[0, 0, 0],[1, 0, 0],[0, 0, 0]],
[[0, 0, 0],[0, 0, 1],[0, 0, 0]],
[[0, 0, 0],[0, 0, 0],[0, 1, 0]]
] shift = lambda x, w: convolve(x.reshape((8, 8)), mode='constant',weights=w).ravel()
X = np.concatenate([X] +
[np.apply_along_axis(shift, 1, X, vector)
for vector in direction_vectors])
Y = np.concatenate([Y for _ in range(5)], axis=0)
return X, Y digits = datasets.load_digits()
X = np.asarray(digits.data, 'float32')
X, Y = nudge_dataset(X, digits.target)
X = (X - np.min(X, 0)) / (np.max(X, 0) + 0.0001) X_train, X_test, Y_train, Y_test = train_test_split(X, Y,test_size=0.2,random_state=0) logistic = linear_model.LogisticRegression()
rbm = BernoulliRBM(random_state=0, verbose=True) classifier = Pipeline(steps=[('rbm', rbm), ('logistic', logistic)]) rbm.learning_rate = 0.06
rbm.n_iter = 20
# More components tend to give better prediction performance, but larger fitting time
rbm.n_components = 100
logistic.C = 6000.0 classifier.fit(X_train, Y_train) logistic_classifier = linear_model.LogisticRegression(C=100.0)
logistic_classifier.fit(X_train, Y_train) print()
print("Logistic regression using RBM features:\n%s\n" % (
metrics.classification_report(
Y_test,classifier.predict(X_test)
)
)) print("Logistic regression using raw pixel features:\n%s\n" % (
metrics.classification_report(
Y_test,
logistic_classifier.predict(X_test)))) plt.figure(figsize=(4.2, 4))
for i, comp in enumerate(rbm.components_):
plt.subplot(10, 10, i + 1)
plt.imshow(comp.reshape((8, 8)), cmap=plt.cm.gray_r,
interpolation='nearest')
plt.xticks(())
plt.yticks(())
plt.suptitle('100 components extracted by RBM', fontsize=16)
plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23) plt.show()
RBM:深度学习之Restricted Boltzmann Machine的BRBM学习+LR分类—Jason niu的更多相关文章
- 限制玻尔兹曼机(Restricted Boltzmann Machine)RBM
假设有一个二部图,每一层的节点之间没有连接,一层是可视层,即输入数据是(v),一层是隐藏层(h),如果假设所有的节点都是随机二值变量节点(只能取0或者1值)同时假设全概率分布满足Boltzmann 分 ...
- 受限玻尔兹曼机(Restricted Boltzmann Machine, RBM) 简介
受限玻尔兹曼机(Restricted Boltzmann Machine,简称RBM)是由Hinton和Sejnowski于1986年提出的一种生成式随机神经网络(generative stochas ...
- 受限玻尔兹曼机(Restricted Boltzmann Machine,RBM)
这篇写的主要是翻译网上一篇关于受限玻尔兹曼机的tutorial,看了那篇博文之后感觉算法方面讲的很清楚,自己收获很大,这里写下来作为学习之用. 原文网址为:http://imonad.com/rbm/ ...
- 受限玻尔兹曼机(Restricted Boltzmann Machine)
受限玻尔兹曼机(Restricted Boltzmann Machine) 作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/ 1. 生成模型 2. 参数学 ...
- 限制Boltzmann机(Restricted Boltzmann Machine)
起源:Boltzmann神经网络 Boltzmann神经网络的结构是由Hopfield递归神经网络改良过来的,Hopfield中引入了统计物理学的能量函数的概念. 即,cost函数由统计物理学的能量函 ...
- 机器学习理论基础学习19---受限玻尔兹曼机(Restricted Boltzmann Machine)
一.背景介绍 玻尔兹曼机 = 马尔科夫随机场 + 隐结点 二.RBM的Representation BM存在问题:inference 精确:untractable: 近似:计算量太大 因此为了使计算简 ...
- Boltzmann Machine 玻尔兹曼机入门
Generative Models 生成模型帮助我们生成新的item,而不只是存储和提取之前的item.Boltzmann Machine就是Generative Models的一种. Boltzma ...
- 受限波兹曼机导论Introduction to Restricted Boltzmann Machines
Suppose you ask a bunch of users to rate a set of movies on a 0-100 scale. In classical factor analy ...
- 限制波尔兹曼机(Restricted Boltzmann Machines)
能量模型的概念从统计力学中得来,它描述着整个系统的某种状态,系统越有序,系统能量波动越小,趋近于平衡状态,系统越无序,能量波动越大.例如:一个孤立的物体,其内部各处的温度不尽相同,那么热就从温度较高的 ...
随机推荐
- Confluence 6 XML 备份失败的问题解决
XML 站点备份仅仅被用于整合到一个新的数据库.设置一个测试服务器 或者 创建一个可用的备份策略 相对 XML 备份来说是更合适的策略. 相关页面: Enabling detailed SQL log ...
- Bootstrap补充
一.一个小知识点 1.截取长屏的操作 2.设置默认格式 3.md,sm, xs 4.空格和没有空格的选择器 二.响应式介绍 - 响应式布局是什么? 同一个网页在不同的终端上呈现不同的布局等 - 响应式 ...
- LeetCode(89):格雷编码
Medium! 题目描述: 格雷编码是一个二进制数字系统,在该系统中,两个连续的数值仅有一个位数的差异. 给定一个代表编码总位数的非负整数 n,打印格雷码序列.格雷码序列必须以 0 开头. 例如,给定 ...
- git如何创建 .gitignore文件
1.右键 点击git bash here 2.输入 touch .gitignore 生成 .gitignore文件 过滤 不上传 node_modules/
- ubuntu下如何编译C语言
ubuntu下如何编译C语言 如果没有gcc编译器的话,使用以下命令获取 ~# sudo apt-get install gcc同时要下载辅助工具 ~# sudo apt-get instal ...
- 常见的HTTP响应状态码解析
概要 状态码的职责是当客户端向服务器端发送请求时,描述返回的请求结果.借助于状态码,浏览器(或者说用户)可以知道服务器是正常的处理了请求,还是出现了错误. 状态码以3位数字和原因短语组成,例如 200 ...
- Python(列表操作应用实战)
# 输入一个数据,删除一个列表中的所有指定元素# 给定的列表数据data = [1,2,3,4,5,6,7,8,9,0,5,4,3,5,"b","a",&quo ...
- MySQL 存储过程与事物
一:存储过程 存储过程可以说是一个记录集吧,它是由一些T-SQL语句组成的代码块,这些T-SQL语句代码像一个方法一样实现一些功能 存储过程的好处: 1.由于数据库执行动作时,是先编 ...
- bat循环打印输出1到10
关键词:for cmd中查看帮助: for /? 循环一个数字序列: FOR /L %variable IN (start,step,end) DO command [command-paramete ...
- webpack学习笔记--配置module
Module module 配置如何处理模块. 配置 Loader rules 配置模块的读取和解析规则,通常用来配置 Loader.其类型是一个数组,数组里每一项都描述了如何去处理部分文件. 配置 ...