pointnet.pytorch代码解析
pointnet.pytorch代码解析
代码运行
Training
cd utils
python train_classification.py --dataset <dataset path> --nepoch=<number epochs> --dataset_type <modelnet40 | shapenet>
python train_segmentation.py --dataset <dataset path> --nepoch=<number epochs>
运行结果
Classification on ShapeNet
epoch = 10 Overall Acc Original implementation N/A this implementation(无 feature transform) 95.6 this implementation(有 feature transform) 92.97 Segmentation on ShapeNet
dataset代码
读取的数据格式
ShapeNetDataset():默认读取分割数据,返回值d:点云个数*(点云数据ps,标签seg)
数据ps:torch.Size([2500, 3]) torch.FloatTensor ,一个点云有2500个点,每个点3个特征
标签seg:torch.Size([2500]) torch.LongTensor,每个点都有一个标签
代码及注释如下:if __name__ == '__main__':
dataset = sys.argv[1] # 运行命令中传入的第一个参数
datapath = sys.argv[2] # 运行命令中传入的第二个参数 if dataset == 'shapenet':
# 读取标签为Chair的分割数据
d = ShapeNetDataset(root = datapath, class_choice = ['Chair'])
print(len(d)) #2658,共有2658个Chair点云
ps, seg = d[0]
print(ps.size(), ps.type(), seg.size(),seg.type())
# torch.Size([2500, 3]) torch.FloatTensor ,第一个点云有2500个点,每个点3个特征
# torch.Size([2500]) torch.LongTensor,每个点都有一个标签 d = ShapeNetDataset(root = datapath, classification = True)
print(len(d))
ps, cls = d[0]
print(ps.size(), ps.type(), cls.size(),cls.type())
# torch.Size([2500, 3]) torch.FloatTensor torch.Size([1]) torch.LongTensor,每个点云一个标签
# get_segmentation_classes(datapath)
数据读取
model代码
网络整体结构
if __name__ == '__main__':
# input transform
sim_data = Variable(torch.rand(32,3,2500)) # 32个点云,3个特征,2500个点
trans = STN3d()
out = trans(sim_data) # stn torch.Size([32, 3, 3]),返回3x3的输入变换矩阵
print('stn', out.size())
print('loss', feature_transform_regularizer(out)) # feature transform
sim_data_64d = Variable(torch.rand(32, 64, 2500))
trans = STNkd(k=64)
out = trans(sim_data_64d) # stn64d torch.Size([32, 64, 64]),返回64x64的特征变换矩阵
print('stn64d', out.size())
print('loss', feature_transform_regularizer(out)) # global feat
pointfeat = PointNetfeat(global_feat=True)
out, _, _ = pointfeat(sim_data) # global feat torch.Size([32, 1024]),32个点云,每个有1024维全局特征
print('global feat', out.size()) # point feat
pointfeat = PointNetfeat(global_feat=False)
out, _, _ = pointfeat(sim_data) # point feat torch.Size([32, 1088, 2500]),2500个点,每个点有1024+64维特征
print('point feat', out.size()) # Classification
cls = PointNetCls(k = 5)
out, _, _ = cls(sim_data) # class torch.Size([32, 5]),global feat经过全连接层,得到在5个类别上的概率信息
print('class', out.size()) # Segmentation
seg = PointNetDenseCls(k = 3)
out, _, _ = seg(sim_data) # seg torch.Size([32, 2500, 3]),point feat经过一维卷积,得到在3个类别上概率信息
print('seg', out.size())
PointNetfeat特征提取网络
class PointNetfeat(nn.Module):
'''
点云的特征提取网络:global feature 和 point features
'''
def __init__(self, global_feat = True, feature_transform = False):
super(PointNetfeat, self).__init__()
self.stn = STN3d()
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.global_feat = global_feat
self.feature_transform = feature_transform
if self.feature_transform:
self.fstn = STNkd(k=64) def forward(self, x):
n_pts = x.size()[2]
trans = self.stn(x)
x = x.transpose(2, 1)
x = torch.bmm(x, trans) # 乘以3x3变换矩阵
x = x.transpose(2, 1)
x = F.relu(self.bn1(self.conv1(x))) if self.feature_transform: # 特征变换,64x64矩阵
trans_feat = self.fstn(x)
x = x.transpose(2,1)
x = torch.bmm(x, trans_feat)
x = x.transpose(2,1)
else:
trans_feat = None pointfeat = x # nx64的点特征
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
x = torch.max(x, 2, keepdim=True)[0] # Maxpool
x = x.view(-1, 1024)
if self.global_feat:
return x, trans, trans_feat # x:mx1x1024的global feature,两个变换矩阵
else:
x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
return torch.cat([x, pointfeat], 1), trans, trans_feat # global feature+point features = nx1088的点特征矩阵
pointnet.pytorch代码解析的更多相关文章
- 【论文笔记】AutoML for MCA on Mobile Devices——论文解读与代码解析
理论部分 方法介绍 本节将详细介绍AMC的算法流程.AMC旨在自动地找出每层的冗余参数. AMC训练一个强化学习的策略,对每个卷积层会给出其action(即压缩率),然后根据压缩率进行裁枝.裁枝后,A ...
- VBA常用代码解析
031 删除工作表中的空行 如果需要删除工作表中所有的空行,可以使用下面的代码. Sub DelBlankRow() DimrRow As Long DimLRow As Long Dimi As L ...
- [nRF51822] 12、基础实验代码解析大全 · 实验19 - PWM
一.PWM概述: PWM(Pulse Width Modulation):脉冲宽度调制技术,通过对一系列脉冲的宽度进行调制,来等效地获得所需要波形. PWM 的几个基本概念: 1) 占空比:占空比是指 ...
- [nRF51822] 11、基础实验代码解析大全 · 实验16 - 内部FLASH读写
一.实验内容: 通过串口发送单个字符到NRF51822,NRF51822 接收到字符后将其写入到FLASH 的最后一页,之后将其读出并通过串口打印出数据. 二.nRF51822芯片内部flash知识 ...
- [nRF51822] 10、基础实验代码解析大全 · 实验15 - RTC
一.实验内容: 配置NRF51822 的RTC0 的TICK 频率为8Hz,COMPARE0 匹配事件触发周期为3 秒,并使能了TICK 和COMPARE0 中断. TICK 中断中驱动指示灯D1 翻 ...
- [nRF51822] 9、基础实验代码解析大全 · 实验12 - ADC
一.本实验ADC 配置 分辨率:10 位. 输入通道:5,即使用输入通道AIN5 检测电位器的电压. ADC 基准电压:1.2V. 二.NRF51822 ADC 管脚分布 NRF51822 的ADC ...
- java集合框架之java HashMap代码解析
java集合框架之java HashMap代码解析 文章Java集合框架综述后,具体集合类的代码,首先以既熟悉又陌生的HashMap开始. 源自http://www.codeceo.com/arti ...
- Kakfa揭秘 Day8 DirectKafkaStream代码解析
Kakfa揭秘 Day8 DirectKafkaStream代码解析 今天让我们进入SparkStreaming,看一下其中重要的Kafka模块DirectStream的具体实现. 构造Stream ...
- linux内存管理--slab及其代码解析
Linux内核使用了源自于 Solaris 的一种方法,但是这种方法在嵌入式系统中已经使用了很长时间了,它是将内存作为对象按照大小进行分配,被称为slab高速缓存. 内存管理的目标是提供一种方法,为实 ...
随机推荐
- MySQL 架构|给你一个“上帝视角”
"我平时的工作就是 CRUD (增删改查)呀!我怎么提升自己的技术?"."平时开发我都是用开源的 MyBatis.Hibernate,连原生的 sql 我都没写过几行&q ...
- redis 客户端实现读写分离实现
背景 (1) redis单机的读写性能轻松上大几万,不过线上环境不会只部署光秃秃的一个节点,还是会配合 sentinel 再部署一个 slave作为高可用节点的: 但是standby的slave节点是 ...
- 6.11、制作windos虚拟机
1.下载kvm支持windows系统的驱动程序: cd /tmp/ wget https://fedorapeople.org/groups/virt/virtio-win/direct-downlo ...
- Redis 底层数据结构之String
文章参考:<Redis设计与实现>黄建宏 Redis 的 string 类型底层使用的是 SDS(动态字符串) 实现的, 具体数据结构如下: struct sdshdr { int len ...
- shell运维习题训练
注:初学shell,以下为本人自己写的答案,如果有更好的,请指教! 1. 求2个数之和: 2. 计算1-100的和 3. 将一目录下所有的文件的扩展名改为bak 4.编译并执行当前目录下的所有.c文件 ...
- P4480 「BJWC2018」「网络流与线性规划24题」餐巾计划问题
刷了n次用了奇淫技巧才拿到rk1,亥 这道题是网络流二十四题中「餐巾计划问题」的加强版. 于是怀着试一试的心情用费用流交了一发: 哇塞,过了9个点!(强烈谴责出题人用*造数据 下面是费用流解法简述: ...
- Ubuntu中Docker的安装与使用
Ubuntu中安装Docker 更新ubuntu的apt源索引 sudo apt-get update 2.安装包允许apt通过HTTPS使用仓库 sudo apt-get install \ apt ...
- C#版Nebula客户端编译
一.需求背景 从Nebula的Github上可以发现,Nebula为以下语言提供了客户端SDK: nebula-cpp nebula-java nebula-go nebula-python nebu ...
- python使用笔记18--写日志
1 import nnlog 2 import traceback 3 #level:输出日志级别,debug:把所有的日志都打印出来,info:打印info以上的日志, 4 # warning:打印 ...
- Scrapy框架安装与使用(基于windows系统)
"人生苦短,我用python".最近了解到一个很好的Spider框架--Scrapy,自己就按着官方文档装了一下,出了些问题,在这里记录一下,免得忘记. Scrapy的安装是基于T ...