McGan: Mean and Covariance Feature Matching GAN
@article{mroueh2017mcgan:,
title={McGan: Mean and Covariance Feature Matching GAN},
author={Mroueh, Youssef and Sercu, Tom and Goel, Vaibhava},
journal={arXiv: Learning},
year={2017}}
概
利用均值和协方差构建IPM, 获得相应的mean GAN 和 covariance gan.
主要内容
IPM:
\]
当\(\mathscr{F}\)是对称空间, 即\(f \in \mathscr{F} \rightarrow - f \in \mathscr{F}\),可得
\]
Mean Matching IPM
\]
其中\(\|\cdot \|_p\)表示\(\ell_p\)范数, \(\Phi_w\)往往用网络来表示, 我们可通过截断\(w\)来使得\(\mathscr{F}_{v,w,p}\)为有界线性函数空间(有界从而使得后面推导中\(\sup\)成为\(\max\)).

其中
\]
最后一个等式的成立是因为:
\]
又\(\| \cdot \|_p\)的对偶范数是\(\|\cdot\|_q, \frac{1}{p}+\frac{1}{q}=1\).
prime
整个GAN的训练过程即为
\min_{g_\theta} \max_{w \in \Omega} \max_{v, \|v\|_p \le 1} \mathscr{L}_{\mu} (v,w,\theta),
\]
其中
\]
估计形式为

dual
也有对应的dual形态
\min_{g_\theta} \max_{w \in \Omega} \|\mu_w(\mathbb{P}_r) - \mu_w (\mathbb{P}_{\theta})\|_q.
\]

Covariance Feature Matching IPM
\]
等价于
\]
并有

其中\([A]_k\)表示\(A\)的\(k\)阶近似, 如果\(A = \sum_i \sigma_iu_iv_i^T\), \(\sigma_1\ge \sigma_2,\ldots\), 则\([A]_k=\sum_{i=1}^k \sigma_i u_iv_i^T\). \(\mathcal{O}_{m,k} := \{M \in \mathbb{R}^{m \times k} | M^TM = I_k \}\), \(\|A\|_*=\sum_i \sigma_i\)表示算子范数.
prime
\min_{g_\theta} \max_{w \in \Omega} \max_{U,V \in \mathcal{P}_{m, k}} \mathscr{L}_{\sigma} (U, V,w,\theta),
\]
其中
\]
采用下式估计

dual
\min_{g_{\theta}} \max_{w \in \Omega} \| [\Sigma_w(\mathbb{P}_r) - \Sigma_w(\mathbb{P}_{\theta})]_k\|_*.
\]
注: 既然\(\Sigma_w(\mathbb{P}_r) - \Sigma_w(\mathbb{P}_{\theta})\)是对称的, 为什么\(U \not =V\)? 因为虽然其对称, 但是并不(半)正定, 所以\(v_i=-u_i\)也是有可能的.
算法



代码
未经测试.
import torch
import torch.nn as nn
from torch.nn.functional import relu
from collections.abc import Callable
def preset(**kwargs):
def decorator(func):
def wrapper(*args, **nkwargs):
nkwargs.update(kwargs)
return func(*args, **nkwargs)
wrapper.__doc__ = func.__doc__
wrapper.__name__ = func.__name__
return wrapper
return decorator
class Meanmatch(nn.Module):
def __init__(self, p, dim, dual=False, prj='l2'):
super(Meanmatch, self).__init__()
self.norm = p
self.dual = dual
if dual:
self.dualnorm = self.norm
else:
self.init_weights(dim)
self.projection = self.proj(prj)
@property
def dualnorm(self):
return self.__dualnorm
@dualnorm.setter
def dualnorm(self, norm):
if norm == 'inf':
norm = float('inf')
elif not isinstance(norm, float):
raise ValueError("Invalid norm")
p = 1 / (1 - 1 / norm)
self.__dualnorm = preset(p=p, dim=1)(torch.norm)
def init_weights(self, dim):
self.weights = nn.Parameter(torch.rand((1, dim)),
requires_grad=True)
@staticmethod
def _proj1(x):
u = x.max()
if u <= 1.:
return x
l = 0.
c = (u + l) / 2
while (u - l) > 1e-4:
r = relu(x - c).sum()
if r > 1.:
l = c
else:
u = c
c = (u + l) / 2
return relu(x - c)
@staticmethod
def _proj2(x):
return x / torch.norm(x)
@staticmethod
def _proj3(x):
return x / torch.max(x)
def proj(self, prj):
if prj == "l1":
return self._proj1
elif prj == "l2":
return self._proj2
elif prj == "linf":
return self._proj3
else:
assert isinstance(prj, Callable), "Invalid prj"
return prj
def forward(self, real, fake):
temp = (real - fake).mean(dim=1)
if self.dual:
return self.dualnorm(temp)
elif not self.training and self.dual:
raise TypeError("just for training...")
else:
self.weights.data = self.projection(self.weights.data) #some diff here!!!!!!!!!!
return self.weights @ temp
class Covmatch(nn.Module):
def __init__(self, dim, k):
super(Covmatch, self).__init__()
self.init_weights(dim, k)
def init_weights(self, dim, k):
temp1 = torch.rand((dim, k))
temp2 = torch.rand((dim, k))
self.U = nn.Parameter(temp1, requires_grad=True)
self.V = nn.Parameter(temp2, requires_grad=True)
def qr(self, w):
q, r = torch.qr(w)
sign = r.diag().sign()
return q * sign
def update_weights(self):
self.U.data = self.qr(self.U.data)
self.V.data = self.qr(self.V.data)
def forward(self, real, fake):
self.update_weights()
temp1 = real @ self.U
temp2 = real @ self.V
temp3 = fake @ self.U
temp4 = fake @ self.V
part1 = torch.trace(temp1 @ temp2.t()).mean()
part2 = torch.trace(temp3 @ temp4.t()).mean()
return part1 - part2
McGan: Mean and Covariance Feature Matching GAN的更多相关文章
- Computer Vision_33_SIFT:Robust scale-invariant feature matching for remote sensing image registration——2009
此部分是计算机视觉部分,主要侧重在底层特征提取,视频分析,跟踪,目标检测和识别方面等方面.对于自己不太熟悉的领域比如摄像机标定和立体视觉,仅仅列出上google上引用次数比较多的文献.有一些刚刚出版的 ...
- Computer Vision_33_SIFT:Remote Sensing Image Registration With Modified SIFT and Enhanced Feature Matching——2017
此部分是计算机视觉部分,主要侧重在底层特征提取,视频分析,跟踪,目标检测和识别方面等方面.对于自己不太熟悉的领域比如摄像机标定和立体视觉,仅仅列出上google上引用次数比较多的文献.有一些刚刚出版的 ...
- [OpenCV] Feature Matching
得到了杂乱无章的特征点后,要筛选出好的特征点,也就是good matches. BruteForceMatcher FlannBasedMatcher 两者的区别:http://yangshen998 ...
- [转]GAN论文集
really-awesome-gan A list of papers and other resources on General Adversarial (Neural) Networks. Th ...
- [论文理解] Good Semi-supervised Learning That Requires a Bad GAN
Good Semi-supervised Learning That Requires a Bad GAN 恢复博客更新,最近没那么忙了,记录一下学习. Intro 本文是一篇稍微偏理论的半监督学习的 ...
- Generative Adversarial Nets[Improved GAN]
0.背景 Tim Salimans等人认为之前的GANs虽然可以生成很好的样本,然而训练GAN本质是找到一个基于连续的,高维参数空间上的非凸游戏上的纳什平衡.然而不幸的是,寻找纳什平衡是一个十分困难的 ...
- (转) GAN论文整理
本文转自:http://www.jianshu.com/p/2acb804dd811 GAN论文整理 作者 FinlayLiu 已关注 2016.11.09 13:21 字数 1551 阅读 1263 ...
- 常见GAN的应用
深入浅出 GAN·原理篇文字版(完整)|干货 from:http://baijiahao.baidu.com/s?id=1568663805038898&wfr=spider&for= ...
- AI佳作解读系列(六) - 生成对抗网络(GAN)综述精华
注:本文来自机器之心的PaperWeekly系列:万字综述之生成对抗网络(GAN),如有侵权,请联系删除,谢谢! 前阵子学习 GAN 的过程发现现在的 GAN 综述文章大都是 2016 年 Ian G ...
随机推荐
- 爬虫系列:存储 CSV 文件
上一期:爬虫系列:存储媒体文件,讲解了如果通过爬虫下载媒体文件,以及下载媒体文件相关代码讲解. 本期将讲解如果将数据保存到 CSV 文件. 逗号分隔值(Comma-Separated Values,C ...
- 100个Shell脚本——【脚本5】数字求和
[脚本5]数字求和 编写shell脚本,要求输入一个数字,然后计算出从1到输入数字的和,要求,如果输入的数字小于1,则重新输入,直到输入正确的数字为止,示例: 一.脚本 #!/bin/bash whi ...
- LeetCode1579题——圆圈中最后剩下的数字
1.题目描述:0,1,,n-1这n个数字排成一个圆圈,从数字0开始,每次从这个圆圈里删除第m个数字.求出这个圆圈里剩下的最后一个数字.例如,0.1.2.3.4这5个数字组成一个圆圈,从数字0开始每次删 ...
- java中的collection小结
Collection 来源于Java.util包,是非常实用常用的数据结构!!!!!字面意思就是容器.具体的继承实现关系如下图,先整体有个印象,再依次介绍各个部分的方法,注意事项,以及应用场景. ...
- Enumeration遍历http请求参数的一个例子
Enumeration<String> paraNames=request.getParameterNames(); for(Enumeration e=paraNames;e.hasMo ...
- Servlet(2):通过servletContext对象实现数据共享
一,ServletContext介绍 web容器在启动时,它会为每一个web应用程序都创建一个ServletContext对象,它代表当前web应用 多个Servlet通过ServletContext ...
- linux-源码软件管理-yum配置
总结如下:1.源码配置软件管理2.配置yum本地源和网络源及yum 工作原理讲解3.计算机硬盘介绍 1.1 源码管理软件 压缩包管理命令: # 主流的压缩格式包括tar.rar.zip.war.gzi ...
- 【Java 基础】Arrays.asList、ArrayList的subList注意事项
1. 使用Arrays.asList的注意事项 1.1 可能会踩的坑 先来看下Arrays.asList的使用: List<Integer> statusList = Arrays.asL ...
- 【Git】【Gitee】通过git远程删除仓库文件
安装Git Git安装配置-菜鸟教程 没有安装下载的,请读者自行安装下载. 启动与初步配置 配置用户名与邮箱 git config --global user.name "用户名" ...
- Python用matplotlib绘图网格线的设置
一.X轴网格线的设置 import matplotlib.pyplot as plt import numpy as np from pylab import mpl mpl.rcParams['fo ...