mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
use float for acc dtype for half tensor sum
we previously only upcast uint and int, and half was using half for acc. change to acc in float for precision. but cast the result back to half to match torch/jax output dtype
This commit is contained in:
@@ -72,10 +72,16 @@ class TestLinearizer(unittest.TestCase):
|
||||
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
|
||||
assert num_ops <= 0, "more load or alu uops than needed"
|
||||
|
||||
def test_tensor_cores(self):
|
||||
if Device.DEFAULT not in tensor_cores:
|
||||
self.skipTest("No tensor cores for device")
|
||||
def test_sum_acc_dtype(self):
|
||||
for tensor_dtype, acc_dtype in ((dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
|
||||
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
|
||||
k = Linearizer(a.lazydata.schedule()[-1].ast)
|
||||
k.linearize()
|
||||
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
|
||||
assert local[0].dtype == acc_dtype
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in tensor_cores, "No tensor cores for device")
|
||||
def test_tensor_cores(self):
|
||||
for tc in tensor_cores[Device.DEFAULT]:
|
||||
if tc.arch is not None and tc.arch != os.uname().machine: continue
|
||||
a, b = Tensor.rand(tc.dims[0], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[1], dtype=tc.dtype_in)
|
||||
|
||||
@@ -515,9 +515,12 @@ class Tensor:
|
||||
return ret if keepdim else ret.reshape(shape=shape)
|
||||
|
||||
def sum(self, axis=None, keepdim=False):
|
||||
output_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
|
||||
least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else self.dtype
|
||||
return self.cast(output_dtype)._reduce(mlops.Sum, axis, keepdim)
|
||||
acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
|
||||
least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \
|
||||
least_upper_dtype(self.dtype, dtypes.float)
|
||||
# cast back to float16 or bfloat16 to match torch / jax behavior, but we use float for acc
|
||||
output_dtype = self.dtype if self.dtype in (dtypes.float16, dtypes.bfloat16) else acc_dtype
|
||||
return self.cast(acc_dtype)._reduce(mlops.Sum, axis, keepdim).cast(output_dtype)
|
||||
|
||||
def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
|
||||
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
|
||||
|
||||
Reference in New Issue
Block a user