Cyr E C, Gulian M, Patel R G, et al. Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint.[J]. arXiv: Learning, 2019.

@article{cyr2019robust,

title={Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint.},

author={Cyr, Eric C and Gulian, Mamikon and Patel, Ravi G and Perego, Mauro and Trask, Nathaniel},

journal={arXiv: Learning},

year={2019}}

这篇文章介绍了一种梯度下降的改进, 以及Box参数初始化方法.

主要内容

\[\tag{6}
\arg \min_{\xi^L \xi^H} \sum_{k=1}^K \epsilon_k \|\mathcal{L}_k[u] - \sum_i \xi_i^L \mathcal{L}_k [\Phi_i(x, \xi^H)]\|^2_{\ell_2(\mathcal{X}_k)}.
\]

LSGD

固定\(\xi^H, \mathcal{X}_k\), 并令\(\epsilon_k=1\), 则问题(6)退化为一个最小二乘问题

\[\arg \min_{\xi^L} \|A\xi^L -b\|^2_{\ell_2 (\mathcal{X})},
\]

其中\(b_i = \mathcal{L}[u](x_i)\), \(A_{ij}=\mathcal{L} [\Phi_j (x_i, \xi^H)]\), \(x_i \in \mathcal{X}, i=1,\ldots, N, j=1, \ldots, w\).

所以算法如下

Box 初始化

该算法期望使得feature-rich,但是我不知道这个rich从何而来.

假设第\(l\)层的输入为\(x \in \mathbb{R}^{d_1}\), 输出为\(y \in \mathbb{R}^{d_2}\), 则该层的权重矩阵\(W \in \mathbb{R}^{d_2 \times d_1}\). 我们逐行地定义\(W\):

  1. 采样\(p\), \(p\sim U[0 ,1]^{d_1}\);
  2. 采样\(n\), \(n \sim \mathcal{N}(0,I_{d_1})\), 并令\(n=n/\|n\|\);
  3. 求参数\(k\)使得
\[\max_{x \in [0, 1]^{d_1}} \sigma(k(x-p) \cdot n)=1.
\]
  1. \(W\)第\(i\)行\(w_i=kn^T\), \(b_i=-kp \cdot n\).

其中\(\sigma\)表示激活函数, 文中指的是ReLU.

求解参数\(k\):

  1. \(p_{max} = \max (0, \mathrm{sign}(n))\);
  2. \(k=\frac{1}{(p_{max}-p) \cdot n}\)

此\(k\)即为所需\(k\), 只需证明\(p_{max}\)是最大化

\[(x - p)\cdot n, \quad x \in [0,1]^{d_1}
\]

的解. 最大化上式, 可以分解为

\[\max_{x_i \in [0, 1]} x_in_i,
\]

故\(x_i = \max(0, \mathrm{sign}(n_i))\).

这个初始化有什么好处呢, 可以发现, 输入\(x \in[0,1]^{d_1}\)满足, 则输出\(y \in [0, 1]^{d_2}\), 保证二者的"值域"范围一致, 以此类推整个网络节点值范围近似.



如果, 作者构建了一个2-2-2-2-2-2-2-2的网络, 可以发现, Xavier 和 Kaiming的初始化方法经过一定层数后, 就会塌缩在某个点, 而Box初始化方法能够缓解这一现象.

下面是文中列出的算法(与这里的符号有一点点不同, 另外\(b\)作者应该是遗漏了负号).

Box for Resnet

因为Resnet特殊的结构,

\[y=(W+I)x+b.
\]

假设\(x \in [0,m]^{d_1}\), 则:

  1. 采样\(p\), \(p\sim U[0 ,m]^{d_1}\);
  2. 采样\(n\), \(n \sim \mathcal{N}(0,I_{d_1})\), 并令\(n=n/\|n\|\);
  3. 求参数\(k\)使得
\[\max_{x \in [0, m]^{d_1}} \sigma(k(x-p) \cdot n)=\delta m.
\]
  1. \(W\)第\(i\)行\(w_i=kn^T\), \(b_i=-kp \cdot n\).
\[k=\frac{\delta m}{(mp_{max}-p) \cdot n}.
\]

若第一层输入\(x_i \in [0,1]\), 去\(\delta=1/L\), 其中\(L\)为总的层数, 则

\[[0,1] \rightarrow [0,1+\frac{1}{L}] \rightarrow [0,(1+\frac{1}{L})^2] \rightarrow \cdots
\]

代码



'''
initialization.py
'''
import torch
import torch.nn as nn
import warnings def generate(size, m, delta):
p = torch.rand(size) * m
n = torch.randn(size)
temp = 1 / torch.norm(n, p=2, dim=1, keepdim=True)
n = temp * n
pmax = nn.functional.relu(torch.sign(n)) * m
temp = (pmax - p) * n
k = (m * delta) / temp.sum(dim=1, keepdim=True)
w = k * n
b = -(w * p).sum(dim=1)
return w, b def box_init(module, m=1, delta=1):
if isinstance(module, nn.Linear):
w, b = generate(module.weight.shape, m, delta)
try:
module.weight.data = w
module.bias.data = b
except AttributeError as e:
s = "Error: \n" + str(e) + "\n stops the initialization" \
" for this module: {}".format(module)
warnings.warn(s) elif isinstance(module, nn.Conv2d):
outc, inc, h, w = module.weight.size()
w, b = generate((outc, inc * h * w), m, delta)
try:
module.weight.data = w.reshape(module.weight.size())
module.bias.data = b
except AttributeError as e:
s = "Error: \n" + str(e) + "\n stops the initialization" \
" for this module: {}".format(module)
warnings.warn(s) else:
pass


"""config.py"""

nums = 10
layers = 6
method = "kaiming" #box/xavier/kaiming
net = "Net" #Net/ResNet


"""
测试
""" import torch
import torch.nn as nn
import config
from initialization import box_init class Net(nn.Module): def __init__(self, l):
super(Net, self).__init__() self.linears = []
for i in range(l):
name = "linear" + str(i)
self.__setattr__(name, nn.Sequential(nn.Linear(2, 2),
nn.ReLU()))
self.linears.append(self.__getattr__(name))
if config.method == 'box':
self.box_init()
elif config.method == "xavier":
self.xavier_init()
else:
self.kaiming_init() def box_init(self):
for module in self.modules():
box_init(module) def xavier_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.xavier_normal_(module.weight) def kaiming_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(module.weight) def forward(self, x):
out = []
temp = x
for linear in self.linears:
temp = linear(temp)
out.append(temp)
return out class ResNet(nn.Module): def __init__(self, l):
super(ResNet, self).__init__() self.linears = []
for i in range(l):
name = "linear" + str(i)
self.__setattr__(name, nn.Sequential(nn.Linear(2, 2),
nn.ReLU()))
self.linears.append(self.__getattr__(name))
if config.method == 'box':
self.box_init(l)
elif config.method == "xavier":
self.xavier_init()
else:
self.kaiming_init() def box_init(self, layers):
delta = 1 / layers
m = 1. + delta
l = 0
for module in self.modules():
if isinstance(module, (nn.Linear)):
if l == 0:
box_init(module, 1, 1)
else:
box_init(module, m ** l, delta)
l += 1 def xavier_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.xavier_normal_(module.weight) def kaiming_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(module.weight) def forward(self, x):
out = []
temp = x
for linear in self.linears:
temp = linear(temp) + temp
out.append(temp)
return out if config.net == "Net":
net = Net(config.layers)
else:
net = ResNet(config.layers) x = torch.linspace(0, 1, config.nums)
y = torch.linspace(0, 1, config.nums) grid_x, grid_y = torch.meshgrid(x, y) x = grid_x.flatten()
y = grid_y.flatten()
data = torch.stack((x, y), dim=1)
outs = net(data) import matplotlib.pyplot as plt def axplot(x, y, ax):
x = x.detach().numpy()
y = y.detach().numpy()
ax.scatter(x, y) def plot(x, y, outs):
fig, axs = plt.subplots(1, config.layers+1, sharey=True, figsize=(12, 2))
axs[0].scatter(x, y)
axs[0].set(title="layer0")
for i in range(config.layers):
ax = axs[i+1]
out = outs[i]
x = out[:, 0]
y = out[:, 1]
axplot(x, y, ax)
ax.set(title="layer"+str(i+1))
plt.tight_layout()
plt.savefig("C:/Users/pkavs/Desktop/fig.png")
#plt.show()
plot(x, y, outs)

[Box] Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint的更多相关文章

  1. Training Deep Neural Networks

    http://handong1587.github.io/deep_learning/2015/10/09/training-dnn.html  //转载于 Training Deep Neural ...

  2. Coursera Deep Learning 2 Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization - week1, Assignment(Initialization)

    声明:所有内容来自coursera,作为个人学习笔记记录在这里. Initialization Welcome to the first assignment of "Improving D ...

  3. Training (deep) Neural Networks Part: 1

    Training (deep) Neural Networks Part: 1 Nowadays training deep learning models have become extremely ...

  4. Exploring Architectural Ingredients of Adversarially Robust Deep Neural Networks

    目录 概 主要内容 深度 宽度 代码 Huang H., Wang Y., Erfani S., Gu Q., Bailey J. and Ma X. Exploring architectural ...

  5. [C4] Andrew Ng - Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization

    About this Course This course will teach you the "magic" of getting deep learning to work ...

  6. (转)Understanding, generalisation, and transfer learning in deep neural networks

    Understanding, generalisation, and transfer learning in deep neural networks FEBRUARY 27, 2017   Thi ...

  7. On Explainability of Deep Neural Networks

    On Explainability of Deep Neural Networks « Learning F# Functional Data Structures and Algorithms is ...

  8. Classifying plankton with deep neural networks

    Classifying plankton with deep neural networks The National Data Science Bowl, a data science compet ...

  9. Must Know Tips/Tricks in Deep Neural Networks

    Must Know Tips/Tricks in Deep Neural Networks (by Xiu-Shen Wei)   Deep Neural Networks, especially C ...

随机推荐

  1. A Child's History of England.24

    Besides all these troubles, William the Conqueror was troubled by quarrels among his sons. He had th ...

  2. 大数据学习day15----第三阶段----scala03--------1.函数(“_”的使用, 函数和方法的区别)2. 数组和集合常用的方法(迭代器,并行集合) 3. 深度理解函数 4 练习(用java实现类似Scala函数式编程的功能(不能使用Lambda表达式))

    1. 函数 函数就是一个非常灵活的运算逻辑,可以灵活的将函数传入方法中,前提是方法中接收的是类型一致的函数类型 函数式编程的好处:想要做什么就调用相应的方法(fliter.map.groupBy.so ...

  3. 一起手写吧!Promise!

    1.Promise 的声明 首先呢,promise肯定是一个类,我们就用class来声明. 由于new Promise((resolve, reject)=>{}),所以传入一个参数(函数),秘 ...

  4. Oracle—回车、换行符

    1.回车换行符 chr(10)是换行符, chr(13)是回车, 增加换行符: select ' update ' || table_name || ' set VALID_STATE =''0A'' ...

  5. UILabel总结

    UILabel 能显示文字,不能直接通过addTarget...方法监听点击 1. 常见属性 @property(nonatomic,copy) NSString *text; 显示文字 @prope ...

  6. Spring(3):AOP面向切面编程

    一,AOP介绍 AOP为Aspect Oriented Programming的缩写,意为:面向切面编程,通过预编译方式和运行期动态代理实现程序功能的统一维护的一种技术.AOP是OOP的延续,是软件开 ...

  7. “==” 和 equals()的区别

    ※ "==" 和 equals()的区别 ※ == :比较. 基本数据类型比较的是值:. 引用类型比较的是地址值. ※ equals(Object o):1)不能比较基本数据类型, ...

  8. Mybatis读取数据实战

    1.Mybatis基础配置 <?xml version="1.0" encoding="UTF-8" ?> <!DOCTYPE configu ...

  9. kubernetes list/watch设计原理

    overview kubernetes的设计里面大致上分为3部分: API驱动型的特点 (API-driven) 控制循环(control loops)与 条件触发 (Level Trigger) A ...

  10. centos配置 显示中文

    目录 一.简介 二.操作 一.简介 不显示中文,出现这个情况一般是由于没有安装中文语言包,或者设置的默认语言有问题导致的. 二.操作 1.查看当前系统语言 登陆linux系统打开操作终端之后,输入 e ...