Cyr E C, Gulian M, Patel R G, et al. Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint.[J]. arXiv: Learning, 2019.

@article{cyr2019robust,

title={Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint.},

author={Cyr, Eric C and Gulian, Mamikon and Patel, Ravi G and Perego, Mauro and Trask, Nathaniel},

journal={arXiv: Learning},

year={2019}}

这篇文章介绍了一种梯度下降的改进, 以及Box参数初始化方法.

主要内容

\[\tag{6}
\arg \min_{\xi^L \xi^H} \sum_{k=1}^K \epsilon_k \|\mathcal{L}_k[u] - \sum_i \xi_i^L \mathcal{L}_k [\Phi_i(x, \xi^H)]\|^2_{\ell_2(\mathcal{X}_k)}.
\]

LSGD

固定\(\xi^H, \mathcal{X}_k\), 并令\(\epsilon_k=1\), 则问题(6)退化为一个最小二乘问题

\[\arg \min_{\xi^L} \|A\xi^L -b\|^2_{\ell_2 (\mathcal{X})},
\]

其中\(b_i = \mathcal{L}[u](x_i)\), \(A_{ij}=\mathcal{L} [\Phi_j (x_i, \xi^H)]\), \(x_i \in \mathcal{X}, i=1,\ldots, N, j=1, \ldots, w\).

所以算法如下

Box 初始化

该算法期望使得feature-rich,但是我不知道这个rich从何而来.

假设第\(l\)层的输入为\(x \in \mathbb{R}^{d_1}\), 输出为\(y \in \mathbb{R}^{d_2}\), 则该层的权重矩阵\(W \in \mathbb{R}^{d_2 \times d_1}\). 我们逐行地定义\(W\):

  1. 采样\(p\), \(p\sim U[0 ,1]^{d_1}\);
  2. 采样\(n\), \(n \sim \mathcal{N}(0,I_{d_1})\), 并令\(n=n/\|n\|\);
  3. 求参数\(k\)使得
\[\max_{x \in [0, 1]^{d_1}} \sigma(k(x-p) \cdot n)=1.
\]
  1. \(W\)第\(i\)行\(w_i=kn^T\), \(b_i=-kp \cdot n\).

其中\(\sigma\)表示激活函数, 文中指的是ReLU.

求解参数\(k\):

  1. \(p_{max} = \max (0, \mathrm{sign}(n))\);
  2. \(k=\frac{1}{(p_{max}-p) \cdot n}\)

此\(k\)即为所需\(k\), 只需证明\(p_{max}\)是最大化

\[(x - p)\cdot n, \quad x \in [0,1]^{d_1}
\]

的解. 最大化上式, 可以分解为

\[\max_{x_i \in [0, 1]} x_in_i,
\]

故\(x_i = \max(0, \mathrm{sign}(n_i))\).

这个初始化有什么好处呢, 可以发现, 输入\(x \in[0,1]^{d_1}\)满足, 则输出\(y \in [0, 1]^{d_2}\), 保证二者的"值域"范围一致, 以此类推整个网络节点值范围近似.



如果, 作者构建了一个2-2-2-2-2-2-2-2的网络, 可以发现, Xavier 和 Kaiming的初始化方法经过一定层数后, 就会塌缩在某个点, 而Box初始化方法能够缓解这一现象.

下面是文中列出的算法(与这里的符号有一点点不同, 另外\(b\)作者应该是遗漏了负号).

Box for Resnet

因为Resnet特殊的结构,

\[y=(W+I)x+b.
\]

假设\(x \in [0,m]^{d_1}\), 则:

  1. 采样\(p\), \(p\sim U[0 ,m]^{d_1}\);
  2. 采样\(n\), \(n \sim \mathcal{N}(0,I_{d_1})\), 并令\(n=n/\|n\|\);
  3. 求参数\(k\)使得
\[\max_{x \in [0, m]^{d_1}} \sigma(k(x-p) \cdot n)=\delta m.
\]
  1. \(W\)第\(i\)行\(w_i=kn^T\), \(b_i=-kp \cdot n\).
\[k=\frac{\delta m}{(mp_{max}-p) \cdot n}.
\]

若第一层输入\(x_i \in [0,1]\), 去\(\delta=1/L\), 其中\(L\)为总的层数, 则

\[[0,1] \rightarrow [0,1+\frac{1}{L}] \rightarrow [0,(1+\frac{1}{L})^2] \rightarrow \cdots
\]

代码



'''
initialization.py
'''
import torch
import torch.nn as nn
import warnings def generate(size, m, delta):
p = torch.rand(size) * m
n = torch.randn(size)
temp = 1 / torch.norm(n, p=2, dim=1, keepdim=True)
n = temp * n
pmax = nn.functional.relu(torch.sign(n)) * m
temp = (pmax - p) * n
k = (m * delta) / temp.sum(dim=1, keepdim=True)
w = k * n
b = -(w * p).sum(dim=1)
return w, b def box_init(module, m=1, delta=1):
if isinstance(module, nn.Linear):
w, b = generate(module.weight.shape, m, delta)
try:
module.weight.data = w
module.bias.data = b
except AttributeError as e:
s = "Error: \n" + str(e) + "\n stops the initialization" \
" for this module: {}".format(module)
warnings.warn(s) elif isinstance(module, nn.Conv2d):
outc, inc, h, w = module.weight.size()
w, b = generate((outc, inc * h * w), m, delta)
try:
module.weight.data = w.reshape(module.weight.size())
module.bias.data = b
except AttributeError as e:
s = "Error: \n" + str(e) + "\n stops the initialization" \
" for this module: {}".format(module)
warnings.warn(s) else:
pass


"""config.py"""

nums = 10
layers = 6
method = "kaiming" #box/xavier/kaiming
net = "Net" #Net/ResNet


"""
测试
""" import torch
import torch.nn as nn
import config
from initialization import box_init class Net(nn.Module): def __init__(self, l):
super(Net, self).__init__() self.linears = []
for i in range(l):
name = "linear" + str(i)
self.__setattr__(name, nn.Sequential(nn.Linear(2, 2),
nn.ReLU()))
self.linears.append(self.__getattr__(name))
if config.method == 'box':
self.box_init()
elif config.method == "xavier":
self.xavier_init()
else:
self.kaiming_init() def box_init(self):
for module in self.modules():
box_init(module) def xavier_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.xavier_normal_(module.weight) def kaiming_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(module.weight) def forward(self, x):
out = []
temp = x
for linear in self.linears:
temp = linear(temp)
out.append(temp)
return out class ResNet(nn.Module): def __init__(self, l):
super(ResNet, self).__init__() self.linears = []
for i in range(l):
name = "linear" + str(i)
self.__setattr__(name, nn.Sequential(nn.Linear(2, 2),
nn.ReLU()))
self.linears.append(self.__getattr__(name))
if config.method == 'box':
self.box_init(l)
elif config.method == "xavier":
self.xavier_init()
else:
self.kaiming_init() def box_init(self, layers):
delta = 1 / layers
m = 1. + delta
l = 0
for module in self.modules():
if isinstance(module, (nn.Linear)):
if l == 0:
box_init(module, 1, 1)
else:
box_init(module, m ** l, delta)
l += 1 def xavier_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.xavier_normal_(module.weight) def kaiming_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(module.weight) def forward(self, x):
out = []
temp = x
for linear in self.linears:
temp = linear(temp) + temp
out.append(temp)
return out if config.net == "Net":
net = Net(config.layers)
else:
net = ResNet(config.layers) x = torch.linspace(0, 1, config.nums)
y = torch.linspace(0, 1, config.nums) grid_x, grid_y = torch.meshgrid(x, y) x = grid_x.flatten()
y = grid_y.flatten()
data = torch.stack((x, y), dim=1)
outs = net(data) import matplotlib.pyplot as plt def axplot(x, y, ax):
x = x.detach().numpy()
y = y.detach().numpy()
ax.scatter(x, y) def plot(x, y, outs):
fig, axs = plt.subplots(1, config.layers+1, sharey=True, figsize=(12, 2))
axs[0].scatter(x, y)
axs[0].set(title="layer0")
for i in range(config.layers):
ax = axs[i+1]
out = outs[i]
x = out[:, 0]
y = out[:, 1]
axplot(x, y, ax)
ax.set(title="layer"+str(i+1))
plt.tight_layout()
plt.savefig("C:/Users/pkavs/Desktop/fig.png")
#plt.show()
plot(x, y, outs)

[Box] Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint的更多相关文章

  1. Training Deep Neural Networks

    http://handong1587.github.io/deep_learning/2015/10/09/training-dnn.html  //转载于 Training Deep Neural ...

  2. Coursera Deep Learning 2 Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization - week1, Assignment(Initialization)

    声明:所有内容来自coursera,作为个人学习笔记记录在这里. Initialization Welcome to the first assignment of "Improving D ...

  3. Training (deep) Neural Networks Part: 1

    Training (deep) Neural Networks Part: 1 Nowadays training deep learning models have become extremely ...

  4. Exploring Architectural Ingredients of Adversarially Robust Deep Neural Networks

    目录 概 主要内容 深度 宽度 代码 Huang H., Wang Y., Erfani S., Gu Q., Bailey J. and Ma X. Exploring architectural ...

  5. [C4] Andrew Ng - Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization

    About this Course This course will teach you the "magic" of getting deep learning to work ...

  6. (转)Understanding, generalisation, and transfer learning in deep neural networks

    Understanding, generalisation, and transfer learning in deep neural networks FEBRUARY 27, 2017   Thi ...

  7. On Explainability of Deep Neural Networks

    On Explainability of Deep Neural Networks « Learning F# Functional Data Structures and Algorithms is ...

  8. Classifying plankton with deep neural networks

    Classifying plankton with deep neural networks The National Data Science Bowl, a data science compet ...

  9. Must Know Tips/Tricks in Deep Neural Networks

    Must Know Tips/Tricks in Deep Neural Networks (by Xiu-Shen Wei)   Deep Neural Networks, especially C ...

随机推荐

  1. Rational Rose的安装及使用教程(包括菜单命令解释、操作向导说明、快捷命令说明)

    一.安装教程 我安装时用的是镜像文件,所以安装前需要辅助软件来处理镜像文件.我用到的是UltraISO.UltraISO中文名叫软碟通 是一款功能强大而又方便实用的光盘映像文件的制作/编辑/转换工具, ...

  2. day 03Linux修改命令提示符

    day 03Linux修改命令提示符 昨日回顾 1.选择客户机操作系统: Microsoft Windows # 一次只能安装一台电脑 Linux(推荐) VMware ESX # 服务器版本VNwa ...

  3. Sharding-JDBC 实现水平分库分表

    1.需求分析

  4. 【Linux】【Basis】文件系统

    FHS:Filesystem Hierarchy Standard Web site: https://wiki.linuxfoundation.org/lsb/fhs http://www.path ...

  5. 二、SpringBoot实现上传文件到fastDFS文件服务器

    上篇文章介绍了如何使用docker安装fastDFS文件服务器,这一篇就介绍整合springBoot实现文件上传到fastDFS文件服务器 1.pom.xml文件添加依赖 <!-- 连接fast ...

  6. 【C/C++】习题3-7 DNA/算法竞赛入门经典/数组与字符串

    [题目] 输入m组n长的DNA序列,要求找出和其他Hamming距离最小的那个序列,求其与其他的Hamming距离总和. 如果有多个序列,求字典序最小的. [注]这道题是我理解错误,不是找出输入的序列 ...

  7. 基于Web的质量和测试度量指标

    直观了解软件质量和测试的完整性 VectorCAST/Analytics可提供便于用户理解的web仪表盘视图来显示软件代码质量和测试完整性指标,让用户能够掌握单个代码库的趋势,或对比多个代码库的度量指 ...

  8. 如何用Serverless让SaaS获得更灵活的租户隔离和更优的资源开销

    关于SaaS和Serverless,相信关注我的很多读者都已经不陌生,所以这篇不会聊它们的技术细节,而将重点放在SaaS软件架构中引入Serverless之后,能给我们的SaaS软件带来多大的收益. ...

  9. WebDriver驱动下载地址

    chrome的webdriver: http://chromedriver.storage.googleapis.com/index.html Firefox驱动下载地址为:https://githu ...

  10. C# .exe和.dll文件图标资源提取工具

    Windows 可执行文件(.exe)和动态库文件(.dll)图标资源提取工具 GitHub 功能 图标资源预览 图标资源导出(仅支持导出 PNG 格式) 代码 获取图标资源使用了 Win32 API ...