【代码阅读】PointNet++中ball query的CUDA实现
发布日期:2021-09-16 07:31:54
浏览次数:2
分类:技术文章
本文共 4057 字,大约阅读时间需要 13 分钟。
文章目录
本文为PointNet++ CUDA代码阅读系列的第三部分,其他详见:
(一) (二) (三) (四)CUDA代码要在pytorch中使用,必须设置好CUDA代码与python的接口,并用python编写pytorch中的模块,这两部分详见。本文直接看ball query的实现。
给定一个点云xyz,然后给定中心点new_xyz,给定半径和邻域内点的数量,Ball Query可以找出以new_xyz为中心的领域内包含的xyz中的点的下标。
直接看代码,仍然是用的代码。先看在python中定义的函数,在pointnet2_utils.py中:
class BallQuery(Function): @staticmethod def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor: """ :param ctx: :param radius: float, radius of the balls :param nsample: int, maximum number of features in the balls :param xyz: (B, N, 3) xyz coordinates of the features :param new_xyz: (B, npoint, 3) centers of the ball query :return: idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls """ assert new_xyz.is_contiguous() assert xyz.is_contiguous() B, N, _ = xyz.size() npoint = new_xyz.size(1) idx = torch.cuda.IntTensor(B, npoint, nsample).zero_() pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx) return idx @staticmethod def backward(ctx, a=None): return None, None, None, None
接着看,上述pointnet2.ball_query_wrapper对应的cpp代码,在src/ball_query.cpp中,其中传入的参数在python代码中已经解释的非常明白:
int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) { // b: batch_size, n: xyz点 CHECK_INPUT(new_xyz_tensor); CHECK_INPUT(xyz_tensor); const float *new_xyz = new_xyz_tensor.data(); const float *xyz = xyz_tensor.data (); int *idx = idx_tensor.data (); cudaStream_t stream = THCState_getCurrentStream(state); ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream); return 1;}
接着看cuda代码,在src/ball_query_gpu.cu中,代码的详细解释我用注释的形式给出:
void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \ const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) { // new_xyz: (B, M, 3) // xyz: (B, N, 3) // output: // idx: (B, M, nsample) cudaError_t err; dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) dim3 threads(THREADS_PER_BLOCK); ball_query_kernel_fast<<>>(b, n, m, radius, nsample, new_xyz, xyz, idx); // cudaDeviceSynchronize(); // for using printf in kernel function err = cudaGetLastError(); if (cudaSuccess != err) { fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); exit(-1); }}__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) { // new_xyz: (B, M, 3) // xyz: (B, N, 3) // output: // idx: (B, M, nsample) int bs_idx = blockIdx.y; // 找到对应的那个batch int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; // 找到new_xyz中的哪个点 if (bs_idx >= b || pt_idx >= m) return; new_xyz += bs_idx * m * 3 + pt_idx * 3; // new_xyz为指针,找到对应的点 xyz += bs_idx * n * 3; // xyz为指针,定位到当前的batch idx += bs_idx * m * nsample + pt_idx * nsample; // idx为指针,找到当前new_xyz对应的输出的idx的起始位置 float radius2 = radius * radius; float new_x = new_xyz[0]; // 选取这个线程处理的new_xyz点的x,y,z,作为中心 float new_y = new_xyz[1]; float new_z = new_xyz[2]; int cnt = 0; for (int k = 0; k < n; ++k) { // 对xyz中的每个点进行遍历,看看是否在当前处理中心的邻域内 float x = xyz[k * 3 + 0]; float y = xyz[k * 3 + 1]; float z = xyz[k * 3 + 2]; float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); if (d2 < radius2){ if (cnt == 0){ for (int l = 0; l < nsample; ++l) { // 如果是找到的第一个在邻域内的点,将后面所有nsample的idx先赋值为这个点的坐标 idx[l] = k; } } idx[cnt] = k; ++cnt; if (cnt >= nsample) break; } }}
转载地址:https://blog.csdn.net/wqwqqwqw1231/article/details/117449633 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!
发表评论
最新留言
路过按个爪印,很不错,赞一个!
[***.219.124.196]2024年03月29日 11时25分26秒
关于作者
喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
xml中常用的转义符
2019-04-27
关于MSDK的几个难点
2019-04-27
使用UnityEditor做工具
2019-04-27
Visual Studio我常用的快捷键
2019-04-27
写C# dll供Unity调用
2019-04-27
Linux制作run安装包
2019-04-27
一分钟学会C#解析XML
2019-04-27
unity AssetBundle的资源管理
2019-04-27
【转】Unity中HideInInspector和SerializeField一起使用
2019-04-27
单例模板类
2019-04-27
Unity与java相互调用
2019-04-27
android截屏代码
2019-04-27
unity NGUI图文混排
2019-04-27
Unity项目优化
2019-04-27
Unity3D Shader 入门
2019-04-27
MSDK手Q邀请透传参数问题:url编解码与base64编解码
2019-04-27
svn提交的一个坑
2019-04-27
eclipse识别不了模拟器解决办法
2019-04-27
unity mesh合并
2019-04-27
谈谈类之间的关联关系与依赖关系
2019-04-27