diff --git a/test/unit/test_function.py b/test/unit/test_function.py index 735c505bec..54edf99637 100644 --- a/test/unit/test_function.py +++ b/test/unit/test_function.py @@ -1,3 +1,4 @@ +import numpy as np import unittest from tinygrad.function import function from tinygrad import Tensor @@ -9,8 +10,14 @@ class TestFunction(unittest.TestCase): a = Tensor([1,2,3]) b = Tensor([4,5,6]) - c = f(a,b) - c.realize() + np.testing.assert_equal(f(a,b).numpy(), [5,7,9]) + + def test_simple_same(self): + @function + def f(a:Tensor, b:Tensor) -> Tensor: return a+b + + a = Tensor([1,2,3]) + np.testing.assert_equal(f(a,a).numpy(), [2,4,6]) def test_implicit(self): inp = Tensor([7,8,9]) @@ -19,8 +26,15 @@ class TestFunction(unittest.TestCase): a = Tensor([1,2,3]) b = Tensor([4,5,6]) - c = f(a,b) - c.realize() + np.testing.assert_equal(f(a,b).numpy(), [12,15,18]) + + def test_implicit_same_as_input(self): + inp = Tensor([7,8,9]) + @function + def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp + + a = Tensor([1,2,3]) + np.testing.assert_equal(f(a, inp).numpy(), [15,18,21]) def test_implicit_2(self): inp = Tensor([7,8,9]) @@ -37,6 +51,84 @@ class TestFunction(unittest.TestCase): c = f(a,b) d = g(a,b) c.realize(d) + np.testing.assert_equal(c.numpy(), [12,15,18]) + np.testing.assert_equal(d.numpy(), [12,15,19]) + + def test_implicit_unrealized(self): + inp = Tensor([1,2,3]) + Tensor([4,5,6]) + @function + def f(a:Tensor) -> Tensor: return a + inp + + np.testing.assert_equal(f(Tensor([10,20,30])).numpy(), [15,27,39]) + + def test_detach(self): + @function + def f(a:Tensor, b:Tensor) -> Tensor: return a.detach() + b + + a = Tensor([1,2,3]) + b = Tensor([4,5,6]) + np.testing.assert_equal(f(a, b).numpy(), [5,7,9]) + + def test_method(self): + class Foo: + def __init__(self): self.w = Tensor([10,20,30]) + @function + def __call__(self, x:Tensor) -> Tensor: return x + self.w + + foo = Foo() + np.testing.assert_equal(foo(Tensor([1,2,3])).numpy(), [11,22,33]) + + def test_grad_gemm(self): + @function + def f(a:Tensor, b:Tensor) -> Tensor: return a @ b + + a = Tensor([[1.,2.],[3.,4.]], requires_grad=True) + b = Tensor([[5.,6.],[7.,8.]], requires_grad=True) + na, nb = a.numpy(), b.numpy() + (f(a, b).contiguous() * b).sum().backward() + # L = sum((a@b) * b), dL/d(a@b) = b, dL/da = b @ b^T, dL/db = a^T @ b + (a@b) + np.testing.assert_allclose(a.grad.numpy(), nb @ nb.T) + np.testing.assert_allclose(b.grad.numpy(), na.T @ nb + na @ nb) + + def test_grad_implicit(self): + w = Tensor([1., 2., 3.], requires_grad=True) + @function + def f(x:Tensor) -> Tensor: return x * w + + x = Tensor([4., 5., 6.]) + f(x).sum().backward() + np.testing.assert_allclose(w.grad.numpy(), [4., 5., 6.]) + + def test_symbolic_index(self): + from tinygrad.uop.ops import UOp + table = Tensor([10,20,30,40]).contiguous().realize() + @function + def f(x:Tensor, start_pos:int|UOp) -> Tensor: + return x + table[start_pos] + + v = UOp.variable("start_pos", 0, 3) + np.testing.assert_equal(f(Tensor([1,2,3]), v.bind(0)).numpy(), [11,12,13]) + + def test_nested_calls(self): + w = Tensor([10., 20., 30.]) + @function + def f(a:Tensor) -> Tensor: return a + w + @function + def g(a:Tensor) -> Tensor: return a * w + + a = Tensor([1., 2., 3.]) + np.testing.assert_allclose(g(f(a)).numpy(), [110., 440., 990.]) + + def test_name(self): + @function + def f(a:Tensor) -> Tensor: return a + 1 + assert f(Tensor([1])).uop.arg.name.endswith("f") + + def test_method_name(self): + class Foo: + @function + def __call__(self, x:Tensor) -> Tensor: return x + 1 + assert Foo()(Tensor([1])).uop.arg.name.endswith("Foo.__call__") if __name__ == '__main__': unittest.main() diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index dc709344e7..0e9c7270c5 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -1,6 +1,6 @@ from __future__ import annotations import sys, argparse, typing, re, unicodedata, json, uuid, time, functools -from tinygrad import Tensor, nn, UOp, TinyJit, getenv +from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler @@ -144,6 +144,7 @@ class TransformerBlock: attn = self.attn_output(attn) return x + attn + @function def _feed_forward(self, h: Tensor) -> Tensor: h_norm = self.ffn_norm(h) if hasattr(self, 'ffn_gate_exps'): diff --git a/tinygrad/engine/allocations.py b/tinygrad/engine/allocations.py index 57041c6217..d7a1f5b238 100644 --- a/tinygrad/engine/allocations.py +++ b/tinygrad/engine/allocations.py @@ -127,6 +127,7 @@ pm_replace_buf = PatternMatcher([ @track_rewrites(lambda _,ret: f"Process {pluralize('Buffer', len(ret[1]))}") def transform_to_call(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]: + if VIZ: graph_rewrite(big_sink, PatternMatcher([]), name="View Tensor Graph") # uop list is a list in the original_sink graph and we can map to the tags later # here we build buffer map dont_realize = {Ops.CONST, Ops.BUFFER, Ops.BIND, Ops.DEFINE_VAR, Ops.AFTER} diff --git a/tinygrad/function.py b/tinygrad/function.py index 11398e8f8e..149725a379 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -1,20 +1,38 @@ import functools from typing import Generic, TypeVar, Callable, cast -from dataclasses import dataclass, field -from tinygrad.helpers import Context -from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite +from tinygrad.helpers import Context, dedup, getenv +from tinygrad.uop.ops import UOp, Ops from tinygrad.tensor import Tensor -@dataclass -class _ImplicitBufCtx: - offset: int - bufs: list[UOp] = field(default_factory=list) +def _srcs(u:UOp) -> tuple[UOp, ...]: + """Get sources of a UOp, skipping src[0] of CALL nodes (other functions' bodies with their own PARAMs).""" + return u.src[1:] if u.op is Ops.CALL else u.src -def _replace_implicit_buffer(ctx:_ImplicitBufCtx, b:UOp): - if b not in ctx.bufs: ctx.bufs.append(b) - return UOp.param(ctx.offset + ctx.bufs.index(b), b.dtype, b.shape, b._device) - -pm_implicit = PatternMatcher([(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), _replace_implicit_buffer)]) +def _find_implicit_inputs(uret:UOp) -> list[UOp]: + """Find implicit inputs by starting at remaining BUFFERs and walking up to the branching point where PARAM-derived nodes meet.""" + all_nodes = list(uret.toposort()) + # build parent map, gating on src[0] of CALL nodes + parents_of: dict[UOp, set[UOp]] = {} + for u in all_nodes: + for s in _srcs(u): + parents_of.setdefault(s, set()).add(u) + # mark which nodes have a PARAM in their subtree (bottom-up, toposort is already bottom-up) + has_param: dict[UOp, bool] = {} + for u in all_nodes: + if u.op is Ops.PARAM: has_param[u] = True + else: has_param[u] = any(has_param.get(s, False) for s in _srcs(u)) + # for each remaining BUFFER, walk up until we hit a node whose parent has PARAM in its subtree + implicit_inputs: list[UOp] = [] + for buf in all_nodes: + if buf.op is not Ops.BUFFER: continue + cur = buf + while True: + ps = parents_of.get(cur, set()) + if not ps or any(has_param.get(p, False) for p in ps): + implicit_inputs.append(cur) + break + cur = next(iter(ps)) + return dedup(implicit_inputs) ReturnType = TypeVar('ReturnType') class function(Generic[ReturnType]): @@ -24,25 +42,30 @@ class function(Generic[ReturnType]): def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self def __call__(self, *args, **kwargs) -> ReturnType: - input_uops: list[UOp] = [(t.uop if isinstance(t, Tensor) else t).multibase + input_uops: list[UOp] = [(t.uop if isinstance(t, Tensor) else t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if isinstance(t, (Tensor, UOp))] + # deduplicate input_uops, keeping the first occurrence index for each unique uop + unique_uops: list[UOp] = dedup(input_uops) + # disable realize/schedule while this is running # run it and do surgery later - with Context(ALLOW_DEVICE_USAGE=0): + with Context(ALLOW_DEVICE_USAGE=getenv("DEVICE_IN_FUNCTION_BUG", 0)): ret = self.fxn(*args, **kwargs) assert isinstance(ret, Tensor), "only supports one tensor return for now" - # replace the known inputs with params + # replace the known inputs with params (using deduplicated slots) subs = {} - for i,x in enumerate(input_uops): + for i,x in enumerate(unique_uops): # TODO: this can be better if x.op is Ops.BIND: subs[x] = UOp.param(i, x.dtype, x._shape, x._device, x._min_max) else: subs[x] = UOp.param(i, x.dtype, x._shape, x._device) uret = ret.uop.substitute(subs) - # replace the implicit BUFFER inputs with params using graph_rewrite - ctx = _ImplicitBufCtx(offset=len(input_uops)) - uret = graph_rewrite(uret, pm_implicit, ctx=ctx) + # find implicit inputs by walking up from remaining BUFFERs to branching points + implicit = _find_implicit_inputs(uret) + for i,imp in enumerate(implicit): + subs[imp] = UOp.param(len(unique_uops) + i, imp.dtype, imp._shape, imp._device) + uret = ret.uop.substitute(subs) - return cast(ReturnType, Tensor(uret.call(*input_uops, *ctx.bufs, name=self.fxn.__name__), device=ret.device)) + return cast(ReturnType, Tensor(uret.call(*unique_uops, *implicit, name=self.fxn.__qualname__), device=ret.device)) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index f274bae424..58dcc0d187 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -98,6 +98,9 @@ earliest_rewrites = mop_cleanup+PatternMatcher([ # split_reduceop (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop), + # remove DETACH/CONTIGUOUS_BACKWARD (TODO: this is copied in allocations) + (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), + # remove contiguous on movement ops before a copy on disk (UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.CONTIGUOUS).f(Ops.COPY, allow_any_len=True, name="copy"), lambda x,copy: copy.replace(src=(x,)+copy.src[1:]) if isinstance(x.device, str) and x.device.startswith("DISK") else None),