原文地址:

https://www.zhangshengrong.com/p/9MNlDK09NJ/

================================================

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致。我经过许多调研和实验,总结了以下方法,记录下来。

全部设置可以分为三部分:

1. CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

from torch.backends import cudnn
cudnn.benchmark = False # if benchmark=True, deterministic will be False
cudnn.deterministic = True

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

2. Pytorch

torch.manual_seed(seed)      # 为CPU设置随机种子
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed) # 为所有GPU设置随机种子

3. Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

最后,关于dataloader:

注意,如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。也就是说,改变num_workers参数,也会对实验结果产生影响。目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

对于不同线程的随机数种子设置,主要通过DataLoader的 worker_init_fn 参数来实现。默认情况下使用线程ID作为随机数种子。如果需要自己设定,可以参考以下代码:

GLOBAL_SEED = 1

def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
global GLOBAL_WORKER_ID
GLOBAL_WORKER_ID = worker_id
set_seed(GLOBAL_SEED + worker_id) dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)

以上这篇浅谈PyTorch的可重复性问题(如何使实验结果可复现)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

=================================================

【转载】 浅谈PyTorch的可重复性问题(如何使实验结果可复现)的更多相关文章

  1. 转载-浅谈Ddos攻击攻击与防御

    EMail: jianxin#80sec.comSite: http://www.80sec.comDate: 2011-2-10From: http://www.80sec.com/ [ 目录 ]一 ...

  2. [转载]浅谈JavaScript函数重载

     原文地址:浅谈JavaScript函数重载 作者:ChessZhang 上个星期四下午,接到了网易的视频面试(前端实习生第二轮技术面试).面了一个多小时,自我感觉面试得很糟糕的,因为问到的很多问题都 ...

  3. <转载>浅谈C/C++的浮点数在内存中的存储方式

    C/C++浮点数在内存中的存储方式 任何数据在内存中都是以二进制的形式存储的,例如一个short型数据1156,其二进制表示形式为00000100 10000100.则在Intel CPU架构的系统中 ...

  4. 转载--浅谈spring4泛型依赖注入

    转载自某SDN-4O4NotFound Spring 4.0版本中更新了很多新功能,其中比较重要的一个就是对带泛型的Bean进行依赖注入的支持.Spring4的这个改动使得代码可以利用泛型进行进一步的 ...

  5. (转载) 浅谈python编码处理

    最近业务中需要用 Python 写一些脚本.尽管脚本的交互只是命令行 + 日志输出,但是为了让界面友好些,我还是决定用中文输出日志信息. 很快,我就遇到了异常: UnicodeEncodeError: ...

  6. [转载]浅谈组策略设置IE受信任站点

    在企业中,通常会有一些业务系统,要求必须加入到客户端IE受信任站点,才能完全正常运行访问,在没有域的情况下,可能要通过管理员手动设置,或者通过其它网络推送方法来设置. 有了域之后,这项工作就可以很好的 ...

  7. [转载]浅谈C/C++内存泄漏及其检测工具

    http://dev.yesky.com/147/2356147_3.shtml 对于一个c/c++程序员来说,内存泄漏是一个常见的也是令人头疼的问题.已经有许多技术被研究出来以应对这个问题,比如Sm ...

  8. 浅谈Java中的深拷贝和浅拷贝

    转载: 浅谈Java中的深拷贝和浅拷贝 假如说你想复制一个简单变量.很简单: int apples = 5; int pears = apples; 不仅仅是int类型,其它七种原始数据类型(bool ...

  9. 浅谈Oracle事务【转载竹沥半夏】

    浅谈Oracle事务[转载竹沥半夏] 所谓事务,他是一个操作序列,这些操作要么都执行,要么都不执行,是一个不可分割的工作单元.通俗解释就是事务是把很多事情当成一件事情来完成,也就是大家都在一条船上,要 ...

  10. 浅谈Java中的深拷贝和浅拷贝(转载)

    浅谈Java中的深拷贝和浅拷贝(转载) 原文链接: http://blog.csdn.net/tounaobun/article/details/8491392 假如说你想复制一个简单变量.很简单: ...

随机推荐

  1. LNMP单机架构

    黄金架构LNMP LNMP是网站架构初期最合适的单体架构.因为初创型技术团队对于技术的选型,需要考虑如下因素 在创业初期,研发资源有限,研发人力有限,技术储备有限,需要选择一个易维护.简单的技术架构: ...

  2. 《Objective-C Direct Methods》学习笔记

    原文通过对Objective-C发展史.Objective-C中Runtime的动态派发,C语言的直接派发进行铺垫介绍,引出了direct methods这个"新特性"(文章写于2 ...

  3. P6623 [省选联考 2020 A 卷] 树

    day2t2但难度不大,和AGC044C解法类似 题目大意: 给定一棵 \(n\) 个结点的有根树 \(T\),结点从 \(1\) 开始编号,根结点为 \(1\) 号结点,每个结点有一个正整数权值 \ ...

  4. 码农的转型之路-全力以赴升级物联网浏览器(IoTBrowser)

    在人生的重要时刻,我站在了毕业的门槛上,望着前方的道路,心中涌动着对未来的无限憧憬与些许忐忑.面前,两条道路蜿蜒伸展:一是继续在职场中寻求稳定,一是勇敢地走出一条属于自己的创新之路.尽管面临年龄和现实 ...

  5. Xilinx SDK 开发Linux APP

    Xilinx SDK 开发Linux APP 步骤 配置环境变量 将工具链需要的程序的所在目录添加到 系统环境变量中,例如: D:\Xilinx_201803\SDK\2018.3\gnu\micro ...

  6. 详解Web应用安全系列(6)安全配置错误

    Web攻击中的安全配置错误漏洞是一个重要的安全问题,它涉及到对应用程序.框架.应用程序服务器.Web服务器.数据库服务器等组件的安全配置不当.这类漏洞往往由于配置过程中的疏忽或错误,使得攻击者能够未经 ...

  7. 深度学习领域的名词解释:SOTA、端到端模型、泛化、RLHF、涌现 ..

    SOTA (State-of-the-Art) 在深度学习领域,SOTA指的是"当前最高技术水平"或"最佳实践".它用来形容在特定任务或领域中性能最优的模型或方 ...

  8. Java中代码Bug记录--泛型失效、数组删除、HashMap死循环

    最近在工作的过程中,遇到了不少奇怪自己或者同事的Bug,都是一些出乎意料的,不太容易发现的,记录一下来帮助可能也遇到了这些Bug的人 1. 编译时泛型校验失效 Map<String, Strin ...

  9. 跟我一起学习和开发动态表单系统-后端用spring boot、mybatis实现方法(4)

    ## 动态表单系统:利用 Spring Boot 和 MyBatis 实现后端服务 在现代企业应用中,表单是数据收集和处理的核心部分.然而,传统的表单系统难以适应快速变化的需求.为了解决这个问题,我们 ...

  10. python rsa加密

    rsa简单加密: 1 import rsa 2 import base64 3 4 rsa_key_pair = rsa.newkeys(2048) # 生成密钥对,返回(PublicKey(n,e) ...