move EXPAND dtype cast back to gradient.py (#15481)

only a concern for gradient, not mixin
This commit is contained in:
chenyu
2026-03-25 19:25:26 -04:00
committed by GitHub
parent 9d2d0774b4
commit 7c8f992894
3 changed files with 9 additions and 17 deletions

View File

@@ -1,7 +1,6 @@
# schedule tests that pass on NULL backend (no copyout needed)
import gc, unittest, time
from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.device import is_dtype_supported
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
from tinygrad.helpers import DEBUG, GlobalCounters, Context
from tinygrad.engine.realize import CompiledRunner, run_schedule
@@ -510,7 +509,6 @@ class TestSchedule(unittest.TestCase):
d = (a+b).reshape(16,1)
check_schedule(d, 0, [c])
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_multi_permute_should_collapse(self):
a = Tensor.empty(4,4,4,4)
b = Tensor.empty(16)
@@ -746,7 +744,6 @@ class TestSchedule(unittest.TestCase):
out1 = out0[0] + Tensor.empty(1, )
check_schedule([r, out0, out1], 3)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_softmax_upcast(self):
# input half, softmax in float
Tensor.manual_seed(0)
@@ -754,15 +751,10 @@ class TestSchedule(unittest.TestCase):
out = x.softmax(dtype=dtypes.float)
sched = out.schedule()
self.assertEqual(len(sched), 3)
self.assertEqual(sched[0].bufs[0].dtype, dtypes.float)
# input float, softmax in float
Tensor.manual_seed(0)
x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.float).realize()
out = x.softmax(dtype=dtypes.float)
sched = out.schedule()
self.assertEqual(len(sched), 3)
self.assertEqual(sched[0].bufs[0].dtype, dtypes.float)
# max reduction stays in input dtype (no numerical loss), upcast happens after subtracting max
self.assertEqual(sched[0].bufs[0].dtype, dtypes.half)
self.assertEqual(sched[1].bufs[0].dtype, dtypes.float)
self.assertEqual(sched[2].bufs[0].dtype, dtypes.float)
def test_softmax_backward(self):
Tensor.manual_seed(0)

View File

@@ -2,6 +2,7 @@ from typing import cast
import math, dataclasses, itertools
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata, graph_rewrite
from tinygrad.helpers import argsort
from tinygrad.dtype import sum_acc_dtype
def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
def broadcast_to_input(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape)
@@ -64,7 +65,8 @@ pm_gradient = PatternMatcher([
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD,tuple(i for i,(s,n) in enumerate(zip(ret.src[0].shape, ret.shape)) if s!=n)), None)),
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
(ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD,tuple(i for i,(s,n) in enumerate(zip(ret.src[0].shape,ret.shape)) if s!=n)).cast(ctx.dtype), None)),
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),

View File

@@ -2,7 +2,7 @@ from typing import Self
from tinygrad.mixin.elementwise import ElementwiseMixin
from tinygrad.mixin.movement import MovementMixin
from tinygrad.uop.ops import _broadcast_shape
from tinygrad.dtype import least_upper_dtype, sum_acc_dtype
from tinygrad.dtype import least_upper_dtype
class OpMixin(ElementwiseMixin, MovementMixin):
@@ -10,6 +10,4 @@ class OpMixin(ElementwiseMixin, MovementMixin):
if not isinstance(y, type(self)): y = self.ufix(y)
x, y = (self, y) if not reverse else (y, self)
out_shape, out_dtype = _broadcast_shape(x.shape, y.shape), least_upper_dtype(x.dtype, y.dtype)
# NOTE: the backward cast is no-op in forward and uses sum_acc_dtype in the backward sum
return x.cast(sum_acc_dtype(x.dtype))._broadcast_to(out_shape).cast(x.dtype).cast(out_dtype), \
y.cast(sum_acc_dtype(y.dtype))._broadcast_to(out_shape).cast(y.dtype).cast(out_dtype)
return x._broadcast_to(out_shape).cast(out_dtype), y._broadcast_to(out_shape).cast(out_dtype)