print function timing with DEBUG=2 (#15468)

* add DEBUG=2 function timing

* remove those functions, they aren't useful

* fix spec
This commit is contained in:
George Hotz
2026-03-25 19:07:32 +08:00
committed by GitHub
parent e7f389efda
commit ae7090b13b
4 changed files with 16 additions and 6 deletions

View File

@@ -128,7 +128,6 @@ class TransformerBlock:
self.ffn_up = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.ffn_down = nn.Linear(config.hidden_dim, config.dim, bias=False)
@function
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
x_norm = self.attn_norm(x) # (B,T,D)
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
@@ -160,7 +159,6 @@ class TransformerBlock:
attn = self.attn_output(attn)
return x + attn
@function
def _feed_forward(self, h: Tensor) -> Tensor:
h_norm = self.ffn_norm(h)
if hasattr(self, 'ffn_gate_exps'):
@@ -394,7 +392,7 @@ if __name__ == "__main__":
raw_model = Tensor.from_url(models.get(args.model, args.model))
model, kv = Transformer.from_gguf(raw_model, args.max_context)
model_name = kv.get('general.name') or kv.get('general.basename') or args.model
print(f"using model {model_name} with {raw_model.nbytes():,} bytes and {sum(x.numel() for x in nn.state.get_parameters(model)):,} params")
print(f"using model \"{model_name}\" with {raw_model.nbytes():,} bytes and {sum(x.numel() for x in nn.state.get_parameters(model)):,} params")
del raw_model
# TODO: why this is required to free the RAM of the GGUF copy?

View File

@@ -18,6 +18,8 @@ class ConstFloat(float):
if isinstance(other, float) and math.isnan(self) and math.isnan(other): return True
return float.__eq__(self, other)
def __hash__(self): return hash(self.bits)
def __repr__(self): return f"ConstFloat({float.__repr__(self)})"
def __str__(self): return float.__repr__(self)
class InvalidType:
_instance: ClassVar[InvalidType|None] = None

View File

@@ -1,6 +1,6 @@
import functools, itertools
import functools, itertools, time
from typing import Generic, TypeVar, Callable, cast, overload
from tinygrad.helpers import Context, dedup, getenv
from tinygrad.helpers import Context, dedup, getenv, DEBUG
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_state_dict
@@ -24,6 +24,7 @@ pm_ctx = PatternMatcher([
ReturnType = TypeVar('ReturnType')
class _function(Generic[ReturnType]):
depth = 0
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool, precompile_backward:bool, allow_implicit:bool, grad_fxn:Callable|None):
self.fxn = fxn
self.precompile = precompile
@@ -34,6 +35,8 @@ class _function(Generic[ReturnType]):
def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self
def __call__(self, *args, **kwargs) -> ReturnType:
st = time.perf_counter()
params = get_state_dict((args, kwargs), tensor_type=(Tensor, UOp)).values()
# deduplicate input_uops, keeping the first occurrence index for each unique uop
@@ -42,7 +45,9 @@ class _function(Generic[ReturnType]):
# disable realize/schedule while this is running
# run it and do surgery later
with Context(ALLOW_DEVICE_USAGE=getenv("DEVICE_IN_FUNCTION_BUG", 0)):
_function.depth += 1
ret = self.fxn(*args, **kwargs)
_function.depth -= 1
if isinstance(ret, Tensor):
uret = ret.uop
elif isinstance(ret, tuple) and all(isinstance(x, Tensor) for x in ret):
@@ -77,6 +82,11 @@ class _function(Generic[ReturnType]):
fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile,
precompile_backward=self.precompile_backward)
if DEBUG >= 2:
#signature = [(x._shape, x.dtype, x._device) for x in call_uops]
print(" "*_function.depth+f"function {uret.key.hex()[:8]} in {(time.perf_counter()-st)*1000:8.2f} ms: {name}") # with sig {signature}")
if isinstance(ret, tuple):
return cast(ReturnType, tuple(Tensor(fret.gettuple(i)) for i in range(len(ret))))
else:

View File

@@ -1567,7 +1567,7 @@ pm_pyrender_extra = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"),
lambda x,u,d: f"UOp.unique_const({x.dtype}, {x.arg}, device={repr(d.arg)}, unique={u.arg})"),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"),
(UPat(Ops.CONST, name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
(UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x:
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.weakint else ''})"),
(UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"),