Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Knn0827 #3172

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Knn0827 #3172

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmcv/ops/csrc/pytorch/npu/knn_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ using namespace std;
void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz,
const Tensor new_xyz, Tensor idx, Tensor dist2) {
// transpose known from [B, N, 3] to [B, 3, N]
at::Tensor source = xyz.transpose(1, 2).contiguous();
at::Tensor source = xyz.transpose(2, 1).contiguous();
at::Tensor target = new_xyz.contiguous();

bool is_from_knn = true;
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, idx, dist2);
}

void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz,
Expand Down
29 changes: 19 additions & 10 deletions mmcv/ops/csrc/pytorch/npu/three_interpolate_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@ void three_interpolate_forward_npu(int b, int c, int m, int n,
TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf),
"three_interpolate_forward ascend only support fp32 and fp16.");

auto point_c_trans = points.transpose(1, 2);

auto point_c_trans = points.transpose(1, 2).to(at::kFloat);
auto weight_cast = weight.to(at::kFloat);
auto out_cast = out.to(at::kFloat);
OpCommand cmd;
cmd.Name("ThreeInterpolate")
.Input(point_c_trans)
.Input(idx)
.Input(weight)
.Output(out)
.Input(weight_cast)
.Output(out_cast)
.Run();

auto output = out.view({b, n, c}).transpose(1, 2);
if (originDtype == at::kHalf) {
out_cast = out_cast.to(at::kHalf);
}
auto output = out_cast.view({b, n, c}).transpose(1, 2);
auto res = output.contiguous();
out.copy_(res);
}
Expand All @@ -34,12 +38,17 @@ void three_interpolate_backward_npu(int b, int c, int n, int m,
TORCH_CHECK((originDtype == at::kFloat || originDtype == at::kHalf),
"three_interpolate_backward ascend only support fp32 and fp16.");

auto grad_x = at::unsqueeze(grad_out, 3);
auto grad_y = at::unsqueeze(grad_points, 3);

EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight, m, grad_y);
auto grad_x = at::unsqueeze(grad_out, 3).to(at::kFloat);
auto grad_y = at::unsqueeze(grad_points, 3).to(at::kFloat);
auto weight_cast = weight.to(at::kFloat);
EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight_cast, m,
grad_y);

auto output = at::squeeze(grad_y, 3);
auto grad_y_cast = grad_y;
if (originDtype == at::kHalf) {
grad_y_cast = grad_y.to(at::kHalf);
}
auto output = at::squeeze(grad_y_cast, 3);
auto res = output.contiguous();
grad_points.copy_(res);
}
Expand Down
13 changes: 2 additions & 11 deletions mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,12 @@ using namespace std;

void three_nn_forward_npu(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2, Tensor idx) {
// transpose known [B, N, 3] -> [B, 3, N]
at::Tensor source = known.transpose(1, 2).contiguous();
at::Tensor source = known.contiguous();
at::Tensor target = unknown.contiguous();
auto originDtype = source.scalar_type();
if (originDtype == at::kHalf) {
source = source.to(at::kFloat);
target = target.to(at::kFloat);
}

bool is_from_knn = false;
uint32_t nsample = 3;
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
if (originDtype == at::kHalf) {
dist2 = dist2.to(at::kHalf);
}
EXEC_NPU_CMD(aclnnKnn, source, target, is_from_knn, dist2);
}

void three_nn_forward_impl(int b, int n, int m, const Tensor unknown,
Expand Down
17 changes: 17 additions & 0 deletions mmcv/ops/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,23 @@ def forward(ctx,

B, npoint, _ = center_xyz.shape
N = xyz.shape[1]
if xyz.device.type == 'npu':
dist = center_xyz.new_zeros((B, npoint, N)).float()
ext_module.knn_forward(
xyz,
center_xyz,
torch.Tensor([]).npu(),
dist,
b=B,
n=N,
m=npoint,
nsample=k)
dist2, idx = torch.topk(dist, k, dim=2, largest=False, sorted=True)
zeros_idx = torch.zeros(
xyz.shape[0], center_xyz.shape[1], k, dtype=torch.int32).npu()
idx.where(dist2 >= 1e10, zeros_idx)
idx = idx.transpose(2, 1).contiguous() # [B, k, npoint]
return idx.int()

idx = center_xyz.new_zeros((B, npoint, k)).int()
dist2 = center_xyz.new_zeros((B, npoint, k)).float()
Expand Down
15 changes: 15 additions & 0 deletions mmcv/ops/three_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ def forward(ctx: Any, target: torch.Tensor,

B, N, _ = target.size()
m = source.size(1)
if source.device.type == 'npu':
# strict to fp32
source = source.transpose(2, 1).contiguous()
dtype_ = source.dtype
if dtype_ == torch.float16:
target = target.float()
source = source.float()
dist = target.new_empty(B, N, m)
ext_module.three_nn_forward(
target, source, dist, torch.Tensor([]).npu(), b=B, n=N, m=m)
dist2, idx = torch.topk(dist, 3, dim=2, largest=False, sorted=True)
dist2 = torch.sqrt(dist2)
if dtype_ == torch.float16:
dist2 = dist2.half()
return dist2, idx.int()
dist2 = target.new_empty(B, N, 3)
idx = target.new_empty(B, N, 3, dtype=torch.int32)

Expand Down