mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
call always has tuple (#15297)
* call always has tuple * fix pre-commit and simplify * update * fix * move that assert * tuple * fix multi * cleanups * fix merge
This commit is contained in:
@@ -5,7 +5,7 @@ from tinygrad.runtime.autogen import llvm
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
|
||||
ARCH_TO_TARGET:dict[str, list[str]] = {
|
||||
"rdna3":["gfx1100"],
|
||||
"rdna3":["gfx1100", "gfx1151"],
|
||||
"rdna4":["gfx1200", "gfx1201"],
|
||||
"cdna":["gfx950", "gfx942"],
|
||||
}
|
||||
|
||||
@@ -245,6 +245,78 @@ class TestCallSchedule(unittest.TestCase):
|
||||
out = f(a) + 2
|
||||
np.testing.assert_allclose(out.numpy(), np.arange(8, dtype=np.float32).reshape(4, 2) + 3)
|
||||
|
||||
class TestCallMultiSharded(unittest.TestCase):
|
||||
# TODO: multi-output + sharded needs per-device CALL execution, which requires reworking how MULTI propagates through TUPLE bodies
|
||||
def test_tuple_sharded(self):
|
||||
"""multi-output function with sharded input"""
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
@function
|
||||
def f(x:Tensor): return (x + 1, x * 2)
|
||||
a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0)
|
||||
t1, t2 = f(a)
|
||||
ref = np.arange(8, dtype=np.float32).reshape(4, 2)
|
||||
np.testing.assert_allclose(t1.numpy(), ref + 1)
|
||||
np.testing.assert_allclose(t2.numpy(), ref * 2)
|
||||
|
||||
def test_tuple_sharded_precompile(self):
|
||||
"""multi-output precompiled function with sharded input"""
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
@function(precompile=True)
|
||||
def f(x:Tensor): return (x + 1, x * 2)
|
||||
a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0)
|
||||
t1, t2 = f(a)
|
||||
ref = np.arange(8, dtype=np.float32).reshape(4, 2)
|
||||
np.testing.assert_allclose(t1.numpy(), ref + 1)
|
||||
np.testing.assert_allclose(t2.numpy(), ref * 2)
|
||||
|
||||
def test_tuple_sharded_different_axis(self):
|
||||
"""multi-output function where outputs have different sharding: one reduces on sharded axis, one doesn't"""
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
@function
|
||||
def f(x:Tensor): return (x.sum(axis=0), x.sum(axis=1))
|
||||
a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0)
|
||||
t1, t2 = f(a)
|
||||
ref = np.arange(8, dtype=np.float32).reshape(4, 2)
|
||||
np.testing.assert_allclose(t1.numpy(), ref.sum(axis=0))
|
||||
np.testing.assert_allclose(t2.numpy(), ref.sum(axis=1))
|
||||
|
||||
def test_tuple_sharded_different_ops(self):
|
||||
"""multi-output function with different operations per output"""
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
@function
|
||||
def f(x:Tensor, y:Tensor): return (x + y, x * y)
|
||||
a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0)
|
||||
b = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0) + 1
|
||||
t1, t2 = f(a, b)
|
||||
ref_a = np.arange(8, dtype=np.float32).reshape(4, 2)
|
||||
ref_b = ref_a + 1
|
||||
np.testing.assert_allclose(t1.numpy(), ref_a + ref_b)
|
||||
np.testing.assert_allclose(t2.numpy(), ref_a * ref_b)
|
||||
|
||||
def test_tuple_sharded_mixed_use(self):
|
||||
"""multi-output sharded results used in further computation"""
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
@function
|
||||
def f(x:Tensor): return (x + 1, x * 2)
|
||||
a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0)
|
||||
t1, t2 = f(a)
|
||||
out = (t1 + t2).sum()
|
||||
ref = np.arange(8, dtype=np.float32).reshape(4, 2)
|
||||
np.testing.assert_allclose(out.numpy(), ((ref + 1) + (ref * 2)).sum())
|
||||
|
||||
def test_tuple_sharded_outputs_different_axis(self):
|
||||
"""multi-output function where the two outputs are sharded on different axes"""
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
@function
|
||||
def f(x:Tensor, y:Tensor): return (x + 1, y + 2)
|
||||
a = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=0)
|
||||
b = Tensor.arange(8).reshape(4, 2).float().shard(devs, axis=1)
|
||||
t1, t2 = f(a, b)
|
||||
ref_a = np.arange(8, dtype=np.float32).reshape(4, 2)
|
||||
ref_b = np.arange(8, dtype=np.float32).reshape(4, 2)
|
||||
np.testing.assert_allclose(t1.numpy(), ref_a + 1)
|
||||
np.testing.assert_allclose(t2.numpy(), ref_b + 2)
|
||||
|
||||
def test_call_reduce_sharded(self):
|
||||
devs = ("CPU:0", "CPU:1")
|
||||
a = Tensor.ones(10, 10).shard(devs, axis=0)
|
||||
|
||||
@@ -165,13 +165,13 @@ class TestFunction(unittest.TestCase):
|
||||
def test_name(self):
|
||||
@function
|
||||
def f(a:Tensor) -> Tensor: return a + 1
|
||||
assert f(Tensor([1])).uop.arg.name.endswith("f")
|
||||
assert f(Tensor([1])).uop.src[0].arg.name.endswith("f")
|
||||
|
||||
def test_method_name(self):
|
||||
class Foo:
|
||||
@function
|
||||
def __call__(self, x:Tensor) -> Tensor: return x + 1
|
||||
assert Foo()(Tensor([1])).uop.arg.name.endswith("Foo.__call__")
|
||||
assert Foo()(Tensor([1])).uop.src[0].arg.name.endswith("Foo.__call__")
|
||||
|
||||
def test_callable_instance(self):
|
||||
class Foo:
|
||||
@@ -180,7 +180,7 @@ class TestFunction(unittest.TestCase):
|
||||
foo = Foo()
|
||||
f = function(foo)
|
||||
np.testing.assert_equal(f(Tensor([1,2,3])).numpy(), [11,22,33])
|
||||
assert f(Tensor([1,2,3])).uop.arg.name.endswith("Foo")
|
||||
assert f(Tensor([1,2,3])).uop.src[0].arg.name.endswith("Foo")
|
||||
|
||||
def test_iadd(self):
|
||||
@function
|
||||
@@ -369,11 +369,11 @@ class TestFunctionTuple(unittest.TestCase):
|
||||
def test_grad_tuple_precompile(self): self.test_grad_tuple(True)
|
||||
|
||||
def test_grad_fxn_tuple(self):
|
||||
# grad_fxn for tuple: ctx is a TUPLE UOp with one element per output
|
||||
def grad_fxn(ctx:UOp, call:UOp):
|
||||
# f(u1, u2) = (u1+1, u2+2), ctx.src = (d_out0, d_out1)
|
||||
# grad_fxn for tuple: receives one gradient per output as positional args
|
||||
def grad_fxn(d_out0:UOp, d_out1:UOp, call:UOp):
|
||||
# f(u1, u2) = (u1+1, u2+2)
|
||||
# df/du1 = d_out0, df/du2 = d_out1
|
||||
return (ctx.src[0], ctx.src[1])
|
||||
return (d_out0, d_out1)
|
||||
|
||||
x = Tensor.ones(3, requires_grad=True).contiguous()
|
||||
y = Tensor.ones(3, requires_grad=True).contiguous()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.uop.ops import UOp, UPat, PatternMatcher, Ops, GroupOp, graph_rewrite, track_rewrites
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import prod, DEBUG, VIZ, pluralize, all_int
|
||||
|
||||
@dataclass
|
||||
@@ -92,25 +92,25 @@ def contiguous_mops_to_view(c:UOp, src:UOp):
|
||||
def transform_precompiled_call(c:UOp) -> UOp|None:
|
||||
if not c.arg.precompile: return None
|
||||
if c.src[0].op is Ops.SINK: return None
|
||||
assert c.src[0].op is Ops.TUPLE, f"expected TUPLE body for precompiled call, got {c.src[0].op}"
|
||||
input_buffers = tuple(x.contiguous() if x.op not in {Ops.AFTER, Ops.BIND} else x for x in c.src[1:])
|
||||
|
||||
# add the outputs to the call
|
||||
srcs = c.src[0].src if c.src[0].op is Ops.TUPLE else (c.src[0],)
|
||||
resolved = [c.gettuple(i) if c.src[0].op is Ops.TUPLE else c for i in range(len(srcs))]
|
||||
srcs = c.src[0].src
|
||||
resolved = [c.gettuple(i) for i in range(len(srcs))]
|
||||
outs = tuple(_buffer_like(r) for r in resolved)
|
||||
targets = [o.param_like(len(c.src)-1+i).shrink_to(s.shape) for i,(o,s) in enumerate(zip(outs, srcs))]
|
||||
fxn = UOp.sink(*[t.after(t.store(s)) for t,s in zip(targets, srcs)])
|
||||
|
||||
# create the new thing for the big graph
|
||||
new_call = c.replace(src=(fxn, *input_buffers, *outs), dtype=dtypes.void, tag=None)
|
||||
new_call = c.replace(src=(fxn, *input_buffers, *outs), tag=None)
|
||||
rets = tuple(o.after(new_call) for o in outs)
|
||||
|
||||
# if the CALL has symbolic shapes, shrink the max-sized output to the actual symbolic shape
|
||||
# NOTE: must use resolved shapes from the CALL (which substitutes PARAMs with external args), not raw body shapes
|
||||
rets = tuple(r.shrink_to(rs.shape) for r,rs in zip(rets, resolved))
|
||||
|
||||
# return tuple if tuple
|
||||
return UOp.maketuple(*rets) if c.src[0].op is Ops.TUPLE else rets[0]
|
||||
return UOp.maketuple(*rets)
|
||||
|
||||
# NOTE: adding rules to here is bad. these all need to run before the schedule cache
|
||||
pm_early_transform_tensor_graph = PatternMatcher([
|
||||
|
||||
@@ -74,7 +74,7 @@ class _function(Generic[ReturnType]):
|
||||
if isinstance(ret, tuple):
|
||||
return cast(ReturnType, tuple(Tensor(fret.gettuple(i), device=fret.device) for i in range(len(ret))))
|
||||
else:
|
||||
return cast(ReturnType, Tensor(fret, device=fret.device))
|
||||
return cast(ReturnType, Tensor(fret.gettuple(0), device=fret.device))
|
||||
|
||||
# overload signatures support both @function and @function(precompile=True) syntax
|
||||
@overload
|
||||
|
||||
@@ -14,22 +14,21 @@ def reduce_gradient(ctx:UOp, ret:UOp, op:Ops):
|
||||
if op == Ops.MUL: return (broadcast_to_input(ctx * ret) / ret.src[0],)
|
||||
|
||||
def call_gradient(ctx:UOp, k:UOp) -> tuple[UOp|None, ...]:
|
||||
if k.arg.grad_fxn is not None: return (None,) + k.arg.grad_fxn(ctx, k)
|
||||
fxn, args = k.src[0], k.src[1:]
|
||||
if k.arg.grad_fxn is not None:
|
||||
return (None,) + (k.arg.grad_fxn(*ctx.src, call=k) if ctx.op is Ops.TUPLE else k.arg.grad_fxn(ctx, k))
|
||||
assert fxn.op is Ops.TUPLE, f"expected TUPLE body for gradient, got {fxn.op}"
|
||||
params = {x.arg:x for x in fxn.toposort(enter_calls=False) if x.op == Ops.PARAM}
|
||||
if fxn.op is Ops.TUPLE:
|
||||
grad_args = ctx.src
|
||||
root_grad = UOp(Ops.TUPLE, src=tuple(g.param_like(len(args) + i) for i, g in enumerate(grad_args)))
|
||||
else:
|
||||
grad_args = (ctx,)
|
||||
root_grad = ctx.param_like(len(args))
|
||||
grad_args = ctx.src
|
||||
root_grad = UOp(Ops.TUPLE, src=tuple(g.param_like(len(args) + i) for i, g in enumerate(grad_args)))
|
||||
grads = compute_gradient(fxn, root_grad, set(params.values()))
|
||||
ret: list[UOp|None] = [None]
|
||||
for i in range(len(args)):
|
||||
if (p:=params.get(i, None)) is not None and p in grads:
|
||||
# TODO: compact the args and remove unused ones
|
||||
assert not grads[p].op_in_backward_slice_with_self(Ops.BUFFER), "BUG: BUFFER in backward slice of grad"
|
||||
ret.append(grads[p].call(*args, *grad_args, name=(k.arg.name or "")+f"_backward_{i}", precompile=k.arg.precompile_backward))
|
||||
bwd_call = grads[p].call(*args, *grad_args, name=(k.arg.name or "")+f"_backward_{i}", precompile=k.arg.precompile_backward)
|
||||
ret.append(bwd_call.gettuple(0))
|
||||
else:
|
||||
ret.append(None)
|
||||
return tuple(ret)
|
||||
|
||||
@@ -116,6 +116,11 @@ def rewrite_into_call(call:UOp):
|
||||
if not should_resolve_call(call): return None
|
||||
new_body = graph_rewrite(call.src[0], multi_pm, name="subcall")
|
||||
new_args = tuple(a.src[0] if a.op is Ops.MULTI else a for a in call.src[1:])
|
||||
# after multi resolution, TUPLE elements may be MULTI — strip MULTI from body, create per-shard CALL, wrap each GETTUPLE in its own MULTI
|
||||
assert new_body.op is Ops.TUPLE
|
||||
if any(s.op is Ops.MULTI for s in new_body.src):
|
||||
shard_call = call.replace(src=(UOp.maketuple(*[s.src[0] if s.op is Ops.MULTI else s for s in new_body.src]),)+new_args)
|
||||
return UOp.maketuple(*[shard_call.gettuple(i).multi(s.axis) if s.op is Ops.MULTI else shard_call.gettuple(i) for i, s in enumerate(new_body.src)])
|
||||
return call.replace(src=(new_body,)+new_args)
|
||||
|
||||
def param_to_multi(p:UOp):
|
||||
@@ -137,12 +142,20 @@ multi_pm = PatternMatcher([
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device"))), copy_multi),
|
||||
(UPat(Ops.ALLREDUCE, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.DEVICE, name="device")), name="red"),
|
||||
lambda multi,device,red: multi.src[0].allreduce(red.arg, device).multi(axis=multi.axis)),
|
||||
|
||||
# resolve TUPLE+GETTUPLE (needed in multi)
|
||||
(UPat(Ops.GETTUPLE, src=(UPat(Ops.TUPLE, name="t"),), name="g"), lambda g,t: t.src[g.arg]),
|
||||
# GETTUPLE on MULTI: passthrough MULTI (e.g. when CALL was replaced by MULTI(GETTUPLE(...)))
|
||||
(UPat(Ops.GETTUPLE, src=(UPat(Ops.MULTI, name="multi"),), name="g"),
|
||||
lambda g, multi: multi.src[0].gettuple(g.arg).multi(multi.axis) if multi.src[0].op in {Ops.CALL, Ops.TUPLE}
|
||||
else multi),
|
||||
# rewrite into calls explicitly for MULTI
|
||||
(UPat(Ops.CALL, name="call"), rewrite_into_call),
|
||||
(UPat((Ops.CALL, Ops.AFTER, Ops.STORE), src=(UPat(Ops.MULTI, name="multi"), ), name="root", allow_any_len=True), passthrough_multi),
|
||||
# we just remove the MULTI from CALLs with dtypes.void and assume they are handled by the user for custom kernels
|
||||
# we just remove the MULTI from non-value-producing CALLs (custom kernels, etc.) — TUPLE body CALLs are handled by rewrite_into_call
|
||||
(UPat(Ops.CALL, dtype=dtypes.void, name="root", custom_early_reject=set([Ops.MULTI])), lambda root:
|
||||
UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)),
|
||||
UOp(root.op, root.dtype, tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src), root.arg)
|
||||
if root.src[0].op is not Ops.TUPLE else None),
|
||||
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH, Ops.CONTIGUOUS_BACKWARD),
|
||||
src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
||||
])+replace_allreduce
|
||||
|
||||
@@ -242,7 +242,8 @@ class Tensor(OpMixin):
|
||||
param = UOp.param(slot, self.dtype, self.shape, self.device)
|
||||
return Tensor(param, device=self.device)
|
||||
def call(self, *lst:Tensor, fxn:Tensor|UOp, grad_fxn:Callable|None=None) -> Tensor:
|
||||
return Tensor((fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn), device=self.device)
|
||||
fret = (fxn.uop if isinstance(fxn, Tensor) else fxn).call(*[t.uop for t in (self,)+lst], grad_fxn=grad_fxn)
|
||||
return Tensor(fret.gettuple(0), device=self.device)
|
||||
|
||||
def custom_kernel(self, *lst:Tensor, fxn:Callable, grad_fxn:Callable|None=None) -> list[Tensor]:
|
||||
"""
|
||||
|
||||
@@ -213,10 +213,15 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return None
|
||||
|
||||
case Ops.GETTUPLE:
|
||||
# GETTUPLE extracts from a TUPLE
|
||||
# GETTUPLE extracts from a TUPLE (possibly through a CALL)
|
||||
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.CALL else self.src[0]
|
||||
assert in_tuple.op is Ops.TUPLE
|
||||
return in_tuple.src[self.arg]._shape
|
||||
inner_shape = in_tuple.src[self.arg]._shape
|
||||
if inner_shape is None: return None
|
||||
# if through a CALL, substitute internal PARAMs in the shape with corresponding args
|
||||
if self.src[0].op is Ops.CALL:
|
||||
return tuple(graph_rewrite(s, _pm_resolve_params, self.src[0].src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape)
|
||||
return inner_shape
|
||||
|
||||
case Ops.CAST:
|
||||
# when PTX casts from ptr to non ptr, remove the shape
|
||||
@@ -248,11 +253,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
|
||||
return self.src[0]._shape
|
||||
|
||||
case Ops.CALL:
|
||||
inner_shape = self.src[0]._shape
|
||||
if inner_shape is None: return None
|
||||
# substitute internal PARAMs in the shape with corresponding args
|
||||
return tuple(graph_rewrite(s, _pm_resolve_params, self.src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape)
|
||||
case Ops.CALL: return None
|
||||
|
||||
# TODO: disallow shape changing bitcast
|
||||
case Ops.BITCAST:
|
||||
@@ -535,6 +536,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
# COPY removes axis. TODO: add more tests for this, and consider MSELECT/MSTACK
|
||||
if self.op is Ops.COPY: return None
|
||||
if self.op is Ops.MULTI: return self.arg
|
||||
# GETTUPLE: axis comes from the specific TUPLE element, not src[0]
|
||||
if self.op is Ops.GETTUPLE:
|
||||
in_tuple = self.src[0].src[0] if self.src[0].op is Ops.CALL else self.src[0]
|
||||
return in_tuple.src[self.arg].axis if in_tuple.op is Ops.TUPLE else None
|
||||
# PARAM: axis is stored as a MULTI source
|
||||
if self.op is Ops.PARAM:
|
||||
for s in self.src:
|
||||
@@ -918,10 +923,13 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if self.axis is not None: p = p.replace(src=p.src + (UOp(Ops.MULTI, arg=self.axis),))
|
||||
return p
|
||||
|
||||
_NO_TUPLE_WRAP = {Ops.SINK, Ops.PROGRAM, Ops.LINEAR, Ops.COPY, Ops.BUFFER_VIEW, Ops.CUSTOM_FUNCTION, Ops.TUPLE}
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(),
|
||||
name:str|None=None, precompile:bool=False, precompile_backward:bool=False) -> UOp:
|
||||
assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
|
||||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
|
||||
# value-producing bodies are always wrapped in TUPLE so CALL dtype is always void
|
||||
body = self if self.op in UOp._NO_TUPLE_WRAP else UOp.maketuple(self)
|
||||
return UOp(Ops.CALL, dtypes.void, (body,)+srcs, CallInfo(grad_fxn, metadata, name, precompile, precompile_backward))
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
|
||||
|
||||
@@ -132,8 +132,8 @@ _tensor_spec = PatternMatcher([
|
||||
# AFTER if things were kernelized
|
||||
(UPat(Ops.AFTER, src=(UPat((Ops.BUFFER, Ops.AFTER)),), allow_any_len=True), lambda: True),
|
||||
|
||||
# allow CALL/PARAM/CUSTOM_FUNCTION
|
||||
(UPat(Ops.CALL, src=(UPat(name="f"),), name="c", allow_any_len=True), lambda c,f: c.dtype == f.dtype),
|
||||
# allow CALL/PARAM/CUSTOM_FUNCTION — CALL dtype is always void
|
||||
(UPat(Ops.CALL, dtypes.void), lambda: True),
|
||||
(UPat(Ops.PARAM), lambda: True),
|
||||
(UPat(Ops.CUSTOM_FUNCTION, name="x"), lambda x: isinstance(x.arg, str)),
|
||||
|
||||
|
||||
Reference in New Issue
Block a user