TensorFlow官方发布剪枝优化工具:参数减少80%,精度几乎不变
去年TensorFlow官方推出了模型优化工具,最多能将模型尺寸减小4倍,运行速度提高3倍。
最近现又有一款新工具加入模型优化“豪华套餐”,这就是基于Keras的剪枝优化工具。
训练AI模型有时需要大量硬件资源,但不是每个人都有4个GPU的豪华配置,剪枝优化可以帮你缩小模型尺寸,以较小的代价进行推理。
什么是权重剪枝?
权重剪枝(Weight Pruning)优化,就是消除权重张量中不必要的值,减少神经网络层之间的连接数量,减少计算中涉及的参数,从而降低操作次数。
这样做的好处是压缩了网络的存储空间,尤其是稀疏张量特别适合压缩。例如,经过处理可以将MNIST的90%稀疏度模型从12MB压缩到2MB。
此外,权重剪枝与量化(quantization)兼容,从而产生复合效益。通过训练后量化(post-training quantization),还能将剪枝后的模型从2MB进一步压缩到仅0.5MB 。
TensorFlow官方承诺,将来TensorFlow Lite会增加对稀疏表示和计算的支持,从而扩展运行内存的压缩优势,并释放性能提升。
优化效果
权重剪枝优化可以用于不同任务、不同类型的模型,从图像处理的CNN用于语音处理的RNN。下表显示了其中一些实验结果。
以GNMT从德语翻译到英语的模型为例,原模型的BLEU为29.47。指定80%的稀疏度,经优化后,张量中的非零参数可以从211M压缩到44M,准确度基本没有损失。
使用方法
现在的权重剪枝API建立在Keras之上,因此开发者可以非常方便地将此技术应用于任何现有的Keras训练模型中。
开发者可以指定最终目标稀疏度(比如50%),以及执行剪枝的计划(比如2000步开始剪枝,在4000步时停止,并且每100步进行一次),以及剪枝结构的可选配置。
import tensorflow_model_optimization as tfmotmodel = build_your_model()pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.0, final_sparsity=0.5,begin_step=2000, end_step=4000)model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=pruning_schedule)…model_for_pruning.fit(…) tensorflow_model_optimization as tfmot
model = build_your_model()
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0, final_sparsity=0.5,
begin_step=2000, end_step=4000)
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=pruning_schedule)
…
model_for_pruning.fit(…)
△ 三个不同张量,左边的没有稀疏度,中心的有多个单独0值,右边的有1x2的稀疏块。
随着训练的进行,剪枝过程开始被执行。在这个过程中,它会消除消除张量中最接近零的权重,直到达到当前稀疏度目标。
每次计划执行剪枝程序时,都会重新计算当前稀疏度目标,根据平滑上升函数逐渐增加稀疏度来达到最终目标稀疏度,从0%开始直到结束。
用户也可以根据需要调整这个上升函数。在某些情况下,可以安排训练过程在某个步骤达到一定收敛级别之后才开始优化,或者在训练总步数之前结束剪枝,以便在达到最终目标稀疏度时进一步微调系统。
△权重张量剪枝动画,黑色的点表示非零权重,随着训练的进行,稀疏度逐渐增加
GitHub地址:
https://github.com/tensorflow/model-optimization
官方教程:
https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras
— 完 —
欢迎关注磐创博客资源汇总站:http://docs.panchuang.net/
欢迎关注PyTorch官方中文教程站:http://pytorch.panchuang.net/
TensorFlow官方发布剪枝优化工具:参数减少80%,精度几乎不变的更多相关文章
- Qt程序打包发布方法(使用官方提供的windeployqt工具)
Qt程序打包发布方法(使用官方提供的windeployqt工具) 转自:http://tieba.baidu.com/p/3730103947?qq-pf-to=pcqq.group Qt 官方开发环 ...
- PyTorch官方中文文档:torch.optim 优化器参数
内容预览: step(closure) 进行单次优化 (参数更新). 参数: closure (callable) –...~ 参数: params (iterable) – 待优化参数的iterab ...
- 进程优化工具Process Lasso Pro 8.4官方版+激活破解方法
Process Lasso是一款来自美国的系统进程优化工具,基于特殊算法动态调整进程的优先级别,通过合理的设置进程优先级来实现降低系统负担的功能.可有效避免蓝 屏.假死.进程停止响应.进程占用 CPU ...
- 发布《Linux工具快速教程》
发布<Linux工具快速教程> 阶段性的完成了这本书开源书籍,发布出来给有需要的朋友,同时也欢迎更多的朋友加入进来,完善这本书: 本书Github地址:https://github.com ...
- SQLSERVER复制优化之一《减少包大小》
原文:SQLSERVER复制优化之一<减少包大小> SQLSERVER复制优化之一<减少包大小> 自从搭了复制之后以为可以安枕无忧了,谁不知问题接踵而来 这次遇到的问题是丢包, ...
- 教你使用Android SDK布局优化工具layoutopt
创建好看的Android布局是个不小的挑战,当你花了数小时调整好它们适应多种设备后,你通常不想再重新调整,但笨重的嵌套布局效率往往非常低下,幸运的是,在Android SDK中有一个工具可以帮助你优化 ...
- 解析Tensorflow官方English-Franch翻译器demo
今天我们来解析下Tensorflow的Seq2Seq的demo.继上篇博客的PTM模型之后,Tensorflow官方也开放了名为translate的demo,这个demo对比之前的PTM要大了很多(首 ...
- 解析Tensorflow官方PTB模型的demo
RNN 模型作为一个可以学习时间序列的模型被认为是深度学习中比较重要的一类模型.在Tensorflow的官方教程中,有两个与之相关的模型被实现出来.第一个模型是围绕着Zaremba的论文Recurre ...
- mysql优化———第二篇:数据库优化调整参数
摘要 参数调优内容: 1. 内存利用方面 2. 日志控制方面 3.文件IO分配,空间占用方面 4. 其它相关参数 一 摘要 通过参数提高MYSQL的性能.核心思想如下: 1 提高my ...
随机推荐
- [置顶]
利用Python 提醒实验室同学值日(自动发送邮件)
前言: 在实验室里一直存在着一个问题,就是老是有人忘记提醒下一个人值日,然后值日就被迫中断了.毕竟良好的 卫生环境需要大家一起来维护的!没办法只能想出一些小对策了. 解决思路: 首先,我 ...
- C++走向远洋——25(项目二,游戏类)
*/ * Copyright (c) 2016,烟台大学计算机与控制工程学院 * All rights reserved. * 文件名:game.cpp * 作者:常轩 * 微信公众号:Worldhe ...
- 自动清理IIS log 日志脚本
系统环境:windows server 2012 r2 IIS 版本:IIS8 操作实现清理IIS log File 脚本如下: @echo off ::自动清理IIS Log file set lo ...
- 跟我猜Spring-Boot:bean的创建
废话在前 最近几年的技术路子很杂,先是node,然后是php,后来是openresty,再后来转到了java,而接触的框架(Framework),也越发的复杂,从最开始的express/koa,到lu ...
- Java基础--数组的定义
1.数组的定义 数组:一组能够储存相同数据类型值的变量的集合. 2.数组的赋值方式 (1)使用默认的初始值来初始化数组中的每一个元素 语法:数组元素类型[]数组名 = new数组元素类型[数组中元素的 ...
- Python基础--动态传参
形参的顺序: 位置 *arg 默认值 **args ps:可以随便搭配,但是*和**以及默认值的位置顺序不能变 *,** 形参:聚合 位置参数* >>元祖 关键字** > ...
- ASP.NET Core中的Http缓存
ASP.NET Core中的Http缓存 Http响应缓存可减少客户端或代理对web服务器发出的请求数.响应缓存还减少了web服务器生成响应所需的工作量.响应缓存由Http请求中的header控制. ...
- 误删除所有redo日志的一组成员的处理过程
系统中共有3个日志文件组,每个组中各有一个日志文件成员.往系统中添加一个日志文件组,组中日志文件成员数量是2.SQL> alter database add logfile group 4 (' ...
- 为企业提供存储功能的Red Hat Stratis 2.0.1发布了
导读 Red Hat的Stratis存储项目用于在Linux上提供企业存储功能,以与ZFS和Btrfs之类的产品竞争,同时在LVM和XFS之上构建,这是其2020年守护进程的首次更新. 通过Strat ...
- Python3 面向对象之:多继承
多继承 class ShenXian: # 神仙 def fei(self): print("神仙都会⻜") class Monkey: # 猴 def chitao(self): ...