Learning Memory-guided Normality代码学习笔记

记忆模块核心

Memory部分的核心在于以下定义Memory类的部分。

class Memory(nn.Module):
def __init__(self, memory_size, feature_dim, key_dim, temp_update, temp_gather):
super(Memory, self).__init__()
# Constants
self.memory_size = memory_size
self.feature_dim = feature_dim
self.key_dim = key_dim
self.temp_update = temp_update
self.temp_gather = temp_gather def hard_neg_mem(self, mem, i):
similarity = torch.matmul(mem,torch.t(self.keys_var))
similarity[:,i] = -1
_, max_idx = torch.topk(similarity, 1, dim=1) return self.keys_var[max_idx] def random_pick_memory(self, mem, max_indices): m, d = mem.size()
output = []
for i in range(m):
flattened_indices = (max_indices==i).nonzero()
a, _ = flattened_indices.size()
if a != 0:
number = np.random.choice(a, 1)
output.append(flattened_indices[number, 0])
else:
output.append(-1) return torch.tensor(output) def get_update_query(self, mem, max_indices, update_indices, score, query, train): m, d = mem.size()
if train:
query_update = torch.zeros((m,d)).cuda()
# random_update = torch.zeros((m,d)).cuda()
for i in range(m):
idx = torch.nonzero(max_indices.squeeze(1)==i)
a, _ = idx.size()
if a != 0:
query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0)
else:
query_update[i] = 0 return query_update else:
query_update = torch.zeros((m,d)).cuda()
for i in range(m):
idx = torch.nonzero(max_indices.squeeze(1)==i)
a, _ = idx.size()
if a != 0:
query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0)
else:
query_update[i] = 0 return query_update def get_score(self, mem, query):
bs, h,w,d = query.size()
m, d = mem.size() score = torch.matmul(query, torch.t(mem))# b X h X w X m
score = score.view(bs*h*w, m)# (b X h X w) X m score_query = F.softmax(score, dim=0)
score_memory = F.softmax(score,dim=1) return score_query, score_memory def forward(self, query, keys, train=True): batch_size, dims,h,w = query.size() # b X d X h X w
query = F.normalize(query, dim=1)
query = query.permute(0,2,3,1) # b X h X w X d #train
if train:
#losses
separateness_loss, compactness_loss = self.gather_loss(query,keys, train)
# read
updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys)
#update
updated_memory = self.update(query, keys, train) return updated_query, updated_memory, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss #test
else:
# loss
compactness_loss, query_re, top1_keys, keys_ind = self.gather_loss(query,keys, train) # read
updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys) #update
updated_memory = keys return updated_query, updated_memory, softmax_score_query, softmax_score_memory, query_re, top1_keys,keys_ind, compactness_loss def update(self, query, keys,train): batch_size, h,w,dims = query.size() # b X h X w X d softmax_score_query, softmax_score_memory = self.get_score(keys, query) query_reshape = query.contiguous().view(batch_size*h*w, dims) _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1)
_, updating_indices = torch.topk(softmax_score_query, 1, dim=0) if train: query_update = self.get_update_query(keys, gathering_indices, updating_indices, softmax_score_query, query_reshape,train)
updated_memory = F.normalize(query_update + keys, dim=1) else:
query_update = self.get_update_query(keys, gathering_indices, updating_indices, softmax_score_query, query_reshape, train)
updated_memory = F.normalize(query_update + keys, dim=1) return updated_memory.detach() def pointwise_gather_loss(self, query_reshape, keys, gathering_indices, train):
n,dims = query_reshape.size() # (b X h X w) X d
loss_mse = torch.nn.MSELoss(reduction='none') pointwise_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach()) return pointwise_loss def gather_loss(self,query, keys, train):
batch_size, h,w,dims = query.size() # b X h X w X d
if train:
loss = torch.nn.TripletMarginLoss(margin=1.0)
loss_mse = torch.nn.MSELoss()
softmax_score_query, softmax_score_memory = self.get_score(keys, query) query_reshape = query.contiguous().view(batch_size*h*w, dims) _, gathering_indices = torch.topk(softmax_score_memory, 2, dim=1) #1st, 2nd closest memories
pos = keys[gathering_indices[:,0]]
neg = keys[gathering_indices[:,1]]
top1_loss = loss_mse(query_reshape, pos.detach())
gathering_loss = loss(query_reshape,pos.detach(), neg.detach()) return gathering_loss, top1_loss else:
loss_mse = torch.nn.MSELoss() softmax_score_query, softmax_score_memory = self.get_score(keys, query) query_reshape = query.contiguous().view(batch_size*h*w, dims) _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1) gathering_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach()) return gathering_loss, query_reshape, keys[gathering_indices].squeeze(1).detach(), gathering_indices[:,0] def read(self, query, updated_memory):
batch_size, h,w,dims = query.size() # b X h X w X d softmax_score_query, softmax_score_memory = self.get_score(updated_memory, query) query_reshape = query.contiguous().view(batch_size*h*w, dims) concat_memory = torch.matmul(softmax_score_memory.detach(), updated_memory) # (b X h X w) X d
updated_query = torch.cat((query_reshape, concat_memory), dim = 1) # (b X h X w) X 2d
updated_query = updated_query.view(batch_size, h, w, 2*dims)
updated_query = updated_query.permute(0,3,1,2) return updated_query, softmax_score_query, softmax_score_memory

Update过程

调用get_update_query(self, mem, max_indices, update_indices, score, query, train)函数计算\(query\_ dpdate= \sum_{k \in U_{t}^M} v_t^{'k,m}q_t^k\)

然后计算\(f(P^m+query_dpdate)\)

文中对f的描述为L2正则。

看一下get_update_query函数的定义:

    def get_update_query(self, mem, max_indices, update_indices, score, query, train):

        m, d = mem.size()
if train:
query_update = torch.zeros((m,d)).cuda()
# random_update = torch.zeros((m,d)).cuda()
for i in range(m):
idx = torch.nonzero(max_indices.squeeze(1)==i)
a, _ = idx.size()
if a != 0:
query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0)
else:
query_update[i] = 0 return query_update else:
query_update = torch.zeros((m,d)).cuda()
for i in range(m):
idx = torch.nonzero(max_indices.squeeze(1)==i)
a, _ = idx.size()
if a != 0:
query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0)
else:
query_update[i] = 0 return query_update

在定义中,我们需要看到\(v_t^{'k,m}\)的计算。代码是通过(score[idx,i] / torch.max(score[:,i])实现的,进一步,我们需要查看\(v_t^{k,m}\)的计算过程。这个参数与\(w\)一样是权重,文中通过get_score函数计算权重,如下为此函数的定义:

    def get_score(self, mem, query):
#计算权重$w_t^{k,m}$
bs, h,w,d = query.size()
m, d = mem.size() score = torch.matmul(query, torch.t(mem))# b X h X w X m
score = score.view(bs*h*w, m)# (b X h X w) X m score_query = F.softmax(score, dim=0)
score_memory = F.softmax(score,dim=1) return score_query, score_memory

实现了文献中的权重计算

Read过程

def read(self, query, updated_memory):
#Read部分
batch_size, h,w,dims = query.size() # b X h X w X d softmax_score_query, softmax_score_memory = self.get_score(updated_memory, query) query_reshape = query.contiguous().view(batch_size*h*w, dims) concat_memory = torch.matmul(softmax_score_memory.detach(), updated_memory) # (b X h X w) X d
# 权重和memory获得加权均值
updated_query = torch.cat((query_reshape, concat_memory), dim = 1) # (b X h X w) X 2d
# 进行拼接
updated_query = updated_query.view(batch_size, h, w, 2*dims)
updated_query = updated_query.permute(0,3,1,2) return updated_query, softmax_score_query, softmax_score_memory

核心部分在代码中给出了注释。

forward过程

separateness_loss, compactness_loss = self.gather_loss(query,keys, train)
# read
updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys)
#update
updated_memory = self.update(query, keys, train) return updated_query, updated_memory, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss

分别调用update函数和read函数

需要说明损失函数的定义,\(L = L_{rec} + \lambda _cL_{compact}+ \lambda _sL_{separate}\)

代码中通过gather_loss函数实现。

def gather_loss(self,query, keys, train):
batch_size, h,w,dims = query.size() # b X h X w X d
if train:
loss = torch.nn.TripletMarginLoss(margin=1.0)
# 计算Feature separateness loss的主要函数
loss_mse = torch.nn.MSELoss()
# 计算均方差损失
softmax_score_query, softmax_score_memory = self.get_score(keys, query) query_reshape = query.contiguous().view(batch_size*h*w, dims) _, gathering_indices = torch.topk(softmax_score_memory, 2, dim=1) #1st, 2nd closest memories
pos = keys[gathering_indices[:,0]]
neg = keys[gathering_indices[:,1]]
top1_loss = loss_mse(query_reshape, pos.detach())
gathering_loss = loss(query_reshape,pos.detach(), neg.detach()) return gathering_loss, top1_loss else:
loss_mse = torch.nn.MSELoss() softmax_score_query, softmax_score_memory = self.get_score(keys, query) query_reshape = query.contiguous().view(batch_size*h*w, dims) _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1) gathering_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach()) return gathering_loss, query_reshape, keys[gathering_indices].squeeze(1).detach(), gathering_indices[:,0]

Learning Memory-guided Normality代码学习笔记的更多相关文章

  1. DeepLearnToolbox-master代码学习笔记

    卷积神经网络(CNN)博大精深,网上资料浩如烟海,让初学者无从下手.笔者以为,学习编程还是从代码实例入们最好.目前,学习CNN最好的代码实例就是,DeepLearnToolbox-master,不用装 ...

  2. 《Learning Play! Framework 2》学习笔记——案例研究1(Templating System)

    注解: 这是对<Learning Play! Framework 2>第三章的学习 本章是一个显示聊天记录的项目,只有一个页面,可以自动对聊天记录进行排序.分组和显示,并整合使用了less ...

  3. Machine Learning In Action 第二章学习笔记: kNN算法

    本文主要记录<Machine Learning In Action>中第二章的内容.书中以两个具体实例来介绍kNN(k nearest neighbors),分别是: 约会对象预测 手写数 ...

  4. C# 好代码学习笔记(1):文件操作、读取文件、Debug/Trace 类、Conditional条件编译、CLS

    目录 1,文件操作 2,读取文件 3,Debug .Trace类 4,条件编译 5,MethodImpl 特性 5,CLSCompliantAttribute 6,必要时自定义类型别名 目录: 1,文 ...

  5. 1.JAVA中使用JNI调用C++代码学习笔记

    Java 之JNI编程1.什么是JNI? JNI:(Java Natibe Inetrface)缩写. 2.为什么要学习JNI?  Java 是跨平台的语言,但是在有些时候仍然是有需要调用本地代码 ( ...

  6. APM代码学习笔记1

    libraries目录 传感器 AP_InertialSensor 惯性导航传感器 就是陀螺仪加速计 AP_Baro 气压计 居然支持BMP085 在我印象中APM一直用高端的MS5611 AP_Co ...

  7. boost timer代码学习笔记

    socket连接中需要判断超时 所以这几天看了看boost中计时器的文档和示例 一共有五个例子 从简单的同步等待到异步调用超时处理 先看第一个例子 // timer1.cpp: 定义控制台应用程序的入 ...

  8. Hands on Machine Learning with Sklearn and TensorFlow学习笔记——机器学习概览

    一.什么是机器学习? 计算机程序利用经验E(训练数据)学习任务T(要做什么,即目标),性能是P(性能指标),如果针对任务T的性能P随着经验E不断增长,成为机器学习.[这是汤姆米切尔在1997年定义] ...

  9. cc代码学习笔记1

    #define #define INT32 int #define INT8 char #define CHAR char #define SSHORT signed short #define IN ...

随机推荐

  1. spring基础:什么是框架,框架优势,spring优势,耦合内聚,什么是Ioc,IOC配置,set注入,第三方资源配置,综合案例spring整合mybatis实现

    知识点梳理 课堂讲义 1)Spring简介 1.1)什么是框架 源自于建筑学,隶属土木工程,后发展到软件工程领域 软件工程中框架的特点: 经过验证 具有一定功能 半成品 1.2)框架的优势 提高开发效 ...

  2. LNMP配置——PHP安装

    一.下载 #cd /usr/local/src //软件包都放在这里方便管理 #wget http://cn2.php.net/distributions/php-5.6.30.tar.gz 二.解压 ...

  3. Android Studio 之创建自定义控件

    •前言 常用控件和布局的继承结构,如下图所示: 可以看到,我们所用的所有的控件都是直接或者间接的继承自View的: 所用的所有布局都是直接或者间接继承自ViewGroup的: View 是 Andro ...

  4. 第十届蓝桥杯大赛软件类省赛C/C++研究生组 试题I:灵能传输

    在游戏<星际争霸 II>中,高阶圣堂武士作为星灵的重要 AOE 单位,在游戏的中后期发挥着重要的作用,其技能"灵能风暴"可以消耗大量的灵能对一片区域内的敌军造成毁灭性的 ...

  5. HTML5新增语法

    ##1.video1.简化版写法:兼容性差```<video src="" controls> </video>``` 2.视频标签标准语法(兼容处理)`` ...

  6. ionic3+angular 倒计时效果

    // 声明变量 applicationInterval: any; // 定时器 nextBtnText: String; nextBtnBool: Boolean; // 使用定时器,每秒执行一次 ...

  7. 解决“用PicGo-2.3.0-beta5 + GitHub做博客图床,github仓库图片文件不显示”的问题记录(备忘)

    解决"用PicGo-2.3.0-beta5 + GitHub做博客图床,github仓库图片文件不显示"的问题记录(备忘) 历时几个小时百度,终于靠自己理解解决了GitHub仓库图 ...

  8. 带你全面认识CMMI V2.0(五)——改进

    改进(Improving)涉及开发.管理和改进过程及其相关资产,其主要重点是提高组织绩效.保持习惯和持久性可确保过程在整个组织中是持久.习惯性地执行和维持,并有助于有效地实现业务绩效目标.治理(GOV ...

  9. Spring(七篇)

    (一)Spring 概述 (二)Spring Bean入门介绍 (三)Spring Bean继续入门 (四)Spring Bean注入方试 (五)Spring AOP简述 (六)Spring AOP切 ...

  10. JVM--Java核心面试知识整理(一)

    JVM 基本概念 JVM 是可运行 Java 代码的假想计算机 ,包括一套字节码指令集.一组寄存器.一个栈. 一个垃圾回收,堆和 一个存储方法域.JVM 是运行在操作系统之上的,它与硬件没有直接的交互 ...