mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
print function timing with DEBUG=2 (#15468)
* add DEBUG=2 function timing * remove those functions, they aren't useful * fix spec
This commit is contained in:
@@ -128,7 +128,6 @@ class TransformerBlock:
|
||||
self.ffn_up = nn.Linear(config.dim, config.hidden_dim, bias=False)
|
||||
self.ffn_down = nn.Linear(config.hidden_dim, config.dim, bias=False)
|
||||
|
||||
@function
|
||||
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
x_norm = self.attn_norm(x) # (B,T,D)
|
||||
q, k, v = self.attn_q(x_norm), self.attn_k(x_norm), self.attn_v(x_norm)
|
||||
@@ -160,7 +159,6 @@ class TransformerBlock:
|
||||
attn = self.attn_output(attn)
|
||||
return x + attn
|
||||
|
||||
@function
|
||||
def _feed_forward(self, h: Tensor) -> Tensor:
|
||||
h_norm = self.ffn_norm(h)
|
||||
if hasattr(self, 'ffn_gate_exps'):
|
||||
@@ -394,7 +392,7 @@ if __name__ == "__main__":
|
||||
raw_model = Tensor.from_url(models.get(args.model, args.model))
|
||||
model, kv = Transformer.from_gguf(raw_model, args.max_context)
|
||||
model_name = kv.get('general.name') or kv.get('general.basename') or args.model
|
||||
print(f"using model {model_name} with {raw_model.nbytes():,} bytes and {sum(x.numel() for x in nn.state.get_parameters(model)):,} params")
|
||||
print(f"using model \"{model_name}\" with {raw_model.nbytes():,} bytes and {sum(x.numel() for x in nn.state.get_parameters(model)):,} params")
|
||||
del raw_model
|
||||
|
||||
# TODO: why this is required to free the RAM of the GGUF copy?
|
||||
|
||||
@@ -18,6 +18,8 @@ class ConstFloat(float):
|
||||
if isinstance(other, float) and math.isnan(self) and math.isnan(other): return True
|
||||
return float.__eq__(self, other)
|
||||
def __hash__(self): return hash(self.bits)
|
||||
def __repr__(self): return f"ConstFloat({float.__repr__(self)})"
|
||||
def __str__(self): return float.__repr__(self)
|
||||
|
||||
class InvalidType:
|
||||
_instance: ClassVar[InvalidType|None] = None
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import functools, itertools
|
||||
import functools, itertools, time
|
||||
from typing import Generic, TypeVar, Callable, cast, overload
|
||||
from tinygrad.helpers import Context, dedup, getenv
|
||||
from tinygrad.helpers import Context, dedup, getenv, DEBUG
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, PatternMatcher, UPat
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
@@ -24,6 +24,7 @@ pm_ctx = PatternMatcher([
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
class _function(Generic[ReturnType]):
|
||||
depth = 0
|
||||
def __init__(self, fxn:Callable[..., ReturnType], *, precompile:bool, precompile_backward:bool, allow_implicit:bool, grad_fxn:Callable|None):
|
||||
self.fxn = fxn
|
||||
self.precompile = precompile
|
||||
@@ -34,6 +35,8 @@ class _function(Generic[ReturnType]):
|
||||
def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self
|
||||
|
||||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
st = time.perf_counter()
|
||||
|
||||
params = get_state_dict((args, kwargs), tensor_type=(Tensor, UOp)).values()
|
||||
|
||||
# deduplicate input_uops, keeping the first occurrence index for each unique uop
|
||||
@@ -42,7 +45,9 @@ class _function(Generic[ReturnType]):
|
||||
# disable realize/schedule while this is running
|
||||
# run it and do surgery later
|
||||
with Context(ALLOW_DEVICE_USAGE=getenv("DEVICE_IN_FUNCTION_BUG", 0)):
|
||||
_function.depth += 1
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
_function.depth -= 1
|
||||
if isinstance(ret, Tensor):
|
||||
uret = ret.uop
|
||||
elif isinstance(ret, tuple) and all(isinstance(x, Tensor) for x in ret):
|
||||
@@ -77,6 +82,11 @@ class _function(Generic[ReturnType]):
|
||||
|
||||
fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile,
|
||||
precompile_backward=self.precompile_backward)
|
||||
|
||||
if DEBUG >= 2:
|
||||
#signature = [(x._shape, x.dtype, x._device) for x in call_uops]
|
||||
print(" "*_function.depth+f"function {uret.key.hex()[:8]} in {(time.perf_counter()-st)*1000:8.2f} ms: {name}") # with sig {signature}")
|
||||
|
||||
if isinstance(ret, tuple):
|
||||
return cast(ReturnType, tuple(Tensor(fret.gettuple(i)) for i in range(len(ret))))
|
||||
else:
|
||||
|
||||
@@ -1567,7 +1567,7 @@ pm_pyrender_extra = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.UNIQUE, name="u"), UPat(Ops.DEVICE, name="d")), name="x"),
|
||||
lambda x,u,d: f"UOp.unique_const({x.dtype}, {x.arg}, device={repr(d.arg)}, unique={u.arg})"),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE, name="d"),), name="x"), lambda x,d: f"UOp.const({x.dtype}, {x.arg}, device={repr(d.arg)})"),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
|
||||
(UPat(Ops.CONST, src=(), name="x"), lambda x: f"UOp.const({x.dtype}, {x.arg})"),
|
||||
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x:
|
||||
f"UOp.variable(\"{x.arg[0]}\", {x.arg[1]}, {x.arg[2]}{', dtype='+str(x.dtype) if x.dtype is not dtypes.weakint else ''})"),
|
||||
(UPat((Ops.CAST, Ops.BITCAST), name="x"), lambda ctx,x: f"{ctx[x.src[0]]}.{x.op.name.lower()}({x.dtype})"),
|
||||
|
||||
Reference in New Issue
Block a user