tensorflow 在加载大型的embedding模型参数时,会遇到cannot be larger than 2GB
这种问题是,对于每一个变量 variable 由于是基于protobuf存在这大小限制(2G),这个时候,我们需要将embedding拆开,拆分成N等分,来使得每一个
variable都在2G以下;
# !/usr/bin/env/python
# coding=utf-8
import tensorflow as tf
import numpy as np input_ids = tf.placeholder(dtype=tf.int32, shape=[None,None]) num_shards = 3
weights = []
weights_shape = np.arange(27).reshape(9, 3)
# assert weights_shape[0] % num_shards == 0
num_shards_len = (weights_shape.shape[0]) / num_shards
assert (weights_shape.shape[0]) % num_shards ==0
begin_ = 0
ends_ = num_shards_len
for i in range(0, num_shards):
if (i + 1) * num_shards_len < weights_shape.shape[0]:
begin_ = i * num_shards_len
if i + 1 == num_shards:
ends_ = weights_shape.shape[0]
else:
ends_ = (i + 1) * num_shards_len
else:
begin_ = i * num_shards_len
ends_ = weights_shape.shape[0]
weights_i = tf.get_variable("words-%02d" % i,
initializer=tf.constant(weights_shape[begin_: ends_, ]))
weights.append(weights_i) input_embedding = tf.nn.embedding_lookup(weights, input_ids,partition_strategy="div") sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
print(sess.run(weights)) print(sess.run(input_embedding, feed_dict={input_ids: [[1, 2], [3, 0], [8, 2], [5, 1]]}))
结果为:
[array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]), array([[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]]), array([[18, 19, 20],
[21, 22, 23],
[24, 25, 26]])]
[[[ 3 4 5]
[ 6 7 8]] [[ 9 10 11]
[ 0 1 2]] [[24 25 26]
[ 6 7 8]] [[15 16 17]
[ 3 4 5]]]
tensorflow 在加载大型的embedding模型参数时,会遇到cannot be larger than 2GB的更多相关文章
- Tensorflow同时加载使用多个模型
在Tensorflow中,所有操作对象都包装到相应的Session中的,所以想要使用不同的模型就需要将这些模型加载到不同的Session中并在使用的时候申明是哪个Session,从而避免由于Sessi ...
- MFC加载大型osg模型
MFC加载模型,发现打开 Navid 缓冲等选项后,加载大型模型的速度就快了很多. #include "stdafx.h" #include "OSGObject.h&q ...
- [CG从零开始] 6. 加载一个柴犬模型学习UV贴图
在第 5 篇文章中,我们成功加载了 fbx 模型,并且做了 MVP 变换,将立方体按照透视投影渲染了出来.但是当时只是随机给顶点颜色,并且默认 fbx 文件里只有一个 mesh,这次我们来加载一个柴犬 ...
- "C:\Program Files\Internet Explorer\iexplore.exe" -extoff 无加载项启动IE 浏览器打开时全屏模式
"C:\Program Files\Internet Explorer\iexplore.exe" -extoff 无加载项启动IE浏览器打开时全屏模式
- tensorflow数据加载、模型训练及预测
数据集 DNN 依赖于大量的数据.可以收集或生成数据,也可以使用可用的标准数据集.TensorFlow 支持三种主要的读取数据的方法,可以在不同的数据集中使用:本教程中用来训练建立模型的一些数据集介绍 ...
- tensorflow学习笔记2:c++程序静态链接tensorflow库加载模型文件
首先需要搞定tensorflow c++库,搜了一遍没有找到现成的包,于是下载tensorflow的源码开始编译: tensorflow的contrib中有一个makefile项目,极大的简化的接下来 ...
- 用TWaver加载大型游戏场景一例
游戏中经常会出现一些大型的户外场景,例如一个小镇.一座古城等.通常这种场景中包含了较多的建筑.道路.桥梁等等元素,其3D模型比较大且复杂.在使用TWaver加载时,可使用一些技巧,让加载速度更快.显示 ...
- tensorflow数据集加载
本篇涉及的内容主要有小型常用的经典数据集的加载步骤,tensorflow提供了如下接口:keras.datasets.tf.data.Dataset.from_tensor_slices(shuffl ...
- Windows下pycharm远程连接服务器调试-tensorflow无法加载问题
最近打算在win系统下使用pycharm开发程序,并远程连接服务器调试程序,其中在import tensorflow时报错如图所示(在远程服务器中执行程序正常): 直观错误为: ImportError ...
随机推荐
- win 10 slmgr.vbs -xpr 无法运行,被豆麦笔记打开解决方法
win 10 slmgr.vbs -xpr 无法运行,被豆麦笔记打开解决方法 删除这个豆麦笔记 如果之前已经在 控制面板 程序中卸载过,那么是找不到的,我们先运行 slmgr.vbs -xpr,这个时 ...
- Tp-validate进阶
阶段1:基础 application/controller/v1/Banner.php <?php namespace app\api\controller\v1; use think\Cont ...
- NSL:CPK_NN神经网络实现预测哪个样本与哪个样本处在同一层,从而科学规避我国煤矿突水灾难—Jason niu
load water_data.mat attributes = mapminmax(attributes); P_train = attributes(:,1:35); T_train = clas ...
- Django基础(四)
Django-4 知识预览 分页器(paginator) COOKIE 与 SESSION Django的用户认证 FORM 回到顶部 分页器(paginator) 分页器的使用 1 2 3 4 5 ...
- 【Excel】SUMIF 或用 筛选器 实现挑选含有某些字段的值,然后把这些值所对应的后面某列上的值相加
Background: 挑选含有某些字段的值,然后把这些值所对应的后面某列上的值相加.比如挑选下表中,所有带有“MX104”这个字段的值,然后把它的后面total那一列的值相加. Solution: ...
- 2017-9-17-MDIO信号线串联小电阻作用【转】
今天做集成测试的时候被领导说测到的MDIO信号过冲较大(正反向过冲都很大),容易损坏接口或阻容,万一那个电容耐压值不够就挂了. 我原本是不屑的,私以为MDIO.IIC.SPI等只要抓到的波形不影响判决 ...
- React生命周期函数详解
React生命周期函数 生命周期函数是指在某一个周期自动执行的函数. React中的生命周期执行过程 以下是React中的常用的生命周期函数,按个部分中按照自动执行顺序列出,这几个过程可能存在同时进行 ...
- BZOJ1889 : Maximal
二分答案,判断是否存在合法方案使得每个数都不超过$mid$. 考虑网络流建图: $i$点的流量下限为$\max(a_i-mid,0)$,费用为$1$,故拆点进行限制. $i$向$i+1$.$S$向$i ...
- [P3957][NOIP2017]跳房子 (DP+二分/队列?)
看到GREED_VI大佬在打这题 我这个蒟蒻偷偷看一眼洛谷上目前普及难度里最难的一题 题目还是能看懂的,不想道路游戏那题,我完全不知道题目是什么意思…… GREED_VI大佬第一次用的是二分的思想,于 ...
- Html图像标签、绝对路径和相对路径:
Html图像标签: <img>标签可以在网页上插入一张图片,它是独立使用的标签,它的常用属性有: (1)src 属性 定义图片的引用地址 (2)alt 属性 定义图片加载失败时显示的文字, ...