Deeplearning知识蒸馏
Deeplearning知识蒸馏
merge
paddleslim.dist.merge(teacher_program, student_program, data_name_map, place, scope=fluid.global_scope(), name_prefix='teacher_')
merge将teacher_program融合到student_program中。在融合的program中,可以为其中合适的teacher特征图和student特征图添加蒸馏损失函数,从而达到用teacher模型的暗知识(Dark Knowledge)指导student模型学习的目的。
参数:
- teacher_program (Program)-定义了teacher模型的 paddle program
- student_program (Program)-定义了student模型的 paddle program
- data_name_map (dict)-teacher输入接口名与student输入接口名的映射,其中dict的 key 为teacher的输入名,value 为student的输入名
- place (fluid.CPUPlace()|fluid.CUDAPlace(N))-该参数表示程序运行在何种设备上,这里的N为GPU对应的ID
- scope (Scope)-该参数表示程序使用的变量作用域,如果不指定将使用默认的全局作用域。默认值: fluid.global_scope()
- name_prefix (str)-merge操作将统一为teacher的 Variables 添加的名称前缀name_prefix。默认值:’teacher_‘
返回: 无
注解
data_name_map 是 teacher_var name到student_var name的映射 ,如果写反可能无法正确进行merge
使用示例:
import paddle.fluid as fluid
import paddleslim.dist as dist
student_program = fluid.Program()
with fluid.program_guard(student_program):
x = fluid.layers.data(name='x', shape=[1, 28, 28])
conv = fluid.layers.conv2d(x, 32, 1)
out = fluid.layers.conv2d(conv, 64, 3, padding=1)
teacher_program = fluid.Program()
with fluid.program_guard(teacher_program):
y = fluid.layers.data(name='y', shape=[1, 28, 28])
conv = fluid.layers.conv2d(y, 32, 1)
conv = fluid.layers.conv2d(conv, 32, 3, padding=1)
out = fluid.layers.conv2d(conv, 64, 3, padding=1)
data_name_map = {'y':'x'}
USE_GPU = False
place = fluid.CUDAPlace(0) if USE_GPU else fluid.CPUPlace()
dist.merge(teacher_program, student_program,
data_name_map, place)
fsp_loss
paddleslim.dist.fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name, student_var2_name, program=fluid.default_main_program())
fsp_loss为program内的teacher var和student var添加fsp loss,出自论文 A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning
参数:
- teacher_var1_name (str): teacher_var1的名称. 对应的variable是一个形为`[batch_size, x_channel, height, width]`的4-D特征图Tensor,数据类型为float32或float64
- teacher_var2_name (str): teacher_var2的名称. 对应的variable是一个形为`[batch_size, y_channel, height, width]`的4-D特征图Tensor,数据类型为float32或float64。只有y_channel可以与teacher_var1的x_channel不同,其他维度必须与teacher_var1相同
- student_var1_name (str): student_var1的名称. 对应的variable需与teacher_var1尺寸保持一致,是一个形为`[batch_size, x_channel, height, width]`的4-D特征图Tensor,数据类型为float32或float64
- student_var2_name (str): student_var2的名称. 对应的variable需与teacher_var2尺寸保持一致,是一个形为`[batch_size, y_channel, height, width]`的4-D特征图Tensor,数据类型为float32或float64。只有y_channel可以与student_var1的x_channel不同,其他维度必须与student_var1相同
- program (Program): 用于蒸馏训练的fluid program。默认值: fluid.default_main_program()
返回: 由teacher_var1, teacher_var2, student_var1, student_var2组合得到的fsp_loss
使用示例:
l2_loss
paddleslim.dist.l2_loss(teacher_var_name, student_var_name, program=fluid.default_main_program())[[]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py#L118)
: l2_loss为program内的teacher var和student var添加l2 loss
参数:
- teacher_var_name (str): teacher_var的名称.
- student_var_name (str): student_var的名称.
- program (Program): 用于蒸馏训练的fluid program。默认值: fluid.default_main_program()
返回: 由teacher_var, student_var组合得到的l2_loss
使用示例:
soft_label_loss
paddleslim.dist.soft_label_loss(teacher_var_name, student_var_name, program=fluid.default_main_program(), teacher_temperature=1., student_temperature=1.)[[]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py#L136)
soft_label_loss为program内的teacher var和student var添加soft label loss,出自论文 Distilling the Knowledge in a Neural Network
参数:
- teacher_var_name (str): teacher_var的名称.
- student_var_name (str): student_var的名称.
- program (Program): 用于蒸馏训练的fluid program。默认值: fluid.default_main_program()
- teacher_temperature (float): 对teacher_var进行soft操作的温度值,温度值越大得到的特征图越平滑
- student_temperature (float): 对student_var进行soft操作的温度值,温度值越大得到的特征图越平滑
返回: 由teacher_var, student_var组合得到的soft_label_loss
使用示例:
loss
paddleslim.dist.loss(loss_func, program=fluid.default_main_program(), **kwargs) [[]](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/dist/single_distiller.py#L165)
: loss函数支持对任意多对teacher_var和student_var使用自定义损失函数
参数:
- **loss_func**( python function): 自定义的损失函数,输入为teacher var和student var,输出为自定义的loss
- program (Program): 用于蒸馏训练的fluid program。默认值: fluid.default_main_program()
- **kwargs : loss_func输入名与对应variable名称
返回 :自定义的损失函数loss
使用示例:
注解
在添加蒸馏loss时会引入新的variable,需要注意新引入的variable不要与student variables命名冲突。这里建议两种用法(两种方法任选其一即可):
- 建议与student_program使用同一个命名空间,以避免一些未指定名称的variables(例如tmp_0, tmp_1...)多次定义为同一名称出现命名冲突
- 建议在添加蒸馏loss时指定一个命名空间前缀
Deeplearning知识蒸馏的更多相关文章
- 知识蒸馏(Distillation)
蒸馏神经网络取名为蒸馏(Distill),其实是一个非常形象的过程. 我们把数据结构信息和数据本身当作一个混合物,分布信息通过概率分布被分离出来.首先,T值很大,相当于用很高的温度将关键的分布信息从原 ...
- 【论文考古】知识蒸馏 Distilling the Knowledge in a Neural Network
论文内容 G. Hinton, O. Vinyals, and J. Dean, "Distilling the Knowledge in a Neural Network." 2 ...
- 【DKNN】Distilling the Knowledge in a Neural Network 第一次提出神经网络的知识蒸馏概念
原文链接 小样本学习与智能前沿 . 在这个公众号后台回复"DKNN",即可获得课件电子资源. 文章已经表明,对于将知识从整体模型或高度正则化的大型模型转换为较小的蒸馏模型,蒸馏非常 ...
- 通过Python包来剪枝、蒸馏DNN
用 Distiller 压缩 PyTorch 模型 作者: PyTorch 中文网发布: 2018年7月15日 5,101阅读 0评论 近日,Intel 开源了一个用于神经网络压缩的开源 Python ...
- ICCV2021 | 简单有效的长尾视觉识别新方案:蒸馏自监督(SSD)
前言 本文提出了一种概念上简单但特别有效的长尾视觉识别的多阶段训练方案,称为蒸馏自监督(Self Supervision to Distillation, SSD).在三个长尾识别基准:Ima ...
- Bert不完全手册1. 推理太慢?模型蒸馏
模型蒸馏的目标主要用于模型的线上部署,解决Bert太大,推理太慢的问题.因此用一个小模型去逼近大模型的效果,实现的方式一般是Teacher-Stuent框架,先用大模型(Teacher)去对样本进行拟 ...
- DeiT:注意力也能蒸馏
DeiT:注意力也能蒸馏 <Training data-efficient image transformers & distillation through attention> ...
- 知识图谱顶刊综述 - (2021年4月) A Survey on Knowledge Graphs: Representation, Acquisition, and Applications
知识图谱综述(2021.4) 论文地址:A Survey on Knowledge Graphs: Representation, Acquisition, and Applications 目录 知 ...
- Bag of Tricks for Image Classification with Convolutional Neural Networks论文笔记
一.高效的训练 1.Large-batch training 使用大的batch size可能会减小训练过程(收敛的慢?我之前训练的时候挺喜欢用较大的batch size),即在相同的迭代次数 ...
随机推荐
- 4- MySQL创建表以及增删改查
查看表结构 查看表的结构,使用命令:desc 表明: 创建表(命令) 格式:使用create table创建表,必须给出下列信息: 1.新表的名字. 2.表中列的名字和定义,用逗号隔开. 语法: cr ...
- 【JDK8】Java8 新增BASE64加解密API
什么是Base64编码? Base64是网络上最常见的用于传输8Bit字节码的编码方式之一,Base64就是一种基于64个可打印字符来表示二进制数据的方法 基于64个字符A-Z,a-z,0-9,+,/ ...
- hdu1251 hash或者字典树
题意: 统计难题 Problem Description Ignatius最近遇到一个难题,老师交给他很多单词(只有小写字母组成,不会有重复的单词出现),现在老师要他统计出以某个字符串为前缀的单词数量 ...
- Win64 驱动内核编程-25.X64枚举和隐藏内核模块
X64枚举和隐藏内核模块 在 WIN64 上枚举内核模块的人方法:使用 ZwQuerySystemInformation 的第 11 号功能和枚举 KLDR_DATA_TABLE_ENTRY 中的 I ...
- spring中注解@Resource 与@Autowire 区别
① .@Resource 是根据名字进行自动装配:@Autowire是通过类型进行装配. ②. @Resource 注解是 jdk 的:@Autowire 是spring的.
- linux命令解压压缩rar文件
一.widonds下打包rar文件并上传 yum install lrzsz rz test.rar 二.下载并安装rar软件 2.1 下载 mkdir -p /home/oldboy/tools c ...
- 5分钟,教你用Python每天跟女朋友说1000遍土味情话!
- 轻量级工具Vite到底牛在哪——一文全知道
转载请注明出处:葡萄城官网,葡萄城为开发者提供专业的开发工具.解决方案和服务,赋能开发者. 原文参考:https://www.sitepoint.com/vitejs-front-end-build- ...
- If-Else 太多,如何优化!!!
完全不必要的 Else 块 public void consumer(int product) { if (product > 1) { // do something } else { // ...
- SpringBoot系列——事件发布与监听
前言 日常开发中,我们经常会碰到这样的业务场景:用户注册,注册成功后需要发送邮箱.短信提示用户,通常我们都是这样写: /** * 用户注册 */ @GetMapping("/userRegi ...