pointnet.pytorch代码解析

代码运行

Training

cd utils
python train_classification.py --dataset <dataset path> --nepoch=<number epochs> --dataset_type <modelnet40 | shapenet>
python train_segmentation.py --dataset <dataset path> --nepoch=<number epochs>

运行结果

  1. Classification on ShapeNet

    epoch = 10 Overall Acc
    Original implementation N/A
    this implementation(无 feature transform) 95.6
    this implementation(有 feature transform) 92.97
  2. Segmentation on ShapeNet

dataset代码

  1. 读取的数据格式

    ShapeNetDataset():默认读取分割数据,返回值d:点云个数*(点云数据ps,标签seg)

    数据ps:torch.Size([2500, 3]) torch.FloatTensor ,一个点云有2500个点,每个点3个特征

    标签seg:torch.Size([2500]) torch.LongTensor,每个点都有一个标签

    代码及注释如下:

    if __name__ == '__main__':
    dataset = sys.argv[1] # 运行命令中传入的第一个参数
    datapath = sys.argv[2] # 运行命令中传入的第二个参数 if dataset == 'shapenet':
    # 读取标签为Chair的分割数据
    d = ShapeNetDataset(root = datapath, class_choice = ['Chair'])
    print(len(d)) #2658,共有2658个Chair点云
    ps, seg = d[0]
    print(ps.size(), ps.type(), seg.size(),seg.type())
    # torch.Size([2500, 3]) torch.FloatTensor ,第一个点云有2500个点,每个点3个特征
    # torch.Size([2500]) torch.LongTensor,每个点都有一个标签 d = ShapeNetDataset(root = datapath, classification = True)
    print(len(d))
    ps, cls = d[0]
    print(ps.size(), ps.type(), cls.size(),cls.type())
    # torch.Size([2500, 3]) torch.FloatTensor torch.Size([1]) torch.LongTensor,每个点云一个标签
    # get_segmentation_classes(datapath)
  2. 数据读取

model代码

  1. 网络整体结构

    if __name__ == '__main__':
    # input transform
    sim_data = Variable(torch.rand(32,3,2500)) # 32个点云,3个特征,2500个点
    trans = STN3d()
    out = trans(sim_data) # stn torch.Size([32, 3, 3]),返回3x3的输入变换矩阵
    print('stn', out.size())
    print('loss', feature_transform_regularizer(out)) # feature transform
    sim_data_64d = Variable(torch.rand(32, 64, 2500))
    trans = STNkd(k=64)
    out = trans(sim_data_64d) # stn64d torch.Size([32, 64, 64]),返回64x64的特征变换矩阵
    print('stn64d', out.size())
    print('loss', feature_transform_regularizer(out)) # global feat
    pointfeat = PointNetfeat(global_feat=True)
    out, _, _ = pointfeat(sim_data) # global feat torch.Size([32, 1024]),32个点云,每个有1024维全局特征
    print('global feat', out.size()) # point feat
    pointfeat = PointNetfeat(global_feat=False)
    out, _, _ = pointfeat(sim_data) # point feat torch.Size([32, 1088, 2500]),2500个点,每个点有1024+64维特征
    print('point feat', out.size()) # Classification
    cls = PointNetCls(k = 5)
    out, _, _ = cls(sim_data) # class torch.Size([32, 5]),global feat经过全连接层,得到在5个类别上的概率信息
    print('class', out.size()) # Segmentation
    seg = PointNetDenseCls(k = 3)
    out, _, _ = seg(sim_data) # seg torch.Size([32, 2500, 3]),point feat经过一维卷积,得到在3个类别上概率信息
    print('seg', out.size())
  2. PointNetfeat特征提取网络

    class PointNetfeat(nn.Module):
    '''
    点云的特征提取网络:global feature 和 point features
    '''
    def __init__(self, global_feat = True, feature_transform = False):
    super(PointNetfeat, self).__init__()
    self.stn = STN3d()
    self.conv1 = torch.nn.Conv1d(3, 64, 1)
    self.conv2 = torch.nn.Conv1d(64, 128, 1)
    self.conv3 = torch.nn.Conv1d(128, 1024, 1)
    self.bn1 = nn.BatchNorm1d(64)
    self.bn2 = nn.BatchNorm1d(128)
    self.bn3 = nn.BatchNorm1d(1024)
    self.global_feat = global_feat
    self.feature_transform = feature_transform
    if self.feature_transform:
    self.fstn = STNkd(k=64) def forward(self, x):
    n_pts = x.size()[2]
    trans = self.stn(x)
    x = x.transpose(2, 1)
    x = torch.bmm(x, trans) # 乘以3x3变换矩阵
    x = x.transpose(2, 1)
    x = F.relu(self.bn1(self.conv1(x))) if self.feature_transform: # 特征变换,64x64矩阵
    trans_feat = self.fstn(x)
    x = x.transpose(2,1)
    x = torch.bmm(x, trans_feat)
    x = x.transpose(2,1)
    else:
    trans_feat = None pointfeat = x # nx64的点特征
    x = F.relu(self.bn2(self.conv2(x)))
    x = self.bn3(self.conv3(x))
    x = torch.max(x, 2, keepdim=True)[0] # Maxpool
    x = x.view(-1, 1024)
    if self.global_feat:
    return x, trans, trans_feat # x:mx1x1024的global feature,两个变换矩阵
    else:
    x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
    return torch.cat([x, pointfeat], 1), trans, trans_feat # global feature+point features = nx1088的点特征矩阵

pointnet.pytorch代码解析的更多相关文章

  1. 【论文笔记】AutoML for MCA on Mobile Devices——论文解读与代码解析

    理论部分 方法介绍 本节将详细介绍AMC的算法流程.AMC旨在自动地找出每层的冗余参数. AMC训练一个强化学习的策略,对每个卷积层会给出其action(即压缩率),然后根据压缩率进行裁枝.裁枝后,A ...

  2. VBA常用代码解析

    031 删除工作表中的空行 如果需要删除工作表中所有的空行,可以使用下面的代码. Sub DelBlankRow() DimrRow As Long DimLRow As Long Dimi As L ...

  3. [nRF51822] 12、基础实验代码解析大全 · 实验19 - PWM

    一.PWM概述: PWM(Pulse Width Modulation):脉冲宽度调制技术,通过对一系列脉冲的宽度进行调制,来等效地获得所需要波形. PWM 的几个基本概念: 1) 占空比:占空比是指 ...

  4. [nRF51822] 11、基础实验代码解析大全 · 实验16 - 内部FLASH读写

     一.实验内容: 通过串口发送单个字符到NRF51822,NRF51822 接收到字符后将其写入到FLASH 的最后一页,之后将其读出并通过串口打印出数据. 二.nRF51822芯片内部flash知识 ...

  5. [nRF51822] 10、基础实验代码解析大全 · 实验15 - RTC

    一.实验内容: 配置NRF51822 的RTC0 的TICK 频率为8Hz,COMPARE0 匹配事件触发周期为3 秒,并使能了TICK 和COMPARE0 中断. TICK 中断中驱动指示灯D1 翻 ...

  6. [nRF51822] 9、基础实验代码解析大全 · 实验12 - ADC

    一.本实验ADC 配置 分辨率:10 位. 输入通道:5,即使用输入通道AIN5 检测电位器的电压. ADC 基准电压:1.2V. 二.NRF51822 ADC 管脚分布 NRF51822 的ADC ...

  7. java集合框架之java HashMap代码解析

     java集合框架之java HashMap代码解析 文章Java集合框架综述后,具体集合类的代码,首先以既熟悉又陌生的HashMap开始. 源自http://www.codeceo.com/arti ...

  8. Kakfa揭秘 Day8 DirectKafkaStream代码解析

    Kakfa揭秘 Day8 DirectKafkaStream代码解析 今天让我们进入SparkStreaming,看一下其中重要的Kafka模块DirectStream的具体实现. 构造Stream ...

  9. linux内存管理--slab及其代码解析

    Linux内核使用了源自于 Solaris 的一种方法,但是这种方法在嵌入式系统中已经使用了很长时间了,它是将内存作为对象按照大小进行分配,被称为slab高速缓存. 内存管理的目标是提供一种方法,为实 ...

随机推荐

  1. 树莓派FRP内网穿透及自启动

    内网穿透的步骤和文件存档 实验室在远方部署了电脑主机来采集数据和图片,每次去调试会很麻烦,因而使用FRP内网穿透使得我们可以在实验室访问主机. 主要功能 实现远程可访问和开机自启FRP程序服务 安装和 ...

  2. Java语言实现二维码的生成

    众所周知,现在生活中二维码已经是无处不见.走在街道上,随处可见广告标语旁有二维码,手机上QQ,微信加个好友都能通过二维码的方式,我不知道是什么时候兴起的二维码浪潮,但是我知道,这在我小时候可是见不到的 ...

  3. 电容三点式振荡电路详解及Multisim实例仿真

    电容三点式振荡器也称考毕兹(Colpitts,也叫科耳皮兹)振荡器,是三极管自激LC振荡器的一种,因振荡回路中两个串联电容的三个端分别与三极管的三个极相接而得名,适合于高频振荡输出的电路形式之一.电容 ...

  4. php 安装 yii 报错: phpunit/phpunit 4.8.32 requires ext-dom *

    php 安装 yii 报错: phpunit/phpunit 4.8.32 requires ext-dom * 我的版本是7.0,以7.0为例演示. 先装这两个拓展试试: sudo apt-get ...

  5. 流程自动化RPA,Power Automate Desktop系列 - 批量备份Git仓库做好灾备

    一.背景 打个比如,你在Github上的代码库需要批量的定时备案到本地的Gitlab上,以便Github不能访问时,可以继续编写,这时候我们可以基于Power Automate Desktop来实现一 ...

  6. 4、mysql登录密码修改和找回

    操作适合5.1-5.5:当前的环境是5.5的环境: 4.1.mysql启动的原理: mysqld_safe -> my.cnf ->mysql.sock http://blog.51cto ...

  7. vue3,后台管理列表页面各组件之间的状态关系

    技术栈 vite2 vue 3.0.5 vue-router 4.0.6 vue-data-state 0.1.1 element-plus 1.0.2-beta.39 前情回顾 表单控件 查询控件 ...

  8. hdu 4686 Arc of Dream 自己推 矩阵快速幂

    A.mat[0][0] = 1, A.mat[0][1] = 1, A.mat[0][2] = 0, A.mat[0][3] = 0, A.mat[0][4] = 0; A.mat[1][0] = 0 ...

  9. Binding(五):多路绑定

    Binding不止能绑定一个源,它还能绑定多个源,这就是我们这节要讲的多路绑定:MultiBinding. 使用多路绑定跟一般的绑定还是有区别的,首先它并不能很好的在标记扩展中使用,另外,使用多路绑定 ...

  10. jce-jdk13-120.jar

    jce-jdk13-120.jar https://files.cnblogs.com/files/blogs/692137/jce-jdk13-120.rar