【技术博客】Pytorch代码生成
开发组在开发过程中,都不可避免地遇到了一些困难或问题,但都最终想出办法克服了。我们认为这样的经验是有必要记录下来的,因此就有了【技术博客】。
Pytorch代码生成经验文档
关于模型代码的生成,主要思路为从根节点开始进行广度优先搜索,从而自顶向下依次生成相关层的代码。这里和搜索相关的主要有三个数据结构:
- Q:队列,记录后续继续搜索的节点,即为后续的Node。
- graph:字典,记录整颗搜索树,每个key对应一个Node,Node为自己封装的一个类,里面包含每层的一些信息。记录搜索树的目的是为了后续的正确性验证,如下为Node的定义:
class Node:
def __init__(self, id = None, name = None, in_channels = 1, out_channels = 1, kernel_size = 3,
stride = 1, padding = 0, data = None, activity = None, pool_way = None, cat_dim = None):
self.fa = np.array([], dtype = str)
self.next = np.array([], dtype = str)
self.id = id
self.name = name
self.data = data
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.pool_way = pool_way
self.activity = activity
self.data_shape = np.array([], dtype = int)
self.cat_dim = cat_dim
def add_fa(self, f):
self.fa = np.append(self.fa, f)
def add_next(self, nx):
self.next = np.append(self.next, nx)
- done:字典,记录某节点相关代码是否已经生成,每个key对应一个boolean值。
同时还有以下需要关注的地方:
广度优先搜索。BFS为代码的主要框架。从’start’节点开始搜索,直到遍历结束,做一个线性的扫描。代码框架如下(省略了主要代码):
def make_graph(nets, nets_conn, init_func, forward_func):
#code here Q = queue.Queue()
Q.put(‘start’) #code here while not Q.empty():
cur_id = Q.get()
if GL.done[cur_id]:
continue '''''''''''' Main codes here '''''''''''' GL.done[cur_id] = True return init_func, forward_func
关于全局变量的处理。由于一开始忽略了python变量的特性(不需要声明),所以在一开始第一全局变量的时候是直接定义在文件开头的,但是这样存在的问题是:如果在局部函数中引用全局变量,则此时则是重新定义了一个变量而不是引用,用global关键字代码看上去又很臃肿。所以采取的办法是重新定义了一个GLOB模块,里面存放着需要的所有全局变量。类似于这样:
class GLOB:
def __init__(self):
self.graph = {}
self.done = {}
self.layer_used_time = {'view_layer': 0, 'linear_layer': 0, 'conv1d_layer': 0, 'conv2d_layer': 0, 'element_wise_add_layer':0, 'concatenate_layer':0}
self.nn_linear = 'torch.nn.Linear'
self.nn_conv1d = 'torch.nn.Conv1d'
self.nn_conv2d = 'torch.nn.Conv2d'
self.nn_view = '.view'
self.nn_sequential = 'torch.nn.Sequential'
self.start_layer = ['start']
self.norm_layer = ['conv1d_layer', 'conv2d_layer', 'view_layer', 'linaer_layer']
self.multi_layer = ['element_wise_add_layer', 'concatenate_layer']
self.layers_except_start = self.norm_layer + self.multi_layer这样,只需要在代码里初始化一个GLOB对象GL,这样在任何地方引用全局变量都不会造成困扰。
关于变量名生成。每层的输出数据的名字格式为:层名 + “data_出现的次数”。有一个数据结构”layer_used_time”(字典)专门负责记录每个层出现的次数,同时,会在该层的代码生成结构后更新layer_used_time和done的值。
关于何时初始化和更新graph。在我们的代码中,当从队列中取出一个节点后会执行一个函数:get_next_nodes_and_update_pre_nodes()。该函数的目的是获取和初始化当前节点的儿子节点,记录前端传入该层的其他参数,更新其父子节点,同时返回当前节点的所有祖先节点代码是否已经生成完毕。另外,在该函数内部也会做模型的一部分正确性验证,主要验证搭建的模型里除了拼接层和相加层以外的层是否存在多个父节点或没有节点。该函数实现的功能较多,后期会考虑重构。
关于正确性验证。考虑到用户在搭建模型时不一定能够保证参数的正确,所以我们对参数的合理性是“宽容”的,但是也有硬性的要求,比如只能有一个开始节点,同时除了拼接层和相加层可以有多个父节点以外,其他层有且仅有一个父节点。
关于生成的模型NET中forward函数的返回值。由于搭建的模型允许出现网状结构,所以不能保证模型的出口只有一个,所以现阶段生成的模型会返回所有出度为0的层的输出值,具体顺序参见代码。
附最终生成的代码效果图(例):

【技术博客】Pytorch代码生成的更多相关文章
- 如何写出高质量的技术博客 这边文章出自http://www.jianshu.com/p/ae9ab21a5730 觉得不错直接拿过来了 好东西要大家分享嘛
如何写出高质量的技术博客?答案是:如果你想,就一定能写出高质量的技术博客.看起来很唯心,但这就是事实.有足够愿力去做一件目标明确,有良好反馈系统的事情往往很简单.就是不停地训练,慢慢地,你自己 ...
- ******IT公司面试题汇总+优秀技术博客汇总
滴滴面试题:滴滴打车数据库如何拆分 前端时间去滴滴面试,有一道题目是这样的,滴滴每天有100万的订单,如果让你去设计数据库,你会怎么去设计? 当时我的想法是根据用户id的最后一位对某个特殊的值取%操作 ...
- 转: BAT等研发团队的技术博客
BAT 技术团队博客 1. 美团技术团队博客: 地址: http://tech.meituan.com/ 2. 腾讯社交用户体验设计(ISUX) 地址:http://isux.tencent.c ...
- 解决Eclipse中文乱码 - 技术博客 - 51CTO技术博客 http://hsj69106.blog.51cto.com/1017401/595598/
解决Eclipse中文乱码 - 技术博客 - 51CTO技术博客 http://hsj69106.blog.51cto.com/1017401/595598/
- 欢迎访问我的最新个人技术博客http://zhangxuefei.top
博客园已停止更新,欢迎访问我的最新个人技术博客http://zhangxuefei.top
- 技术博客(初用markdown)。
技术博客 菜鸟教程在这个网站我学到许多有趣的东西,并且弥补了我之前的一些不足之处. 以下为我学习到的内容 输出不同的三位数 以下为代码和输出结果 *** #include<stdio.h> ...
- 技术博客(初用markdown)
技术博客 菜鸟教程在这个网站我学到许多有趣的东西,并且弥补了我之前的一些不足之处. 以下为我学习到的内容. 1 如果想输出多个多位数的时候,可以尝试用多个if语句.如果需要输出3为数的时候,设置三个变 ...
- 【转】【技术博客】Spark性能优化指南——高级篇
http://mp.weixin.qq.com/s?__biz=MjM5NjQ5MTI5OA==&mid=2651745207&idx=1&sn=3d70d59cede236e ...
- 作业一:创建个人技术博客、自我介绍、简单的C程序
年9月14日中午12点: 一.主要内容 建个人技术博客(博客园 www.cnblogs.com) 本学期将通过写博客的方式提交作业,实际上,最终的目的是希望同学们能通过博客的形式记录我们整个学习过程 ...
随机推荐
- 【转载】C#通过Copy方法快速复制DataTable对象
C#中的Datatable数据变量的操作过程中,可以通过DataTable的Copy方法快速复制当前的DataTable变量到新对象中,复制数据包含当前DataTable的结构信息如列名,同时也包含当 ...
- Vue项目打包发布后CSS中的背景图片不显示
相信有很多同学在学习vue的刚开始都遇到过项目打包发布后发现CSS中的背景图片不显示,具体如何解决只需要更改bind的配置即可 修改 build/utils.js 中的 generateLoaders ...
- H3C 什么是漫游
- 性能测试基础---测试流程,LR安装
·性能测试流程详解: 一般来说,性能测试通常可以分为以下过程: ·前期分析.测试计划.测试方案.测试环境的搭建.测试数据的准备.测试脚本的开发.测试场景的设计.测试场景的实现和执行.资源的监控.分析结 ...
- centos安全加固
设置SSH登录超时时间 /etc/profile export TMOUT=900 设置账户密码策略 /etc/login.defs PASS_MAX_DAYS 180 PASS_MIN_DAYS 0 ...
- 追光的人对Echo,SkyReach的Beta产品测试报告
所属课程 软件工程1916 作业要求 Beta冲刺博客汇总 团队名称 追光的人 作业目标 团队互测 队员学号 队员博客 221600219 小墨 https://www.cnblogs.com/hen ...
- python基础笔记-列表
列表是最常用的Python数据类型,它可以作为一个方括号内的逗号分隔值出现. 列表的数据项不需要具有相同的类型 创建一个列表,只要把逗号分隔的不同的数据项使用方括号括起来即可. def main(): ...
- 牛客NOIP暑期七天营-提高组5
A:deco的abs. 水题,先%,然后相邻两个数min()一下差值. #include<bits/stdc++.h> #define ll long long using namespa ...
- redis 哨兵集群原理及部署
复制粘贴自: https://www.cnblogs.com/kevingrace/p/9004460.html 请点击此链接查看原文. 仅供本人学习参考, 如有侵权, 请联系删除, 多谢! Redi ...
- placeholder 效果的实现,input提示字,获取焦点时消失
<!doctype html><html><head><meta charset="utf-8"><title>plac ...