diff --git a/test/null/test_schedule.py b/test/null/test_schedule.py index 6c38de9c73..d5365b6d9d 100644 --- a/test/null/test_schedule.py +++ b/test/null/test_schedule.py @@ -1,7 +1,6 @@ # schedule tests that pass on NULL backend (no copyout needed) import gc, unittest, time from tinygrad import nn, dtypes, Device, Tensor -from tinygrad.device import is_dtype_supported from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat from tinygrad.helpers import DEBUG, GlobalCounters, Context from tinygrad.engine.realize import CompiledRunner, run_schedule @@ -510,7 +509,6 @@ class TestSchedule(unittest.TestCase): d = (a+b).reshape(16,1) check_schedule(d, 0, [c]) - @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_multi_permute_should_collapse(self): a = Tensor.empty(4,4,4,4) b = Tensor.empty(16) @@ -746,7 +744,6 @@ class TestSchedule(unittest.TestCase): out1 = out0[0] + Tensor.empty(1, ) check_schedule([r, out0, out1], 3) - @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_softmax_upcast(self): # input half, softmax in float Tensor.manual_seed(0) @@ -754,15 +751,10 @@ class TestSchedule(unittest.TestCase): out = x.softmax(dtype=dtypes.float) sched = out.schedule() self.assertEqual(len(sched), 3) - self.assertEqual(sched[0].bufs[0].dtype, dtypes.float) - - # input float, softmax in float - Tensor.manual_seed(0) - x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.float).realize() - out = x.softmax(dtype=dtypes.float) - sched = out.schedule() - self.assertEqual(len(sched), 3) - self.assertEqual(sched[0].bufs[0].dtype, dtypes.float) + # max reduction stays in input dtype (no numerical loss), upcast happens after subtracting max + self.assertEqual(sched[0].bufs[0].dtype, dtypes.half) + self.assertEqual(sched[1].bufs[0].dtype, dtypes.float) + self.assertEqual(sched[2].bufs[0].dtype, dtypes.float) def test_softmax_backward(self): Tensor.manual_seed(0) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index c1270c25d1..fb53210c04 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -2,6 +2,7 @@ from typing import cast import math, dataclasses, itertools from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata, graph_rewrite from tinygrad.helpers import argsort +from tinygrad.dtype import sum_acc_dtype def reduce_gradient(ctx:UOp, ret:UOp, op:Ops): def broadcast_to_input(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape) @@ -64,7 +65,8 @@ pm_gradient = PatternMatcher([ (UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)), (UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)), (UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)), - (UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD,tuple(i for i,(s,n) in enumerate(zip(ret.src[0].shape, ret.shape)) if s!=n)), None)), + (UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: + (ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD,tuple(i for i,(s,n) in enumerate(zip(ret.src[0].shape,ret.shape)) if s!=n)).cast(ctx.dtype), None)), (UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)), (UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)), (UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)), diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index e4be275b9c..d650aa3331 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -2,7 +2,7 @@ from typing import Self from tinygrad.mixin.elementwise import ElementwiseMixin from tinygrad.mixin.movement import MovementMixin from tinygrad.uop.ops import _broadcast_shape -from tinygrad.dtype import least_upper_dtype, sum_acc_dtype +from tinygrad.dtype import least_upper_dtype class OpMixin(ElementwiseMixin, MovementMixin): @@ -10,6 +10,4 @@ class OpMixin(ElementwiseMixin, MovementMixin): if not isinstance(y, type(self)): y = self.ufix(y) x, y = (self, y) if not reverse else (y, self) out_shape, out_dtype = _broadcast_shape(x.shape, y.shape), least_upper_dtype(x.dtype, y.dtype) - # NOTE: the backward cast is no-op in forward and uses sum_acc_dtype in the backward sum - return x.cast(sum_acc_dtype(x.dtype))._broadcast_to(out_shape).cast(x.dtype).cast(out_dtype), \ - y.cast(sum_acc_dtype(y.dtype))._broadcast_to(out_shape).cast(y.dtype).cast(out_dtype) + return x._broadcast_to(out_shape).cast(out_dtype), y._broadcast_to(out_shape).cast(out_dtype)