mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add more tests to test_function (#15003)
* add more tests to test_function * add function to llm * function decorator on llm * works * symbolic fixups * minimum change * implicit inputs * don't actually update llama yet
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad.function import function
|
||||
from tinygrad import Tensor
|
||||
@@ -9,8 +10,14 @@ class TestFunction(unittest.TestCase):
|
||||
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
c = f(a,b)
|
||||
c.realize()
|
||||
np.testing.assert_equal(f(a,b).numpy(), [5,7,9])
|
||||
|
||||
def test_simple_same(self):
|
||||
@function
|
||||
def f(a:Tensor, b:Tensor) -> Tensor: return a+b
|
||||
|
||||
a = Tensor([1,2,3])
|
||||
np.testing.assert_equal(f(a,a).numpy(), [2,4,6])
|
||||
|
||||
def test_implicit(self):
|
||||
inp = Tensor([7,8,9])
|
||||
@@ -19,8 +26,15 @@ class TestFunction(unittest.TestCase):
|
||||
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
c = f(a,b)
|
||||
c.realize()
|
||||
np.testing.assert_equal(f(a,b).numpy(), [12,15,18])
|
||||
|
||||
def test_implicit_same_as_input(self):
|
||||
inp = Tensor([7,8,9])
|
||||
@function
|
||||
def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp
|
||||
|
||||
a = Tensor([1,2,3])
|
||||
np.testing.assert_equal(f(a, inp).numpy(), [15,18,21])
|
||||
|
||||
def test_implicit_2(self):
|
||||
inp = Tensor([7,8,9])
|
||||
@@ -37,6 +51,84 @@ class TestFunction(unittest.TestCase):
|
||||
c = f(a,b)
|
||||
d = g(a,b)
|
||||
c.realize(d)
|
||||
np.testing.assert_equal(c.numpy(), [12,15,18])
|
||||
np.testing.assert_equal(d.numpy(), [12,15,19])
|
||||
|
||||
def test_implicit_unrealized(self):
|
||||
inp = Tensor([1,2,3]) + Tensor([4,5,6])
|
||||
@function
|
||||
def f(a:Tensor) -> Tensor: return a + inp
|
||||
|
||||
np.testing.assert_equal(f(Tensor([10,20,30])).numpy(), [15,27,39])
|
||||
|
||||
def test_detach(self):
|
||||
@function
|
||||
def f(a:Tensor, b:Tensor) -> Tensor: return a.detach() + b
|
||||
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
np.testing.assert_equal(f(a, b).numpy(), [5,7,9])
|
||||
|
||||
def test_method(self):
|
||||
class Foo:
|
||||
def __init__(self): self.w = Tensor([10,20,30])
|
||||
@function
|
||||
def __call__(self, x:Tensor) -> Tensor: return x + self.w
|
||||
|
||||
foo = Foo()
|
||||
np.testing.assert_equal(foo(Tensor([1,2,3])).numpy(), [11,22,33])
|
||||
|
||||
def test_grad_gemm(self):
|
||||
@function
|
||||
def f(a:Tensor, b:Tensor) -> Tensor: return a @ b
|
||||
|
||||
a = Tensor([[1.,2.],[3.,4.]], requires_grad=True)
|
||||
b = Tensor([[5.,6.],[7.,8.]], requires_grad=True)
|
||||
na, nb = a.numpy(), b.numpy()
|
||||
(f(a, b).contiguous() * b).sum().backward()
|
||||
# L = sum((a@b) * b), dL/d(a@b) = b, dL/da = b @ b^T, dL/db = a^T @ b + (a@b)
|
||||
np.testing.assert_allclose(a.grad.numpy(), nb @ nb.T)
|
||||
np.testing.assert_allclose(b.grad.numpy(), na.T @ nb + na @ nb)
|
||||
|
||||
def test_grad_implicit(self):
|
||||
w = Tensor([1., 2., 3.], requires_grad=True)
|
||||
@function
|
||||
def f(x:Tensor) -> Tensor: return x * w
|
||||
|
||||
x = Tensor([4., 5., 6.])
|
||||
f(x).sum().backward()
|
||||
np.testing.assert_allclose(w.grad.numpy(), [4., 5., 6.])
|
||||
|
||||
def test_symbolic_index(self):
|
||||
from tinygrad.uop.ops import UOp
|
||||
table = Tensor([10,20,30,40]).contiguous().realize()
|
||||
@function
|
||||
def f(x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
return x + table[start_pos]
|
||||
|
||||
v = UOp.variable("start_pos", 0, 3)
|
||||
np.testing.assert_equal(f(Tensor([1,2,3]), v.bind(0)).numpy(), [11,12,13])
|
||||
|
||||
def test_nested_calls(self):
|
||||
w = Tensor([10., 20., 30.])
|
||||
@function
|
||||
def f(a:Tensor) -> Tensor: return a + w
|
||||
@function
|
||||
def g(a:Tensor) -> Tensor: return a * w
|
||||
|
||||
a = Tensor([1., 2., 3.])
|
||||
np.testing.assert_allclose(g(f(a)).numpy(), [110., 440., 990.])
|
||||
|
||||
def test_name(self):
|
||||
@function
|
||||
def f(a:Tensor) -> Tensor: return a + 1
|
||||
assert f(Tensor([1])).uop.arg.name.endswith("f")
|
||||
|
||||
def test_method_name(self):
|
||||
class Foo:
|
||||
@function
|
||||
def __call__(self, x:Tensor) -> Tensor: return x + 1
|
||||
assert Foo()(Tensor([1])).uop.arg.name.endswith("Foo.__call__")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools
|
||||
from tinygrad import Tensor, nn, UOp, TinyJit, getenv
|
||||
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function
|
||||
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored
|
||||
from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
|
||||
|
||||
@@ -144,6 +144,7 @@ class TransformerBlock:
|
||||
attn = self.attn_output(attn)
|
||||
return x + attn
|
||||
|
||||
@function
|
||||
def _feed_forward(self, h: Tensor) -> Tensor:
|
||||
h_norm = self.ffn_norm(h)
|
||||
if hasattr(self, 'ffn_gate_exps'):
|
||||
|
||||
@@ -127,6 +127,7 @@ pm_replace_buf = PatternMatcher([
|
||||
|
||||
@track_rewrites(lambda _,ret: f"Process {pluralize('Buffer', len(ret[1]))}")
|
||||
def transform_to_call(big_sink:UOp) -> tuple[UOp, dict[UOp, UOp]]:
|
||||
if VIZ: graph_rewrite(big_sink, PatternMatcher([]), name="View Tensor Graph")
|
||||
# uop list is a list in the original_sink graph and we can map to the tags later
|
||||
# here we build buffer map
|
||||
dont_realize = {Ops.CONST, Ops.BUFFER, Ops.BIND, Ops.DEFINE_VAR, Ops.AFTER}
|
||||
|
||||
@@ -1,20 +1,38 @@
|
||||
import functools
|
||||
from typing import Generic, TypeVar, Callable, cast
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite
|
||||
from tinygrad.helpers import Context, dedup, getenv
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
@dataclass
|
||||
class _ImplicitBufCtx:
|
||||
offset: int
|
||||
bufs: list[UOp] = field(default_factory=list)
|
||||
def _srcs(u:UOp) -> tuple[UOp, ...]:
|
||||
"""Get sources of a UOp, skipping src[0] of CALL nodes (other functions' bodies with their own PARAMs)."""
|
||||
return u.src[1:] if u.op is Ops.CALL else u.src
|
||||
|
||||
def _replace_implicit_buffer(ctx:_ImplicitBufCtx, b:UOp):
|
||||
if b not in ctx.bufs: ctx.bufs.append(b)
|
||||
return UOp.param(ctx.offset + ctx.bufs.index(b), b.dtype, b.shape, b._device)
|
||||
|
||||
pm_implicit = PatternMatcher([(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), _replace_implicit_buffer)])
|
||||
def _find_implicit_inputs(uret:UOp) -> list[UOp]:
|
||||
"""Find implicit inputs by starting at remaining BUFFERs and walking up to the branching point where PARAM-derived nodes meet."""
|
||||
all_nodes = list(uret.toposort())
|
||||
# build parent map, gating on src[0] of CALL nodes
|
||||
parents_of: dict[UOp, set[UOp]] = {}
|
||||
for u in all_nodes:
|
||||
for s in _srcs(u):
|
||||
parents_of.setdefault(s, set()).add(u)
|
||||
# mark which nodes have a PARAM in their subtree (bottom-up, toposort is already bottom-up)
|
||||
has_param: dict[UOp, bool] = {}
|
||||
for u in all_nodes:
|
||||
if u.op is Ops.PARAM: has_param[u] = True
|
||||
else: has_param[u] = any(has_param.get(s, False) for s in _srcs(u))
|
||||
# for each remaining BUFFER, walk up until we hit a node whose parent has PARAM in its subtree
|
||||
implicit_inputs: list[UOp] = []
|
||||
for buf in all_nodes:
|
||||
if buf.op is not Ops.BUFFER: continue
|
||||
cur = buf
|
||||
while True:
|
||||
ps = parents_of.get(cur, set())
|
||||
if not ps or any(has_param.get(p, False) for p in ps):
|
||||
implicit_inputs.append(cur)
|
||||
break
|
||||
cur = next(iter(ps))
|
||||
return dedup(implicit_inputs)
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
class function(Generic[ReturnType]):
|
||||
@@ -24,25 +42,30 @@ class function(Generic[ReturnType]):
|
||||
def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self
|
||||
|
||||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
input_uops: list[UOp] = [(t.uop if isinstance(t, Tensor) else t).multibase
|
||||
input_uops: list[UOp] = [(t.uop if isinstance(t, Tensor) else t)
|
||||
for name,t in list(enumerate(args))+sorted(kwargs.items()) if isinstance(t, (Tensor, UOp))]
|
||||
|
||||
# deduplicate input_uops, keeping the first occurrence index for each unique uop
|
||||
unique_uops: list[UOp] = dedup(input_uops)
|
||||
|
||||
# disable realize/schedule while this is running
|
||||
# run it and do surgery later
|
||||
with Context(ALLOW_DEVICE_USAGE=0):
|
||||
with Context(ALLOW_DEVICE_USAGE=getenv("DEVICE_IN_FUNCTION_BUG", 0)):
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
assert isinstance(ret, Tensor), "only supports one tensor return for now"
|
||||
|
||||
# replace the known inputs with params
|
||||
# replace the known inputs with params (using deduplicated slots)
|
||||
subs = {}
|
||||
for i,x in enumerate(input_uops):
|
||||
for i,x in enumerate(unique_uops):
|
||||
# TODO: this can be better
|
||||
if x.op is Ops.BIND: subs[x] = UOp.param(i, x.dtype, x._shape, x._device, x._min_max)
|
||||
else: subs[x] = UOp.param(i, x.dtype, x._shape, x._device)
|
||||
uret = ret.uop.substitute(subs)
|
||||
|
||||
# replace the implicit BUFFER inputs with params using graph_rewrite
|
||||
ctx = _ImplicitBufCtx(offset=len(input_uops))
|
||||
uret = graph_rewrite(uret, pm_implicit, ctx=ctx)
|
||||
# find implicit inputs by walking up from remaining BUFFERs to branching points
|
||||
implicit = _find_implicit_inputs(uret)
|
||||
for i,imp in enumerate(implicit):
|
||||
subs[imp] = UOp.param(len(unique_uops) + i, imp.dtype, imp._shape, imp._device)
|
||||
uret = ret.uop.substitute(subs)
|
||||
|
||||
return cast(ReturnType, Tensor(uret.call(*input_uops, *ctx.bufs, name=self.fxn.__name__), device=ret.device))
|
||||
return cast(ReturnType, Tensor(uret.call(*unique_uops, *implicit, name=self.fxn.__qualname__), device=ret.device))
|
||||
|
||||
@@ -98,6 +98,9 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
# split_reduceop
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
|
||||
|
||||
# remove DETACH/CONTIGUOUS_BACKWARD (TODO: this is copied in allocations)
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
|
||||
# remove contiguous on movement ops before a copy on disk
|
||||
(UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.CONTIGUOUS).f(Ops.COPY, allow_any_len=True, name="copy"),
|
||||
lambda x,copy: copy.replace(src=(x,)+copy.src[1:]) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
|
||||
|
||||
Reference in New Issue
Block a user