主题列表:juejin, github, smartblue, cyanosis, channing-cyan, fancy, hydrogen, condensed-night-purple, greenwillow, v-green, vue-pro, healer-readable

贡献主题:https://github.com/xitu/juejin-markdown-themes

theme: smartblue

highlight:

在上一篇文章中已经讲解了Siamese Net的原理,和这种网络架构的关键——损失函数contrastive loss。现在我们来用pytorch来做一个简单的案例。经过这个案例,我个人的收获有到了以下的几点:

  • Siamese Net适合小数据集;
  • 目前Siamese Net用在分类任务(如果有朋友知道如何用在分割或者其他任务可以私信我,WX:cyx645016617)
  • Siamese Net的可解释性较好。

1 准备数据

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
from sklearn.model_selection import train_test_split
device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_train = pd.read_csv('../input/fashion-mnist_train.csv')
data_train.head()

这个数据文件是csv格式,第一列是类别,之后的784列其实好似28x28的像素值。

划分训练集和验证集,然后把数据转换成28x28的图片

X_full = data_train.iloc[:,1:]
y_full = data_train.iloc[:,:1]
x_train, x_test, y_train, y_test = train_test_split(X_full, y_full, test_size = 0.05)
x_train = x_train.values.reshape(-1, 28, 28, 1).astype('float32') / 255.
x_test = x_test.values.reshape(-1, 28, 28, 1).astype('float32') / 255.
y_train.label.unique()
>>> array([8, 9, 7, 6, 4, 2, 3, 1, 5, 0])

可以看到这个Fashion MNIST数据集中也是跟MNIST类似,划分了10个不同的类别。

  • 0 T-shirt/top
  • 1 Trouser
  • 2 Pullover
  • 3 Dress
  • 4 Coat
  • 5 Sandal
  • 6 Shirt
  • 7 Sneaker
  • 8 Bag
  • 9 Ankle boot
np.bincount(y_train.label.values),np.bincount(y_test.label.values)
>>> (array([4230, 4195, 4135, 4218, 4174, 4172, 4193, 4250, 4238, 4195]),
array([1770, 1805, 1865, 1782, 1826, 1828, 1807, 1750, 1762, 1805]))

可以看到,每个类别的数据还是非常均衡的。

2 构建Dataset和可视化

class mydataset(Dataset):
def __init__(self,x_data,y_data):
self.x_data = x_data
self.y_data = y_data.label.values
def __len__(self):
return len(self.x_data)
def __getitem__(self,idx):
img1 = self.x_data[idx]
y1 = self.y_data[idx]
if np.random.rand() < 0.5:
idx2 = np.random.choice(np.arange(len(self.y_data))[self.y_data==y1],1)
else:
idx2 = np.random.choice(np.arange(len(self.y_data))[self.y_data!=y1],1)
img2 = self.x_data[idx2[0]]
y2 = self.y_data[idx2[0]]
label = 0 if y1==y2 else 1
return img1,img2,label

关于torch.utils.data.Dataset的构建结构,我就不再赘述了,在之前的《小白学PyTorch》系列中已经讲解的很清楚啦。上面的逻辑就是,给定一个idx,然后我们先判断,这个数据是找两个同类别的图片还是两个不同类别的图片。50%的概率选择两个同类别的图片,然后最后输出的时候,输出这两个图片,然后再输出一个label,这个label为0的时候表示两个图片的类别是相同的,1表示两个图片的类别是不同的。这样就可以进行模型训练和损失函数的计算了。

train_dataset = mydataset(x_train,y_train)
train_dataloader = DataLoader(dataset = train_dataset,batch_size=8)
val_dataset = mydataset(x_test,y_test)
val_dataloader = DataLoader(dataset = val_dataset,batch_size=8)
for idx,(img1,img2,target) in enumerate(train_dataloader):
fig, axs = plt.subplots(2, img1.shape[0], figsize = (12, 6))
for idx,(ax1,ax2) in enumerate(axs.T):
ax1.imshow(img1[idx,:,:,0].numpy(),cmap='gray')
ax1.set_title('image A')
ax2.imshow(img2[idx,:,:,0].numpy(),cmap='gray')
ax2.set_title('{}'.format('same' if target[idx]==0 else 'different'))
break

这一段的代码就是对一个batch的数据进行一个可视化:

到目前位置应该没有什么问题把,有问题可以联系我讨论交流,WX:cyx645016617.我个人认为从交流中可以快速解决问题和进步。

3 构建模型

class siamese(nn.Module):
def __init__(self,z_dimensions=2):
super(siamese,self).__init__()
self.feature_net = nn.Sequential(
nn.Conv2d(1,4,kernel_size=3,padding=1,stride=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(4),
nn.Conv2d(4,4,kernel_size=3,padding=1,stride=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(4),
nn.MaxPool2d(2),
nn.Conv2d(4,8,kernel_size=3,padding=1,stride=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
nn.Conv2d(8,8,kernel_size=3,padding=1,stride=1),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
nn.MaxPool2d(2),
nn.Conv2d(8,1,kernel_size=3,padding=1,stride=1),
nn.ReLU(inplace=True)
)
self.linear = nn.Linear(49,z_dimensions)
def forward(self,x):
x = self.feature_net(x)
x = x.view(x.shape[0],-1)
x = self.linear(x)
return x

一个非常简单的卷积网络,输出的向量的维度就是z-dimensions的大小。

def contrastive_loss(pred1,pred2,target):
MARGIN = 2
euclidean_dis = F.pairwise_distance(pred1,pred2)
target = target.view(-1)
loss = (1-target)*torch.pow(euclidean_dis,2) + target * torch.pow(torch.clamp(MARGIN-euclidean_dis,min=0),2)
return loss

然后构建了一个contrastive loss的损失函数计算。

4 训练

model = siamese(z_dimensions=8).to(device)
# model.load_state_dict(torch.load('../working/saimese.pth'))
optimizor = torch.optim.Adam(model.parameters(),lr=0.001)
for e in range(10):
history = []
for idx,(img1,img2,target) in enumerate(train_dataloader):
img1 = img1.to(device)
img2 = img2.to(device)
target = target.to(device) pred1 = model(img1)
pred2 = model(img2)
loss = contrastive_loss(pred1,pred2,target) optimizor.zero_grad()
loss.backward()
optimizor.step() loss = loss.detach().cpu().numpy()
history.append(loss)
train_loss = np.mean(history)
history = []
with torch.no_grad():
for idx,(img1,img2,target) in enumerate(val_dataloader):
img1 = img1.to(device)
img2 = img2.to(device)
target = target.to(device) pred1 = model(img1)
pred2 = model(img2)
loss = contrastive_loss(pred1,pred2,target) loss = loss.detach().cpu().numpy()
history.append(loss)
val_loss = np.mean(history)
print(f'train_loss:{train_loss},val_loss:{val_loss}')

这里为了加快训练,我把batch-size增加到了128个,其他的并没有改变:

这是运行的10个epoch的结果,不要忘记把模型保存一下:

torch.save(model.state_dict(),'saimese.pth')

差不多是这个样子,然后看一看验证集的可视化效果,这里使用的是t-sne高位特征可视化的方法,其内核是PCA降维:

from sklearn import manifold
'''X是特征,不包含target;X_tsne是已经降维之后的特征'''
tsne = manifold.TSNE(n_components=2, init='pca', random_state=501)
X_tsne = tsne.fit_transform(X)
print("Org data dimension is {}. \
Embedded data dimension is {}".format(X.shape[-1], X_tsne.shape[-1])) x_min, x_max = X_tsne.min(0), X_tsne.max(0)
X_norm = (X_tsne - x_min) / (x_max - x_min) # 归一化
plt.figure(figsize=(8, 8))
for i in range(10):
plt.scatter(X_norm[y==i][:,0],X_norm[y==i][:,1],alpha=0.3,label=f'{i}')
plt.legend()

输入图像为:

可以看得出来,不同类别之间划分的是比较好的,可以看到不同类别之间的距离还是比较大的,比较明显的,甚至可以放下公众号的名字。这里使用的隐变量是8。

这里有一个问题,我内心已有答案不知大家的想法如何,假如我把z潜变量的维度直接改成2,这样就不需要使用tsne和pca的方法来降低维度就可以直接可视化,但是这样的话可视化的效果并不比从8降维到2来可视化的效果好,这是为什么呢?

提示:一方面在于维度过小导致信息的缺失,但是这个解释站不住脚,因为PCA其实等价于一个退化的线形层,所以PCA同样会造成这种缺失;我认为关键应该是损失函数中的欧式距离的计算,如果维度高,那么欧式距离就会偏大,这样需要相应的调整MARGIN的数值。

孪生网络入门(下) Siamese Net分类服装MNIST数据集(pytorch)的更多相关文章

  1. 孪生网络入门(上) Siamese Net及其损失函数

    最近在多个关键词(小数据集,无监督半监督,图像分割,SOTA模型)的范畴内,都看到了这样的一个概念,孪生网络,所以今天有空大概翻看了一下相关的经典论文和博文,之后做了一个简单的案例来强化理解.如果需要 ...

  2. 机器学习-MNIST数据集使用二分类

    一.二分类训练MNIST数据集练习 %matplotlib inlineimport matplotlibimport numpy as npimport matplotlib.pyplot as p ...

  3. Pytorch 入门之Siamese网络

    首次体验Pytorch,本文参考于:github and PyTorch 中文网人脸相似度对比 本文主要熟悉Pytorch大致流程,修改了读取数据部分.没有采用原作者的ImageFolder方法:   ...

  4. 孪生网络(Siamese Network)在句子语义相似度计算中的应用

    1,概述 在NLP中孪生网络基本是用来计算句子间的语义相似度的.其结构如下 在计算句子语义相似度的时候,都是以句子对的形式输入到网络中,孪生网络就是定义两个网络结构分别来表征句子对中的句子,然后通过曼 ...

  5. Linux网络栈下两层实现

    http://www.cnblogs.com/zmkeil/archive/2013/04/18/3029339.html 1.1简介 VLAN是网络栈的一个附加功能,且位于下两层.首先来学习Linu ...

  6. 源码分析——迁移学习Inception V3网络重训练实现图片分类

    1. 前言 近些年来,随着以卷积神经网络(CNN)为代表的深度学习在图像识别领域的突破,越来越多的图像识别算法不断涌现.在去年,我们初步成功尝试了图像识别在测试领域的应用:将网站样式错乱问题.无线领域 ...

  7. EcShop调用显示指定分类下的子分类方法

    ECSHOP首页默认的只有全部分类,还有循环大类以及下面小类的代码,貌似没有可以调用显示指定大类下的子分类代码.于是就有这个文章的产生了,下面由夏日博客来总结下网站建设过程中ECSHOP此类问题的网络 ...

  8. Pytorch入门下 —— 其他

    本节内容参照小土堆的pytorch入门视频教程. 现有模型使用和修改 pytorch框架提供了很多现有模型,其中torchvision.models包中有很多关于视觉(图像)领域的模型,如下图: 下面 ...

  9. 主机WIFI网络环境下,Linux虚拟机网络设置

    在主机使用WIFI网络环境下,怎么样进行虚拟机静态ip设置和连接互联网呢,原理什么太麻烦,另类的网络共享而已: 1.其实简单将网络连接模式设置成NAT模式即可. 2.虚拟网络编辑器依旧是桥接模式,选择 ...

随机推荐

  1. Zookeeper(4)---ZK集群部署和选举

    一.集群部署 1.准备三台机器,安装好ZK.强烈建议奇数台机器,因为zookeeper 通过判断大多数节点的存活来判断整个服务是否可用.3个节点,挂掉了2个表示整个集群挂掉,而用偶数4个,挂掉了2个也 ...

  2. slideUp和slideDown的区别

    slideUp():通过使用滑动效果,隐藏被选元素,如果元素已显示出来的话.语法:$(selector).slideUp(speed,callback).speed:可选,表示动画运行的时候.call ...

  3. 当年使用dpdk干的事

    mark一下  晚点上传 先不上传 ....0727

  4. 主动关闭 tcp fin-wait-2 time-wait 定时器

    后面整理相关信息 //后面整理相关信息 /* * This function implements the receiving procedure of RFC 793 for * all state ...

  5. http服务器文件名大小写忽略

    问题 文件从windows里面放到nginx里面去的时候,文件在windows下面是大小写忽略,也就是不论大小写都可以匹配的,而到linux下面的时候,因为linux是区分大小写的,也就是会出现无法忽 ...

  6. 划分问题(Java 动态规划)

    Description 给定一个正整数的集合A={a1,a2,-.,an},是否可以将其分割成两个子集合,使两个子集合的数加起来的和相等.例A = { 1, 3, 8, 4, 10} 可以分割:{1, ...

  7. Docker 初始

    1. Docker 是什么? 官网的介绍是"Docker is the world's leading software container platform." 官方给Docke ...

  8. Mac系统使用Parallels Desktop安装Win10

    1.Parallels Desktop破解版下载 2.原版Windows 10 2004 X64位 (原版安装)2020 11 Windows 系统镜像必须为原版,ghost版不行.亲测ghost版本 ...

  9. linux 身份鉴别口令复杂度整改

    口令复杂度: 1.首先安装apt install libpam-cracklib -y2.vim /etc/pam.d/common-password3.在第2步末尾添加password requis ...

  10. Mysql预处理语句prepare、execute、deallocate

    前言 做CTF题的时候遇到的所以参考资料学习一波.... MySQL的SQL预处理(Prepared) 一.SQL 语句的执行处理 1.即时 SQL 一条 SQL 在 DB 接收到最终执行完毕返回,大 ...