1. Introduction


In this work, inspired by metric learning based on deep neural features and memory augment neural networks, authors propose matching networks that map a small labelled support set and an unlabelled example to its label. Then they define one-shot learning problems on vision and language tasks and obtain an improving one-shot accuracy on ImageNet and Omnight. The novelty of their work is twofold: at the modeling level, and at the training procedure.

2. Model


Their non-parametric approach to solving one-shot is based on two components. First, the model architecture follows recent advances in neural networks augmented with memory. Given a support set $S$, the model difines a function $c_S$(or classifier) for each $S$ Sencond, we employ a training strategy which is tailored for one-shot learning from the support set $S$

2.1 Model Architecture

Matching Networks are able to produce sensible test labels for unobserved classes without any changes to the network. We wish to map from a support set of $k$ examples of images-label pairs $S={(x_i,y_i)}_{i=1}^k$ to a classfier $c_S(\hat{x})$ which,given a test example $\hat{x}$, defines a probability distribution over outputs $\hat{y}$. Furthmore, difine the mapping $S\rightarrow c_S(\hat{x})$ to be $P(\hat{y} \mid \hat{x},S)$ where $P$ is parameterised by a neural network. Thus, When given a new support set of examples $S'$ from which to one-shot learn, we simply use the parametric neural network defined by $P$ to make predictions about the appropriate label $\hat{y}$ for each test example $\hat{x}$: $P(\hat{y} \mid \hat{x},S')$. In general, our predicted output class for a given input unseen example $\hat{x}$ and a support set $S$ becomes $arg \max_y P(y\mid \hat{x},S)$. The model in its simplest form computes $\hat{y}$ as follows:

$$ \hat{y}=\sum_{i=1}^k a(\hat{x},x_i)y_i $$

where $x_i,y_i$ are the samples and labels from the support set $S=\{(x_i,y_i)\}_{i=1}^k$, and $a$ is an attention mechanism. Here,the attention kernel function is the softmax over the cosine distance. $$ a(\hat{x},x_i)=\frac{e^{c(f(\hat{x}),g(x_i))}}{\sum_{j=1}^k e^{c(f(\hat{x}),g(x_j))}} $$ where embeding functions $f$ and $g$ are, actually, appropriate neural networks to embed $\hat{x}$ and $x_i$

2.2 Training Strategy

Let us define a tast $T$ as distribution over possible label sets $L$. To form an “episode” to compute gradients and update our model, we first sample $L$ from $T$(e.g.,$L$ could be the label set {cats; dogs}). We then use $L$ to sample the support set $S$ and a batch $B$ (i.e., both $S$ and $B$ are labelled examples of cats and dogs). The Matching Net is then trained to minimise the error predicting the labels in the batch B conditioned on the support set $S$. This is a form of meta-learning since the training procedure explicitly learns to learn from a given support set to minimise a loss over a batch. More precisely, the Matching Nets training objective is as follows:

$$ \theta = arg\max_{\theta}E_{L\sim T}\Big[E_{S\sim L,B\sim L}\Big[\sum_{(x,y)\in B}\log P_{\theta}(y\mid x,S)\Big]\Big] $$

Training $\theta$ with this objective function yields a model which works well when sampling $S'\sim T'$ from a different distribution of novel labels

3. Appendix


3.1 The Fully Conditional Embedding $f$

The embedding function for an example $\hat{x}$ in the batch $B$ is as follows:

$$ f(\hat{x},S)=attLSTM(f'(\hat{x}),g(S),K) $$

where $f'$ is a neural network. $K$ is the number of "processing" steps following work. $g(S)$ represents the embedding function $g$ applied to each element $x_i$ from the set $S$. Thus, the state after $k$ processing steps is as follows:

$$ \hat{h}_k,c_k = LSTM(f'(\hat{x}),[h_{k-1},r_{k-1}],c_{k-1}) $$

$$ h_k = \hat{h}_k+f'(\hat{x}) $$

$$ r_{k-1}=\sum_{i=1}^{|S|}a(h_{k-1},g(x_i))g(x_i) $$

$$ a(h_{k-1},g(x_i))=softmax(h_{k-1}^Tg(x_i)) $$

3.2 The Fully Conditional Embedding $g$

The encoding function for the elements in the support set $S$, $g(x_i,S)$ as a bidirectional LSTM. Let g'(x_i) be a neural network, then we difine $g(x_i,S)=\vec{h}_i+h_i^{\leftarrow}+g'(x_i)$ with:

$$ \vec{h}_i,\vec{c}_i=LSTM(g'(x_i),\vec{h}_{i-1},\vec{c}_{i-1}) $$

$$ h_i^{\leftarrow},c_i^{\leftarrow}=LSTM(g'(x_i),h_{i+1}^{\leftarrow},c_{i+1}^{\leftarrow}) $$

Reference: https://arxiv.org/abs/1606.04080

Matching Networks for One Shot Learning的更多相关文章

  1. (转)Paper list of Meta Learning/ Learning to Learn/ One Shot Learning/ Lifelong Learning

    Meta Learning/ Learning to Learn/ One Shot Learning/ Lifelong Learning 2018-08-03 19:16:56 本文转自:http ...

  2. Multi-attention Network for One Shot Learning

    Multi-attention Network for One Shot Learning 2018-05-15 22:35:50  本文的贡献点在于: 1. 表明类别标签信息对 one shot l ...

  3. (六)6.11 Neurons Networks implements of self-taught learning

    在machine learning领域,更多的数据往往强于更优秀的算法,然而现实中的情况是一般人无法获取大量的已标注数据,这时候可以通过无监督方法获取大量的未标注数据,自学习( self-taught ...

  4. 论文笔记系列-Speeding Up Automatic Hyperparameter Optimization of Deep Neural Networks by Extrapolation of Learning Curves

    I. 背景介绍 1. 学习曲线(Learning Curve) 我们都知道在手工调试模型的参数的时候,我们并不会每次都等到模型迭代完后再修改超参数,而是待模型训练了一定的epoch次数后,通过观察学习 ...

  5. CS229 6.11 Neurons Networks implements of self-taught learning

    在machine learning领域,更多的数据往往强于更优秀的算法,然而现实中的情况是一般人无法获取大量的已标注数据,这时候可以通过无监督方法获取大量的未标注数据,自学习( self-taught ...

  6. 零样本学习 - (Zero shot learning,ZSL)

    https://zhuanlan.zhihu.com/p/41846072 https://zhuanlan.zhihu.com/p/38418698 https://zhuanlan.zhihu.c ...

  7. 18 Issues in Current Deep Reinforcement Learning from ZhiHu

    深度强化学习的18个关键问题 from: https://zhuanlan.zhihu.com/p/32153603 85 人赞了该文章 深度强化学习的问题在哪里?未来怎么走?哪些方面可以突破? 这两 ...

  8. (zhuan) Where can I start with Deep Learning?

    Where can I start with Deep Learning? By Rotek Song, Deep Reinforcement Learning/Robotics/Computer V ...

  9. Few-Shot/One-Shot Learning

    Few-Shot/One-Shot Learning指的是小样本学习,目的是克服机器学习中训练模型需要海量数据的问题,期望通过少量数据即可获得足够的知识. Matching Networks for ...

随机推荐

  1. 关于Verilog中begin-end & fork-join

     转载:http://blog.sina.com.cn/s/blog_6c7b6f030101cpgt.html begin-end and fork-join are used to combi ...

  2. Mysql临时文件目录控制

    查看mysql的log-error日志发现如下错误: ERROR 3 (HY000): Error writing file '/tmp/MYbEd05t' (Errcode: 28) 这是由于mys ...

  3. ThinkPHP5.*版本发布安全更新

    2018 年 12 月 9 日 发布 本次版本更新主要涉及一个安全更新,由于框架对控制器名没有进行足够的检测会导致在没有开启强制路由的情况下可能的getshell漏洞,受影响的版本包括5.0和5.1版 ...

  4. sed -i命令详解

    [root@www ~]# sed [-nefr] [动作] 选项与参数: -n :使用安静(silent)模式.在一般 sed 的用法中,所有来自 STDIN 的数据一般都会被列出到终端上.但如果加 ...

  5. ADB server didn't ACK failed to start daemon 5037

    错误信息: C:\Users\lizy>adb devices adb devicesadb server is out of date.  killing... ADB server didn ...

  6. [UnityAPI]SerializedObject类 & SerializedProperty类

    以Image类为例 1.MyImage.cs using UnityEngine; using UnityEngine.UI; public class MyImage : Image { ; pro ...

  7. scrollview嵌套recyclerview卡顿现象

    方式一xml: android:nestedScrollingEnabled="false" <android.support.v7.widget.RecyclerView ...

  8. android toolbar使用记录

    1.打开Project structure,选择app modules,切换到Dependencies添加com.android.support.design.26.0.0.alpha1 2.在lay ...

  9. windows注册表解析说明

    https://www.cnblogs.com/wfq9330/p/9176654.html

  10. tp5框架中jquery+ajax分页

    jaxa分页,点击按钮直接替换数据, //php代码$page=Request::instance()->param("page"); $page = empty($page ...