tracemeta fixups (#5904)

This commit is contained in:
wozeparrot
2024-08-04 23:15:06 +00:00
committed by GitHub
parent adba5efc64
commit f33950f454
2 changed files with 21 additions and 16 deletions

1
.gitignore vendored
View File

@@ -1,5 +1,6 @@
__pycache__
.venv/
.venv-*/
.vscode
.DS_Store
notebooks

View File

@@ -3160,25 +3160,29 @@ def _metadata_wrapper(fn):
def _wrapper(*args, **kwargs):
if _METADATA.get() is not None: return fn(*args, **kwargs)
caller_frame = sys._getframe(frame := 1)
caller_module = caller_frame.f_globals.get("__name__", None)
caller_func = caller_frame.f_code.co_name
if caller_module is None: return fn(*args, **kwargs)
# if its called from nn we want to step up frames until we are out of nn
while caller_module.startswith("tinygrad.nn") and "optim" not in caller_module:
caller_frame = sys._getframe(frame := frame + 1)
if TRACEMETA >= 2:
caller_frame = sys._getframe(frame := 1)
caller_module = caller_frame.f_globals.get("__name__", None)
caller_func = caller_frame.f_code.co_name
if caller_module is None: return fn(*args, **kwargs)
# if its called from a lambda in tinygrad we want to look two more frames up
if caller_module.startswith("tinygrad") and caller_func == "<lambda>": caller_frame = sys._getframe(frame := frame + 2)
caller_module = caller_frame.f_globals.get("__name__", None)
if caller_module is None: return fn(*args, **kwargs)
caller_func = caller_frame.f_code.co_name
caller_lineno = caller_frame.f_lineno
# if its called from nn we want to step up frames until we are out of nn
while caller_module.startswith("tinygrad.nn") and "optim" not in caller_module:
caller_frame = sys._getframe(frame := frame + 1)
caller_module = caller_frame.f_globals.get("__name__", None)
if caller_module is None: return fn(*args, **kwargs)
token = _METADATA.set(Metadata(name=fn.__name__, caller=f"{caller_module}:{caller_lineno}::{caller_func}"))
# if its called from a lambda in tinygrad we want to look two more frames up
if caller_module.startswith("tinygrad") and caller_func == "<lambda>": caller_frame = sys._getframe(frame := frame + 2)
caller_module = caller_frame.f_globals.get("__name__", None)
if caller_module is None: return fn(*args, **kwargs)
caller_func = caller_frame.f_code.co_name
caller_lineno = caller_frame.f_lineno
caller = f"{caller_module}:{caller_lineno}::{caller_func}"
else: caller = ""
token = _METADATA.set(Metadata(name=fn.__name__, caller=caller))
ret = fn(*args, **kwargs)
_METADATA.reset(token)
return ret
@@ -3186,5 +3190,5 @@ def _metadata_wrapper(fn):
if TRACEMETA >= 1:
for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
if name in ["__class__", "__init__", "__repr__", "backward", "sequential"]: continue
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue
setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))