技术背景

在数学中我们都学过偏导数\(\frac{\partial f(x,y)}{\partial x}\),而这里我们提到的偏函数,指的是\(f(y)(x)\)。也就是说,在代码实现的过程中,虽然我们实现的一个函数可能带有很多个变量,但是可以用偏函数的形式把其中一些不需要拆分和变化的变量转变为固有变量。比较典型的两个例子是计算偏导数和多进程优化。虽然大部分支持自动微分的框架都有相应的支持偏导数的接口,多进程操作中也可以指定额外的args,但是这些自带的方法在形式上都是比较tricky的,感觉并不如使用偏函数优雅和简洁。这里我们主要介绍python中可能会用到的偏函数功能--partial

Partial简单案例

我们先来一个最简单的乘法函数\(f(x,y)=xy\)。假如说我们想得到该函数关于y的偏导数,注意,这里y是第二个输入的变量,不是第一个位置,一般自动微分框架都默认都第一个位置的变量计算偏导数。相关代码实现如下所示:

from functools import partial

def mul(x, y):
print (locals())
return x * y x = 2
y = 3
res_0 = mul(x, y)
partial_mul = partial(mul, x=x)
res_1 = partial_mul(y=3)
print ('The result is: {}'.format(res_0))
print ('The result is: {}'.format(res_1))

这段代码的运行结果为:

{'x': 2, 'y': 3}
{'x': 2, 'y': 3}
The result is: 6
The result is: 6

我们现在来分析一下上面这个案例中所体现的信息:

  1. 在使用partial函数时使用的是关键字参数,即时原本的变量不是一个关键字参数,而是一个位置参数。
  2. 虽然得到的偏函数partial_mul运行的方式跟函数一致,但其实它是一个partial的对象类型。
  3. 在生成partial_mul对象时已经执行过一遍函数,因此函数中的打印语句被打印了两次。
  4. 偏函数的计算结果肯定是跟原函数保持一致的,但是在一些特殊场景下,我们可能会用到这种单变量的偏函数。

Concurrent多核并行场景

现在我们稍微修改一下上面的案例,我们要用concurrent这个并行工具去分别执行上述乘法任务,同时输入的x也变成了一个多维的数组。然后为了验证并行算法,这里每计算一次元素乘法,我们都用time.sleep方法让进程休眠2秒钟时间。由于此时的参数y还是一个标量,但是每次乘法计算我们都需要输入这个标量,因此我们直接将其封装到一个partial偏函数中,使得函数变成:\(f(x,y)=f(y)(x)=P(x)\),然后对x这个入参进行并行化操作:

import numpy as np
import concurrent.futures
from functools import partial
import time
# 定义休眠函数
def mul(x, y):
time.sleep(2)
return x * y
# 定义入参
x = np.array([1, 2, 3], np.float32)
y = 3.
# 有阻塞计算
time_0 = time.time()
res_0 = []
for _x in x:
res_0.append(mul(_x, y))
res_0 = np.array(res_0, np.float32)
time_1 = time.time()
# 并行计算
partial_mul = partial(mul, y=y)
time_2 = time.time()
with concurrent.futures.ProcessPoolExecutor(max_workers=x.shape[0]) as executor:
res = executor.map(partial_mul, x)
res_1 = np.array(list(res), np.float32)
time_3 = time.time()
print ('The result is: {}, and for loop time cost is: {}s'.format(res_0, time_1 - time_0))
print ('The result is: {}, and concurrent time cost is : {}s'.format(res_1, time_3 - time_2))

如果有感兴趣的童鞋也可以去尝试一下,在这种场景下的并行运算,如果参量y不是一个可迭代式的变量,是无法用zip压缩传到map函数中去的。上述代码的运行结果如下:

The result is: [3. 6. 9.], and for loop time cost is: 6.005392789840698s
The result is: [3. 6. 9.], and concurrent time cost is : 2.0451698303222656s

这个计算时长其实就约等于休眠时长,因为这里我们开启了3个进程来进行休眠,因此并行时长是2s。

Jax自动微分场景

这里我们用Jax的自动微分框架做一个示例,没有安装Jax和Jaxlib的想运行需要自行安装相关软件。虽然在Jax的grad函数中,支持argnums这样的参数配置,但从代码层面角度来说,总是显得可读性并不好。正常情况下我们算偏导数\(\frac{\partial f(x,y)}{\partial x}\)其实更合理的表述应该是\(\frac{\partial P(x)}{\partial x}\)。而如果按照Jax这种写法,更像是从\([\frac{\partial f(x, y)}{\partial x}, \frac{\partial f(x, y)}{\partial x}]\)两个元素中取了第一个元素。当然,这只是表述上的问题,也是我个人的理解,其实并不影响程序的正确性。这里使用partial偏函数的相关案例如下所示:

from functools import partial
from jax import grad
from jax import numpy as jnp
# Jax要求grad函数输出结果为标量,所以要加一项求和
def mul(x, y):
f = x * y
return f.sum()
# 定义输入变量
x = jnp.array([1, 2, 3], jnp.float32)
y = 3.
# 定义偏函数和对应偏导数
partial_mul = partial(mul, y=y)
grad_mul = grad(partial_mul)
print (grad_mul(x))

执行结果如下:

[3. 3. 3.]

总结概要

本文介绍了在Python中使用偏函数partial的方法,并且介绍了两个使用partial函数的案例,分别是concurrent并行场景和基于jax的自动微分场景。在这些相关的场景下,我们用partial函数更多时候可以使得代码的可读性更好,在性能上其实并没有什么提升。如果不想使用partial函数,类似的功能也可以使用参考链接中所介绍的方法,实现一个装饰器,也可以做到一样的功能。

版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/partial.html

作者ID:DechinPhy

更多原著文章:https://www.cnblogs.com/dechinphy/

请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

参考链接

  1. https://www.cnblogs.com/huyangblog/p/8999866.html

Python代码中的偏函数的更多相关文章

  1. Kivy A to Z -- 怎样从python代码中直接訪问Android的Service

    在Kivy中,通过pyjnius扩展能够间接调用Java代码,而pyjnius利用的是Java的反射机制.可是在Python对象和Java对象中转来转去总让人感觉到十分别扭.好在android提供了b ...

  2. pycharm运行Pytest,有没有将Pytest写入Python代码中的区别

    初学pytest. 将pytest写进Python代码中 不同运行方式都可正常运行     =======================**********************========= ...

  3. python代码中判断版本

    在python代码中判断python版本: if sys.version_info < (3, 0): lib.make_flows.argtypes = [c_char_p, c_char_p ...

  4. 如何在创建hive表格的python代码中导入外部文件

    业务场景大概是这样的,我要对用户博文进行分词(这个步骤可以看这篇文章如何在hive调用python的时候使用第三方不存在的库-how to use external python library in ...

  5. python代码中pass的用法

    我们有时会在方法中写一些注释代码,用来提示这个方法是干嘛的之类,看下面代码: class Game_object: def __init__(self, name): self.name = name ...

  6. 解决python代码中含有中文报错

    python中写入中文时报错如下图所示: 依照网上解决方法:在py文件中加入:#encoding=utf-8 然后继续报错如下图所示: 解决方法: 在py文件中加入: import sysreload ...

  7. Python代码中func(*args, **kwargs)

    这是Python函数可变参数 args及kwargs *args表示任何多个无名参数,它是一个tuple **kwargs表示关键字参数,它是一个dict 测试代码如下: def foo(*args, ...

  8. python代码中碰到的问题及解决

    一.针对raw_input输入的字符进行类型判断及转换: raw_input输入默认为字符,如果输入的是数字字符,想自动转换,即:输入为a,不做操作,如果输入为3,即转化为整数. 可利用try..ex ...

  9. python代码中使用settings

    在具体的Django应用中,通过引入 django.conf.settings 使用配置,例: from django.conf import settings settings.configure( ...

  10. 生成.m文件的python代码中出现的错误

    错误代码 import tempfile import subprocess import shlex import os import numpy as np import scipy.io scr ...

随机推荐

  1. GIT提交修改的项目到远程仓库

    1.在项目目录下右键选择Git Bash. 2.执行提交命令三部曲 git add . //文件-暂存区,即将所有新增的文件添加到提交索引中,,add后面是"空格 点"就表示当前目 ...

  2. MySQL数据库触发器讲解 [创建/删除/查询/select into]

    刚学习实例完mysql触发器, 前来分享学习经验.    菜鸟装逼,  老鸟勿喷 先来认识一下有关触发器的一些关键词.  在使用触发器时, 这些关键词将被用到, 请记下它们的模样和用途(意思) cre ...

  3. Linux虚拟机安装及下载

    centos 7操作系统下载及安装步骤 (仅供参考) 下载: 1.打开如下网站:先下载镜像文件 ping:https://www.centos.org/download/ 2.进入到如下界面 3.然后 ...

  4. 文心一言 VS 讯飞星火 VS chatgpt (91)-- 算法导论8.3 4题

    四.用go语言,说明如何在 O(n) 时间内,对 0 到 $n^3-1$ 区间内的n个整数进行排序. 文心一言: 在 O(n) 时间内对 0 到 n^3-1 区间内的 n 个整数进行排序,可以使用基数 ...

  5. 【RocketMQ】消息的消费总结

    消费者从Broker拉取到消息之后,会将消息提交到线程池中进行消费,RocketMQ消息消费是批量进行的,如果一批消息的个数小于预先设置的批量消费大小,直接构建消费请求ConsumeRequest将消 ...

  6. pci p2p

    概述 在2018年,针对pci支持p2pdma的驱动合入主线,没记错的话应该是4.20. 补丁如下: commit 52916982af48d9f9fc01ad825259de1eb3a9b25e A ...

  7. 【Azure Function App】Python Function调用Powershell脚本在Azure上执行失败的案例

    问题描述 编写Python Function,并且在Function中通过 subprocess  调用powershell.exe 执行 powershell脚本. import azure.fun ...

  8. How to install Django¶

    This document will get you up and running with Django. Install Python Being a Python Web framework, ...

  9. vue项目使用lodash节流防抖函数问题与解决

    背景 在lodash函数工具库中,防抖_.debounce和节流_.throttle函数在一些频繁触发的事件中比较常用. 防抖函数_.debounce(func, [wait=0], [options ...

  10. Linux: Authentication token is no longer valid

    遇见问题: [oracle@sxty-jkdb-184:/u01/rman]crontab -l Authentication token is no longer valid; new one re ...