相关算法

python代码参考http://blog.csdn.net/zc02051126/article/details/9668439#(作少量修改与注释)

 #coding:utf8
import matplotlib.pylab as plt
import numpy as np
import cPickle class RBM:
def __init__(self,n_visul, n_hidden, max_epoch = 50, batch_size = 110, penalty = 2e-4):
self.n_visible = n_visul
self.n_hidden = n_hidden
self.max_epoch = max_epoch
self.batch_size = batch_size
self.penalty = penalty
self.w = np.random.random((self.n_visible, self.n_hidden)) * 0.1
self.v_bias = np.zeros((1, self.n_visible))
self.h_bias = np.zeros((1, self.n_hidden)) def sigmoid(self, z):
return 1.0 / (1.0 + np.exp( -z )) def forward(self, vis):
return self.sigmoid(np.dot(vis.T, self.w) + self.h_bias) def backward(self, vis):
return self.sigmoid(np.dot(vis, self.w.T) + self.v_bias) def batch(self):
d, N = self.x.shape
num_batchs = int(round(N / self.batch_size)) + 1
groups = np.ravel(np.repeat([range(0, num_batchs)], self.batch_size, axis = 0))
groups=groups[:N]
np.random.shuffle(groups)
batch_data = []
for i in range(0, num_batchs):
index = groups == i
batch_data.append(self.x[:, index])
return batch_data def rbmBB(self, x):
self.x = x
eta = 0.1
momentum = 0.5 #动量项
W = self.w
b = self.h_bias
c = self.v_bias
Winc = np.zeros((self.n_visible, self.n_hidden))
binc = np.zeros(self.n_hidden)
cinc = np.zeros(self.n_visible)
batch_data = self.batch()
num_batch = len(batch_data)
errors = []
for epoch in range(0, self.max_epoch):
err_sum = 0.0
for batch in range(0, num_batch):
num_dims, num_cases = batch_data[batch].shape
data = batch_data[batch]
# 已知可见层,采样出隐藏层
ph = self.forward(data)
ph_states = np.zeros((num_cases, self.n_hidden))
ph_states[ph > np.random.random((num_cases, self.n_hidden))] = 1
# 已知隐藏层,采样出可见层
neg_data = self.backward(ph_states)
neg_data_states = np.zeros((num_cases, num_dims))
neg_data_states[neg_data > np.random.random((num_cases, num_dims))] = 1
neg_data_states = neg_data_states.transpose()
nh = self.forward(neg_data_states)
# CD算法
dW = np.dot(data, ph) - np.dot(neg_data_states, nh)
dc = np.sum(data, axis = 1) - np.sum(neg_data_states, axis = 1)
db = np.sum(ph, axis = 0) - np.sum(nh, axis = 0)
# 刷新参数
Winc = momentum * Winc + eta * (dW / num_cases - self.penalty * W)
binc = momentum * binc + eta * (db / num_cases);
cinc = momentum * cinc + eta * (dc / num_cases);
W = W + Winc
b = b + binc
c = c + cinc
self.w = W
self.h_bais = b
self.v_bias = c
err = np.linalg.norm(data - neg_data.transpose())
err_sum += err
print epoch, err_sum
errors.append(err_sum)
self.errors = errors
self.hiden_value = self.forward(self.x)
h_row, h_col = self.hiden_value.shape
hiden_states = np.zeros((h_row, h_col))
hiden_states[self.hiden_value > np.random.random((h_row, h_col))] = 1
self.rebuild_value = self.backward(hiden_states) def visualize(self, X): #可视化
D, N = X.shape
s = int(np.sqrt(D))
num = int(np.ceil(np.sqrt(N)))
a = np.zeros((num*s + num + 1, num * s + num + 1)) - 1.0
x = 0
y = 0
for i in range(0, N):
z = X[:,i]
z = z.reshape(s,s,order='F')
z = z.transpose()
a[x*s+x:x*s+s+x , y*s+y:y*s+s+y] = z
x = x + 1
if(x >= num):
x = 0
y = y + 1
return a def readData(path):
data = []
for line in open(path, 'r'):
ele = line.split(' ')
tmp = []
for e in ele:
if e != '':
tmp.append(float(e.strip(' ')))
data.append(tmp)
return data if __name__ == '__main__':
f = open('mnist.pkl', 'rb')
training_data, validation_data, test_data = cPickle.load(f)
training_inputs = [np.reshape(x, 784) for x in training_data[0]]
data =training_inputs[:5000]
data = np.array(data)
data = data.transpose()
rbm = Rbm(784, 100,max_epoch = 50)
rbm.rbmBB(data) a = rbm.visualize(data) #(2060L, 2060L)
fig = plt.figure(1)
ax = fig.add_subplot(111)
ax.imshow(a)
plt.title('original data') rebuild_value = rbm.rebuild_value.transpose()
b = rbm.visualize(rebuild_value) #(2060L, 2060L)
fig = plt.figure(2)
ax = fig.add_subplot(111)
ax.imshow(b)
plt.title('rebuild data') hidden_value = rbm.hiden_value.transpose()
c = rbm.visualize(hidden_value) #(782L, 782L)
fig = plt.figure(3)
ax = fig.add_subplot(111)
ax.imshow(c)
plt.title('hidden data') w_value = rbm.w
d = rbm.visualize(w_value) #(291L, 291L)
fig = plt.figure(4)
ax = fig.add_subplot(111)
ax.imshow(d)
plt.title('weight value(w)')
plt.show()

受限玻尔兹曼机RBM的更多相关文章

  1. 基于受限玻尔兹曼机(RBM)的协同过滤

    受限玻尔兹曼机是一种生成式随机神经网络(generative stochastic neural network), 详细介绍可见我的博文<受限玻尔兹曼机(RBM)简介>, 本文主要介绍R ...

  2. 深度学习方法:受限玻尔兹曼机RBM(一)基本概念

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术.应用感兴趣的同学加入. 最近在复习经典机器学习算法的同 ...

  3. 深度学习方法:受限玻尔兹曼机RBM(四)对比散度contrastive divergence,CD

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术.应用感兴趣的同学加入 上篇讲到,如果用Gibbs Sa ...

  4. 深度学习方法:受限玻尔兹曼机RBM(三)模型求解,Gibbs sampling

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术.应用感兴趣的同学加入. 接下来重点讲一下RBM模型求解 ...

  5. 深度学习方法:受限玻尔兹曼机RBM(二)网络模型

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术.应用感兴趣的同学加入 上解上一篇RBM(一)基本概念, ...

  6. 受限玻尔兹曼机RBM—简易详解

  7. 受限玻尔兹曼机(Restricted Boltzmann Machine,RBM)

    这篇写的主要是翻译网上一篇关于受限玻尔兹曼机的tutorial,看了那篇博文之后感觉算法方面讲的很清楚,自己收获很大,这里写下来作为学习之用. 原文网址为:http://imonad.com/rbm/ ...

  8. 受限玻尔兹曼机(RBM)

    能量模型 RBM用到了能量模型. 简单的概括一下能量模型.假设一个孤立系统(总能量$E$一定,粒子个数$N$一定),温度恒定为1,每个粒子有$m$个可能的状态,每个状态对应一个能量$e_i$.那么,在 ...

  9. 受限玻尔兹曼机(RBM)原理总结

    在前面我们讲到了深度学习的两类神经网络模型的原理,第一类是前向的神经网络,即DNN和CNN.第二类是有反馈的神经网络,即RNN和LSTM.今天我们就总结下深度学习里的第三类神经网络模型:玻尔兹曼机.主 ...

随机推荐

  1. mine layer(2008 World Final C)

    类似于扫雷游戏,在一些格子中散布着一些地雷,具体的埋藏位置并不清楚,但知道每个格子及其周围八个格子的地雷总数.请问此时正中间那一行最多可能有多少地雷(题目假定所有的输入都是奇数行的)? 输入: 第一行 ...

  2. 转:115个Java面试题和答案——终极列表(上)

    转自:http://www.importnew.com/10980.html 本文我们将要讨论Java面试中的各种不同类型的面试题,它们可以让雇主测试应聘者的Java和通用的面向对象编程的能力.下面的 ...

  3. 刷固件Layer1到手机FLASH(硬刷)

    开头: 注意:本文章并不是做GSM 嗅探必须的,平时我们刷机叫软刷是刷到内存里面的,断电就消失了,这个是硬刷,刷到flash里面的,断电不消失,开机就运行的. 本文章经过作者实测可行,这只是单个应用程 ...

  4. NSAttributedString的用法

    标签: 以前看到这种字号和颜色不一样的字符串,想出个讨巧的办法就是“¥150”一个UILable,“元/位”一个UILable.今天翻看以前的工程,command点进UITextField中看到[at ...

  5. 初见Gnuplot——时间序列的描述

    研读一本书,<数据之魅:基于开源工具的数据分析>(Data Analysis with Open Source Tools),写的很好.这里,复述一下书中用Gnuplot分析时间序列数据的 ...

  6. Bash简介

    Bash(GNU bourne-Again Shell)是一个为GNU计划编写的Unix shell,它是很多Linux平台默认的使用的shell. shell是一个命令解析器,是介于操作系统内核与用 ...

  7. iOS LaunchScreen启动图设置

    新建的iOS 项目启动画面默认为LaunchScreen.xib 如果想实现一张图片作为启动页,如下图 如果启动不行  记得clear 一下工程 是启动页停留一段时间  只需要在 AppDelegat ...

  8. XMLParser解析xml--内容源自网络(在静态库中不能用GDATA来解析,因为静态库不能加动态库)

    </Books> 从其文档结构我们可以看出,要定义一个Book实体类描述具体的书籍信息,其中用于存储的相关xml文档元素的实例变量与对应元素同名(本例:title.author.summa ...

  9. SQL技巧

    数据查询    且不说你是否正在从事编程方面的工作或者不打算学习SQL,可事实上几乎每一位开发者最终都会遭遇它.你多半还用不着负责创建和维持某个数据库,但你怎么着也该知道以下的一些有关的SQL知识.我 ...

  10. 使用SSMS 2014将本地数据库迁移到Azure SQL Database

    使用SQL Server Management Studio 2014将本地数据库迁移到Azure SQL Database的过程比较简单,在SSMS2014中,有一个任务选项为“将数据库部署到Win ...