tf32 tc for nv and ptx (#8635)

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
ignaciosica
2025-01-17 22:43:57 -03:00
committed by GitHub
parent 5afb0a4a81
commit d2234e308a
7 changed files with 23 additions and 11 deletions

View File

@@ -1982,7 +1982,7 @@ class TestKernelOpts(unittest.TestCase):
Tensor.manual_seed(1552)
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
# bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices.
if tc.dtype_in == dtypes.bfloat16: continue
if tc.dtype_in != dtypes.half and tc.dtype_out != dtypes.half: continue
a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in)
r = a.matmul(b, acc_dtype=tc.dtype_out)
(atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4)
@@ -2009,7 +2009,7 @@ class TestKernelOpts(unittest.TestCase):
Tensor.manual_seed(1552)
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
# bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices.
if tc.dtype_in == dtypes.bfloat16: continue
if tc.dtype_in != dtypes.half and tc.dtype_out != dtypes.half: continue
a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in)
r = a.matmul(b, acc_dtype=tc.dtype_out)
(atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4)