From ab7dfd637b31695bc332b76254f98dac4ead4b68 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 3 Jan 2024 18:56:42 -0800 Subject: [PATCH] use float for acc dtype for half tensor sum we previously only upcast uint and int, and half was using half for acc. change to acc in float for precision. but cast the result back to half to match torch/jax output dtype --- test/test_linearizer.py | 12 +++++++++--- tinygrad/tensor.py | 9 ++++++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 6f1031a99d..09d46983e5 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -72,10 +72,16 @@ class TestLinearizer(unittest.TestCase): num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]]) assert num_ops <= 0, "more load or alu uops than needed" - def test_tensor_cores(self): - if Device.DEFAULT not in tensor_cores: - self.skipTest("No tensor cores for device") + def test_sum_acc_dtype(self): + for tensor_dtype, acc_dtype in ((dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)): + a = Tensor([1, 2, 3], dtype=tensor_dtype).sum() + k = Linearizer(a.lazydata.schedule()[-1].ast) + k.linearize() + local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC] + assert local[0].dtype == acc_dtype + @unittest.skipUnless(Device.DEFAULT in tensor_cores, "No tensor cores for device") + def test_tensor_cores(self): for tc in tensor_cores[Device.DEFAULT]: if tc.arch is not None and tc.arch != os.uname().machine: continue a, b = Tensor.rand(tc.dims[0], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[1], dtype=tc.dtype_in) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 29ce411c9f..1c716394c0 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -515,9 +515,12 @@ class Tensor: return ret if keepdim else ret.reshape(shape=shape) def sum(self, axis=None, keepdim=False): - output_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \ - least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else self.dtype - return self.cast(output_dtype)._reduce(mlops.Sum, axis, keepdim) + acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \ + least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \ + least_upper_dtype(self.dtype, dtypes.float) + # cast back to float16 or bfloat16 to match torch / jax behavior, but we use float for acc + output_dtype = self.dtype if self.dtype in (dtypes.float16, dtypes.bfloat16) else acc_dtype + return self.cast(acc_dtype)._reduce(mlops.Sum, axis, keepdim).cast(output_dtype) def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim) def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))