【代码阅读】PointNet++中ball query的CUDA实现
发布日期:2021-09-16 07:31:54 浏览次数:2 分类:技术文章

本文共 4057 字,大约阅读时间需要 13 分钟。

文章目录

本文为PointNet++ CUDA代码阅读系列的第三部分,其他详见:

(一)
(二)
(三)
(四)


CUDA代码要在pytorch中使用,必须设置好CUDA代码与python的接口,并用python编写pytorch中的模块,这两部分详见。本文直接看ball query的实现。

给定一个点云xyz,然后给定中心点new_xyz,给定半径和邻域内点的数量,Ball Query可以找出以new_xyz为中心的领域内包含的xyz中的点的下标。

直接看代码,仍然是用的代码。先看在python中定义的函数,在pointnet2_utils.py中:

class BallQuery(Function):    @staticmethod    def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:        """        :param ctx:        :param radius: float, radius of the balls        :param nsample: int, maximum number of features in the balls        :param xyz: (B, N, 3) xyz coordinates of the features        :param new_xyz: (B, npoint, 3) centers of the ball query        :return:            idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls        """        assert new_xyz.is_contiguous()        assert xyz.is_contiguous()        B, N, _ = xyz.size()        npoint = new_xyz.size(1)        idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()        pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)        return idx    @staticmethod    def backward(ctx, a=None):        return None, None, None, None

接着看,上述pointnet2.ball_query_wrapper对应的cpp代码,在src/ball_query.cpp中,其中传入的参数在python代码中已经解释的非常明白:

int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample,     at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) {
// b: batch_size, n: xyz点 CHECK_INPUT(new_xyz_tensor); CHECK_INPUT(xyz_tensor); const float *new_xyz = new_xyz_tensor.data
(); const float *xyz = xyz_tensor.data
(); int *idx = idx_tensor.data
(); cudaStream_t stream = THCState_getCurrentStream(state); ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream); return 1;}

接着看cuda代码,在src/ball_query_gpu.cu中,代码的详细解释我用注释的形式给出:

void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \    const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) {
// new_xyz: (B, M, 3) // xyz: (B, N, 3) // output: // idx: (B, M, nsample) cudaError_t err; dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) dim3 threads(THREADS_PER_BLOCK); ball_query_kernel_fast<<
>>(b, n, m, radius, nsample, new_xyz, xyz, idx); // cudaDeviceSynchronize(); // for using printf in kernel function err = cudaGetLastError(); if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); exit(-1); }}__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) {
// new_xyz: (B, M, 3) // xyz: (B, N, 3) // output: // idx: (B, M, nsample) int bs_idx = blockIdx.y; // 找到对应的那个batch int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; // 找到new_xyz中的哪个点 if (bs_idx >= b || pt_idx >= m) return; new_xyz += bs_idx * m * 3 + pt_idx * 3; // new_xyz为指针,找到对应的点 xyz += bs_idx * n * 3; // xyz为指针,定位到当前的batch idx += bs_idx * m * nsample + pt_idx * nsample; // idx为指针,找到当前new_xyz对应的输出的idx的起始位置 float radius2 = radius * radius; float new_x = new_xyz[0]; // 选取这个线程处理的new_xyz点的x,y,z,作为中心 float new_y = new_xyz[1]; float new_z = new_xyz[2]; int cnt = 0; for (int k = 0; k < n; ++k) {
// 对xyz中的每个点进行遍历,看看是否在当前处理中心的邻域内 float x = xyz[k * 3 + 0]; float y = xyz[k * 3 + 1]; float z = xyz[k * 3 + 2]; float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); if (d2 < radius2){
if (cnt == 0){
for (int l = 0; l < nsample; ++l) {
// 如果是找到的第一个在邻域内的点,将后面所有nsample的idx先赋值为这个点的坐标 idx[l] = k; } } idx[cnt] = k; ++cnt; if (cnt >= nsample) break; } }}

转载地址:https://blog.csdn.net/wqwqqwqw1231/article/details/117449633 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:【代码阅读】PointNet++代码梳理
下一篇:【论文阅读】【二维目标检测】Generalized Focal Loss

发表评论

最新留言

路过按个爪印,很不错,赞一个!
[***.219.124.196]2024年03月29日 11时25分26秒

关于作者

    喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!

推荐文章