From 2e1c81c23ffe9095af594fd9d1545d32fb34ea94 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 16 Mar 2026 16:40:14 +0800 Subject: [PATCH] allow_implicit to disable implicit params (#15291) * allow_implicit to disable implicit params * get both Tensor and UOp * no implicits in llm --- tinygrad/apps/llm.py | 10 +++++----- tinygrad/function.py | 28 ++++++++++++++++------------ 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 19b5169d18..cc43ef9302 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -117,7 +117,7 @@ class TransformerBlock: self.ffn_up = nn.Linear(dim, hidden_dim, bias=False) self.ffn_down = nn.Linear(hidden_dim, dim, bias=False) - @function(precompile=bool(getenv("PRECOMPILE", 0))) + @function(precompile=bool(getenv("PRECOMPILE", 0)), allow_implicit=False) def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor: x_norm = self.attn_norm(x) # (B,T,D) q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm) @@ -129,9 +129,8 @@ class TransformerBlock: v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd) if self.qk_norm == self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k) - freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T] - q = apply_rope(q, freqs_cis) - k = apply_rope(k, freqs_cis) + q = apply_rope(q, self.freqs_cis[start_pos:start_pos+T]) + k = apply_rope(k, self.freqs_cis[start_pos:start_pos+T]) # TODO: fix assign to behave like this assigned_kv = self.cache_kv.uop.after(self.cache_kv[:, :, :, start_pos:start_pos+T, :].uop.assign(Tensor.stack(k, v).contiguous().uop)) @@ -151,7 +150,7 @@ class TransformerBlock: attn = self.attn_output(attn) return x + attn - @function(precompile=bool(getenv("PRECOMPILE", 0))) + @function(precompile=bool(getenv("PRECOMPILE", 0)), allow_implicit=False) def _feed_forward(self, h: Tensor) -> Tensor: h_norm = self.ffn_norm(h) if hasattr(self, 'ffn_gate_exps'): @@ -168,6 +167,7 @@ class TransformerBlock: # TODO: how is the dtype of this determined? # NOTE: clone is used to promise the creation of a specific buffer self.cache_kv = Tensor.zeros(2, x.shape[0], self.n_kv_heads, self.max_context, self.head_dim, device=x.device).clone() + self.freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta) return self._feed_forward(self._attention(x, start_pos)).contiguous() class Transformer: diff --git a/tinygrad/function.py b/tinygrad/function.py index 94a3100b38..936a445933 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -3,6 +3,7 @@ from typing import Generic, TypeVar, Callable, cast, overload from tinygrad.helpers import Context, dedup, getenv from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat from tinygrad.tensor import Tensor +from tinygrad.nn.state import get_state_dict def add_to_ctx(ctx, x:UOp): ret = x.param_like(len(ctx)) @@ -19,21 +20,18 @@ pm_ctx = PatternMatcher([ ReturnType = TypeVar('ReturnType') class _function(Generic[ReturnType]): - def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False): + def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True): self.fxn = fxn self.precompile = precompile + self.allow_implicit = allow_implicit def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self def __call__(self, *args, **kwargs) -> ReturnType: - input_uops: list[UOp] = [(t.uop if isinstance(t, Tensor) else t) - for name,t in list(enumerate(args))+sorted(kwargs.items()) if isinstance(t, (Tensor, UOp))] - - # use the base - #input_uops = [x.multibase for x in input_uops] + params = get_state_dict((args, kwargs), tensor_type=(Tensor, UOp)).values() # deduplicate input_uops, keeping the first occurrence index for each unique uop - call_uops: list[UOp] = dedup(input_uops) + call_uops: list[UOp] = dedup([(t.uop if isinstance(t, Tensor) else t) for t in params]) # disable realize/schedule while this is running # run it and do surgery later @@ -55,8 +53,14 @@ class _function(Generic[ReturnType]): #call_uops = [x.contiguous() for x in call_uops] # the BUFFERs that are left are the implicit inputs + num_explicit = len(call_uops) uret = graph_rewrite(uret, pm_ctx, call_uops, bottom_up=True, name="get_implicit_inputs") name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__ + if not self.allow_implicit: + implicit_buffers = [x for x in call_uops[num_explicit:] if x.op is Ops.BUFFER] + if implicit_buffers: + buf_strs = '\n '.join(f"{i}: dtype={b.dtype}, size={b.size}, device={b.device}" for i,b in enumerate(implicit_buffers)) + raise RuntimeError(f"function {name} has {len(implicit_buffers)} implicit buffer(s), but allow_implicit=False\n {buf_strs}") # assign output #pbuffer = uret.param_like(len(call_uops)) @@ -73,9 +77,9 @@ class _function(Generic[ReturnType]): # overload signatures support both @function and @function(precompile=True) syntax @overload -def function(fxn:Callable[..., ReturnType], *, precompile:bool=False) -> _function[ReturnType]: ... +def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True) -> _function[ReturnType]: ... @overload -def function(fxn:None=None, *, precompile:bool=False) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ... -def function(fxn=None, *, precompile:bool=False): - if fxn is None: return lambda f: _function(f, precompile=precompile) - return _function(fxn, precompile=precompile) +def function(fxn:None=None, *, precompile:bool=False, allow_implicit:bool=True) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ... +def function(fxn=None, *, precompile:bool=False, allow_implicit:bool=True): + if fxn is None: return lambda f: _function(f, precompile=precompile, allow_implicit=allow_implicit) + return _function(fxn, precompile=precompile, allow_implicit=allow_implicit)