Skip to content

Commit

Permalink
lint error
Browse files Browse the repository at this point in the history
  • Loading branch information
lihao7212148 committed Oct 23, 2023
1 parent 02b9fc8 commit aadb7fc
Showing 1 changed file with 30 additions and 37 deletions.
67 changes: 30 additions & 37 deletions tests/test_ops/test_three_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,28 +72,28 @@ def test_three_interpolate(dtype, device):
], [
2.2060e-01, 3.4110e-01, 3.4110e-01, 2.2060e-01, 2.2060e-01, 2.1380e-01
]],
[[
8.1773e-01, 9.5440e-01, 2.4532e+00,
8.1773e-01, 8.1773e-01, 1.1359e+00
],
[
8.4689e-01, 1.9176e+00, 1.4715e+00,
8.4689e-01, 8.4689e-01, 1.3079e+00
],
[
6.9473e-01, 2.7440e-01, 2.0842e+00,
6.9473e-01, 6.9473e-01, 7.8619e-01
],
[
7.6789e-01, 1.5063e+00, 1.6209e+00,
7.6789e-01, 7.6789e-01, 1.1562e+00
],
[
3.8760e-01, 1.0300e-02, 8.3569e-09,
3.8760e-01, 3.8760e-01, 1.9723e-01
]]],
dtype=dtype,
device=device)
[[
8.1773e-01, 9.5440e-01, 2.4532e+00,
8.1773e-01, 8.1773e-01, 1.1359e+00
],
[
8.4689e-01, 1.9176e+00, 1.4715e+00,
8.4689e-01, 8.4689e-01, 1.3079e+00
],
[
6.9473e-01, 2.7440e-01, 2.0842e+00,
6.9473e-01, 6.9473e-01, 7.8619e-01
],
[
7.6789e-01, 1.5063e+00, 1.6209e+00,
7.6789e-01, 7.6789e-01, 1.1562e+00
],
[
3.8760e-01, 1.0300e-02, 8.3569e-09,
3.8760e-01, 3.8760e-01, 1.9723e-01
]]],
dtype=dtype,
device=device)

assert torch.allclose(output, expected_output, 1e-3, 1e-4)

Expand Down Expand Up @@ -148,24 +148,16 @@ def torch_type_trans(dtype):
return np.float64


@pytest.mark.parametrize('dtype', [
torch.half,
torch.float
])
@pytest.mark.parametrize('dtype', [torch.half, torch.float])
@pytest.mark.parametrize('device', [
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
@pytest.mark.parametrize('shape', [
(2, 5, 6, 6),
(10, 10, 10, 10),
(20, 21, 13, 4),
(2, 10, 2, 18),
(10, 602, 910, 200),
(600, 100, 300, 101)
])
@pytest.mark.parametrize('shape', [(2, 5, 6, 6), (10, 10, 10, 10),
(20, 21, 13, 4), (2, 10, 2, 18),
(10, 602, 910, 200), (600, 100, 300, 101)])
def test_three_interpolate_npu_dynamic_shape(dtype, device, shape):
bs = shape[0]
cs = shape[1]
Expand All @@ -175,13 +167,14 @@ def test_three_interpolate_npu_dynamic_shape(dtype, device, shape):
features = np.random.uniform(-10.0, 10.0,
(bs, cs, ms).astype(torch_type_trans(dtype)))
idx = np.random.uniform(0, ms, size=(bs, ns, 3), dtype=np.int32)
weight = np.random.uniform(-10.0, 10.0 (bs, ns, 3)
).astype(torch_type_trans(dtype))
weight = np.random.uniform(-10.0,
10.0 (bs, ns,
3)).astype(torch_type_trans(dtype))

features_npu = torch.tensor(features, dtype=dtype).to(device)
idx_npu = torch.tensor(idx, dtype=torch.int32).to(device)
weight_npu = torch.tensor(weight, dtype=dtype).to(device)

expected_output = three_interpolate_forward_gloden(features, idx, weight)
output = three_interpolate(features_npu, idx_npu, weight_npu)
assert np.allclose(output.cpu().numpy(), expected_output, 1e-3, 1e-4)
assert np.allclose(output.cpu().numpy(), expected_output, 1e-3, 1e-4)

0 comments on commit aadb7fc

Please sign in to comment.