Run more webgpu tests (#8142)

This commit is contained in:
Ahmed Harmouche
2024-12-10 23:20:04 +01:00
committed by GitHub
parent ed7318a3f5
commit a8cfdc70ed
6 changed files with 11 additions and 4 deletions

View File

@@ -499,6 +499,8 @@ class TestMoveTensor(unittest.TestCase):
@given(strat.sampled_from([d0, d1]), strat.sampled_from([d0, d1]),
strat.sampled_from([dtypes.float16, dtypes.float32]), strat.sampled_from([True, False, None]))
def test_to_preserves(self, src, dest, dtype, requires_grad):
if not is_dtype_supported(dtype):
return
s = Tensor([1, 2, 3], device=src, dtype=dtype, requires_grad=requires_grad)
if requires_grad: s.sum().backward()
t = s.to(dest)