Files
tinygrad/tinygrad/function.py
George Hotz fe3ee8c27e fix symbolic shapes in calls (#15021)
* fix symbolic shapes in calls

* fix after in the big graph

* real tests
2026-02-26 17:17:18 +08:00

63 lines
2.4 KiB
Python

import functools
from typing import Generic, TypeVar, Callable, cast
from tinygrad.helpers import Context, dedup, getenv
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat
from tinygrad.tensor import Tensor
def add_to_ctx(ctx, x:UOp):
ret = x.param_like(len(ctx))
ctx.append(x)
return ret
pm_ctx = PatternMatcher([
(UPat((Ops.BUFFER, Ops.BIND), name="x"), add_to_ctx),
(UPat((Ops.ASSIGN, Ops.CONTIGUOUS), name="x"),
lambda ctx,x: add_to_ctx(ctx,x) if not x.op_in_backward_slice_with_self(Ops.PARAM) else None),
])
ReturnType = TypeVar('ReturnType')
class function(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType]):
self.fxn = fxn
def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self
def __call__(self, *args, **kwargs) -> ReturnType:
input_uops: list[UOp] = [(t.uop if isinstance(t, Tensor) else t)
for name,t in list(enumerate(args))+sorted(kwargs.items()) if isinstance(t, (Tensor, UOp))]
# use the base
#input_uops = [x.multibase for x in input_uops]
# deduplicate input_uops, keeping the first occurrence index for each unique uop
call_uops: list[UOp] = dedup(input_uops)
# disable realize/schedule while this is running
# run it and do surgery later
with Context(ALLOW_DEVICE_USAGE=getenv("DEVICE_IN_FUNCTION_BUG", 0)):
ret = self.fxn(*args, **kwargs)
assert isinstance(ret, Tensor), "only supports one tensor return for now"
# replace the known inputs with params (using deduplicated slots)
subs = {}
for i,x in enumerate(call_uops): subs[x] = x.param_like(i)
uret = ret.uop.substitute(subs)
# add contiguous to call_uops
#call_uops = [x.contiguous() for x in call_uops]
# the BUFFERs that are left are the implicit inputs
uret = graph_rewrite(uret, pm_ctx, call_uops, bottom_up=True, name="get_implicit_inputs")
name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__
# assign output
#pbuffer = uret.param_like(len(call_uops))
#assigned = pbuffer.assign(uret).sink()
#buffer = UOp.new_buffer(pbuffer.device, pbuffer.size, pbuffer.dtype).reshape(uret.shape)
#call = assigned.call(*call_uops, buffer, name=name)
#ret = buffer.after(call)
ret = uret.call(*call_uops, name=name)
return cast(ReturnType, Tensor(ret, device=ret.device))