P和C
import tensorflow as tf
import numpy as np
import math
import keras
from keras.layers import Conv2D,Reshape,Input
import numpy as np
import matplotlib.pyplot as plt """ Channel attention module""" if __name__ == '__main__':
file = tf.read_file('img.jpg')
x = tf.image.decode_jpeg(file)
#print("Tensor:", x)
sess = tf.Session()
x1 = sess.run(x)
print("x1:",x1)
gamma = 0.05
sess = tf.Session()
x1 = sess.run(x)
x1 = tf.expand_dims(x1, dim =0)
print("x1.shape:", x1.shape) m_batchsize, height, width, C = x1.shape proj_query = Reshape((width * height, C))(x1)
print("proj_query:", type(proj_query))
print("proj_query:", proj_query.shape)
proj_query = sess.run(proj_query)
print(proj_query)
proj_key = Reshape((width * height, C))(x1)
proj_key = sess.run(proj_key).transpose(0, 2, 1)
print(proj_key)
print("proj_key:", type(proj_key))
print("proj_key:", proj_key.shape) proj_query = proj_query.astype(np.float32)
proj_key = proj_key.astype(np.float32) # N, C, C, bmm 批次矩阵乘法
energy = tf.matmul(proj_key,proj_query)
energy = sess.run(energy)
print("energy:", energy) # 这里实现了softmax用最后一维的最大值减去了原始数据, 获得了一个不是太大的值
# 沿着最后一维的C选择最大值, keepdim保证输出和输入形状一致, 除了指定的dim维度大小为1
energy_new = tf.reduce_max(energy, -1, keep_dims=True)
print("after_softmax_energy:",sess.run(energy_new)) sess = tf.Session()
e = energy_new
print("b:", sess.run(energy_new)) size = energy.shape[1]
for i in range(size - 1):
e = tf.concat([e, energy_new], axis=-1) energy_new = e
print("energy_new2:", sess.run(energy_new))
energy_new = energy_new - energy
print("energy_new3:", sess.run(energy_new)) attention = tf.nn.softmax(energy_new, axis=-1)
print("attention:", sess.run(attention)) proj_value = Reshape((width * height, C))(x1)
proj_value = sess.run(proj_value)
proj_value = proj_value.astype(np.float32)
print("proj_value:", proj_value.shape)
out = tf.matmul(proj_value, attention) out = sess.run(out)
#plt.imshow(out)
print("out1:", out)
out = out.reshape(m_batchsize, width * height, C)
#out1 = out.reshape(m_batchsize, C, height, width)
print("out2:", out.shape) out = gamma * out + x
#out = sess.run(out)
#out = out.astype(np.int16)
print("out3:", out)
import tensorflow as tf
import numpy as np
import math
import keras
from keras.layers import Conv2D,Reshape,Input
from keras.regularizers import l2
from keras.layers.advanced_activations import ELU, LeakyReLU
from keras import Model
import cv2 """
Important: 1、A为CxHxW => Conv+BN+ReLU => B, C 都为CxHxW 2、Reshape B, C to CxN (N=HxW)
3、Transpose B to B’
4、Softmax(Matmul(B’, C)) => spatial attention map S为NxN(HWxHW)
5、如上式1, 其中sji测量了第i个位置在第j位置上的影响
6、也就是第i个位置和第j个位置之间的关联程度/相关性, 越大越相似.
7、A => Covn+BN+ReLU => D 为CxHxW => reshape to CxN
8、Matmul(D, S’) => CxHxW, 这里设置为DS
9、Element-wise sum(scale parameter alpha * DS, A) => the final output E 为 CxHxW (式2)
10、alpha is initialized as 0 and gradually learn to assign more weight.
"""
"""
inputs :
x : input feature maps( N X C X H X W)
returns :
out : attention value + input feature
attention: N X (HxW) X (HxW)
"""
""" Position attention module"""
if __name__ == '__main__':
#x = tf.random_uniform([2, 7, 7, 3],minval=0,maxval=255,dtype=tf.float32)
file = tf.read_file('img.jpg')
x = tf.image.decode_jpeg(file)
#x = cv2.imread('ROIVIA3.jpg')
print(x)
gamma = 0.05
sess = tf.Session()
x1 = sess.run(x)
x1 = tf.expand_dims(x1, axis=0)
print(x1.shape)
in_dim = 3 xlen = x1.shape[1]
ylen = x1.shape[2]
input = Input(shape=(xlen,ylen,3))
query_conv = Conv2D(1, (1,1), activation='relu',kernel_initializer='he_normal')(input)
key_conv = Conv2D(1, (1, 1), activation='relu', kernel_initializer='he_normal')(input)
value_conv = Conv2D(3, (1, 1), activation='relu', kernel_initializer='he_normal')(input)
print(query_conv) batchsize, height, width, C = x1.shape
#print(C, height, width )
# B => N, C, HW
proj_query = Reshape(( width * height ,1))(query_conv)
proj_key = Reshape(( width * height, 1))(key_conv)
proj_value = Reshape((width * height, 3))(value_conv)
print("proj_query:",proj_query)
print("proj_key:", proj_key)
print("proj_value:",proj_value.shape)
model = Model(inputs=[input],outputs=[proj_query])
model.compile(optimizer='adam',loss='binary_crossentropy')
proj_query = model.predict(x1,steps=1)
print("proj_query:",proj_query)
# B' => N, HW, C
proj_query = proj_query.transpose(0, 2, 1)
print("proj_query2:", proj_query.shape)
print("proj_query2:", type(proj_query))
# C => N, C, HW
model1 = Model(inputs=[input], outputs=[proj_key])
model1.compile(optimizer='adam', loss='binary_crossentropy')
proj_key = model1.predict(x1, steps=1)
print("proj_key:", proj_key.shape) print(proj_key)
# B'xC => N, HW, HW
energy = tf.matmul(proj_key, proj_query)
print("energy:",energy.shape) # S = softmax(B'xC) => N, HW, HW
attention = tf.nn.softmax(energy, axis=-1)
print("attention:", attention.shape) # D => N, C, HW
model2 = Model(inputs=[input], outputs=[proj_value])
model2.compile(optimizer='adam', loss='binary_crossentropy')
proj_value = model2.predict(x1, steps=1)
print("proj_value:",proj_value.shape) # DxS' => N, C, HW
out = tf.matmul(proj_value, sess.run(attention).transpose(0, 2, 1))
print("out:", out.shape) # N, C, H, W
out = Reshape((height, width, 3))(out)
print("out1:", out.shape) out = gamma * out + sess.run(x1)
print("out2:", type(out))
随机推荐
- MacOS:Django + Python3 + MySQL
Django Django是一个开放源代码的Web应用框架,由Python写成.采用了MVC的框架模式,即模型M,视图V和控制器C.它最初是被开发来用于管理劳伦斯出版集团旗下的一些以新闻内容为主的网站 ...
- ERROR 2002 (HY000): Can't connect to local MySQL server through socket '/var/lib/mysql/mysql.sock' (13)解答
我在使用mysqll客户端连接我的mysql服务器的时候,出现了上述的问题.我的操作系统是ubuntu,安装版本是对应的64位服务器.我的服务器的启动方式是sudo service mysql sta ...
- python笔记(三)---文件读写、修改文件内容、处理json、函数
文件读写(一) #r 只读,打开文件不存在的话,会报错 #w 只写,会清空原来文件的内容 #a 追加写,不会请求,打开的文件不存在的话,也会帮你新建的一个文件 print(f.read()) #获取到 ...
- Redis深入学习笔记(四)主从数据复制流程
主从节点的数据复制是Redis高可用和高负载的重要基础,本篇介绍数据的主从复制流程. 数据复制策略: 全量复制:一般用于初次复制场景,Redis早期支持的复制功能只有全量复制,它会把主节点全部数据一次 ...
- arcgis更改栅格数据范围
栅格数据范围默认为有效值的外接矩形范围,其行列号也是有效值最大行数减去最小行数,最大列数减去最小列号. 通过使用extract by mask 工具可实现改变栅格数据范围. 使用过程中要修改环境功能中 ...
- 如何将composer设置为全局变量?
全局安装是将 Composer 安装到系统环境变量 PATH 所包含的路径下面,然后就能够在命令行窗口中直接执行 composer 命令了. Mac 或 Linux 系统: 打开命令行窗口并执行如下命 ...
- django 三种缓存模式的使用及注意点
Django 缓存模式的使用(主要针对RestFul设计模式的项目) 有三种模式: 全站使用缓存模式(整个项目每个接口都会使用缓存,缺点:所以接口都无法实时性获取数据) 单独视图缓存模式(单个接口使用 ...
- Object详解(转)
Object类是Java中其他所有类的祖先,没有Object类Java面向对象无从谈起.作为其他所有类的基类,Object具有哪些属性和行为,是Java语言设计背后的思维体现. Object类位于ja ...
- 佳鑫:信息流广告CTR一样高,哪条文案转化率更好?
在优化信息流广告的过程中,你有没有遇到这样的帐户? 投了几个AB方案,点击率好不容易上去了,但转化率还是有的高.有的低! 这儿就有这么一个为难的案例: 一个广告主计划向有意愿在北京预订酒店的用户投放信 ...
- Java2E中的路径问题
本节主要介绍: 1.request.getContextPath()-----项目的发布的根路径 2.request.getRealPath('t')----t目录在当前磁盘中的物理位置,包括盘符,文 ...