Fix tracemeta 0 (#13049)

* chore: tclesius branch resolved

* fix: indentation

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Ayman Jabr
2025-11-13 20:07:11 +03:00
committed by GitHub
parent 7e0aaadecd
commit 256f81bb02
2 changed files with 13 additions and 3 deletions

View File

@@ -3,7 +3,7 @@ import torch
import unittest, copy, mmap, random, math, array import unittest, copy, mmap, random, math, array
from tinygrad import Tensor, Device, dtypes from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _METADATA from tinygrad.tensor import _METADATA
from tinygrad.helpers import getenv, temp, mv_address from tinygrad.helpers import Context, getenv, temp, mv_address
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat from hypothesis import given, settings, strategies as strat
from tinygrad.device import is_dtype_supported from tinygrad.device import is_dtype_supported
@@ -846,6 +846,16 @@ class TestTensorMetadata(unittest.TestCase):
#self.assertEqual(len(bw), 1) #self.assertEqual(len(bw), 1)
#self.assertEqual(bw[0].name, "sigmoid") #self.assertEqual(bw[0].name, "sigmoid")
def test_tracemeta_0(self):
with Context(TRACEMETA=0):
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
out = (x.relu() * y.sigmoid()).sum()
self.assertIsNone(out.uop.metadata)
self.assertIsNone(out.uop.src[0].metadata)
si = out.schedule()[-1]
self.assertEqual(si.metadata, ())
class TestIdxUpcast(unittest.TestCase): class TestIdxUpcast(unittest.TestCase):
def _find_op(self, ast: UOp, op: Ops): def _find_op(self, ast: UOp, op: Ops):
if ast.op is op: return ast if ast.op is op: return ast

View File

@@ -172,7 +172,7 @@ class Tensor(OpMixin):
def _apply_uop(self, fxn:Callable, *x:Tensor, extra_args=(), **kwargs) -> Tensor: def _apply_uop(self, fxn:Callable, *x:Tensor, extra_args=(), **kwargs) -> Tensor:
new_uop: UOp = fxn(*[t.uop for t in (self,)+x], *extra_args, **kwargs) new_uop: UOp = fxn(*[t.uop for t in (self,)+x], *extra_args, **kwargs)
if (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,) if (metadata:=_METADATA.get()) is not None and TRACEMETA >= 1: all_metadata[new_uop] = (metadata,)
needs_input_grad = [t.requires_grad for t in (self,)+x] needs_input_grad = [t.requires_grad for t in (self,)+x]
return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False) return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)
@@ -4178,7 +4178,7 @@ _METADATA: _ContextVar[Metadata|None] = _ContextVar(default=None)
def _metadata_wrapper(fn: Callable[P, T]) -> Callable[P, T]: def _metadata_wrapper(fn: Callable[P, T]) -> Callable[P, T]:
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> T: def _wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
if _METADATA.get() is not None: return fn(*args, **kwargs) if TRACEMETA < 1 or _METADATA.get() is not None: return fn(*args, **kwargs)
if TRACEMETA >= 2: if TRACEMETA >= 2:
caller_frame = sys._getframe(frame := 1) caller_frame = sys._getframe(frame := 1)