tensorflow expand_dims和squeeze
有时我们会碰到升维或降维的需求,比如现在有一个图像样本,形状是 [height, width, channels],我们需要把它输入到已经训练好的模型中做分类,而模型定义的输入变量是一个batch,即形状为 [batch_size, height, width, channels],这时就需要升维了。tensorflow提供了一个方便的升维函数:expand_dims,参数定义如下:
tf.expand_dims(input, axis=None, name=None, dim=None)
参数说明:
input:待升维的tensor
axis:插入新维度的索引位置
name:输出tensor名称
dim: 一般不用
import tensorflow as tf sess = tf.Session() t = tf.constant([1, 2, 3], dtype=tf.int32) t.get_shape()
# TensorShape([Dimension(3)]) tf.expand_dims(t, 0).get_shape()
# TensorShape([Dimension(1), Dimension(3)]) tf.expand_dims(t, 1).get_shape()
# TensorShape([Dimension(3), Dimension(1)])
squeeze正好执行相反的操作:删除大小是1的维度
tf.squeeze(input, squeeze_dims=None, name=None)
input: 待降维的张量
sequeeze_dims: list[int]类型,表示需要删除的维度索引。默认为[],即删除所以大小为1的维度
# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
shape(squeeze(t)) ==> [2, 3]
Or, to remove specific size 1 dimensions: # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
在处理tensor的时候合理使用这两个函数,能极大的提高效率。例如处理输入样本、执行向量与矩阵的点乘等情况。
参考:https://blog.csdn.net/qq_31780525/article/details/72280284
tensorflow expand_dims和squeeze的更多相关文章
- tensorflow之tf.squeeze()
tf.squeeze()函数的作用是从tensor中删除所有大小(szie)是1的维度. 给定丈量输入, 此操作返回的是相同类型的张量, 并删除所有尺寸为1的维度.如果不想删除所有尺寸为1的维度, 可 ...
- Tensorflow从0到1(3)之实战传统机器算法
计算图中的操作 import numpy as np import tensorflow as tf sess = tf.Session() x_vals = np.array([1., 3., 5. ...
- TensorFlow机器学习实战指南之第二章
一.计算图中的操作 在这个例子中,我们将结合前面所学的知识,传入一个列表到计算图中的操作,并打印返回值: 声明张量和占位符.这里,创建一个numpy数组,传入计算图操作: import tensorf ...
- TF常用知识
命名空间及变量共享 # coding=utf-8 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt; ...
- tensorflow 基本函数(1.tf.split, 2.tf.concat,3.tf.squeeze, 4.tf.less_equal, 5.tf.where, 6.tf.gather, 7.tf.cast, 8.tf.expand_dims, 9.tf.argmax, 10.tf.reshape, 11.tf.stack, 12tf.less, 13.tf.boolean_mask
1. tf.split(3, group, input) # 拆分函数 3 表示的是在第三个维度上, group表示拆分的次数, input 表示输入的值 import tensorflow ...
- tensorflow 笔记14:tf.expand_dims和tf.squeeze函数
tf.expand_dims和tf.squeeze函数 一.tf.expand_dims() Function tf.expand_dims(input, axis=None, name=None, ...
- tf.expand_dims和tf.squeeze函数
from http://blog.csdn.net/qq_31780525/article/details/72280284 tf.expand_dims() Function tf.expand_d ...
- (转)Image Segmentation with Tensorflow using CNNs and Conditional Random Fields
Daniil's blog Machine Learning and Computer Vision artisan. About/ Blog/ Image Segmentation with Ten ...
- TensorFlow和最近发布的slim
笔者将和大家分享一个结合了TensorFlow和最近发布的slim库的小应用,来实现图像分类.图像标注以及图像分割的任务,围绕着slim展开,包括其理论知识和应用场景. 之前自己尝试过许多其它的库,比 ...
随机推荐
- leetcode--js--Two Sum
问题描述: 给定一个整数数列,找出其中和为特定值的那两个数. 你可以假设每个输入都只会有一种答案,同样的元素不能被重用. 示例: 给定 nums = [2, 7, 11, 15], target = ...
- SAP 对HU做货物移动报错-Only 0 serial numbers entered instead of 30 -
SAP 对HU做货物移动报错-Only 0 serial numbers entered instead of 30 - 元旦刚过,就收到客户的业务人员报错说,当其对HU做转库(同一个公司代码下工厂到 ...
- 原创:mysql5 还原至mysql 8.0.11数据库链接配置提示错误(修改内容有三处
原创:mysql5 还原至mysql 8.0.11数据库链接配置提示错误改有三: a) mysql 连接jar包版修改 b)类路径修改 c)配置连接池地址修改 因版本升级,首先要修改 1:mysql- ...
- Elasticsearch编程操作
1.创建工程导入依赖 <dependency> <groupId>org.elasticsearch</groupId> <artifactId>ela ...
- Github上优秀的.NET Core项目
Github上优秀的.NET Core开源项目的集合.内容包括:库.工具.框架.模板引擎.身份认证.数据库.ORM框架.图片处理.文本处理.机器学习.日志.代码分析.教程等. Github地址:htt ...
- MySql概述及入门(五)
MySql概述及入门(五) MySQL集群搭建之读写分离 读写分离的理解 为解决单数据库节点在高并发.高压力情况下出现的性能瓶颈问题,读写分离的特性包括会话不开启事务,读语句直接发送到 salve 执 ...
- 《手把手教你构建自己的 Linux 系统》学习笔记(4)
汇编链接器(Binutils) 这是一个软件包,这个软件包其实是一个工具集,里面含有了大量的用于汇编程序活着读取二进制文件相关的程序. CC 它是一条命令的别名,这条命令的作用是使用 GCC 的 C ...
- Node.js---起步
1.下载--安装 2.创建js文件 var http = require('http'); var url=require('url'); //引入url 模块,帮助解析 var querystrin ...
- StringBuilder的性能
1.新创建一个对象 long startTimeA = System.currentTimeMillis(); StringBuilder sb = null; for (int i = 1 ...
- MySQL第六课
SELECT [DISTINCT] * /{字段名1,字段名2,字段名3,.........} FROM 表名 [WHERE 条件表达式1] [GROUP BY 字段名[HAVING 条件表达 ...