Numpy计算近邻表时间对比
技术背景
所谓的近邻表求解,就是给定N个原子的体系,找出满足cutoff要求的每一对原子。在前面的几篇博客中,我们分别介绍过CUDA近邻表计算与JAX-MD关于格点法求解近邻表的实现。虽然我们从理论上可以知道,用格点法求解近邻表,在复杂度上肯定是要优于传统的算法。本文主要从Python代码的实现上来具体测试一下二者的速度差异,这里使用的硬件还是CPU。
算法解析
若一对原子A和B满足下述条件,则称A、B为一对近邻原子:
\]
传统的求解方法,就是把所有原子间距都计算一遍,然后对每个原子的近邻原子进行排序,最终按照给定的cutoff截断值确定相关的近邻原子。在Python中的实现,因为有numpy这样的强力工具,我们在计算原子两两间距时,只需要对一组维度为(N,D)的原子坐标进行扩维,分别变成(1,N,D)和(N,1,D)大小的原子坐标。然后将二者相减,计算过程中会自动广播(Broadcast)成(N,N,D)和(N,N,D)的两个数组进行计算。对得到的结果做一个Norm,就可以得到维度为(N,N)的两两间距矩阵。该算法的计算复杂度为O(N^2)
。
相对高效的一种求解方案是将原子坐标所在的空间划分成众多的小区域,通常我们设定这些小区域为边长等于cutoff的小正方体。这种设定有一个好处是,我们可以确定每一个正方体的近邻原子,一定在最靠近其周边的26个小正方体区域内。这样一来,我们就不需要去计算全局的两两间距,只需要计算单个小正方体内(假定有M个原子)的两两间距(M,M),以及单个正方体与周边正方体内原子的配对间距(M,26M)。之所以这样分开计算,是为了减少原子跟自身间距的这一项重复计算。那么对于整个空间的原子,就需要计算(N,27M)这么多次的原子间距,是一个复杂度为O(NlogN)
的算法。
Numpy代码实现
这里我们基于Python中的numpy框架来实现这两个不同的计算近邻表的算法。其实当我们使用numpy来进行计算的时候,应当尽可能的避免循环体的使用。但是这里仅演示两种算法的差异性,因此在实现格点法的时候偷了点懒,用了两个for循环,感兴趣的童鞋可以自行优化。
import time
from itertools import chain
from operator import itemgetter
import numpy as np
# 在格点法中,为了避免重复计算,我们可以仅计算一半的近邻格点中的原子间距
NEIGHBOUR_GRID = np.array([
[-1, 1, 0],
[-1, -1, 1],
[-1, 0, 1],
[-1, 1, 1],
[ 0, -1, 1],
[ 0, 0, 1],
[ 0, 1, 0],
[ 0, 1, 1],
[ 1, -1, 1],
[ 1, 0, 0],
[ 1, 0, 1],
[ 1, 1, 0],
[ 1, 1, 1]], np.int32)
# 原始的两两间距计算方法,需要排序
def get_neighbours_by_dist(crd, cutoff):
large_dis = np.tril(np.ones((crd.shape[0], crd.shape[0])) * 999)
# (N, N)
dis = np.linalg.norm(crd[None] - crd[:, None], axis=-1) + large_dis
# (N, M)
neigh = np.argsort(dis, axis=-1)
# (N, M)
cut = np.take_along_axis(dis, neigh, axis=1)
# (2, P)
pairs = np.where(cut <= cutoff)
# (P, )
pairs_id0 = pairs[0]
pairs_id1 = neigh[pairs]
# (P, 2)
sort_args = np.argsort(pairs_id0)
return np.hstack((pairs_id0[..., None], pairs_id1[..., None]))[sort_args]
# 格点法计算近邻表,先分格点,然后分两个模块计算单格点内原子间距,和中心格点-周边格点内的原子间距
def get_neighbours_by_grid(crd, cutoff):
# (D, )
min_xyz = np.min(crd, axis=0)
max_xyz = np.max(crd, axis=0)
space = max_xyz - min_xyz
grids = np.ceil(space / cutoff).astype(np.int32)
num_grids = np.product(grids)
buffer = (grids * cutoff - space) / 2
start_crd = min_xyz - buffer
# (N, D)
grid_id = ((crd - start_crd) // cutoff).astype(np.int32)
grid_coe = np.array([1, grids[0], grids[1]], np.int32)
# (N, )
grid_id_1d = np.sum(grid_id * grid_coe, axis=-1).astype(np.int32)
# (N, 2)
grid_id_dict = np.ndenumerate(grid_id_1d)
# (G, *)
grid_dict = dict.fromkeys(range(num_grids), ())
for index, value in grid_id_dict:
grid_dict[value] += index
neighbour_grid = (NEIGHBOUR_GRID * grid_coe).sum(axis=-1).astype(np.int32)
neighbour_pairs = []
for i in range(num_grids):
if grid_dict[i]:
keeps = np.where((neighbour_grid + i < num_grids) & (neighbour_grid + i >= 0))[0]
neighbour_grid_keep = neighbour_grid[keeps] + i
grid_atoms = np.array(list(grid_dict[i]), np.int32)
try:
grid_neighbours = np.array(list(chain(*itemgetter(*neighbour_grid_keep)(grid_dict))), np.int32)
except TypeError:
if neighbour_grid_keep.size == 0:
grid_neighbours = np.array([], np.int32)
else:
grid_neighbours = np.array(list(itemgetter(*neighbour_grid_keep)(grid_dict)), np.int32)
grid_crds = crd[grid_atoms]
grid_neighbour_crds = crd[grid_neighbours]
large_dis = np.tril(np.ones((grid_crds.shape[0], grid_crds.shape[0])) * 999)
# 单格点内部原子间距
grid_dis = np.linalg.norm(grid_crds[None] - grid_crds[:, None], axis=-1) + large_dis
grid_pairs = np.argsort(grid_dis, axis=-1)
grid_cut = np.take_along_axis(grid_dis, grid_pairs, axis=-1)
pairs = np.where(grid_cut <= cutoff)
pairs_id0 = grid_atoms[pairs[0]]
pairs_id1 = grid_atoms[grid_pairs[pairs]]
neighbour_pairs.extend(list(np.hstack((pairs_id0[..., None], pairs_id1[..., None]))))
# 中心格点-周边格点内原子间距
grid_dis = np.linalg.norm(grid_crds[:, None] - grid_neighbour_crds[None], axis=-1)
grid_pairs = np.argsort(grid_dis, axis=-1)
grid_cut = np.take_along_axis(grid_dis, grid_pairs, axis=-1)
pairs = np.where(grid_cut <= cutoff)
pairs_id0 = grid_atoms[pairs[0]]
pairs_id1 = grid_neighbours[grid_pairs[pairs]]
neighbour_pairs.extend(list(np.hstack((pairs_id0[..., None], pairs_id1[..., None]))))
neighbour_pairs = np.sort(np.array(neighbour_pairs), axis=-1)
sort_args = np.argsort(neighbour_pairs[:, 0])
return neighbour_pairs[sort_args]
# 时间测算函数
def benchmark(N, cutoff=0.3, D=3):
crd = np.random.random((N, D)).astype(np.float32) * np.array([3., 4., 5.], np.float32)
# Solution 1
time0 = time.time()
neighbours_1 = get_neighbours_by_dist(crd, cutoff)
time1 = time.time()
record_1 = time1 - time0
# Solution 2
time0 = time.time()
neighbours_2 = get_neighbours_by_grid(crd, cutoff)
time1 = time.time()
record_2 = time1 - time0
for pair in neighbours_1:
if (np.isin(neighbours_2, pair).sum(axis=-1) < 2).all():
print (pair)
assert neighbours_1.shape == neighbours_2.shape
return record_1, record_2
# 绘图主函数
if __name__ == '__main__':
import matplotlib.pyplot as plt
sizes = range(1000, 10000, 1000)
time_dis = []
time_grid = []
for size in sizes:
print (size)
times = benchmark(size)
time_dis.append(times[0])
time_grid.append(times[1])
plt.figure()
plt.title('Neighbour List Calculation Time')
plt.plot(sizes, time_dis, color='black', label='Full Connect')
plt.plot(sizes, time_grid, color='blue', label='Cell List')
plt.xlabel('Size')
plt.ylabel('Time/s')
plt.legend()
plt.grid()
plt.show()
上述代码的运行结果如下图所示:

其实因为格点法中使用了for循环的问题,函数效率并不高。因此在体系非常小的场景下(比如只有几十个原子的体系),本文用到的格点法代码效率并不如计算所有的原子两两间距。但是毕竟格点法的复杂度较低,因此在运行过程中随着体系的增长,格点法的优势也越来越大。
近邻表计算与分子动力学模拟
在分子动力学模拟中计算长程相互作用时,会经常使用到近邻表。如果要在GPU上实现格点近邻算法,有可能会遇到这样的一些问题:
- GPU更加擅长处理静态Shape的张量,因此往往会使用一个
最大近邻数
,对每一个原子的近邻原子标号进行限制,一般不允许满足cutoff的近邻原子数超过最大近邻数,否则这个cutoff就失去意义了。而如果单个原子的近邻原子数量低于最大近邻数,这时候就会用一个没有意义的数对剩下分配好的张量空间进行填充(Padding),这样一来会带来很多不必要的计算。 - 在运行分子动力学模拟的过程中,体系原子的坐标在不断的变化,近邻表也会随之变化,而此时的最大近邻数有可能无法存储完整的cutoff内的原子。
总结概要
本文介绍了在Python的numpy框架下计算近邻表的两种不同算法的原理以及复杂度,另有分别对应的两种代码实现。在实际使用中,我们更偏向于第二种算法的使用。因为对于第一种算法来说,哪怕是一个10000个原子的小体系,如果要计算两两间距,也会变成10000*10000这么大的一个张量的运算。可想而知,这样计算的效率肯定是比较低下的。
版权声明
本文首发链接为:https://www.cnblogs.com/dechinphy/p/cell-list.html
作者ID:DechinPhy
更多原著文章:https://www.cnblogs.com/dechinphy/
请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html
Numpy计算近邻表时间对比的更多相关文章
- JAX-MD在近邻表的计算中,使用了什么奇技淫巧?(一)
技术背景 JAX-MD是一款基于JAX的纯Python高性能分子动力学模拟软件,应该说在纯Python的软件中很难超越其性能.当然,比一部分直接基于CUDA的分子动力学模拟软件性能还是有些差距.而在计 ...
- Python的GPU编程实例——近邻表计算
技术背景 GPU加速是现代工业各种场景中非常常用的一种技术,这得益于GPU计算的高度并行化.在Python中存在有多种GPU并行优化的解决方案,包括之前的博客中提到的cupy.pycuda和numba ...
- select … into outfile 备份恢复(load data)以及mysqldump时间对比
select … into outfile 'path' 备份 此种方式恢复速度非常快,比insert的插入速度要快的多,他跟有备份功能丰富的mysqldump不同的是,他只能备份表中的数据,并不能包 ...
- numpy计算数组中满足条件的个数
Numpy计算数组中满足条件元素个数 需求:有一个非常大的数组比如1亿个数字,求出里面数字小于5000的数字数目 1. 使用numpy的random模块生成1亿个数字 2. 使用Python原生语法实 ...
- mysql对比表结构对比同步,sqlyog架构同步工具
mysql对比表结构对比同步,sqlyog架构同步工具 对比后的结果示例: 执行后的结果示例: 点击:"另存为(S)" 按钮可以把更新sql导出来.
- C#计算两个时间年份月份差
C#计算两个时间年份月份差 https://blog.csdn.net/u011127019/article/details/79142612
- python init 方法 与 sql语句当前时间对比
def init(self,cr): tools.sql.drop_view_if_exists(cr, 'custrom_product_infomation_report') cr.execute ...
- 计算2个时间之间经过多少Ticks
Ticks是一个周期,存储的是一百纳秒,换算为秒,一千万分之一秒.我们需要计算2个时间之间,经过多少Ticks,可以使用下面的方法来实现,使用2个时间相减. 得到结果为正数,是使用较晚的时间减去较早的 ...
- C# 计算传入的时间距离今天的时间差
/// <summary> /// 计算传入的时间距离今天的时间差 /// </summary> /// <param name="dt">&l ...
- numpy计算路线距离
numpy计算路线距离 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献 enumerate遍历数组 np.diff函数 numpy适用数组作为索引 标记路线上的点 \[X={X1,X ...
随机推荐
- QQ机器人整理合集
QQ机器人有什么用呢? QQ机器人可以实现包括自动回复.定时推送.发送图片QQ机器人,营销圈用的比较多,可以开发各种自动功能等等.用其制作的QQ机器人程序 机器人框架+插件 小栗子机器人 官网:htt ...
- 圆角android
资源地址 <shape xmlns:android="http://schemas.android.com/apk/res/android"> <solid an ...
- 自学一周python做的一个小游戏《大球吃小球》
需求 1,显示一个窗口. 2,我们要做到的功能有鼠标点击屏幕生成小球. 3,生成的小球大小随机,颜色随机,向随机方向移动,速度也随机. 4,大的球碰到小球时可以吃掉小球,吃掉后会变大. 5,球碰到边界 ...
- LibOciLib使用说明(2017-1-26更新)
LibOciLib使用说明 整理者:赤勇玄心行天道 QQ:280604597 Email:280604597@qq.com 大家有什么不明白的地方,或者想要详细了解的地方可以联系我,我会认真回复的! ...
- Educational Codeforces Round 148 [Rated for Div. 2]A~C
A #include <bits/stdc++.h> using namespace std; typedef long long LL; const int N = 60; char c ...
- reverse--[HZNUCTF 2023 preliminary]easyAPK
首先这是一个apk文件,一开始我是用jadx打开的,发现要aes加密啥的,后面我用jeb打开,发现账号和密码都已经解密出来了 真的很方便,然后根据代码逻辑判断,这应该是安卓程序的一个登录界面,接下来我 ...
- 《实现领域驱动设计》笔记——DDD入门
设计不只是感观,设计就是产品的工作方式. 我们的目标应该是创造一个可观测的.可伸缩的.组织良好的软件模型. DDD同时提供了战略上的战术上的建模工具. 我能DDD吗? DDD首先并不是关于技术的,而是 ...
- 开源一套快速部署程序的工具(CI/CD)
随着微服务越写越多,程序发布就成了一个麻烦事,所以写了一个部署工具 Vela,只要填写一个git地址.编译命令等简单信息,就能自动完成程序的部署. Vela 特性: 代码可在任意一台电脑自动完成编译, ...
- 总结---Django部分
Django是走大而全的方向,它最出名的是其全自动化的管理后台:只需要使用起ORM,做简单的对象定义,它就能自动生成数据库结构.以及全功能的管理后台. Django内置的ORM跟框架内的其他模块耦合程 ...
- Java核心知识体系7:线程安全性讨论
Java核心知识体系1:泛型机制详解 Java核心知识体系2:注解机制详解 Java核心知识体系3:异常机制详解 Java核心知识体系4:AOP原理和切面应用 Java核心知识体系5:反射机制详解 J ...