mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
move EXPAND dtype cast back to gradient.py (#15481)
only a concern for gradient, not mixin
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# schedule tests that pass on NULL backend (no copyout needed)
|
||||
import gc, unittest, time
|
||||
from tinygrad import nn, dtypes, Device, Tensor
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
|
||||
from tinygrad.helpers import DEBUG, GlobalCounters, Context
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
@@ -510,7 +509,6 @@ class TestSchedule(unittest.TestCase):
|
||||
d = (a+b).reshape(16,1)
|
||||
check_schedule(d, 0, [c])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_multi_permute_should_collapse(self):
|
||||
a = Tensor.empty(4,4,4,4)
|
||||
b = Tensor.empty(16)
|
||||
@@ -746,7 +744,6 @@ class TestSchedule(unittest.TestCase):
|
||||
out1 = out0[0] + Tensor.empty(1, )
|
||||
check_schedule([r, out0, out1], 3)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_softmax_upcast(self):
|
||||
# input half, softmax in float
|
||||
Tensor.manual_seed(0)
|
||||
@@ -754,15 +751,10 @@ class TestSchedule(unittest.TestCase):
|
||||
out = x.softmax(dtype=dtypes.float)
|
||||
sched = out.schedule()
|
||||
self.assertEqual(len(sched), 3)
|
||||
self.assertEqual(sched[0].bufs[0].dtype, dtypes.float)
|
||||
|
||||
# input float, softmax in float
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 12, 64, 64, dtype=dtypes.float).realize()
|
||||
out = x.softmax(dtype=dtypes.float)
|
||||
sched = out.schedule()
|
||||
self.assertEqual(len(sched), 3)
|
||||
self.assertEqual(sched[0].bufs[0].dtype, dtypes.float)
|
||||
# max reduction stays in input dtype (no numerical loss), upcast happens after subtracting max
|
||||
self.assertEqual(sched[0].bufs[0].dtype, dtypes.half)
|
||||
self.assertEqual(sched[1].bufs[0].dtype, dtypes.float)
|
||||
self.assertEqual(sched[2].bufs[0].dtype, dtypes.float)
|
||||
|
||||
def test_softmax_backward(self):
|
||||
Tensor.manual_seed(0)
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import cast
|
||||
import math, dataclasses, itertools
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, all_metadata, graph_rewrite
|
||||
from tinygrad.helpers import argsort
|
||||
from tinygrad.dtype import sum_acc_dtype
|
||||
|
||||
def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
|
||||
def broadcast_to_input(x): return x.reshape(x.shape+(1,)*(len(ret.src[0].shape)-len(x.shape))).expand(ret.src[0].shape)
|
||||
@@ -64,7 +65,8 @@ pm_gradient = PatternMatcher([
|
||||
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
|
||||
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
|
||||
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape), None)),
|
||||
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.r(Ops.ADD,tuple(i for i,(s,n) in enumerate(zip(ret.src[0].shape, ret.shape)) if s!=n)), None)),
|
||||
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
|
||||
(ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD,tuple(i for i,(s,n) in enumerate(zip(ret.src[0].shape,ret.shape)) if s!=n)).cast(ctx.dtype), None)),
|
||||
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Self
|
||||
from tinygrad.mixin.elementwise import ElementwiseMixin
|
||||
from tinygrad.mixin.movement import MovementMixin
|
||||
from tinygrad.uop.ops import _broadcast_shape
|
||||
from tinygrad.dtype import least_upper_dtype, sum_acc_dtype
|
||||
from tinygrad.dtype import least_upper_dtype
|
||||
|
||||
|
||||
class OpMixin(ElementwiseMixin, MovementMixin):
|
||||
@@ -10,6 +10,4 @@ class OpMixin(ElementwiseMixin, MovementMixin):
|
||||
if not isinstance(y, type(self)): y = self.ufix(y)
|
||||
x, y = (self, y) if not reverse else (y, self)
|
||||
out_shape, out_dtype = _broadcast_shape(x.shape, y.shape), least_upper_dtype(x.dtype, y.dtype)
|
||||
# NOTE: the backward cast is no-op in forward and uses sum_acc_dtype in the backward sum
|
||||
return x.cast(sum_acc_dtype(x.dtype))._broadcast_to(out_shape).cast(x.dtype).cast(out_dtype), \
|
||||
y.cast(sum_acc_dtype(y.dtype))._broadcast_to(out_shape).cast(y.dtype).cast(out_dtype)
|
||||
return x._broadcast_to(out_shape).cast(out_dtype), y._broadcast_to(out_shape).cast(out_dtype)
|
||||
|
||||
Reference in New Issue
Block a user