1. 利用tensorboard看loss:

tensorflow和pytorch环境是好的的话,链接中的logger.py拉到自己的工程里,train.py里添加相应代码,直接能用。

关于环境,小小折腾了下,大概一小时:

大概一年前用过tensorflow, mac里的环境还在,当时装的虚拟环境,由于工程中用到了caffe,在虚拟环境中编译caffe,去掉之前工程中用的caffe路径,把虚拟环境中的caffe路径添加到虚拟环境下的pythonpath

export PYTHONPATH=/Users/tensorflow虚拟环境中的caffe路径/python:$PYTHONPATH
pip install torch torchvision

然后进入python,依次import caffe torch tensorflow没问题就可以了,一个粗陋的loss曲线到手:

2. 多进程load数据

这个折腾了一天,最开始打算用Queue实现,陷在里面半天,返回label没问题,一旦加入图像就死掉了,自己生成一个超大的数据也不行,怀疑是Queue容量有限,暂且存疑;

考虑pytorch的__getitem__就是多进程的,打算在里面判断,如果self.src_map没有对应的key就读入,有就直接用。结果在mac上一个worker没啥问题,放到服务器上拆开多个子线程工作,会一直重新读入数据,因为self.src_map是属于主进程的,主进程并不会跟子进程共享这个字典,所以对每个子进程来说self.src_map都是空的,定位到这个问题就好办,最后是用链接里面“进程之间共享数据”方法实现的:

class FaceDataSet(data.Dataset):
def __init__(self, root, list, dst_size = 128, n_worker = 6):
super(DataSet, self).__init__()
self.all_data = []
self.dst_size = dst_size
self.src_map = {}
self.n_worker = n_worker fread = open(root +'/'+ list, 'r')
for line in fread.readlines():
img_filename = line.strip()
pt_filename = img_filename.replace('.jpg', '.txt')
imgfile_fullpath = os.path.join(root, img_filename)
ptfile_fullpath = os.path.join(root, pt_filename)
label = np.loadtxt(ptfile_fullpath, dtype=float)
#路径保存到all_data
if os.path.exists(ptfile_fullpath) and os.path.exists(imgfile_fullpath):
self.all_data.append(DataPath(imgfile_fullpath=imgfile_fullpath, ptfile_fullpath=ptfile_fullpath)) print time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) +\
' All data length: %d' % len(self.all_data) + \
" load workers:%d"%(self.n_worker)
data_length = len(self.all_data)
q_in = [[] for i in range(self.n_worker)]
for data_idx in range(data_length):
q_in[data_idx % len(q_in)].append(self.all_data[data_idx])
with multiprocessing.Manager() as MG:
p_map = [ multiprocessing.Manager().dict() for i in range(self.n_worker) ]
readers = [multiprocessing.Process( target=self.load_func, args=(q_in[i], p_map[i]) )\
for i in range(self.n_worker) ]
for p in readers:
p.start()
for p in readers:
p.join()
# 至此,主程序会等待最后一个进程执行完
for map in p_map: #把每个子进程读取结果拼到一起
self.src_map.update(map)
print time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + " src_map length:%d"%( len(self.src_map.keys()) ) def load_func(self, qin, pmap):
for item in qin:
img_path, pt_path = item
img = cv2.imread(img_path) #读入数据,每个子进程读入的先放在自己的字典里
label = np.loadtxt(pt_path).astype(np.float32)
pmap[img_path] = self.__get_small_img(img, label, dst_size=self.dst_size, data_aug=False)

3. 打印网络每层输出形状

pip install torchsummary #命令行安装
from torchsummary import summary #代码
summary(model, (, , ))

4. view()的用法

define __init__:
self.conv = conv2d(512,512,kernel_size=7,stride=1,pad=0,group=512)
self.fc = nn.Linear(512,128)
define forward(self, x): #假设batchsize = 128
x = self.conv #(128,512,1,1,)
x = x.view(x.size(0), -1) #(128,512) - x.size(0)=batchsize, 按batchsize拉平
x = self.fc(x) #(128,128)
return x

5. 单机多卡显存占用不均衡: 不是很明显,应该没有抓住主要矛盾,聊胜于无。

checkpoint = torch.load("checkpoint.pth", map_location=torch.device('cpu'))  #map到cpu只能起到一点点效果

6. 在gpu卡n上训练的网络换卡加载报错:Attempting to deserialize object on CUDA device 4 but torch.cuda.device_count()

 torch.load(model_path, map_location='cuda:0' )

7. Apex半精度训练:

虽然没有宣称的3行代码那么简单,也是很容易了,过一遍示例代码,花不到一小时可以跑起来。参考官方git quick start,完全安装没有成功,只安装了python版,然后执行下main_amp.py可以正常import了,再按示例代码添加几行即可。其他参考这里

8. pytorch和numpy默认数据格式不一致导致的CUDNN错误:

一个大坑,从自己写的dataloader加载数据莫名其妙的报错:RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM;

1) 训练代码中用torch.randn()生成的数据作为input没问题,生成的数据格式是troch.float32;

2) np.loadtxt()加载数据默认是np.float64格式,pytroch网络是float,所以在执行input.to(Device)之前先input = input.float()强制转换一下;

3) np.load()加载数据保留了保存数据时候的np.float32格式,送给pytorch的dataloader之前需要astype成float64,否则会报错"RuntimmeError: Expected object of scalar type Double but got scalar type Float ....",所以在送给dataloader之前先转成np.float64,训练代码中再和第二条一样,强制转为float

总结就是dataloader需要double型的数据,网络需要float型数据,匹配一致即可。

-------------2019.4.22----------

关于环境, cuda 9.1 + anaconda2用pip install和conda install都失败,最后分别下载pytorch和vision源码

python setup.py install

装完pytorch需要重启下才能找到

nn.ReLU和F.ReLU的区别

torch.where和np.where的区别,另外,np.where可以缺省后面两个参数返回满足条件的索引,而torch.where省略参数会报错

打印网络参数

for name, parameters in TEACHER.named_parameters():
print(name,parameters.size())

pytorch使用不完全文档的更多相关文章

  1. info sed 中文不完全文档

    快速指南: sed 的一般使用方法:sed -option 'adress|command' -f scpritfiles(1)'|' 只是用来说明性的分隔 adress 和 command,实际使用 ...

  2. Effective Java 第三版——82. 线程安全文档化

    Tips 书中的源代码地址:https://github.com/jbloch/effective-java-3e-source-code 注意,书中的有些代码里方法是基于Java 9 API中的,所 ...

  3. PyTorch 1.4 中文文档校对活动正式启动 | ApacheCN

    一如既往,PyTorch 1.4 中文文档校对活动启动了! 认领须知 请您勇敢地去翻译和改进翻译.虽然我们追求卓越,但我们并不要求您做到十全十美,因此请不要担心因为翻译上犯错--在大部分情况下,我们的 ...

  4. 【PyTorch v1.1.0文档研习】60分钟快速上手

    阅读文档:使用 PyTorch 进行深度学习:60分钟快速入门. 本教程的目标是: 总体上理解 PyTorch 的张量库和神经网络 训练一个小的神经网络来进行图像分类 PyTorch 是个啥? 这是基 ...

  5. 备战春招!开源社区系统 Echo 超全文档助力面试

    博主东南大学硕士在读,寒假前半个月到现在差不多一个多月,断断续续做完了这个项目,现在终于可以开源出来了,我的想法是为这个项目编写一套完整的教程,包括技术选型分析.架构分析.业务逻辑分析.核心技术点分析 ...

  6. Docker学习 ,超全文档!

    我们的口号是:再小的帆也能远航,人生不设限!!        一.学习规划: Docker概述 Docker安装 Docker命令 Docker镜像 镜像命令 容器命令 操作命令 容器数据卷  Doc ...

  7. Docker精华 ,超全文档!

    我们的口号是:再小的帆也能远航,人生不设限!!    学习规划:继续上篇 <Docker入门>https://www.cnblogs.com/dk1024/p/13121389.html  ...

  8. EditPlus软件自动补全文档htmlbar.acp设置 及 模板文件格式

    1.在htmlbar.acp文件末尾添加如下内容,可自动补全: #T=HTML <html>    ^! </html>   #T=HEAD <head>    ^ ...

  9. 新品成熟EMR源码电子病历系统软件NET网络版CS可用带数据库全文档

    查看电子病历系统演示 医院医疗信息管理系统,EMR电子病历系统,功能模块如下所示: 1.住院医生站 2.住院护士站 3.病案浏览工作站 4.质量控制工作站 5.系统维护工作站  本店出售系统全套源码, ...

随机推荐

  1. php实现多进程和关闭进程

    一.php实现多进程 PHP有个pcntl_fork的函数可以实现多进程,但要加载pcntl拓展,而且只有在linux下才能编译这个拓展. 先代码: <?php$arr = ['30000000 ...

  2. [jzoj]3875.【NOIP2014八校联考第4场第2试10.20】星球联盟(alliance)

    Link https://jzoj.net/senior/#main/show/3875 Problem 在遥远的S星系中一共有N个星球,编号为1…N.其中的一些星球决定组成联盟,以方便相互间的交流. ...

  3. 转 mysql Next-Key Locking

    原文:http://dev.mysql.com/doc/refman/5.5/en/innodb-next-key-locking.html 14.5.2.5 Avoiding the Phantom ...

  4. 10_常见的get和post请求_路由器_ejs服务器渲染模板引擎

    1. 常见的 get 和 post 请求有哪些? 常见的发送 get 请求方式: 在浏览器地址栏输入 url 地址访问 所有的标签默认发送的是 get 请求:如 script link img a f ...

  5. Web版记账本开发记录(二)开发过程遇到的问题小结1 对数据库的区间查询

    问题1 对数据库的区间查询 如功能显示,想要按照年份和月份查询相应的记录,就要使用区间查询 对应的代码如下 servlet层的ChaXun java.sql.Date sDate = new java ...

  6. yum安装mysql5.7

    [root@ycj ~]# wget -i -c http://dev.mysql.com/get/mysql57-community-release-el7-10.noarch.rpm //下载安装 ...

  7. CS(计算机科学)知识体

    附 录 A                   CS( 计算机科学)知识体 计算教程 2001 报告的这篇附录定义了计算机科学本科教学计划中可能讲授的知识领域.该分类方案的依据及其历史.结构和应用的其 ...

  8. oracle 安装介绍

    oracle 分为客户端和服务器 全局数据库是 实例名通常就是所说的服务,就是说数据库和操作系统之间的交互用的是数据库实例名 导入 sql文件 @路径    例如@d:/my.sql [oracle@ ...

  9. mayan游戏

    这道题超级好 就是我太菜了写了几个小时不算是debug了几个小时. 我只想出了几个小剪枝 可能是状态不太好吧 写完这道题真的是完美诠释了什么,叫做: 暴力出奇迹!!! 真的是太暴力了. 最多只移动5步 ...

  10. 召回率(Recall),精确率(Precision),平均正确率

    https://blog.csdn.net/yanhx1204/article/details/81017134 摘要 在训练YOLO v2的过程中,系统会显示出一些评价训练效果的值,如Recall, ...