本来想自己写一个EM算法的,但是操作没两步就进行不下去了。对那些数学公式着实不懂。只好从网上找找代码,看看别人是怎么做的。

代码:来自http://blog.sina.com.cn/s/blog_98b365150101f2xb.html 经验证可用

%EM
M=; % M个高斯分布混合
N=; % 样本数
th=0.000001; % 收敛阈值
K=; % 样本维数
% 待生成数据的参数
a_real =[/;/;/];%混合模型中基模型高斯密度函数的权重
mu_real=[ ; ];%均值
cov_real(:,:,)=[ ; 0.2];%协方差
cov_real(:,:,)=[0.1 ; 0.1];
cov_real(:,:,)=[0.1 ; 0.1];
%生成符合标准的样本数据(每一列为一个样本)
x=[ mvnrnd( mu_real(:,) , cov_real(:,:,) , round(N*a_real()) )' ,...
mvnrnd( mu_real(:,) , cov_real(:,:,) , round(N*a_real()) )' ,...
mvnrnd( mu_real(:,) , cov_real(:,:,) , round(N*a_real()) )' ]; %初始化参数
a=[/;/;/];
mu=[ ; ];
cov(:,:,)=[ ; ];
cov(:,:,)=[ ; ];
cov(:,:,)=[ ; ];
t=inf;
while t>=th
a_old = a;
mu_old = mu;
cov_old= cov;
rznk_temp=zeros(M,N);
for k=:M
for n=:N
%计算P(x|mu_cm,cov_cm)
rznk_temp(k,n)=exp(-/*(x(:,n)-mu(:,k))'*inv(cov(:,:,k))*(x(:,n)-mu(:,k)));
end
rznk_temp(k,:)=rznk_temp(k,:)/sqrt(det(cov(:,:,k)));
end
rznk_temp=rznk_temp*(*pi)^(-K/);
%E step
%求rznk
rznk=zeros(M,N);
for n=:N
for k=:M
rznk(k,n)=a(k)*rznk_temp(k,n);
end
rznk(:,n)=rznk(:,n)/sum(rznk(:,n));
end
% M step
%求Nk
nk=zeros(,M);
nk=sum(rznk'); % 求a
a=nk/N; % 求MU
for k=:M
mu_k_sum=;
for n=:N
mu_k_sum=mu_k_sum+rznk(k,n)*x(:,n);
end
mu(:,k)=mu_k_sum/nk(k);
end % 求COV
for k=:M
cov_k_sum=;
for n=:N
cov_k_sum=cov_k_sum+rznk(k,n)*(x(:,n)-mu(:,k))*(x(:,n)-mu(:,k))';
end
cov(:,:,k)=cov_k_sum/nk(k);
end t=max([norm(a_old(:)-a(:))/norm(a_old(:));norm(mu_old(:)-mu(:))/norm(mu_old(:));norm(cov_old(:)-cov(:))/norm(cov_old(:))]);
end

分解说明:

M=;          % M个高斯分布混合
N=; % 样本数
th=0.000001; % 收敛阈值
K=; % 样本维数
% 待生成数据的参数
a_real =[/;/;/];%混合模型中基模型高斯密度函数的权重
mu_real=[ ; ];%均值
cov_real(:,:,)=[ ; 0.2];%协方差
cov_real(:,:,)=[0.1 ; 0.1];
cov_real(:,:,)=[0.1 ; 0.1];
%生成符合标准的样本数据(每一列为一个样本)
x=[ mvnrnd( mu_real(:,) , cov_real(:,:,) , round(N*a_real()) )' ,...
mvnrnd( mu_real(:,) , cov_real(:,:,) , round(N*a_real()) )' ,...
mvnrnd( mu_real(:,) , cov_real(:,:,) , round(N*a_real()) )' ];

这一部分是产生原始的数据。有600个样本,产生自3个高斯模型,每个模型样本数的比重为 2/3、  1/6、  1/6。

%初始化参数
a=[/;/;/];
mu=[ ; ];
cov(:,:,)=[ ; ];
cov(:,:,)=[ ; ];
cov(:,:,)=[ ; ];

EM算法第一步,初始化参数。注意,隐含有多少类是要提前知道的。即这里,我们必须知道有3类。

需要初始化的有:三个模型所占的比例、三个模型的均值、三个模型的协方差

之后进入大循环,不断迭代E步和M步

rznk_temp=zeros(M,N);
for k=:M
for n=:N
%计算P(x|mu_cm,cov_cm)
rznk_temp(k,n)=exp(-/*(x(:,n)-mu(:,k))'*inv(cov(:,:,k))*(x(:,n)-mu(:,k)));
end
rznk_temp(k,:)=rznk_temp(k,:)/sqrt(det(cov(:,:,k)));
end
rznk_temp=rznk_temp*(*pi)^(-K/);

rznk_temp是一个3行600列的矩阵。每列对应一个样本,每列中的一个数据表示这个样本从第k个高斯模型中抽取到的概率。

比如第100个样本,那rznk_temp(1,100)表示该样本从第1个高斯分布中被抽中的概率。具体求解就是代入高斯模型。

%E step
%求rznk
rznk=zeros(M,N);
for n=:N
for k=:M
rznk(k,n)=a(k)*rznk_temp(k,n);
end
rznk(:,n)=rznk(:,n)/sum(rznk(:,n));
end

E步:

从原理上:根据参数初始值或上一次迭代的模型参数来计算出隐性变量的后验概率,其实就是隐性变量的期望。作为隐藏变量的现估计值:

就是我们要根据现有数据求出这600个样本分别来自某一个高斯分布的概率。

代码上:

rznk(k,n)=a(k)*rznk_temp(k,n);   计算选中第k个高斯模型,且抽中样本n的概率

rznk(:,n)=rznk(:,n)/sum(rznk(:,n));   计算第n个模型属于这三个模型概率的百分比。即对于第n个模型来说,它分别属于第1个模型、第2个模型、第3个模型的概率,这三个值加起来为1。

% M step
%求Nk
nk=zeros(,M);
nk=sum(rznk'); % 求a
a=nk/N; % 求MU
for k=:M
mu_k_sum=;
for n=:N
mu_k_sum=mu_k_sum+rznk(k,n)*x(:,n);
end
mu(:,k)=mu_k_sum/nk(k);
end % 求COV
for k=:M
cov_k_sum=;
for n=:N
cov_k_sum=cov_k_sum+rznk(k,n)*(x(:,n)-mu(:,k))*(x(:,n)-mu(:,k))';
end
cov(:,:,k)=cov_k_sum/nk(k);
end

M步:

原理上:将似然函数最大化以获得新的参数值,就是最大似然估计。

    

公式看起来非常的复杂。

代码上:没有用那个复杂的公式,只是单纯的用得到的rznk更新比例,均值和方差。

nk=sum(rznk'); 综合这600个数据,每个模型被选中的概率和
a=nk/N;  每个模型被选中的概率 除以600是因为之前的和加起来等于600 我们需要归一化
for k=1:M
mu_k_sum=0;
for n=1:N
mu_k_sum=mu_k_sum+rznk(k,n)*x(:,n);
  end
   mu(:,k)=mu_k_sum/nk(k); 每个模型的均值估计是通过 sum(样本均值*样本被该模型抽中的概率)/sum(样本被该模型抽中的概率)
end

方差同理。

t=max([norm(a_old(:)-a(:))/norm(a_old(:));norm(mu_old(:)-mu(:))/norm(mu_old(:));norm(cov_old(:)-cov(:))/norm(cov_old(:))]); 

最后,计算新的值与过去的值的变化率。作为是否迭代完成的条件。

 

【EM】代码理解的更多相关文章

  1. linux io的cfq代码理解

    内核版本: 3.10内核. CFQ,即Completely Fair Queueing绝对公平调度器,原理是基于时间片的角度去保证公平,其实如果一台设备既有单队列,又有多队列,既有快速的NVME,又有 ...

  2. 通过汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的

    秦鼎涛  <Linux内核分析>MOOC课程http://mooc.study.163.com/course/USTC-1000029000 实验一 通过汇编一个简单的C程序,分析汇编代码 ...

  3. 『TensorFlow』通过代码理解gan网络_中

    『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 上篇是一个尝试生成minist手写体数据的简单GAN网络,之前有介绍过,图片维度是28*28*1,生成器的上采样使 ...

  4. 通过反汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的

    实验一:通过反汇编一个简单的C程序,分析汇编代码理解计算机是如何工作的 学号:20135114 姓名:王朝宪 注: 原创作品转载请注明出处   <Linux内核分析>MOOC课程http: ...

  5. EM算法理解的九层境界

    EM算法理解的九层境界 EM 就是 E + M EM 是一种局部下限构造 K-Means是一种Hard EM算法 从EM 到 广义EM 广义EM的一个特例是VBEM 广义EM的另一个特例是WS算法 广 ...

  6. linux内核分析作业:以一简单C程序为例,分析汇编代码理解计算机如何工作

    一.实验 使用gcc –S –o main.s main.c -m32 命令编译成汇编代码,如下代码中的数字请自行修改以防与他人雷同 int g(int x) { return x + 3; } in ...

  7. (原创)JS闭包看代码理解

    <html xmlns="http://www.w3.org/1999/xhtml"> <head> <meta http-equiv="C ...

  8. junit4X系列源码--Junit4 Runner以及test case执行顺序和源代码理解

    原文出处:http://www.cnblogs.com/caoyuanzhanlang/p/3534846.html.感谢作者的无私分享. 前一篇文章我们总体介绍了Junit4的用法以及一些简单的测试 ...

  9. java代码理解

    public int maxProfit(int k, int[] prices) {            int pl = prices.length;            int nothin ...

随机推荐

  1. NetBeans使用习惯:升级与保存配置

    如何升级:点击 netbeans 的升级更新 ,即可升级版本:不推荐官网下载进行安装,否则会出现,以前的旧版本8.0的目录和8.0.1目录,虽然它会自动检测到以前版本的配置,提示导入... 如何备份: ...

  2. Java-开启Java之路

    .NET还是我的最爱··· 准备学习下底层知识,为Java打打基础, 然后开始正式学习Java, 也算是曲线救国了, Go,Go,Go ... 二〇一六年十一月十日 18:09:54

  3. hdu.1111.Secret Code(dfs + 秦九韶算法)

    Secret Code Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/32768 K (Java/Others) Tota ...

  4. JVM性能调优监控工具jps、jstack、jmap、jhat、jstat使用详解(转VIII)

    JVM本身就是一个java进程,一个java程序运行在一个jvm进程中.多个java程序同时运行就会有多个jvm进程.一个jvm进程有多个线程至少有一个gc线程和一个用户线程. JDK本身提供了很多方 ...

  5. mysql解决自动断开8小时未曾用过的链接

    今天有运维的同事反映,发布关键词不太稳定,点了没反应.就去线上看了一下日志,发现数据库没有链接,就查了一下问题 关于mysql自动断开的问题研究结果如下,在mysql中有相关参数设定,当数据库连接空闲 ...

  6. struts2-(1)使用Filter作为控制器

    1.使用filter作为控制器 (1)创建类,实现javax.servlet.Filter package com.controller.filter; import java.io.IOExcept ...

  7. Jquery Ajax调用aspx页面方法

    Jquery Ajax调用aspx页面方法 在asp.net webform开发中,用jQuery ajax传值一般有几种玩法 1)普通玩法:通过一般处理程序ashx进行处理: 2)高级玩法:通过as ...

  8. Javascript包含对象的数组去重

    Array.prototype.clearRepeat = function(){ var result = [], obj = {}; for(var i = 0; i < this.leng ...

  9. CentOS 6.5 zabbix 3.0.4 监控MySQL性能

    安装mysql [root@test3 /]# yum -y install mysql mysql-server 初始化数据库 [root@test3 /]# /etc/init.d/mysqld ...

  10. activiti查看流程图,有中文乱码

    第一种 因为服务器缺少必要的字体到这的问题: 解决办法 <!-- 发布流程生成图片是正常显示中文 -->            <property name="activi ...