Cyr E C, Gulian M, Patel R G, et al. Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint.[J]. arXiv: Learning, 2019.

@article{cyr2019robust,

title={Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint.},

author={Cyr, Eric C and Gulian, Mamikon and Patel, Ravi G and Perego, Mauro and Trask, Nathaniel},

journal={arXiv: Learning},

year={2019}}

这篇文章介绍了一种梯度下降的改进, 以及Box参数初始化方法.

主要内容

\[\tag{6}
\arg \min_{\xi^L \xi^H} \sum_{k=1}^K \epsilon_k \|\mathcal{L}_k[u] - \sum_i \xi_i^L \mathcal{L}_k [\Phi_i(x, \xi^H)]\|^2_{\ell_2(\mathcal{X}_k)}.
\]

LSGD

固定\(\xi^H, \mathcal{X}_k\), 并令\(\epsilon_k=1\), 则问题(6)退化为一个最小二乘问题

\[\arg \min_{\xi^L} \|A\xi^L -b\|^2_{\ell_2 (\mathcal{X})},
\]

其中\(b_i = \mathcal{L}[u](x_i)\), \(A_{ij}=\mathcal{L} [\Phi_j (x_i, \xi^H)]\), \(x_i \in \mathcal{X}, i=1,\ldots, N, j=1, \ldots, w\).

所以算法如下

Box 初始化

该算法期望使得feature-rich,但是我不知道这个rich从何而来.

假设第\(l\)层的输入为\(x \in \mathbb{R}^{d_1}\), 输出为\(y \in \mathbb{R}^{d_2}\), 则该层的权重矩阵\(W \in \mathbb{R}^{d_2 \times d_1}\). 我们逐行地定义\(W\):

  1. 采样\(p\), \(p\sim U[0 ,1]^{d_1}\);
  2. 采样\(n\), \(n \sim \mathcal{N}(0,I_{d_1})\), 并令\(n=n/\|n\|\);
  3. 求参数\(k\)使得
\[\max_{x \in [0, 1]^{d_1}} \sigma(k(x-p) \cdot n)=1.
\]
  1. \(W\)第\(i\)行\(w_i=kn^T\), \(b_i=-kp \cdot n\).

其中\(\sigma\)表示激活函数, 文中指的是ReLU.

求解参数\(k\):

  1. \(p_{max} = \max (0, \mathrm{sign}(n))\);
  2. \(k=\frac{1}{(p_{max}-p) \cdot n}\)

此\(k\)即为所需\(k\), 只需证明\(p_{max}\)是最大化

\[(x - p)\cdot n, \quad x \in [0,1]^{d_1}
\]

的解. 最大化上式, 可以分解为

\[\max_{x_i \in [0, 1]} x_in_i,
\]

故\(x_i = \max(0, \mathrm{sign}(n_i))\).

这个初始化有什么好处呢, 可以发现, 输入\(x \in[0,1]^{d_1}\)满足, 则输出\(y \in [0, 1]^{d_2}\), 保证二者的"值域"范围一致, 以此类推整个网络节点值范围近似.



如果, 作者构建了一个2-2-2-2-2-2-2-2的网络, 可以发现, Xavier 和 Kaiming的初始化方法经过一定层数后, 就会塌缩在某个点, 而Box初始化方法能够缓解这一现象.

下面是文中列出的算法(与这里的符号有一点点不同, 另外\(b\)作者应该是遗漏了负号).

Box for Resnet

因为Resnet特殊的结构,

\[y=(W+I)x+b.
\]

假设\(x \in [0,m]^{d_1}\), 则:

  1. 采样\(p\), \(p\sim U[0 ,m]^{d_1}\);
  2. 采样\(n\), \(n \sim \mathcal{N}(0,I_{d_1})\), 并令\(n=n/\|n\|\);
  3. 求参数\(k\)使得
\[\max_{x \in [0, m]^{d_1}} \sigma(k(x-p) \cdot n)=\delta m.
\]
  1. \(W\)第\(i\)行\(w_i=kn^T\), \(b_i=-kp \cdot n\).
\[k=\frac{\delta m}{(mp_{max}-p) \cdot n}.
\]

若第一层输入\(x_i \in [0,1]\), 去\(\delta=1/L\), 其中\(L\)为总的层数, 则

\[[0,1] \rightarrow [0,1+\frac{1}{L}] \rightarrow [0,(1+\frac{1}{L})^2] \rightarrow \cdots
\]

代码



'''
initialization.py
'''
import torch
import torch.nn as nn
import warnings def generate(size, m, delta):
p = torch.rand(size) * m
n = torch.randn(size)
temp = 1 / torch.norm(n, p=2, dim=1, keepdim=True)
n = temp * n
pmax = nn.functional.relu(torch.sign(n)) * m
temp = (pmax - p) * n
k = (m * delta) / temp.sum(dim=1, keepdim=True)
w = k * n
b = -(w * p).sum(dim=1)
return w, b def box_init(module, m=1, delta=1):
if isinstance(module, nn.Linear):
w, b = generate(module.weight.shape, m, delta)
try:
module.weight.data = w
module.bias.data = b
except AttributeError as e:
s = "Error: \n" + str(e) + "\n stops the initialization" \
" for this module: {}".format(module)
warnings.warn(s) elif isinstance(module, nn.Conv2d):
outc, inc, h, w = module.weight.size()
w, b = generate((outc, inc * h * w), m, delta)
try:
module.weight.data = w.reshape(module.weight.size())
module.bias.data = b
except AttributeError as e:
s = "Error: \n" + str(e) + "\n stops the initialization" \
" for this module: {}".format(module)
warnings.warn(s) else:
pass


"""config.py"""

nums = 10
layers = 6
method = "kaiming" #box/xavier/kaiming
net = "Net" #Net/ResNet


"""
测试
""" import torch
import torch.nn as nn
import config
from initialization import box_init class Net(nn.Module): def __init__(self, l):
super(Net, self).__init__() self.linears = []
for i in range(l):
name = "linear" + str(i)
self.__setattr__(name, nn.Sequential(nn.Linear(2, 2),
nn.ReLU()))
self.linears.append(self.__getattr__(name))
if config.method == 'box':
self.box_init()
elif config.method == "xavier":
self.xavier_init()
else:
self.kaiming_init() def box_init(self):
for module in self.modules():
box_init(module) def xavier_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.xavier_normal_(module.weight) def kaiming_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(module.weight) def forward(self, x):
out = []
temp = x
for linear in self.linears:
temp = linear(temp)
out.append(temp)
return out class ResNet(nn.Module): def __init__(self, l):
super(ResNet, self).__init__() self.linears = []
for i in range(l):
name = "linear" + str(i)
self.__setattr__(name, nn.Sequential(nn.Linear(2, 2),
nn.ReLU()))
self.linears.append(self.__getattr__(name))
if config.method == 'box':
self.box_init(l)
elif config.method == "xavier":
self.xavier_init()
else:
self.kaiming_init() def box_init(self, layers):
delta = 1 / layers
m = 1. + delta
l = 0
for module in self.modules():
if isinstance(module, (nn.Linear)):
if l == 0:
box_init(module, 1, 1)
else:
box_init(module, m ** l, delta)
l += 1 def xavier_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.xavier_normal_(module.weight) def kaiming_init(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(module.weight) def forward(self, x):
out = []
temp = x
for linear in self.linears:
temp = linear(temp) + temp
out.append(temp)
return out if config.net == "Net":
net = Net(config.layers)
else:
net = ResNet(config.layers) x = torch.linspace(0, 1, config.nums)
y = torch.linspace(0, 1, config.nums) grid_x, grid_y = torch.meshgrid(x, y) x = grid_x.flatten()
y = grid_y.flatten()
data = torch.stack((x, y), dim=1)
outs = net(data) import matplotlib.pyplot as plt def axplot(x, y, ax):
x = x.detach().numpy()
y = y.detach().numpy()
ax.scatter(x, y) def plot(x, y, outs):
fig, axs = plt.subplots(1, config.layers+1, sharey=True, figsize=(12, 2))
axs[0].scatter(x, y)
axs[0].set(title="layer0")
for i in range(config.layers):
ax = axs[i+1]
out = outs[i]
x = out[:, 0]
y = out[:, 1]
axplot(x, y, ax)
ax.set(title="layer"+str(i+1))
plt.tight_layout()
plt.savefig("C:/Users/pkavs/Desktop/fig.png")
#plt.show()
plot(x, y, outs)

[Box] Robust Training and Initialization of Deep Neural Networks: An Adaptive Basis Viewpoint的更多相关文章

  1. Training Deep Neural Networks

    http://handong1587.github.io/deep_learning/2015/10/09/training-dnn.html  //转载于 Training Deep Neural ...

  2. Coursera Deep Learning 2 Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization - week1, Assignment(Initialization)

    声明:所有内容来自coursera,作为个人学习笔记记录在这里. Initialization Welcome to the first assignment of "Improving D ...

  3. Training (deep) Neural Networks Part: 1

    Training (deep) Neural Networks Part: 1 Nowadays training deep learning models have become extremely ...

  4. Exploring Architectural Ingredients of Adversarially Robust Deep Neural Networks

    目录 概 主要内容 深度 宽度 代码 Huang H., Wang Y., Erfani S., Gu Q., Bailey J. and Ma X. Exploring architectural ...

  5. [C4] Andrew Ng - Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization

    About this Course This course will teach you the "magic" of getting deep learning to work ...

  6. (转)Understanding, generalisation, and transfer learning in deep neural networks

    Understanding, generalisation, and transfer learning in deep neural networks FEBRUARY 27, 2017   Thi ...

  7. On Explainability of Deep Neural Networks

    On Explainability of Deep Neural Networks « Learning F# Functional Data Structures and Algorithms is ...

  8. Classifying plankton with deep neural networks

    Classifying plankton with deep neural networks The National Data Science Bowl, a data science compet ...

  9. Must Know Tips/Tricks in Deep Neural Networks

    Must Know Tips/Tricks in Deep Neural Networks (by Xiu-Shen Wei)   Deep Neural Networks, especially C ...

随机推荐

  1. Android系统编程入门系列之硬件交互——多媒体麦克风

    在多媒体摄像头及相关硬件文章中,对摄像头的使用方式需要区分应用程序的目标版本以使用不同的代码流程,而与之相比,麦克风硬件的使用就简单多了. 麦克风及相关硬件 麦克风硬件在移动设备上作为音频的采集设备, ...

  2. day9 文件处理

    day09 文件处理 一.注册与登录功能 username = input('请输入您的密码:').strip() password = input('请输入您的密码:').strip() f = o ...

  3. 打破砂锅问到底!HTTP和HTTPS详解

    HTTP 引自维基百科HTTP:超文本传输协议(英文:HyperText Transfer Protocol,缩写:HTTP)是一种用于分布式.协作式和超媒体信息系统的应用层协议.HTTP是万维网的数 ...

  4. jmeter进阶

    1.如果(if)控制器的使用     2.参数的调用 3.数据库的链接

  5. adb命令对app进行测试

    1.何为adb adb android  debug  bridge ,sdk包中的工具,将Platform-tooks 和tools  两个路径配置到环境变量中 2.SDK下载链接:http://t ...

  6. Docker的常用命令总结

    一.普通指令 启动 Docker sudo systemctl start docker 停止 Docker sudo systemctl stop docker 普通重启 Docker sudo s ...

  7. html之table的tr加间隔

    <table style="border-collapse:separate; border-spacing:0px 10px;"> <tr> <td ...

  8. 【Java多线程】ExecutorService和ThreadPoolExecutor

    ExecutorService Java.util.concurrent.ExecutorService接口代表一种异步执行机制,它能够在后台执行任务.因此ExecutorService与thread ...

  9. solr8.2

    https://www.cnblogs.com/carlosouyang/p/11352779.html

  10. ssm动态查询向前台传json

    1.数据协议层 public User selectById(Integer id);//通过id值查询用户 2.数据层 <select id="selectById" re ...