mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
tracemeta fixups (#5904)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,5 +1,6 @@
|
||||
__pycache__
|
||||
.venv/
|
||||
.venv-*/
|
||||
.vscode
|
||||
.DS_Store
|
||||
notebooks
|
||||
|
||||
@@ -3160,25 +3160,29 @@ def _metadata_wrapper(fn):
|
||||
def _wrapper(*args, **kwargs):
|
||||
if _METADATA.get() is not None: return fn(*args, **kwargs)
|
||||
|
||||
caller_frame = sys._getframe(frame := 1)
|
||||
caller_module = caller_frame.f_globals.get("__name__", None)
|
||||
caller_func = caller_frame.f_code.co_name
|
||||
if caller_module is None: return fn(*args, **kwargs)
|
||||
|
||||
# if its called from nn we want to step up frames until we are out of nn
|
||||
while caller_module.startswith("tinygrad.nn") and "optim" not in caller_module:
|
||||
caller_frame = sys._getframe(frame := frame + 1)
|
||||
if TRACEMETA >= 2:
|
||||
caller_frame = sys._getframe(frame := 1)
|
||||
caller_module = caller_frame.f_globals.get("__name__", None)
|
||||
caller_func = caller_frame.f_code.co_name
|
||||
if caller_module is None: return fn(*args, **kwargs)
|
||||
|
||||
# if its called from a lambda in tinygrad we want to look two more frames up
|
||||
if caller_module.startswith("tinygrad") and caller_func == "<lambda>": caller_frame = sys._getframe(frame := frame + 2)
|
||||
caller_module = caller_frame.f_globals.get("__name__", None)
|
||||
if caller_module is None: return fn(*args, **kwargs)
|
||||
caller_func = caller_frame.f_code.co_name
|
||||
caller_lineno = caller_frame.f_lineno
|
||||
# if its called from nn we want to step up frames until we are out of nn
|
||||
while caller_module.startswith("tinygrad.nn") and "optim" not in caller_module:
|
||||
caller_frame = sys._getframe(frame := frame + 1)
|
||||
caller_module = caller_frame.f_globals.get("__name__", None)
|
||||
if caller_module is None: return fn(*args, **kwargs)
|
||||
|
||||
token = _METADATA.set(Metadata(name=fn.__name__, caller=f"{caller_module}:{caller_lineno}::{caller_func}"))
|
||||
# if its called from a lambda in tinygrad we want to look two more frames up
|
||||
if caller_module.startswith("tinygrad") and caller_func == "<lambda>": caller_frame = sys._getframe(frame := frame + 2)
|
||||
caller_module = caller_frame.f_globals.get("__name__", None)
|
||||
if caller_module is None: return fn(*args, **kwargs)
|
||||
caller_func = caller_frame.f_code.co_name
|
||||
caller_lineno = caller_frame.f_lineno
|
||||
|
||||
caller = f"{caller_module}:{caller_lineno}::{caller_func}"
|
||||
else: caller = ""
|
||||
|
||||
token = _METADATA.set(Metadata(name=fn.__name__, caller=caller))
|
||||
ret = fn(*args, **kwargs)
|
||||
_METADATA.reset(token)
|
||||
return ret
|
||||
@@ -3186,5 +3190,5 @@ def _metadata_wrapper(fn):
|
||||
|
||||
if TRACEMETA >= 1:
|
||||
for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
|
||||
if name in ["__class__", "__init__", "__repr__", "backward", "sequential"]: continue
|
||||
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue
|
||||
setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))
|
||||
|
||||
Reference in New Issue
Block a user