From 3ff03be413e9b8b0721fde659b68b759a4383843 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 17 Mar 2026 10:58:46 +0800 Subject: [PATCH] call always has tuple (#15297) * call always has tuple * fix pre-commit and simplify * update * fix * move that assert * tuple * fix multi * cleanups * fix merge --- test/amd/helpers.py | 2 +- test/unit/test_call.py | 72 ++++++++++++++++++++++++++++++++++ test/unit/test_function.py | 14 +++---- tinygrad/engine/allocations.py | 12 +++--- tinygrad/function.py | 2 +- tinygrad/gradient.py | 15 ++++--- tinygrad/schedule/multi.py | 17 +++++++- tinygrad/tensor.py | 3 +- tinygrad/uop/ops.py | 24 ++++++++---- tinygrad/uop/spec.py | 4 +- 10 files changed, 129 insertions(+), 36 deletions(-) diff --git a/test/amd/helpers.py b/test/amd/helpers.py index 367268924e..1ed4a8b5ad 100644 --- a/test/amd/helpers.py +++ b/test/amd/helpers.py @@ -5,7 +5,7 @@ from tinygrad.runtime.autogen import llvm from tinygrad.runtime.support.elf import elf_loader ARCH_TO_TARGET:dict[str, list[str]] = { - "rdna3":["gfx1100"], + "rdna3":["gfx1100", "gfx1151"], "rdna4":["gfx1200", "gfx1201"], "cdna":["gfx950", "gfx942"], } diff --git a/test/unit/test_call.py b/test/unit/test_call.py index d7c1cc2873..a889be4a4e 100644 --- a/test/unit/test_call.py +++ b/test/unit/test_call.py @@ -245,6 +245,78 @@ class TestCallSchedule(unittest.TestCase): out = f(a) + 2 np.testing.assert_allclose(out.numpy(), np.arange(8, dtype=np.float32).reshape(4, 2) + 3) +class TestCallMultiSharded(unittest.TestCase): + # TODO: multi-output + sharded needs per-device CALL execution, which requires reworking how MULTI propagates through TUPLE bodies + def test_tuple_sharded(self): + """multi-output function with sharded input""" + devs = ("CPU:0", "CPU:1") + @function + def f(x:Tensor): return (x + 1, x * 2) + a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0) + t1, t2 = f(a) + ref = np.arange(8, dtype=np.float32).reshape(4, 2) + np.testing.assert_allclose(t1.numpy(), ref + 1) + np.testing.assert_allclose(t2.numpy(), ref * 2) + + def test_tuple_sharded_precompile(self): + """multi-output precompiled function with sharded input""" + devs = ("CPU:0", "CPU:1") + @function(precompile=True) + def f(x:Tensor): return (x + 1, x * 2) + a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0) + t1, t2 = f(a) + ref = np.arange(8, dtype=np.float32).reshape(4, 2) + np.testing.assert_allclose(t1.numpy(), ref + 1) + np.testing.assert_allclose(t2.numpy(), ref * 2) + + def test_tuple_sharded_different_axis(self): + """multi-output function where outputs have different sharding: one reduces on sharded axis, one doesn't""" + devs = ("CPU:0", "CPU:1") + @function + def f(x:Tensor): return (x.sum(axis=0), x.sum(axis=1)) + a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0) + t1, t2 = f(a) + ref = np.arange(8, dtype=np.float32).reshape(4, 2) + np.testing.assert_allclose(t1.numpy(), ref.sum(axis=0)) + np.testing.assert_allclose(t2.numpy(), ref.sum(axis=1)) + + def test_tuple_sharded_different_ops(self): + """multi-output function with different operations per output""" + devs = ("CPU:0", "CPU:1") + @function + def f(x:Tensor, y:Tensor): return (x + y, x * y) + a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0) + b = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0) + 1 + t1, t2 = f(a, b) + ref_a = np.arange(8, dtype=np.float32).reshape(4, 2) + ref_b = ref_a + 1 + np.testing.assert_allclose(t1.numpy(), ref_a + ref_b) + np.testing.assert_allclose(t2.numpy(), ref_a * ref_b) + + def test_tuple_sharded_mixed_use(self): + """multi-output sharded results used in further computation""" + devs = ("CPU:0", "CPU:1") + @function + def f(x:Tensor): return (x + 1, x * 2) + a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0) + t1, t2 = f(a) + out = (t1 + t2).sum() + ref = np.arange(8, dtype=np.float32).reshape(4, 2) + np.testing.assert_allclose(out.numpy(), ((ref + 1) + (ref * 2)).sum()) + + def test_tuple_sharded_outputs_different_axis(self): + """multi-output function where the two outputs are sharded on different axes""" + devs = ("CPU:0", "CPU:1") + @function + def f(x:Tensor, y:Tensor): return (x + 1, y + 2) + a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0) + b = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=1) + t1, t2 = f(a, b) + ref_a = np.arange(8, dtype=np.float32).reshape(4, 2) + ref_b = np.arange(8, dtype=np.float32).reshape(4, 2) + np.testing.assert_allclose(t1.numpy(), ref_a + 1) + np.testing.assert_allclose(t2.numpy(), ref_b + 2) + def test_call_reduce_sharded(self): devs = ("CPU:0", "CPU:1") a = Tensor.ones(10, 10).shard(devs, axis=0) diff --git a/test/unit/test_function.py b/test/unit/test_function.py index 5774d5bbe1..d5cbc8f192 100644 --- a/test/unit/test_function.py +++ b/test/unit/test_function.py @@ -165,13 +165,13 @@ class TestFunction(unittest.TestCase): def test_name(self): @function def f(a:Tensor) -> Tensor: return a + 1 - assert f(Tensor([1])).uop.arg.name.endswith("f") + assert f(Tensor([1])).uop.src[0].arg.name.endswith("f") def test_method_name(self): class Foo: @function def __call__(self, x:Tensor) -> Tensor: return x + 1 - assert Foo()(Tensor([1])).uop.arg.name.endswith("Foo.__call__") + assert Foo()(Tensor([1])).uop.src[0].arg.name.endswith("Foo.__call__") def test_callable_instance(self): class Foo: @@ -180,7 +180,7 @@ class TestFunction(unittest.TestCase): foo = Foo() f = function(foo) np.testing.assert_equal(f(Tensor([1,2,3])).numpy(), [11,22,33]) - assert f(Tensor([1,2,3])).uop.arg.name.endswith("Foo") + assert f(Tensor([1,2,3])).uop.src[0].arg.name.endswith("Foo") def test_iadd(self): @function @@ -369,11 +369,11 @@ class TestFunctionTuple(unittest.TestCase): def test_grad_tuple_precompile(self): self.test_grad_tuple(True) def test_grad_fxn_tuple(self): - # grad_fxn for tuple: ctx is a TUPLE UOp with one element per output - def grad_fxn(ctx:UOp, call:UOp): - # f(u1, u2) = (u1+1, u2+2), ctx.src = (d_out0, d_out1) + # grad_fxn for tuple: receives one gradient per output as positional args + def grad_fxn(d_out0:UOp, d_out1:UOp, call:UOp): + # f(u1, u2) = (u1+1, u2+2) # df/du1 = d_out0, df/du2 = d_out1 - return (ctx.src[0], ctx.src[1]) + return (d_out0, d_out1) x = Tensor.ones(3, requires_grad=True).contiguous() y = Tensor.ones(3, requires_grad=True).contiguous() diff --git a/tinygrad/engine/allocations.py b/tinygrad/engine/allocations.py index 7fa477156c..67f1934b51 100644 --- a/tinygrad/engine/allocations.py +++ b/tinygrad/engine/allocations.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, track_rewrites -from tinygrad.dtype import dtypes, ImageDType +from tinygrad.dtype import ImageDType from tinygrad.helpers import prod, DEBUG, VIZ, pluralize, all_int @dataclass @@ -92,25 +92,25 @@ def contiguous_mops_to_view(c:UOp, src:UOp): def transform_precompiled_call(c:UOp) -> UOp|None: if not c.arg.precompile: return None if c.src[0].op is Ops.SINK: return None + assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled call, got {c.src[0].op}" input_buffers = tuple(x.contiguous() if x.op not in {Ops.AFTER, Ops.BIND} else x for x in c.src[1:]) # add the outputs to the call - srcs = c.src[0].src if c.src[0].op is Ops.TUPLE else (c.src[0],) - resolved = [c.gettuple(i) if c.src[0].op is Ops.TUPLE else c for i in range(len(srcs))] + srcs = c.src[0].src + resolved = [c.gettuple(i) for i in range(len(srcs))] outs = tuple(_buffer_like(r) for r in resolved) targets = [o.param_like(len(c.src)-1+i).shrink_to(s.shape) for i,(o,s) in enumerate(zip(outs, srcs))] fxn = UOp.sink(*[t.after(t.store(s)) for t,s in zip(targets, srcs)]) # create the new thing for the big graph - new_call = c.replace(src=(fxn, *input_buffers, *outs), dtype=dtypes.void, tag=None) + new_call = c.replace(src=(fxn, *input_buffers, *outs), tag=None) rets = tuple(o.after(new_call) for o in outs) # if the CALL has symbolic shapes, shrink the max-sized output to the actual symbolic shape # NOTE: must use resolved shapes from the CALL (which substitutes PARAMs with external args), not raw body shapes rets = tuple(r.shrink_to(rs.shape) for r,rs in zip(rets, resolved)) - # return tuple if tuple - return UOp.maketuple(*rets) if c.src[0].op is Ops.TUPLE else rets[0] + return UOp.maketuple(*rets) # NOTE: adding rules to here is bad. these all need to run before the schedule cache pm_early_transform_tensor_graph = PatternMatcher([ diff --git a/tinygrad/function.py b/tinygrad/function.py index 79d9743288..25504b94f3 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -74,7 +74,7 @@ class _function(Generic[ReturnType]): if isinstance(ret, tuple): return cast(ReturnType, tuple(Tensor(fret.gettuple(i), device=fret.device) for i in range(len(ret)))) else: - return cast(ReturnType, Tensor(fret, device=fret.device)) + return cast(ReturnType, Tensor(fret.gettuple(0), device=fret.device)) # overload signatures support both @function and @function(precompile=True) syntax @overload diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 6772b391fa..a961454587 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -14,22 +14,21 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops): if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],) def call_gradient(ctx:UOp, k:UOp) -> tuple[UOp|None, ...]: - if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k) fxn, args = k.src[0], k.src[1:] + if k.arg.grad_fxn is not None: + return (None,) + (k.arg.grad_fxn(*ctx.src, call=k) if ctx.op is Ops.TUPLE else k.arg.grad_fxn(ctx, k)) + assert fxn.op is Ops.TUPLE, f"expected TUPLE body for gradient, got {fxn.op}" params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM} - if fxn.op is Ops.TUPLE: - grad_args = ctx.src - root_grad = UOp(Ops.TUPLE, src=tuple(g.param_like(len(args) + i) for i, g in enumerate(grad_args))) - else: - grad_args = (ctx,) - root_grad = ctx.param_like(len(args)) + grad_args = ctx.src + root_grad = UOp(Ops.TUPLE, src=tuple(g.param_like(len(args) + i) for i, g in enumerate(grad_args))) grads = compute_gradient(fxn, root_grad, set(params.values())) ret: list[UOp|None] = [None] for i in range(len(args)): if (p:=params.get(i, None)) is not None and p in grads: # TODO: compact the args and remove unused ones assert not grads[p].op_in_backward_slice_with_self(Ops.BUFFER), "BUG: BUFFER in backward slice of grad" - ret.append(grads[p].call(*args, *grad_args, name=(k.arg.name or "")+f"_backward_{i}", precompile=k.arg.precompile_backward)) + bwd_call = grads[p].call(*args, *grad_args, name=(k.arg.name or "")+f"_backward_{i}", precompile=k.arg.precompile_backward) + ret.append(bwd_call.gettuple(0)) else: ret.append(None) return tuple(ret) diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 435363bfcd..afaf6b480f 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -116,6 +116,11 @@ def rewrite_into_call(call:UOp): if not should_resolve_call(call): return None new_body = graph_rewrite(call.src[0], multi_pm, name="subcall") new_args = tuple(a.src[0] if a.op is Ops.MULTI else a for a in call.src[1:]) + # after multi resolution, TUPLE elements may be MULTI — strip MULTI from body, create per-shard CALL, wrap each GETTUPLE in its own MULTI + assert new_body.op is Ops.TUPLE + if any(s.op is Ops.MULTI for s in new_body.src): + shard_call = call.replace(src=(UOp.maketuple(*[s.src[0] if s.op is Ops.MULTI else s for s in new_body.src]),)+new_args) + return UOp.maketuple(*[shard_call.gettuple(i).multi(s.axis) if s.op is Ops.MULTI else shard_call.gettuple(i) for i, s in enumerate(new_body.src)]) return call.replace(src=(new_body,)+new_args) def param_to_multi(p:UOp): @@ -137,12 +142,20 @@ multi_pm = PatternMatcher([ (UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi), (UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"), lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)), + + # resolve TUPLE+GETTUPLE (needed in multi) + (UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]), + # GETTUPLE on MULTI: passthrough MULTI (e.g. when CALL was replaced by MULTI(GETTUPLE(...))) + (UPat(Ops.GETTUPLE, src=(UPat(Ops.MULTI, name="multi"),), name="g"), + lambda g, multi: multi.src[0].gettuple(g.arg).multi(multi.axis) if multi.src[0].op in {Ops.CALL, Ops.TUPLE} + else multi), # rewrite into calls explicitly for MULTI (UPat(Ops.CALL, name="call"), rewrite_into_call), (UPat((Ops.CALL, Ops.AFTER, Ops.STORE), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi), - # we just remove the MULTI from CALLs with dtypes.void and assume they are handled by the user for custom kernels + # we just remove the MULTI from non-value-producing CALLs (custom kernels, etc.) — TUPLE body CALLs are handled by rewrite_into_call (UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root: - UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)), + UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg) + if root.src[0].op is not Ops.TUPLE else None), (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi), ])+replace_allreduce diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index a79d35552e..88e111e62a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -242,7 +242,8 @@ class Tensor(OpMixin): param = UOp.param(slot, self.dtype, self.shape, self.device) return Tensor(param, device=self.device) def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor: - return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn), device=self.device) + fret = (fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn) + return Tensor(fret.gettuple(0), device=self.device) def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]: """ diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index a9ada0dd2a..df618277af 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -213,10 +213,15 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return None case Ops.GETTUPLE: - # GETTUPLE extracts from a TUPLE + # GETTUPLE extracts from a TUPLE (possibly through a CALL) in_tuple = self.src[0].src[0] if self.src[0].op is Ops.CALL else self.src[0] assert in_tuple.op is Ops.TUPLE - return in_tuple.src[self.arg]._shape + inner_shape = in_tuple.src[self.arg]._shape + if inner_shape is None: return None + # if through a CALL, substitute internal PARAMs in the shape with corresponding args + if self.src[0].op is Ops.CALL: + return tuple(graph_rewrite(s, _pm_resolve_params, self.src[0].src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape) + return inner_shape case Ops.CAST: # when PTX casts from ptr to non ptr, remove the shape @@ -248,11 +253,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END: return self.src[0]._shape - case Ops.CALL: - inner_shape = self.src[0]._shape - if inner_shape is None: return None - # substitute internal PARAMs in the shape with corresponding args - return tuple(graph_rewrite(s, _pm_resolve_params, self.src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape) + case Ops.CALL: return None # TODO: disallow shape changing bitcast case Ops.BITCAST: @@ -535,6 +536,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass): # COPY removes axis. TODO: add more tests for this, and consider MSELECT/MSTACK if self.op is Ops.COPY: return None if self.op is Ops.MULTI: return self.arg + # GETTUPLE: axis comes from the specific TUPLE element, not src[0] + if self.op is Ops.GETTUPLE: + in_tuple = self.src[0].src[0] if self.src[0].op is Ops.CALL else self.src[0] + return in_tuple.src[self.arg].axis if in_tuple.op is Ops.TUPLE else None # PARAM: axis is stored as a MULTI source if self.op is Ops.PARAM: for s in self.src: @@ -918,10 +923,13 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if self.axis is not None: p = p.replace(src=p.src + (UOp(Ops.MULTI, arg=self.axis),)) return p + _NO_TUPLE_WRAP = {Ops.SINK, Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.BUFFER_VIEW, Ops.CUSTOM_FUNCTION, Ops.TUPLE} def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None, precompile:bool=False, precompile_backward:bool=False) -> UOp: assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}" - return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward)) + # value-producing bodies are always wrapped in TUPLE so CALL dtype is always void + body = self if self.op in UOp._NO_TUPLE_WRAP else UOp.maketuple(self) + return UOp(Ops.CALL, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward)) def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]: contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs) placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)] diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 05af174412..2a219925a3 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -132,8 +132,8 @@ _tensor_spec = PatternMatcher([ # AFTER if things were kernelized (UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True), - # allow CALL/PARAM/CUSTOM_FUNCTION - (UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype), + # allow CALL/PARAM/CUSTOM_FUNCTION — CALL dtype is always void + (UPat(Ops.CALL, dtypes.void), lambda: True), (UPat(Ops.PARAM), lambda: True), (UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)),