allow_implicit to disable implicit params (#15291)

* allow_implicit to disable implicit params

* get both Tensor and UOp

* no implicits in llm
This commit is contained in:
George Hotz
2026-03-16 16:40:14 +08:00
committed by GitHub
parent a0d1444790
commit 2e1c81c23f
2 changed files with 21 additions and 17 deletions

View File

@@ -117,7 +117,7 @@ class TransformerBlock:
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
@function(precompile=bool(getenv("PRECOMPILE", 0)))
@function(precompile=bool(getenv("PRECOMPILE", 0)), allow_implicit=False)
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
x_norm = self.attn_norm(x) # (B,T,D)
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
@@ -129,9 +129,8 @@ class TransformerBlock:
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
if self.qk_norm == self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T]
q = apply_rope(q, freqs_cis)
k = apply_rope(k, freqs_cis)
q = apply_rope(q, self.freqs_cis[start_pos:start_pos+T])
k = apply_rope(k, self.freqs_cis[start_pos:start_pos+T])
# TODO: fix assign to behave like this
assigned_kv = self.cache_kv.uop.after(self.cache_kv[:, :, :, start_pos:start_pos+T, :].uop.assign(Tensor.stack(k, v).contiguous().uop))
@@ -151,7 +150,7 @@ class TransformerBlock:
attn = self.attn_output(attn)
return x + attn
@function(precompile=bool(getenv("PRECOMPILE", 0)))
@function(precompile=bool(getenv("PRECOMPILE", 0)), allow_implicit=False)
def _feed_forward(self, h: Tensor) -> Tensor:
h_norm = self.ffn_norm(h)
if hasattr(self, 'ffn_gate_exps'):
@@ -168,6 +167,7 @@ class TransformerBlock:
# TODO: how is the dtype of this determined?
# NOTE: clone is used to promise the creation of a specific buffer
self.cache_kv = Tensor.zeros(2, x.shape[0], self.n_kv_heads, self.max_context, self.head_dim, device=x.device).clone()
self.freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)
return self._feed_forward(self._attention(x, start_pos)).contiguous()
class Transformer:

View File

@@ -3,6 +3,7 @@ from typing import Generic, TypeVar, Callable, cast, overload
from tinygrad.helpers import Context, dedup, getenv
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_state_dict
def add_to_ctx(ctx, x:UOp):
ret = x.param_like(len(ctx))
@@ -19,21 +20,18 @@ pm_ctx = PatternMatcher([
ReturnType = TypeVar('ReturnType')
class _function(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False):
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True):
self.fxn = fxn
self.precompile = precompile
self.allow_implicit = allow_implicit
def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self
def __call__(self, *args, **kwargs) -> ReturnType:
input_uops: list[UOp] = [(t.uop if isinstance(t, Tensor) else t)
for name,t in list(enumerate(args))+sorted(kwargs.items()) if isinstance(t, (Tensor, UOp))]
# use the base
#input_uops = [x.multibase for x in input_uops]
params = get_state_dict((args, kwargs), tensor_type=(Tensor, UOp)).values()
# deduplicate input_uops, keeping the first occurrence index for each unique uop
call_uops: list[UOp] = dedup(input_uops)
call_uops: list[UOp] = dedup([(t.uop if isinstance(t, Tensor) else t) for t in params])
# disable realize/schedule while this is running
# run it and do surgery later
@@ -55,8 +53,14 @@ class _function(Generic[ReturnType]):
#call_uops = [x.contiguous() for x in call_uops]
# the BUFFERs that are left are the implicit inputs
num_explicit = len(call_uops)
uret = graph_rewrite(uret, pm_ctx, call_uops, bottom_up=True, name="get_implicit_inputs")
name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__
if not self.allow_implicit:
implicit_buffers = [x for x in call_uops[num_explicit:] if x.op is Ops.BUFFER]
if implicit_buffers:
buf_strs = '\n '.join(f"{i}: dtype={b.dtype}, size={b.size}, device={b.device}" for i,b in enumerate(implicit_buffers))
raise RuntimeError(f"function {name} has {len(implicit_buffers)} implicit buffer(s), but allow_implicit=False\n {buf_strs}")
# assign output
#pbuffer = uret.param_like(len(call_uops))
@@ -73,9 +77,9 @@ class _function(Generic[ReturnType]):
# overload signatures support both @function and @function(precompile=True) syntax
@overload
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False) -> _function[ReturnType]: ...
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True) -> _function[ReturnType]: ...
@overload
def function(fxn:None=None, *, precompile:bool=False) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
def function(fxn=None, *, precompile:bool=False):
if fxn is None: return lambda f: _function(f, precompile=precompile)
return _function(fxn, precompile=precompile)
def function(fxn:None=None, *, precompile:bool=False, allow_implicit:bool=True) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
def function(fxn=None, *, precompile:bool=False, allow_implicit:bool=True):
if fxn is None: return lambda f: _function(f, precompile=precompile, allow_implicit=allow_implicit)
return _function(fxn, precompile=precompile, allow_implicit=allow_implicit)