[BACKEND] Remove dependency between NVGPU and TritonNvidiaGPU (#2282)

This commit is contained in:
Thomas Raoux
2023-09-12 11:02:20 -07:00
committed by GitHub
parent 37f12497b0
commit 994f7e4460
9 changed files with 11 additions and 4 deletions

View File

@@ -36,7 +36,7 @@ def test_op(M, N, dtype, mode):
x.grad = None
th_y.backward(dy)
th_dx = x.grad.clone()
if dtype == 'float16':
if dtype == torch.float16:
torch.testing.assert_close(th_dx, tt_dx, rtol=0.001, atol=0.001)
else:
torch.testing.assert_close(th_dx, tt_dx)