use float for acc dtype for half tensor sum

we previously only upcast uint and int, and half was using half for acc.
change to acc in float for precision. but cast the result back to half to match torch/jax output dtype
This commit is contained in:
chenyu
2024-01-03 18:56:42 -08:00
parent 6fa285b943
commit ab7dfd637b
2 changed files with 15 additions and 6 deletions

View File

@@ -72,10 +72,16 @@ class TestLinearizer(unittest.TestCase):
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
assert num_ops <= 0, "more load or alu uops than needed"
def test_tensor_cores(self):
if Device.DEFAULT not in tensor_cores:
self.skipTest("No tensor cores for device")
def test_sum_acc_dtype(self):
for tensor_dtype, acc_dtype in ((dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
k = Linearizer(a.lazydata.schedule()[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
assert local[0].dtype == acc_dtype
@unittest.skipUnless(Device.DEFAULT in tensor_cores, "No tensor cores for device")
def test_tensor_cores(self):
for tc in tensor_cores[Device.DEFAULT]:
if tc.arch is not None and tc.arch != os.uname().machine: continue
a, b = Tensor.rand(tc.dims[0], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[1], dtype=tc.dtype_in)

View File

@@ -515,9 +515,12 @@ class Tensor:
return ret if keepdim else ret.reshape(shape=shape)
def sum(self, axis=None, keepdim=False):
output_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else self.dtype
return self.cast(output_dtype)._reduce(mlops.Sum, axis, keepdim)
acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \
least_upper_dtype(self.dtype, dtypes.float)
# cast back to float16 or bfloat16 to match torch / jax behavior, but we use float for acc
output_dtype = self.dtype if self.dtype in (dtypes.float16, dtypes.bfloat16) else acc_dtype
return self.cast(acc_dtype)._reduce(mlops.Sum, axis, keepdim).cast(output_dtype)
def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))