Cyr E C, Gulian M, Patel R G, et al. Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint.[J]. arXiv: Learning, 2019.

@article{cyr2019robust,

title={Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint.},

author={Cyr, Eric C and Gulian, Mamikon and Patel, Ravi G and Perego, Mauro and Trask, Nathaniel},

journal={arXiv: Learning},

year={2019}}

这篇文章介绍了一种梯度下降的改进, 以及Box参数初始化方法.

主要内容

\[\tag{6}
\arg \min_{\xi^L \xi^H} \sum_{k=1}^K \epsilon_k \|\mathcal{L}_k[u] - \sum_i \xi_i^L \mathcal{L}_k [\Phi_i(x, \xi^H)]\|^2_{\ell_2(\mathcal{X}_k)}.
\]

LSGD

固定\(\xi^H, \mathcal{X}_k\), 并令\(\epsilon_k=1\), 则问题(6)退化为一个最小二乘问题

\[\arg \min_{\xi^L} \|A\xi^L -b\|^2_{\ell_2 (\mathcal{X})},
\]

其中\(b_i = \mathcal{L}[u](x_i)\), \(A_{ij}=\mathcal{L} [\Phi_j (x_i, \xi^H)]\), \(x_i \in \mathcal{X}, i=1,\ldots, N, j=1, \ldots, w\).

所以算法如下

Box 初始化

该算法期望使得feature-rich,但是我不知道这个rich从何而来.

假设第\(l\)层的输入为\(x \in \mathbb{R}^{d_1}\), 输出为\(y \in \mathbb{R}^{d_2}\), 则该层的权重矩阵\(W \in \mathbb{R}^{d_2 \times d_1}\). 我们逐行地定义\(W\):

  1. 采样\(p\), \(p\sim U[0 ,1]^{d_1}\);
  2. 采样\(n\), \(n \sim \mathcal{N}(0,I_{d_1})\), 并令\(n=n/\|n\|\);
  3. 求参数\(k\)使得
\[\max_{x \in [0, 1]^{d_1}} \sigma(k(x-p) \cdot n)=1.
\]
  1. \(W\)第\(i\)行\(w_i=kn^T\), \(b_i=-kp \cdot n\).

其中\(\sigma\)表示激活函数, 文中指的是ReLU.

求解参数\(k\):

  1. \(p_{max} = \max (0, \mathrm{sign}(n))\);
  2. \(k=\frac{1}{(p_{max}-p) \cdot n}\)

此\(k\)即为所需\(k\), 只需证明\(p_{max}\)是最大化

\[(x - p)\cdot n, \quad x \in [0,1]^{d_1}
\]

的解. 最大化上式, 可以分解为

\[\max_{x_i \in [0, 1]} x_in_i,
\]

故\(x_i = \max(0, \mathrm{sign}(n_i))\).

这个初始化有什么好处呢, 可以发现, 输入\(x \in[0,1]^{d_1}\)满足, 则输出\(y \in [0, 1]^{d_2}\), 保证二者的"值域"范围一致, 以此类推整个网络节点值范围近似.



如果, 作者构建了一个2-2-2-2-2-2-2-2的网络, 可以发现, Xavier 和 Kaiming的初始化方法经过一定层数后, 就会塌缩在某个点, 而Box初始化方法能够缓解这一现象.

下面是文中列出的算法(与这里的符号有一点点不同, 另外\(b\)作者应该是遗漏了负号).

Box for Resnet

因为Resnet特殊的结构,

\[y=(W+I)x+b.
\]

假设\(x \in [0,m]^{d_1}\), 则:

  1. 采样\(p\), \(p\sim U[0 ,m]^{d_1}\);
  2. 采样\(n\), \(n \sim \mathcal{N}(0,I_{d_1})\), 并令\(n=n/\|n\|\);
  3. 求参数\(k\)使得
\[\max_{x \in [0, m]^{d_1}} \sigma(k(x-p) \cdot n)=\delta m.
\]
  1. \(W\)第\(i\)行\(w_i=kn^T\), \(b_i=-kp \cdot n\).
\[k=\frac{\delta m}{(mp_{max}-p) \cdot n}.
\]

若第一层输入\(x_i \in [0,1]\), 去\(\delta=1/L\), 其中\(L\)为总的层数, 则

\[[0,1] \rightarrow [0,1+\frac{1}{L}] \rightarrow [0,(1+\frac{1}{L})^2] \rightarrow \cdots
\]

代码



'''
initialization.py
'''
import torch
import torch.nn as nn
import warnings def generate(size, m, delta):
p = torch.rand(size) * m
n = torch.randn(size)
temp = 1 / torch.norm(n, p=2, dim=1, keepdim=True)
n = temp * n
pmax = nn.functional.relu(torch.sign(n)) * m
temp = (pmax - p) * n
k = (m * delta) / temp.sum(dim=1, keepdim=True)
w = k * n
b = -(w * p).sum(dim=1)
return w, b def box_init(module, m=1, delta=1):
if isinstance(module, nn.Linear):
w, b = generate(module.weight.shape, m, delta)
try:
module.weight.data = w
module.bias.data = b
except AttributeError as e:
s = "Error: \n" + str(e) + "\n stops the initialization" \
" for this module: {}".format(module)
warnings.warn(s) elif isinstance(module, nn.Conv2d):
outc, inc, h, w = module.weight.size()
w, b = generate((outc, inc * h * w), m, delta)
try:
module.weight.data = w.reshape(module.weight.size())
module.bias.data = b
except AttributeError as e:
s = "Error: \n" + str(e) + "\n stops the initialization" \
" for this module: {}".format(module)
warnings.warn(s) else:
pass


"""config.py"""

nums = 10
layers = 6
method = "kaiming" #box/xavier/kaiming
net = "Net" #Net/ResNet


"""
测试
""" import torch
import torch.nn as nn
import config
from initialization import box_init class Net(nn.Module): def __init__(self, l):
super(Net, self).__init__() self.linears = []
for i in range(l):
name = "linear" + str(i)
self.__setattr__(name, nn.Sequential(nn.Linear(2, 2),
nn.ReLU()))
self.linears.append(self.__getattr__(name))
if config.method == 'box':
self.box_init()
elif config.method == "xavier":
self.xavier_init()
else:
self.kaiming_init() def box_init(self):
for module in self.modules():
box_init(module) def xavier_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.xavier_normal_(module.weight) def kaiming_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(module.weight) def forward(self, x):
out = []
temp = x
for linear in self.linears:
temp = linear(temp)
out.append(temp)
return out class ResNet(nn.Module): def __init__(self, l):
super(ResNet, self).__init__() self.linears = []
for i in range(l):
name = "linear" + str(i)
self.__setattr__(name, nn.Sequential(nn.Linear(2, 2),
nn.ReLU()))
self.linears.append(self.__getattr__(name))
if config.method == 'box':
self.box_init(l)
elif config.method == "xavier":
self.xavier_init()
else:
self.kaiming_init() def box_init(self, layers):
delta = 1 / layers
m = 1. + delta
l = 0
for module in self.modules():
if isinstance(module, (nn.Linear)):
if l == 0:
box_init(module, 1, 1)
else:
box_init(module, m ** l, delta)
l += 1 def xavier_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.xavier_normal_(module.weight) def kaiming_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(module.weight) def forward(self, x):
out = []
temp = x
for linear in self.linears:
temp = linear(temp) + temp
out.append(temp)
return out if config.net == "Net":
net = Net(config.layers)
else:
net = ResNet(config.layers) x = torch.linspace(0, 1, config.nums)
y = torch.linspace(0, 1, config.nums) grid_x, grid_y = torch.meshgrid(x, y) x = grid_x.flatten()
y = grid_y.flatten()
data = torch.stack((x, y), dim=1)
outs = net(data) import matplotlib.pyplot as plt def axplot(x, y, ax):
x = x.detach().numpy()
y = y.detach().numpy()
ax.scatter(x, y) def plot(x, y, outs):
fig, axs = plt.subplots(1, config.layers+1, sharey=True, figsize=(12, 2))
axs[0].scatter(x, y)
axs[0].set(title="layer0")
for i in range(config.layers):
ax = axs[i+1]
out = outs[i]
x = out[:, 0]
y = out[:, 1]
axplot(x, y, ax)
ax.set(title="layer"+str(i+1))
plt.tight_layout()
plt.savefig("C:/Users/pkavs/Desktop/fig.png")
#plt.show()
plot(x, y, outs)

[Box] Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint的更多相关文章

  1. Training Deep Neural Networks

    http://handong1587.github.io/deep_learning/2015/10/09/training-dnn.html  //转载于 Training Deep Neural ...

  2. Coursera Deep Learning 2 Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization - week1, Assignment(Initialization)

    声明:所有内容来自coursera,作为个人学习笔记记录在这里. Initialization Welcome to the first assignment of "Improving D ...

  3. Training (deep) Neural Networks Part: 1

    Training (deep) Neural Networks Part: 1 Nowadays training deep learning models have become extremely ...

  4. Exploring Architectural Ingredients of Adversarially Robust Deep Neural Networks

    目录 概 主要内容 深度 宽度 代码 Huang H., Wang Y., Erfani S., Gu Q., Bailey J. and Ma X. Exploring architectural ...

  5. [C4] Andrew Ng - Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization

    About this Course This course will teach you the "magic" of getting deep learning to work ...

  6. (转)Understanding, generalisation, and transfer learning in deep neural networks

    Understanding, generalisation, and transfer learning in deep neural networks FEBRUARY 27, 2017   Thi ...

  7. On Explainability of Deep Neural Networks

    On Explainability of Deep Neural Networks « Learning F# Functional Data Structures and Algorithms is ...

  8. Classifying plankton with deep neural networks

    Classifying plankton with deep neural networks The National Data Science Bowl, a data science compet ...

  9. Must Know Tips/Tricks in Deep Neural Networks

    Must Know Tips/Tricks in Deep Neural Networks (by Xiu-Shen Wei)   Deep Neural Networks, especially C ...

随机推荐

  1. A Child's History of England.7

    After the death of Ethelbert, Edwin, King of Northumbria [公元616年,隋朝末年], who was such a good king tha ...

  2. java加密方式

    加密,是以某种特殊的算法改变原有的信息数据,使得未授权的用户即使获得了已加密的信息,但因不知解密的方法,仍然无法了解信息的内容.大体上分为双向加密和单向加密,而双向加密又分为对称加密和非对称加密(有些 ...

  3. Linux基础命令---wget下载工具

    wget wget是一个免费的文件下载工具,可以从指定的URL下载文件到本地主机.它支持HTTP和FTP协议,经常用来抓取大量的网页文件. 此命令的适用范围:RedHat.RHEL.Ubuntu.Ce ...

  4. 使用 ACE 库框架在 UNIX 中开发高性能并发应用

    使用 ACE 库框架在 UNIX 中开发高性能并发应用来源:developerWorks 中国 作者:Arpan Sen ACE 开放源码工具包可以帮助开发人员创建健壮的可移植多线程应用程序.本文讨论 ...

  5. jquery datatable使用简单示例

    目标: 使用 jQuery Datatable 构造数据列表,并且增加或者隐藏相应的列,已达到数据显示要求.同时, jQuery Datatable 强大的功能支持:排序,分页,搜索等. Query ...

  6. dom4j解析XML学习

    原理:把dom与SAX进行了封装 优点:JDOM的一个智能分支.扩充了其灵活性增加了一些额外的功能. package com.dom4j.xml; import java.io.FileNotFoun ...

  7. js和jquery之间的转换

    <!DOCTYPE html><html lang="en"><head> <meta charset="UTF-8" ...

  8. java基础---局部变量和全局变量

    1.成员变量的概念: 成员变量就是属于类的变量,在类中,方法体外定义的变量 1)成员变量又分为两种: 类变量(又称静态变量) 实例变量(又称非静态变量) 类变量(静态变量)   :是被static所修 ...

  9. Mysql资料 查询条件

    目录 一.计算 二.比较 三.逻辑运算符 四.位运算符 五.优先顺序 一.计算 二.比较 三.逻辑运算符 四.位运算符 五.优先顺序 实际上,很少有人能将这些优先级熟练记忆,很多情况下我们都是用&qu ...

  10. 【cs231n笔记】assignment1之KNN

    k-Nearest Neighbor (kNN) 练习 这篇博文是对cs231n课程assignment1的第一个问题KNN算法的完成,参考了一些网上的博客,不具有什么创造性,以个人学习笔记为目的发布 ...