tensorflow.py_func是TensorFlow1.x版本下的函数,在TensorFlow.2.x已经不建议使用了,但是依然可以通过tf.compat.v1.py_func的方式来进行调用。

可以说TensorFlow1.x下的py_func函数在TensorFlow2.x下除了通过tf.compat.v1.py_func的方式来进行调用就再也没有等价的使用方法了,具体可以看TensorFlow2.x的API文档:

https://tensorflow.google.cn/api_docs/python/tf/compat/v1/py_func

----------------------------------------------------------

这里需要着重说明一点,很多人认为TensorFlow1.x中的tf.py_func等价于TensorFlow2.x中的tf.py_function,其实不然。在TensorFlow2.x中除了对tf.py_func进行v1版本保留和兼容的tf.compat.v1.py_func,其实是没有完全同tf.py_func等价的函数。如果说在 TensorFlow2.x中 比较相近的函数应该是tf.numpy_function而不是tf.py_function。

在TensorFlow2.x中对tf.numpy_function的解释:

https://tensorflow.google.cn/api_docs/python/tf/numpy_function

在TensorFlow2.x中对tf.py_function的解释:

https://tensorflow.google.cn/api_docs/python/tf/py_function

----------------------------------------------------------

TensorFlow2.x中 tf.numpy_function 和 TensorFlow1.x中 tf.py_functf.compat.v1.py_func)中唯一的区别是:

tf.py_func中是可以设置函数是否考虑状态的,而tf.numpy_function中是必须要考虑状态的(没有定义不考虑状态的设置)。

This name was deprecated and removed in TF2, but tf.numpy_function is a near-exact replacement, just drop the stateful argument (all tf.numpy_function calls are considered stateful).

=============================================

对TensorFlow1.x中 tf.py_func进行下一步解释:

tf.py_func其实是将python函数包装成TensorFlow的一个操作operation,tf.py_func的输入可以是numpy,可以是tensor,也可以是Variable,其输入只能是tensor。

tf.py_func定义的操作是属于TensorFlow的计算图的,在定义tf.py_func时是不会具体执行的,只有在具体的tf.Session中还可以执行,但是tf.py_func并不同于其他的TensorFlow的operation,因为tf.py_func定义的操作是运行在python空间下的而不是运行在TensorFlow空间下的。

tf.py_func定义后,在session中运行时的基本原理就是将输入的变量(不论是tensor还是numpy.array)转换为python空间下的numpy.array变量,在经过numpy运算后在将获得的numpy.array结果转换为tensor,给到TensorFlow的计算图。

其实,tf.py_func的功能完全可以手动实现类似的,就是手动的把tensor变量转为numpy.array,然后运算好后把结果手动转为tensor,tf.py_func最大的好处就是把这一过程给自动化了,不过随之也使这个运算过程变得难以理解了。从tf.py_func的原理我们就可以知道,虽然tf.py_func可以作为TensorFlow计算图的一部分挂在计算图上,但是由于其本质是将TensorFlow空间变量转为python空间变量后经过运算再转为TensorFlow空间变量,中间经过了命名空间和运算空间的转换,因此tf.py_func是不可以进行梯度反传的,或许我们更可以把这个操作看做是一种简易的为TensorFlow提供支持的python库。

================================================

2022年10月13日更新

如果tf.py_func包装的python函数的参数是string类型,那么传到包装的函数内时会被自动转为bytes类型,也就是string变bytes,这一点需要注意,否则真的是不知道什么地方报错的。

例子:

import tensorflow as tf
import numpy as np sess = tf.Session() def fun(a, b):
print("+"*30)
print("function fun excute!!!")
print(a, b)
return np.array(len(a+b), dtype=np.float32) x = "abc"
y = "bde"
ans = tf.py_func(fun, (x, y), (tf.float32, ), name="ab_op")
print("="*30, "result:")
print(ans) print(sess.run(ans))

运行结果:

可以看到,由tensorflow空间传参到python空间会自动的将string类型转为bytes类型。

在python3.x版本中,可以使用bytes.decode()的方法将传入的bytes类型转会string类型,具体:

修改后的代码:

import tensorflow as tf
import numpy as np sess = tf.Session() def fun(a, b):
print("+"*30)
print("function fun excute!!!")
print(a, b)
a = a.decode()
b = b.decode()
print(a, b)
return np.array(len(a+b), dtype=np.float32) x = "abc"
y = "bde"
ans = tf.py_func(fun, (x, y), (tf.float32, ), name="ab_op")
print("="*30, "result:")
print(ans) print(sess.run(ans))

重点部分:

================================================

一些例子:

以下代码均为TensorFlow1.x版本:

import tensorflow as tf
import numpy as np sess = tf.Session() def fun(a, b):
print("+"*30)
print("function fun excute!!!")
return a+1, b+1 x = np.array([1.0,2.0,3.0], dtype=np.float32)
y = np.array([4.0,5.0,6.0], dtype=np.float32) ans=tf.py_func(fun, (x, y), (tf.float32, tf.float32), name="ab_python") print("="*30, "result:")
print(ans)
print(sess.run(ans))

运行结果:

可以看到,tf.py_func的执行其实是为TensorFlow定义了一个operation,而tf.py_func所包装的python函数只有在TensorFlow执行计算图的时候才会被真正执行。

tf.py_func为包装的python函数所传入的参数可以是numpy.array类型,也可以是tensor类型,也可以是Variable类型,但是不管在tf.py_func中传入的参数是什么类型,最后传入到所包装的python函数中都会被转为numpy.array类型,而包装后的函数在session开始执行后所返回给计算图的数据类型也会被转换为tensor类型。

--------------------------------------------------------

包装的参数为tensor:

import tensorflow as tf
import numpy as np sess = tf.Session() def fun(a, b):
print("+"*30)
print("function fun excute!!!")
return a+1, b+1 x = tf.constant([1.0,2.0,3.0], dtype=np.float32)
y = tf.constant([4.0,5.0,6.0], dtype=np.float32)
ans =tf.py_func(fun, (x, y), (tf.float32, tf.float32), name="ab_op") print("="*30, "result:")
print(ans)
print("session is running!!!")
print(sess.run(ans))

运行结果:

--------------------------------------------------------

包装的参数为Variable:

import tensorflow as tf
import numpy as np sess = tf.Session() def fun(a, b):
print("+"*30)
print("function fun excute!!!")
return a+1, b+1 x = tf.Variable([1.0,2.0,3.0], dtype=np.float32)
y = tf.Variable([4.0,5.0,6.0], dtype=np.float32)
ans = tf.py_func(fun, (x, y), (tf.float32, tf.float32), name="ab_op")
print("="*30, "result:")
print(ans) sess.run(tf.global_variables_initializer())
print(sess.run(ans))

运行结果:

---------------------------

包装的参数为Variable:

import tensorflow as tf
import numpy as np sess = tf.Session() def fun(a, b):
print("+"*30)
print("function fun excute!!!")
return a+b x = tf.Variable([1.0,2.0,3.0], dtype=np.float32)
y = tf.Variable([4.0,5.0,6.0], dtype=np.float32)
ans = tf.py_func(fun, (x, y), (tf.float32, ), name="ab_op")
print("="*30, "result:")
print(ans) sess.run(tf.global_variables_initializer())
print(sess.run(ans))

运行结果:

---------------------------------------------

包装的参数为Variable,求反传梯度报错:

import tensorflow as tf
import numpy as np sess = tf.Session() def fun(a, b):
print("+"*30)
print("function fun excute!!!")
return a+b x = tf.Variable([1.0,2.0,3.0], dtype=np.float32)
y = tf.Variable([4.0,5.0,6.0], dtype=np.float32)
x2 = tf.Variable([1.0,2.0,3.0], dtype=np.float32)
y2 = tf.Variable([4.0,5.0,6.0], dtype=np.float32)
ans = tf.py_func(fun, (x, y), (tf.float32, ), name="ab_op")
ans2 = x2 + y2
print("="*30, "result:")
print(ans) sess.run(tf.global_variables_initializer())
print(sess.run(ans)) op2 = tf.gradients(ans2, (x2, y2))
print("Ops2 Gradients: \n", sess.run(op2)) op = tf.gradients(ans, (x, y))
print("Ops Gradients: \n", sess.run(op))

运行结果:

证明:

tf.py_func包装后的函数是不可以进行反传的。

其实,tf.py_func就是在tensorflow计算图执行的时候调用python代码,而调用python代码时运行在python的代码空间中,自然是不支持反传的。

==================================================

tf.py_func的一些使用笔记——TensorFlow1.x的更多相关文章

  1. tf.py_func

    在 faster  rcnn的tensorflow 实现中看到这个函数 rois,rpn_scores=tf.py_func(proposal_layer,[rpn_cls_prob,rpn_bbox ...

  2. Tensorflow之调试(Debug) && tf.py_func()

    Tensorflow之调试(Debug)及打印变量 tensorflow调试tfdbg 几种常用方法: 1.通过Session.run()获取变量的值 2.利用Tensorboard查看一些可视化统计 ...

  3. 使用多块GPU进行训练 1.slim.arg_scope(对于同等类型使用相同操作) 2.tf.name_scope(定义名字的范围) 3.tf.get_variable_scope().reuse_variable(参数的复用) 4.tf.py_func(构造函数)

    1. slim.arg_scope(函数, 传参) # 对于同类的函数操作,都传入相同的参数 from tensorflow.contrib import slim as slim import te ...

  4. tf.contrib.layers.fully_connected参数笔记

    tf.contrib.layers.fully_connected 添加完全连接的图层. tf.contrib.layers.fully_connected(    inputs,    num_ou ...

  5. tf.split函数的用法(tensorflow1.13.0)

    tf.split(input, num_split, dimension): dimension指输入张量的哪一个维度,如果是0就表示对第0维度进行切割:num_split就是切割的数量,如果是2就表 ...

  6. TensorFlow学习笔记(一):数据操作指南

    扩充 TensorFlow tf.tile 对数据进行扩充操作 import tensorflow as tf temp = tf.tile([1,2,3],[2]) temp2 = tf.tile( ...

  7. tf.data

    以往的TensorFLow模型数据的导入方法可以分为两个主要方法,一种是使用feed_dict另外一种是使用TensorFlow中的Queues.前者使用起来比较灵活,可以利用Python处理各种输入 ...

  8. tf调试函数

    Tensorflow之调试(Debug)及打印变量   参考资料:https://wookayin.github.io/tensorflow-talk-debugging 几种常用方法: 1.通过Se ...

  9. R2CNN项目部分代码学习

    首先放出大佬的项目地址:https://github.com/yangxue0827/R2CNN_FPN_Tensorflow 那么从输入的数据开始吧,输入的数据要求为tfrecord格式的数据集,好 ...

  10. tensorflow_目标识别object_detection_api,RuntimeError: main thread is not in main loop,fig = plt.figure(frameon=False)_tkinter.TclError: no display name and no $DISPLAY environment variable

    最近在使用目标识别api,但是报错了: File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/script_o ...

随机推荐

  1. 关于编译告警 C4819 的完整解决方案 - The file contains a character that cannot be represented in the current code page (number). Save the file in Unicode format to prevent data loss.

    引言 今天迁移开发环境的时候遇到一个问题,同样的操作系统和 Visual Studio 版本,原始开发环境一切正常,但是迁移后 VS 出现了 C4819 告警,上网查了中文的一些博客,大部分涵盖几种解 ...

  2. 用 Visual C++ 2022 和 CMake 编译 CUnit 静态库

    准备工作 源代码获取 CUnit 是知名的 C 语言单元测框架,其源代码最初发布在 sourceforge 上,网址为:https://sourceforge.net/projects/cunit/ ...

  3. Python str 转 b’二进制串

    用raw_unique_escape来编码无义意的二进制串 bytes(strtext, encoding='raw_unique_escape')

  4. HarmonyOS SDK助力鸿蒙原生应用“易感知、易理解、易操作”

    6月21-23日,华为开发者大会(HDC 2024)盛大开幕.6月23日上午,<HarmonyOS开放能力,使能应用原生易用体验>分论坛成功举办,大会邀请了多位华为技术专家深度解读如何通过 ...

  5. .NET个人博客-使用Back进行消息推送

    使用Back推送消息到你的iPhone 前言 我的好友看了我的博客,给我提了个需求,让我搞个网站通知,我开始以为就是评论回复然后发送邮件通知.不过他告诉我网站通知是,当有人评论或者留言后,会通知到我这 ...

  6. Linux C 读写超过2G的大文件 注意事项

    背景 在项目中做大文件的增量读写,遇到了问题: fopen : Value too large for defined data type. 习惯性地根据这个提示查阅的有关资料显示: 1)工具链太老了 ...

  7. 在C#中使用RabbitMQ做个简单的发送邮件小项目

    在C#中使用RabbitMQ做个简单的发送邮件小项目 前言 好久没有做项目了,这次做一个发送邮件的小项目.发邮件是一个比较耗时的操作,之前在我的个人博客里面回复评论和友链申请是会通过发送邮件来通知对方 ...

  8. 解决BitBucket仓库较大拉取失败,使用SSH拉取

    HTTPS 拉取 如果使用的是https拉取,可使用以下命令尝试,如果还是失败,可使用 ssh 拉取 git clone --depth=1 xxxx.git --depth=1:拉取最近1次提交记录 ...

  9. koa web框架入门

    1.在hello-koa这个目录下创建一个package.json,这个文件描述了我们的hello-koa工程会用到哪些包.完整的文件内容如下: { "name": "h ...

  10. P3938

    斐波那契 题意描述 输入 5 1 1 2 3 5 7 7 13 4 12 输出 1 1 2 2 4 点拨 根据题目去找规律,每一个儿子与父亲结点具有斐波那契数的规律,我们只需要每次找到该数在斐波那契数 ...