add more tests to test_function (#15003)

* add more tests to test_function

* add function to llm

* function decorator on llm

* works

* symbolic fixups

* minimum change

* implicit inputs

* don't actually update llama yet
This commit is contained in:
George Hotz
2026-02-25 18:42:06 +08:00
committed by GitHub
parent d941dd5aeb
commit 68831cd852
5 changed files with 145 additions and 25 deletions

View File

@@ -1,3 +1,4 @@
import numpy as np
import unittest
from tinygrad.function import function
from tinygrad import Tensor
@@ -9,8 +10,14 @@ class TestFunction(unittest.TestCase):
a = Tensor([1,2,3])
b = Tensor([4,5,6])
c = f(a,b)
c.realize()
np.testing.assert_equal(f(a,b).numpy(), [5,7,9])
def test_simple_same(self):
@function
def f(a:Tensor, b:Tensor) -> Tensor: return a+b
a = Tensor([1,2,3])
np.testing.assert_equal(f(a,a).numpy(), [2,4,6])
def test_implicit(self):
inp = Tensor([7,8,9])
@@ -19,8 +26,15 @@ class TestFunction(unittest.TestCase):
a = Tensor([1,2,3])
b = Tensor([4,5,6])
c = f(a,b)
c.realize()
np.testing.assert_equal(f(a,b).numpy(), [12,15,18])
def test_implicit_same_as_input(self):
inp = Tensor([7,8,9])
@function
def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp
a = Tensor([1,2,3])
np.testing.assert_equal(f(a, inp).numpy(), [15,18,21])
def test_implicit_2(self):
inp = Tensor([7,8,9])
@@ -37,6 +51,84 @@ class TestFunction(unittest.TestCase):
c = f(a,b)
d = g(a,b)
c.realize(d)
np.testing.assert_equal(c.numpy(), [12,15,18])
np.testing.assert_equal(d.numpy(), [12,15,19])
def test_implicit_unrealized(self):
inp = Tensor([1,2,3]) + Tensor([4,5,6])
@function
def f(a:Tensor) -> Tensor: return a + inp
np.testing.assert_equal(f(Tensor([10,20,30])).numpy(), [15,27,39])
def test_detach(self):
@function
def f(a:Tensor, b:Tensor) -> Tensor: return a.detach() + b
a = Tensor([1,2,3])
b = Tensor([4,5,6])
np.testing.assert_equal(f(a, b).numpy(), [5,7,9])
def test_method(self):
class Foo:
def __init__(self): self.w = Tensor([10,20,30])
@function
def __call__(self, x:Tensor) -> Tensor: return x + self.w
foo = Foo()
np.testing.assert_equal(foo(Tensor([1,2,3])).numpy(), [11,22,33])
def test_grad_gemm(self):
@function
def f(a:Tensor, b:Tensor) -> Tensor: return a @ b
a = Tensor([[1.,2.],[3.,4.]], requires_grad=True)
b = Tensor([[5.,6.],[7.,8.]], requires_grad=True)
na, nb = a.numpy(), b.numpy()
(f(a, b).contiguous() * b).sum().backward()
# L = sum((a@b) * b), dL/d(a@b) = b, dL/da = b @ b^T, dL/db = a^T @ b + (a@b)
np.testing.assert_allclose(a.grad.numpy(), nb @ nb.T)
np.testing.assert_allclose(b.grad.numpy(), na.T @ nb + na @ nb)
def test_grad_implicit(self):
w = Tensor([1., 2., 3.], requires_grad=True)
@function
def f(x:Tensor) -> Tensor: return x * w
x = Tensor([4., 5., 6.])
f(x).sum().backward()
np.testing.assert_allclose(w.grad.numpy(), [4., 5., 6.])
def test_symbolic_index(self):
from tinygrad.uop.ops import UOp
table = Tensor([10,20,30,40]).contiguous().realize()
@function
def f(x:Tensor, start_pos:int|UOp) -> Tensor:
return x + table[start_pos]
v = UOp.variable("start_pos", 0, 3)
np.testing.assert_equal(f(Tensor([1,2,3]), v.bind(0)).numpy(), [11,12,13])
def test_nested_calls(self):
w = Tensor([10., 20., 30.])
@function
def f(a:Tensor) -> Tensor: return a + w
@function
def g(a:Tensor) -> Tensor: return a * w
a = Tensor([1., 2., 3.])
np.testing.assert_allclose(g(f(a)).numpy(), [110., 440., 990.])
def test_name(self):
@function
def f(a:Tensor) -> Tensor: return a + 1
assert f(Tensor([1])).uop.arg.name.endswith("f")
def test_method_name(self):
class Foo:
@function
def __call__(self, x:Tensor) -> Tensor: return x + 1
assert Foo()(Tensor([1])).uop.arg.name.endswith("Foo.__call__")
if __name__ == '__main__':
unittest.main()

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools
from tinygrad import Tensor, nn, UOp, TinyJit, getenv
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored
from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
@@ -144,6 +144,7 @@ class TransformerBlock:
attn = self.attn_output(attn)
return x + attn
@function
def _feed_forward(self, h: Tensor) -> Tensor:
h_norm = self.ffn_norm(h)
if hasattr(self, 'ffn_gate_exps'):

View File

@@ -127,6 +127,7 @@ pm_replace_buf = PatternMatcher([
@track_rewrites(lambda _,ret: f"Process {pluralize('Buffer', len(ret[1]))}")
def transform_to_call(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
if VIZ: graph_rewrite(big_sink, PatternMatcher([]), name="View Tensor Graph")
# uop list is a list in the original_sink graph and we can map to the tags later
# here we build buffer map
dont_realize = {Ops.CONST, Ops.BUFFER, Ops.BIND, Ops.DEFINE_VAR, Ops.AFTER}

View File

@@ -1,20 +1,38 @@
import functools
from typing import Generic, TypeVar, Callable, cast
from dataclasses import dataclass, field
from tinygrad.helpers import Context
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite
from tinygrad.helpers import Context, dedup, getenv
from tinygrad.uop.ops import UOp, Ops
from tinygrad.tensor import Tensor
@dataclass
class _ImplicitBufCtx:
offset: int
bufs: list[UOp] = field(default_factory=list)
def _srcs(u:UOp) -> tuple[UOp, ...]:
"""Get sources of a UOp, skipping src[0] of CALL nodes (other functions' bodies with their own PARAMs)."""
return u.src[1:] if u.op is Ops.CALL else u.src
def _replace_implicit_buffer(ctx:_ImplicitBufCtx, b:UOp):
if b not in ctx.bufs: ctx.bufs.append(b)
return UOp.param(ctx.offset + ctx.bufs.index(b), b.dtype, b.shape, b._device)
pm_implicit = PatternMatcher([(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), _replace_implicit_buffer)])
def _find_implicit_inputs(uret:UOp) -> list[UOp]:
"""Find implicit inputs by starting at remaining BUFFERs and walking up to the branching point where PARAM-derived nodes meet."""
all_nodes = list(uret.toposort())
# build parent map, gating on src[0] of CALL nodes
parents_of: dict[UOp, set[UOp]] = {}
for u in all_nodes:
for s in _srcs(u):
parents_of.setdefault(s, set()).add(u)
# mark which nodes have a PARAM in their subtree (bottom-up, toposort is already bottom-up)
has_param: dict[UOp, bool] = {}
for u in all_nodes:
if u.op is Ops.PARAM: has_param[u] = True
else: has_param[u] = any(has_param.get(s, False) for s in _srcs(u))
# for each remaining BUFFER, walk up until we hit a node whose parent has PARAM in its subtree
implicit_inputs: list[UOp] = []
for buf in all_nodes:
if buf.op is not Ops.BUFFER: continue
cur = buf
while True:
ps = parents_of.get(cur, set())
if not ps or any(has_param.get(p, False) for p in ps):
implicit_inputs.append(cur)
break
cur = next(iter(ps))
return dedup(implicit_inputs)
ReturnType = TypeVar('ReturnType')
class function(Generic[ReturnType]):
@@ -24,25 +42,30 @@ class function(Generic[ReturnType]):
def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self
def __call__(self, *args, **kwargs) -> ReturnType:
input_uops: list[UOp] = [(t.uop if isinstance(t, Tensor) else t).multibase
input_uops: list[UOp] = [(t.uop if isinstance(t, Tensor) else t)
for name,t in list(enumerate(args))+sorted(kwargs.items()) if isinstance(t, (Tensor, UOp))]
# deduplicate input_uops, keeping the first occurrence index for each unique uop
unique_uops: list[UOp] = dedup(input_uops)
# disable realize/schedule while this is running
# run it and do surgery later
with Context(ALLOW_DEVICE_USAGE=0):
with Context(ALLOW_DEVICE_USAGE=getenv("DEVICE_IN_FUNCTION_BUG", 0)):
ret = self.fxn(*args, **kwargs)
assert isinstance(ret, Tensor), "only supports one tensor return for now"
# replace the known inputs with params
# replace the known inputs with params (using deduplicated slots)
subs = {}
for i,x in enumerate(input_uops):
for i,x in enumerate(unique_uops):
# TODO: this can be better
if x.op is Ops.BIND: subs[x] = UOp.param(i, x.dtype, x._shape, x._device, x._min_max)
else: subs[x] = UOp.param(i, x.dtype, x._shape, x._device)
uret = ret.uop.substitute(subs)
# replace the implicit BUFFER inputs with params using graph_rewrite
ctx = _ImplicitBufCtx(offset=len(input_uops))
uret = graph_rewrite(uret, pm_implicit, ctx=ctx)
# find implicit inputs by walking up from remaining BUFFERs to branching points
implicit = _find_implicit_inputs(uret)
for i,imp in enumerate(implicit):
subs[imp] = UOp.param(len(unique_uops) + i, imp.dtype, imp._shape, imp._device)
uret = ret.uop.substitute(subs)
return cast(ReturnType, Tensor(uret.call(*input_uops, *ctx.bufs, name=self.fxn.__name__), device=ret.device))
return cast(ReturnType, Tensor(uret.call(*unique_uops, *implicit, name=self.fxn.__qualname__), device=ret.device))

View File

@@ -98,6 +98,9 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# split_reduceop
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
# remove DETACH/CONTIGUOUS_BACKWARD (TODO: this is copied in allocations)
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# remove contiguous on movement ops before a copy on disk
(UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.CONTIGUOUS).f(Ops.COPY, allow_any_len=True, name="copy"),
lambda x,copy: copy.replace(src=(x,)+copy.src[1:]) if isinstance(x.device, str) and x.device.startswith("DISK") else None),