diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 7438dfded7..4e673de4fc 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -128,7 +128,6 @@ class TransformerBlock: self.ffn_up = nn.Linear(config.dim, config.hidden_dim, bias=False) self.ffn_down = nn.Linear(config.hidden_dim, config.dim, bias=False) - @function def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor: x_norm = self.attn_norm(x) # (B,T,D) q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm) @@ -160,7 +159,6 @@ class TransformerBlock: attn = self.attn_output(attn) return x + attn - @function def _feed_forward(self, h: Tensor) -> Tensor: h_norm = self.ffn_norm(h) if hasattr(self, 'ffn_gate_exps'): @@ -394,7 +392,7 @@ if __name__ == "__main__": raw_model = Tensor.from_url(models.get(args.model, args.model)) model, kv = Transformer.from_gguf(raw_model, args.max_context) model_name = kv.get('general.name') or kv.get('general.basename') or args.model - print(f"using model {model_name} with {raw_model.nbytes():,} bytes and {sum(x.numel() for x in nn.state.get_parameters(model)):,} params") + print(f"using model \"{model_name}\" with {raw_model.nbytes():,} bytes and {sum(x.numel() for x in nn.state.get_parameters(model)):,} params") del raw_model # TODO: why this is required to free the RAM of the GGUF copy? diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 75a16869d7..975a5ef06b 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -18,6 +18,8 @@ class ConstFloat(float): if isinstance(other, float) and math.isnan(self) and math.isnan(other): return True return float.__eq__(self, other) def __hash__(self): return hash(self.bits) + def __repr__(self): return f"ConstFloat({float.__repr__(self)})" + def __str__(self): return float.__repr__(self) class InvalidType: _instance: ClassVar[InvalidType|None] = None diff --git a/tinygrad/function.py b/tinygrad/function.py index 495bcd8748..87586443ef 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -1,6 +1,6 @@ -import functools, itertools +import functools, itertools, time from typing import Generic, TypeVar, Callable, cast, overload -from tinygrad.helpers import Context, dedup, getenv +from tinygrad.helpers import Context, dedup, getenv, DEBUG from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat from tinygrad.tensor import Tensor from tinygrad.nn.state import get_state_dict @@ -24,6 +24,7 @@ pm_ctx = PatternMatcher([ ReturnType = TypeVar('ReturnType') class _function(Generic[ReturnType]): + depth = 0 def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool, precompile_backward:bool, allow_implicit:bool, grad_fxn:Callable|None): self.fxn = fxn self.precompile = precompile @@ -34,6 +35,8 @@ class _function(Generic[ReturnType]): def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self def __call__(self, *args, **kwargs) -> ReturnType: + st = time.perf_counter() + params = get_state_dict((args, kwargs), tensor_type=(Tensor, UOp)).values() # deduplicate input_uops, keeping the first occurrence index for each unique uop @@ -42,7 +45,9 @@ class _function(Generic[ReturnType]): # disable realize/schedule while this is running # run it and do surgery later with Context(ALLOW_DEVICE_USAGE=getenv("DEVICE_IN_FUNCTION_BUG", 0)): + _function.depth += 1 ret = self.fxn(*args, **kwargs) + _function.depth -= 1 if isinstance(ret, Tensor): uret = ret.uop elif isinstance(ret, tuple) and all(isinstance(x, Tensor) for x in ret): @@ -77,6 +82,11 @@ class _function(Generic[ReturnType]): fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile, precompile_backward=self.precompile_backward) + + if DEBUG >= 2: + #signature = [(x._shape, x.dtype, x._device) for x in call_uops] + print(" "*_function.depth+f"function {uret.key.hex()[:8]} in {(time.perf_counter()-st)*1000:8.2f} ms: {name}") # with sig {signature}") + if isinstance(ret, tuple): return cast(ReturnType, tuple(Tensor(fret.gettuple(i)) for i in range(len(ret)))) else: diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 8b25492cfe..260ca8ff22 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -1567,7 +1567,7 @@ pm_pyrender_extra = PatternMatcher([ (UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"), lambda x,u,d: f"UOp.unique_const({x.dtype}, {x.arg}, device={repr(d.arg)}, unique={u.arg})"), (UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"), - (UPat(Ops.CONST, name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"), + (UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"), (UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.weakint else ''})"), (UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"),