【EM】代码理解
本来想自己写一个EM算法的,但是操作没两步就进行不下去了。对那些数学公式着实不懂。只好从网上找找代码,看看别人是怎么做的。
代码:来自http://blog.sina.com.cn/s/blog_98b365150101f2xb.html 经验证可用
%EM
M=; % M个高斯分布混合
N=; % 样本数
th=0.000001; % 收敛阈值
K=; % 样本维数
% 待生成数据的参数
a_real =[/;/;/];%混合模型中基模型高斯密度函数的权重
mu_real=[ ; ];%均值
cov_real(:,:,)=[ ; 0.2];%协方差
cov_real(:,:,)=[0.1 ; 0.1];
cov_real(:,:,)=[0.1 ; 0.1];
%生成符合标准的样本数据(每一列为一个样本)
x=[ mvnrnd( mu_real(:,) , cov_real(:,:,) , round(N*a_real()) )' ,...
mvnrnd( mu_real(:,) , cov_real(:,:,) , round(N*a_real()) )' ,...
mvnrnd( mu_real(:,) , cov_real(:,:,) , round(N*a_real()) )' ]; %初始化参数
a=[/;/;/];
mu=[ ; ];
cov(:,:,)=[ ; ];
cov(:,:,)=[ ; ];
cov(:,:,)=[ ; ];
t=inf;
while t>=th
a_old = a;
mu_old = mu;
cov_old= cov;
rznk_temp=zeros(M,N);
for k=:M
for n=:N
%计算P(x|mu_cm,cov_cm)
rznk_temp(k,n)=exp(-/*(x(:,n)-mu(:,k))'*inv(cov(:,:,k))*(x(:,n)-mu(:,k)));
end
rznk_temp(k,:)=rznk_temp(k,:)/sqrt(det(cov(:,:,k)));
end
rznk_temp=rznk_temp*(*pi)^(-K/);
%E step
%求rznk
rznk=zeros(M,N);
for n=:N
for k=:M
rznk(k,n)=a(k)*rznk_temp(k,n);
end
rznk(:,n)=rznk(:,n)/sum(rznk(:,n));
end
% M step
%求Nk
nk=zeros(,M);
nk=sum(rznk'); % 求a
a=nk/N; % 求MU
for k=:M
mu_k_sum=;
for n=:N
mu_k_sum=mu_k_sum+rznk(k,n)*x(:,n);
end
mu(:,k)=mu_k_sum/nk(k);
end % 求COV
for k=:M
cov_k_sum=;
for n=:N
cov_k_sum=cov_k_sum+rznk(k,n)*(x(:,n)-mu(:,k))*(x(:,n)-mu(:,k))';
end
cov(:,:,k)=cov_k_sum/nk(k);
end t=max([norm(a_old(:)-a(:))/norm(a_old(:));norm(mu_old(:)-mu(:))/norm(mu_old(:));norm(cov_old(:)-cov(:))/norm(cov_old(:))]);
end
分解说明:
M=; % M个高斯分布混合
N=; % 样本数
th=0.000001; % 收敛阈值
K=; % 样本维数
% 待生成数据的参数
a_real =[/;/;/];%混合模型中基模型高斯密度函数的权重
mu_real=[ ; ];%均值
cov_real(:,:,)=[ ; 0.2];%协方差
cov_real(:,:,)=[0.1 ; 0.1];
cov_real(:,:,)=[0.1 ; 0.1];
%生成符合标准的样本数据(每一列为一个样本)
x=[ mvnrnd( mu_real(:,) , cov_real(:,:,) , round(N*a_real()) )' ,...
mvnrnd( mu_real(:,) , cov_real(:,:,) , round(N*a_real()) )' ,...
mvnrnd( mu_real(:,) , cov_real(:,:,) , round(N*a_real()) )' ];
这一部分是产生原始的数据。有600个样本,产生自3个高斯模型,每个模型样本数的比重为 2/3、 1/6、 1/6。
%初始化参数
a=[/;/;/];
mu=[ ; ];
cov(:,:,)=[ ; ];
cov(:,:,)=[ ; ];
cov(:,:,)=[ ; ];
EM算法第一步,初始化参数。注意,隐含有多少类是要提前知道的。即这里,我们必须知道有3类。
需要初始化的有:三个模型所占的比例、三个模型的均值、三个模型的协方差
之后进入大循环,不断迭代E步和M步
rznk_temp=zeros(M,N);
for k=:M
for n=:N
%计算P(x|mu_cm,cov_cm)
rznk_temp(k,n)=exp(-/*(x(:,n)-mu(:,k))'*inv(cov(:,:,k))*(x(:,n)-mu(:,k)));
end
rznk_temp(k,:)=rznk_temp(k,:)/sqrt(det(cov(:,:,k)));
end
rznk_temp=rznk_temp*(*pi)^(-K/);
rznk_temp是一个3行600列的矩阵。每列对应一个样本,每列中的一个数据表示这个样本从第k个高斯模型中抽取到的概率。
比如第100个样本,那rznk_temp(1,100)表示该样本从第1个高斯分布中被抽中的概率。具体求解就是代入高斯模型。
%E step
%求rznk
rznk=zeros(M,N);
for n=:N
for k=:M
rznk(k,n)=a(k)*rznk_temp(k,n);
end
rznk(:,n)=rznk(:,n)/sum(rznk(:,n));
end
E步:
从原理上:根据参数初始值或上一次迭代的模型参数来计算出隐性变量的后验概率,其实就是隐性变量的期望。作为隐藏变量的现估计值:

就是我们要根据现有数据求出这600个样本分别来自某一个高斯分布的概率。
代码上:
rznk(k,n)=a(k)*rznk_temp(k,n); 计算选中第k个高斯模型,且抽中样本n的概率
rznk(:,n)=rznk(:,n)/sum(rznk(:,n)); 计算第n个模型属于这三个模型概率的百分比。即对于第n个模型来说,它分别属于第1个模型、第2个模型、第3个模型的概率,这三个值加起来为1。
% M step
%求Nk
nk=zeros(,M);
nk=sum(rznk'); % 求a
a=nk/N; % 求MU
for k=:M
mu_k_sum=;
for n=:N
mu_k_sum=mu_k_sum+rznk(k,n)*x(:,n);
end
mu(:,k)=mu_k_sum/nk(k);
end % 求COV
for k=:M
cov_k_sum=;
for n=:N
cov_k_sum=cov_k_sum+rznk(k,n)*(x(:,n)-mu(:,k))*(x(:,n)-mu(:,k))';
end
cov(:,:,k)=cov_k_sum/nk(k);
end
M步:
原理上:将似然函数最大化以获得新的参数值,就是最大似然估计。

公式看起来非常的复杂。
代码上:没有用那个复杂的公式,只是单纯的用得到的rznk更新比例,均值和方差。
nk=sum(rznk'); 综合这600个数据,每个模型被选中的概率和
a=nk/N; 每个模型被选中的概率 除以600是因为之前的和加起来等于600 我们需要归一化
for k=1:M
mu_k_sum=0;
for n=1:N
mu_k_sum=mu_k_sum+rznk(k,n)*x(:,n);
end
mu(:,k)=mu_k_sum/nk(k); 每个模型的均值估计是通过 sum(样本均值*样本被该模型抽中的概率)/sum(样本被该模型抽中的概率)
end
方差同理。
t=max([norm(a_old(:)-a(:))/norm(a_old(:));norm(mu_old(:)-mu(:))/norm(mu_old(:));norm(cov_old(:)-cov(:))/norm(cov_old(:))]);
最后,计算新的值与过去的值的变化率。作为是否迭代完成的条件。
【EM】代码理解的更多相关文章
- linux io的cfq代码理解
内核版本: 3.10内核. CFQ,即Completely Fair Queueing绝对公平调度器,原理是基于时间片的角度去保证公平,其实如果一台设备既有单队列,又有多队列,既有快速的NVME,又有 ...
- 通过汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的
秦鼎涛 <Linux内核分析>MOOC课程http://mooc.study.163.com/course/USTC-1000029000 实验一 通过汇编一个简单的C程序,分析汇编代码 ...
- 『TensorFlow』通过代码理解gan网络_中
『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...
- 通过反汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的
实验一:通过反汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的 学号:20135114 姓名:王朝宪 注: 原创作品转载请注明出处 <Linux内核分析>MOOC课程http: ...
- EM算法理解的九层境界
EM算法理解的九层境界 EM 就是 E + M EM 是一种局部下限构造 K-Means是一种Hard EM算法 从EM 到 广义EM 广义EM的一个特例是VBEM 广义EM的另一个特例是WS算法 广 ...
- linux内核分析作业:以一简单C程序为例,分析汇编代码理解计算机如何工作
一.实验 使用gcc –S –o main.s main.c -m32 命令编译成汇编代码,如下代码中的数字请自行修改以防与他人雷同 int g(int x) { return x + 3; } in ...
- (原创)JS闭包看代码理解
<html xmlns="http://www.w3.org/1999/xhtml"> <head> <meta http-equiv="C ...
- junit4X系列源码--Junit4 Runner以及test case执行顺序和源代码理解
原文出处:http://www.cnblogs.com/caoyuanzhanlang/p/3534846.html.感谢作者的无私分享. 前一篇文章我们总体介绍了Junit4的用法以及一些简单的测试 ...
- java代码理解
public int maxProfit(int k, int[] prices) { int pl = prices.length; int nothin ...
随机推荐
- VTK初学一,b_PolyVertex_CellArray多个点的绘制
#ifndef INITIAL_OPENGL #define INITIAL_OPENGL #include <vtkAutoInit.h> VTK_MODULE_INIT(vtkRend ...
- 绿书模拟day10 单词前缀
[题目描述]一组单词是安全的,当且仅当不存在一个单词是另一个单词的前缀,这样才能保证数据不容易被误解,现在你手上有一个单词集合s,你需要计算有多少个自己是安全的.注意空集永远是安全的.[输入格式]第一 ...
- ajax基础了解
使用Ajax的最大优点,就是能在不更新整个页面的前提下维护数据.这使得Web应用程序更为迅捷地回应用户动作,并避免了在网络上发送那些没有改变过的信息.AJAX即“Asynchronous JavaSc ...
- 2014牡丹江K Known Notation
Known Notation Time Limit: 2 Seconds Memory Limit: 65536 KB Do you know reverse Polish notation ...
- ie下获取上传文件全路径
ie下获取上传文件全路径,3.5之后的火狐是没法获取上传文件全路径的 /*获取上传文件路径*/ function getFilePath(obj) { var form = $(this).paren ...
- Android在TextView中实现RichText风格
参考: Android实战技巧:用TextView实现Rich Text---在同一个TextView中设置不同的字体风格 Demo: private SpannableStringBuilder c ...
- Redis学习笔记一:数据结构与对象
1. String(SDS) Redis使用自定义的一种字符串结构SDS来作为字符串的表示. 127.0.0.1:6379> set name liushijie OK 在如上操作中,name( ...
- IE6对png图片的处理
在学习phpcms系统搜索模块的时候,发现下面这段代码: <!--[if IE 6]> <script type="text/javascript" src=&q ...
- php中mysql与mysqli的区别
两个函数都是用来处理DB 的.首先, mysqli 连接是永久连接,而mysql是非永久连接. mysql连接每当第二次使用的时候,都会重新打开一个新的进程,而mysqli则只使用同一个进程,这样可以 ...
- 忧桑三角形,调了半天,真忧桑TAT
忧桑三角形 试题描述 小J是一名文化课选手,他十分喜欢做题,尤其是裸题.有一棵树,树上每个点都有点权,现在有以下两个操作: 1. 修改某个点的点权 2. 查询点u和点v构成的简单路径上是否能选出三个点 ...