【代码阅读】PointNet++代码梳理
发布日期:2021-09-16 07:31:55 浏览次数:2 分类:技术文章

本文共 1916 字,大约阅读时间需要 6 分钟。

文章目录

本文为PointNet++ CUDA代码阅读系列的第一部分,其他详见:

(一)
(二)
(三)
(四)


PointNet++的核心操作是SA层和FP层,这里就来梳理一下SA层和FP层都干了什么。

SA层

参数:降采样点的数量(npoints),邻域半径(radii),邻域内点的数量(nsample),MLP

输入:xyz,features

计算过程如下:

new_xyz_idx = FPS(xyz, npoints)  #使用FPS选出降采样点的下标,记作new_xyz_idxnew_xyz = gather(xyz, new_xyz_idx)  #根据下标选出降采样的点,记作new_xyzidx = ball_query(xyz, new_xyz, radii) #根据邻域半径radii,通过ball_query函数实现在xyz中属于new_xyz邻域内的点的下标,记作idxgrouped_xyz, grouped_feature = group(xyz, feature, idx) #根据下标,选取邻域内的xyz和相对应的featurenew_feature = torch.cat([grouped_xyz-new_xyz, grouped-feature]) # new_feature:(B, 3 + C, npoint, nsample)new_feature = MLP(new_feature)  # new_feature:(B, C', npoint, nsample)new_feature = max_pooling(new_feature, dim=3)  # new_feature:(B, C', npoint)

至此,SA层完成点的降采样和特征提取。那具体来看,其中有4个函数是用CUDA编写的,分别是FPS,gather,ball_query,group。在这四个函数中,FPS和ball_query是找下标,不对点和特征做操作,是无需传播梯度的。而gather和group则是根据下标从xyz中选点,这个其实用torch.gather也可以实现。

经过以上分析,比较重要的FPS和ball_query,另外两个就是选取点而已。

FP层

参数:MLP

输入:known,unknown,known_feature,unknown_feature,定义如下

:param unknown: (B, n, 3) tensor of the xyz positions of the unknown features    :param known: (B, m, 3) tensor of the xyz positions of the known features    :param unknow_feats: (B, C1, n) tensor of the features to be propigated to    :param known_feats: (B, C2, m) tensor of features to be propigated

简单来说,FP层是将global feature回传到原始点云的过程,所以know是靠近特征金字塔顶部的点,unknown是靠近特征金字塔底部的点,那known_feature是指known包含的特征,是global的feature,unknown_feature是unknow原本的特征是局部特征

计算过程如下:

dist, idx = three_nn(unknown, known)  # 对于unknown的每个点,找到其在known最近的3个点的距离和下标weight = f(dist)  # 利用距离计算权重interpolated_feats = three_interpolate(known_feats, idx, weight)new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n)new_features = MLP(new_features) # (B, C2 + C1, n)

看计算过程,其中three_nn和three_interpolate是用CUDA编写的。three_nn是计算距离和下标,不需要梯度。three_interpolate相当于是在known先用idx进行一次gather,然后再和weight加权平均。所以重要的就是three_nn如何实现。

转载地址:https://blog.csdn.net/wqwqqwqw1231/article/details/117449965 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:【代码阅读】PointNet++中的Three_nn的CUDA实现
下一篇:【代码阅读】PointNet++中ball query的CUDA实现

发表评论

最新留言

网站不错 人气很旺了 加油
[***.192.178.218]2024年04月15日 15时09分19秒