mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
allow_implicit to disable implicit params (#15291)
* allow_implicit to disable implicit params * get both Tensor and UOp * no implicits in llm
This commit is contained in:
@@ -117,7 +117,7 @@ class TransformerBlock:
|
||||
self.ffn_up = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.ffn_down = nn.Linear(hidden_dim, dim, bias=False)
|
||||
|
||||
@function(precompile=bool(getenv("PRECOMPILE", 0)))
|
||||
@function(precompile=bool(getenv("PRECOMPILE", 0)), allow_implicit=False)
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
x_norm = self.attn_norm(x) # (B,T,D)
|
||||
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
|
||||
@@ -129,9 +129,8 @@ class TransformerBlock:
|
||||
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
if self.qk_norm == self.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)[start_pos:start_pos+T]
|
||||
q = apply_rope(q, freqs_cis)
|
||||
k = apply_rope(k, freqs_cis)
|
||||
q = apply_rope(q, self.freqs_cis[start_pos:start_pos+T])
|
||||
k = apply_rope(k, self.freqs_cis[start_pos:start_pos+T])
|
||||
|
||||
# TODO: fix assign to behave like this
|
||||
assigned_kv = self.cache_kv.uop.after(self.cache_kv[:, :, :, start_pos:start_pos+T, :].uop.assign(Tensor.stack(k, v).contiguous().uop))
|
||||
@@ -151,7 +150,7 @@ class TransformerBlock:
|
||||
attn = self.attn_output(attn)
|
||||
return x + attn
|
||||
|
||||
@function(precompile=bool(getenv("PRECOMPILE", 0)))
|
||||
@function(precompile=bool(getenv("PRECOMPILE", 0)), allow_implicit=False)
|
||||
def _feed_forward(self, h: Tensor) -> Tensor:
|
||||
h_norm = self.ffn_norm(h)
|
||||
if hasattr(self, 'ffn_gate_exps'):
|
||||
@@ -168,6 +167,7 @@ class TransformerBlock:
|
||||
# TODO: how is the dtype of this determined?
|
||||
# NOTE: clone is used to promise the creation of a specific buffer
|
||||
self.cache_kv = Tensor.zeros(2, x.shape[0], self.n_kv_heads, self.max_context, self.head_dim, device=x.device).clone()
|
||||
self.freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context, self.rope_theta)
|
||||
return self._feed_forward(self._attention(x, start_pos)).contiguous()
|
||||
|
||||
class Transformer:
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Generic, TypeVar, Callable, cast, overload
|
||||
from tinygrad.helpers import Context, dedup, getenv
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
|
||||
def add_to_ctx(ctx, x:UOp):
|
||||
ret = x.param_like(len(ctx))
|
||||
@@ -19,21 +20,18 @@ pm_ctx = PatternMatcher([
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
class _function(Generic[ReturnType]):
|
||||
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False):
|
||||
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True):
|
||||
self.fxn = fxn
|
||||
self.precompile = precompile
|
||||
self.allow_implicit = allow_implicit
|
||||
|
||||
def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self
|
||||
|
||||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
input_uops: list[UOp] = [(t.uop if isinstance(t, Tensor) else t)
|
||||
for name,t in list(enumerate(args))+sorted(kwargs.items()) if isinstance(t, (Tensor, UOp))]
|
||||
|
||||
# use the base
|
||||
#input_uops = [x.multibase for x in input_uops]
|
||||
params = get_state_dict((args, kwargs), tensor_type=(Tensor, UOp)).values()
|
||||
|
||||
# deduplicate input_uops, keeping the first occurrence index for each unique uop
|
||||
call_uops: list[UOp] = dedup(input_uops)
|
||||
call_uops: list[UOp] = dedup([(t.uop if isinstance(t, Tensor) else t) for t in params])
|
||||
|
||||
# disable realize/schedule while this is running
|
||||
# run it and do surgery later
|
||||
@@ -55,8 +53,14 @@ class _function(Generic[ReturnType]):
|
||||
#call_uops = [x.contiguous() for x in call_uops]
|
||||
|
||||
# the BUFFERs that are left are the implicit inputs
|
||||
num_explicit = len(call_uops)
|
||||
uret = graph_rewrite(uret, pm_ctx, call_uops, bottom_up=True, name="get_implicit_inputs")
|
||||
name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__
|
||||
if not self.allow_implicit:
|
||||
implicit_buffers = [x for x in call_uops[num_explicit:] if x.op is Ops.BUFFER]
|
||||
if implicit_buffers:
|
||||
buf_strs = '\n '.join(f"{i}: dtype={b.dtype}, size={b.size}, device={b.device}" for i,b in enumerate(implicit_buffers))
|
||||
raise RuntimeError(f"function {name} has {len(implicit_buffers)} implicit buffer(s), but allow_implicit=False\n {buf_strs}")
|
||||
|
||||
# assign output
|
||||
#pbuffer = uret.param_like(len(call_uops))
|
||||
@@ -73,9 +77,9 @@ class _function(Generic[ReturnType]):
|
||||
|
||||
# overload signatures support both @function and @function(precompile=True) syntax
|
||||
@overload
|
||||
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False) -> _function[ReturnType]: ...
|
||||
def function(fxn:Callable[..., ReturnType], *, precompile:bool=False, allow_implicit:bool=True) -> _function[ReturnType]: ...
|
||||
@overload
|
||||
def function(fxn:None=None, *, precompile:bool=False) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
|
||||
def function(fxn=None, *, precompile:bool=False):
|
||||
if fxn is None: return lambda f: _function(f, precompile=precompile)
|
||||
return _function(fxn, precompile=precompile)
|
||||
def function(fxn:None=None, *, precompile:bool=False, allow_implicit:bool=True) -> Callable[[Callable[..., ReturnType]], _function[ReturnType]]: ...
|
||||
def function(fxn=None, *, precompile:bool=False, allow_implicit:bool=True):
|
||||
if fxn is None: return lambda f: _function(f, precompile=precompile, allow_implicit=allow_implicit)
|
||||
return _function(fxn, precompile=precompile, allow_implicit=allow_implicit)
|
||||
|
||||
Reference in New Issue
Block a user