mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Fix tracemeta 0 (#13049)
* chore: tclesius branch resolved * fix: indentation --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,7 @@ import torch
|
|||||||
import unittest, copy, mmap, random, math, array
|
import unittest, copy, mmap, random, math, array
|
||||||
from tinygrad import Tensor, Device, dtypes
|
from tinygrad import Tensor, Device, dtypes
|
||||||
from tinygrad.tensor import _METADATA
|
from tinygrad.tensor import _METADATA
|
||||||
from tinygrad.helpers import getenv, temp, mv_address
|
from tinygrad.helpers import Context, getenv, temp, mv_address
|
||||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||||
from hypothesis import given, settings, strategies as strat
|
from hypothesis import given, settings, strategies as strat
|
||||||
from tinygrad.device import is_dtype_supported
|
from tinygrad.device import is_dtype_supported
|
||||||
@@ -846,6 +846,16 @@ class TestTensorMetadata(unittest.TestCase):
|
|||||||
#self.assertEqual(len(bw), 1)
|
#self.assertEqual(len(bw), 1)
|
||||||
#self.assertEqual(bw[0].name, "sigmoid")
|
#self.assertEqual(bw[0].name, "sigmoid")
|
||||||
|
|
||||||
|
def test_tracemeta_0(self):
|
||||||
|
with Context(TRACEMETA=0):
|
||||||
|
x = Tensor.rand(3, requires_grad=True)
|
||||||
|
y = Tensor.rand(3, requires_grad=True)
|
||||||
|
out = (x.relu() * y.sigmoid()).sum()
|
||||||
|
self.assertIsNone(out.uop.metadata)
|
||||||
|
self.assertIsNone(out.uop.src[0].metadata)
|
||||||
|
si = out.schedule()[-1]
|
||||||
|
self.assertEqual(si.metadata, ())
|
||||||
|
|
||||||
class TestIdxUpcast(unittest.TestCase):
|
class TestIdxUpcast(unittest.TestCase):
|
||||||
def _find_op(self, ast: UOp, op: Ops):
|
def _find_op(self, ast: UOp, op: Ops):
|
||||||
if ast.op is op: return ast
|
if ast.op is op: return ast
|
||||||
|
|||||||
@@ -172,7 +172,7 @@ class Tensor(OpMixin):
|
|||||||
|
|
||||||
def _apply_uop(self, fxn:Callable, *x:Tensor, extra_args=(), **kwargs) -> Tensor:
|
def _apply_uop(self, fxn:Callable, *x:Tensor, extra_args=(), **kwargs) -> Tensor:
|
||||||
new_uop: UOp = fxn(*[t.uop for t in (self,)+x], *extra_args, **kwargs)
|
new_uop: UOp = fxn(*[t.uop for t in (self,)+x], *extra_args, **kwargs)
|
||||||
if (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,)
|
if (metadata:=_METADATA.get()) is not None and TRACEMETA >= 1: all_metadata[new_uop] = (metadata,)
|
||||||
needs_input_grad = [t.requires_grad for t in (self,)+x]
|
needs_input_grad = [t.requires_grad for t in (self,)+x]
|
||||||
return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)
|
return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)
|
||||||
|
|
||||||
@@ -4178,7 +4178,7 @@ _METADATA: _ContextVar[Metadata|None] = _ContextVar(default=None)
|
|||||||
|
|
||||||
def _metadata_wrapper(fn: Callable[P, T]) -> Callable[P, T]:
|
def _metadata_wrapper(fn: Callable[P, T]) -> Callable[P, T]:
|
||||||
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
if _METADATA.get() is not None: return fn(*args, **kwargs)
|
if TRACEMETA < 1 or _METADATA.get() is not None: return fn(*args, **kwargs)
|
||||||
|
|
||||||
if TRACEMETA >= 2:
|
if TRACEMETA >= 2:
|
||||||
caller_frame = sys._getframe(frame := 1)
|
caller_frame = sys._getframe(frame := 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user