论文

Belghazi, Mohamed Ishmael, et al. “ Mutual information neural estimation .”  International Conference on Machine Learning . 2018.

利用神经网络的梯度下降法可以实现快速高维连续随机变量之间互信息的估计,上述论文提出了Mutual Information Neural Estimator (MINE)。NN在维度和样本量上都是线性可伸缩的,MI的计算可以通过反向传播进行训练。

核心

Python实现

现有github上的代码无法计算和估计高维随机变量,只能计算一维随机变量,下面的代码给出的修改方案能够计算真实和估计高维随机变量的真实互信息。

其中,为了计算理论的真实互信息,我们不直接暴力求解矩阵(耗时,这也是为什么要有MINE的原因),我们采用给定生成随机变量的参数计算理论互信息。

SIGNAL_NOISE = 0.2
SIGNAL_POWER = 3

完整代码基于pytorch

# Name: MINE_simple
# Author: Reacubeth
# Time: 2020/12/15 18:49
# Mail: noverfitting@gmail.com
# Site: www.omegaxyz.com
# *_*coding:utf-8 *_*
 
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt
 
 
SIGNAL_NOISE = 0.2
SIGNAL_POWER = 3
 
data_dim = 3
num_instances = 20000
 
 
def gen_x(num, dim):
    return np.random.normal(0., np.sqrt(SIGNAL_POWER), [num, dim])
 
 
def gen_y(x, num, dim):
    return x + np.random.normal(0., np.sqrt(SIGNAL_NOISE), [num, dim])
 
 
def true_mi(power, noise, dim):
    return dim * 0.5 * np.log2(1 + power/noise)
 
 
mi = true_mi(SIGNAL_POWER, SIGNAL_NOISE, data_dim)
print('True MI:', mi)
 
 
hidden_size = 10
n_epoch = 500
 
 
class MINE(nn.Module):
    def __init__(self, hidden_size=10):
        super(MINE, self).__init__()
        self.layers = nn.Sequential(nn.Linear(2 * data_dim, hidden_size),
                                    nn.ReLU(),
                                    nn.Linear(hidden_size, 1))
 
    def forward(self, x, y):
        batch_size = x.size(0)
        tiled_x = torch.cat([x, x, ], dim=0)
        idx = torch.randperm(batch_size)
 
        shuffled_y = y[idx]
        concat_y = torch.cat([y, shuffled_y], dim=0)
        inputs = torch.cat([tiled_x, concat_y], dim=1)
        logits = self.layers(inputs)
 
        pred_xy = logits[:batch_size]
        pred_x_y = logits[batch_size:]
        loss = - np.log2(np.exp(1)) * (torch.mean(pred_xy) - torch.log(torch.mean(torch.exp(pred_x_y))))
        # compute loss, you'd better scale exp to bit
        return loss
 
 
model = MINE(hidden_size)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
plot_loss = []
all_mi = []
for epoch in tqdm(range(n_epoch)):
    x_sample = gen_x(num_instances, data_dim)
    y_sample = gen_y(x_sample, num_instances, data_dim)
 
    x_sample = torch.from_numpy(x_sample).float()
    y_sample = torch.from_numpy(y_sample).float()
 
    loss = model(x_sample, y_sample)
 
    model.zero_grad()
    loss.backward()
    optimizer.step()
    all_mi.append(-loss.item())
 
 
fig, ax = plt.subplots()
ax.plot(range(len(all_mi)), all_mi, label='MINE Estimate')
ax.plot([0, len(all_mi)], [mi, mi], label='True Mutual Information')
ax.set_xlabel('training steps')
ax.legend(loc='best')
plt.show()

结果

变量维度为1

变量维度为3

需要指出的是在计算最终的互信息时需要将基数e转为基数2。如果只是求得一个比较值,在真实使用的过程中可以省略。

本文的文字及图片来源于网络,仅供学习、交流使用,不具有任何商业用途,如有问题请及时联系我们以作处理

想要获取更多Python学习资料可以加
QQ:2955637827私聊
或加Q群630390733
大家一起来学习讨论吧!

神经网络高维互信息计算Python实现(MINE)的更多相关文章

  1. 基于神经网络的混合计算(DNC)-Hybrid computing using a NN with dynamic external memory

    前言: DNC可以称为NTM的进一步发展,希望先看看这篇译文,关于NTM的译文:人工机器-NTM-Neutral Turing Machine 基于神经网络的混合计算 Hybrid computing ...

  2. 北京地铁月度消费总金额计算(Python版)

    最近业余时间在学习Python,这是那天坐地铁时突发奇想,想看看我这一个月的地铁费共多少钱,所以简单的构思了下思路,就直接开写了,没想到用Python来实现还挺简单的. 设计思路: 每次乘车正常消费7 ...

  3. 函数计算 Python 连接 SQL Server 小结

    python 连接数据库通常要安装第三方模块,连接 MS SQL Server 需要安装 pymssql .由于 pymsql 依赖于 FreeTDS,对于先于 2.1.3 版本的 pymssql,需 ...

  4. GIL计算python 2 和 python 3 计算密集型

    首先我画了一张图来表示GIL运行的方式: Python 3执行如下计算代码:#-*-conding:utf-8-*-import threading import timedef add(): n = ...

  5. 计算Python运行时间

    可以调用datetime 或者 time库实现得到Python运行时间 方法1 import datetime start_t  = datetime.datetime.now() #运行大型代码 e ...

  6. 机器学习作业(四)神经网络参数的拟合——Python(numpy)实现

    题目下载[传送门] 题目简述:识别图片中的数字,训练该模型,求参数θ. 出现了一个问题:虽然训练的模型能够有很好的预测准确率,但是使用minimize函数时候始终无法成功,无论设计的迭代次数有多大,如 ...

  7. 相似度与距离计算python代码实现

    #定义几种距离计算函数 #更高效的方式为把得分向量化之后使用scipy中定义的distance方法 from math import sqrt def euclidean_dis(rating1, r ...

  8. 计算Python代码运行时间长度方法

    在代码中有时要计算某部分代码运行时间,便于分析. import time start = time.clock() run_function() end = time.clock() print st ...

  9. 菜鸟之路——机器学习之BP神经网络个人理解及Python实现

    关键词: 输入层(Input layer).隐藏层(Hidden layer).输出层(Output layer) 理论上如果有足够多的隐藏层和足够大的训练集,神经网络可以模拟出任何方程.隐藏层多的时 ...

随机推荐

  1. Spring 事件监听机制及原理分析

    简介 在JAVA体系中,有支持实现事件监听机制,在Spring 中也专门提供了一套事件机制的接口,方便我们实现.比如我们可以实现当用户注册后,给他发送一封邮件告诉他注册成功的一些信息,比如用户订阅的主 ...

  2. Java线程的死锁和活锁

    目录 1.概览 2.死锁 2.1.什么是死锁 2.2 死锁举例 2.3 避免死锁 3.活锁 3.1 什么是活锁 3.2 活锁举例 3.3 避免活锁 1.概览 当多线程帮助我们提高应用性能的同时,它同时 ...

  3. Linux中redis服务开启docker运行redis并设置密码

    //查询目前可用的reids镜像 docker search redis //选择拉取官网的镜像 docker pull redis //查看本地是否有redis镜像 docker images // ...

  4. 趣文分享:C 语言和 C++、C# 的区别在什么地方?

    任务: 把大象放到冰箱里.

  5. CentOS升级参考

    CentOS生产系统升级策略: 1)升级前评估 a)确认kernel或包bug. b)用评估工具 c) 测试验证 2)确认升级内容 a)单独升级kernel b)单独升级包 c)都升级 4)确认升级方 ...

  6. Maven 依赖树的解析规则

    对于 Java 开发工程师来说,Maven 是依赖管理和代码构建的标准.遵循「约定大于配置」理念.Maven 是 Java 开发工程师日常使用的工具,本篇文章简要介绍一下 Maven 的依赖树解析. ...

  7. bypass disable_function

    windows 1.com组件绕过 <?php$command=$_POST['a'];$wsh = new COM('WScript.shell'); // 生成一个COM对象 Shell.A ...

  8. go语言的指针类型

    一.指针与引用的相关概念 什么是指针? 指针,全称为指针变量,是用来存储内存地址的一种变量.程序中,一般通过指针来访问其指向的内存地址中的内容(数据). 什么是引用? 引用,是C++中提出来的一种新的 ...

  9. moviepy音视频剪辑:使用fl_time报错OSError: MoviePy error: failed to read the first frame of video file

    专栏:Python基础教程目录 专栏:使用PyQt开发图形界面Python应用 专栏:PyQt+moviepy音视频剪辑实战 专栏:PyQt入门学习 老猿Python博文目录 老猿学5G博文目录 在m ...

  10. PyQt(Python+Qt)学习随笔:窗口的布局设置及访问

    老猿Python博文目录 老猿Python博客地址 在Qt Designer中,可以在一个窗体上拖拽左边的布局部件,在窗口中进行布局管理,但除了基于窗体之上进行布局之外,还需要窗体本身也进行布局管理才 ...