mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Fix tracemeta 0 (#13049)
* chore: tclesius branch resolved * fix: indentation --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,7 @@ import torch
|
||||
import unittest, copy, mmap, random, math, array
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.tensor import _METADATA
|
||||
from tinygrad.helpers import getenv, temp, mv_address
|
||||
from tinygrad.helpers import Context, getenv, temp, mv_address
|
||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from tinygrad.device import is_dtype_supported
|
||||
@@ -846,6 +846,16 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
#self.assertEqual(len(bw), 1)
|
||||
#self.assertEqual(bw[0].name, "sigmoid")
|
||||
|
||||
def test_tracemeta_0(self):
|
||||
with Context(TRACEMETA=0):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
y = Tensor.rand(3, requires_grad=True)
|
||||
out = (x.relu() * y.sigmoid()).sum()
|
||||
self.assertIsNone(out.uop.metadata)
|
||||
self.assertIsNone(out.uop.src[0].metadata)
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(si.metadata, ())
|
||||
|
||||
class TestIdxUpcast(unittest.TestCase):
|
||||
def _find_op(self, ast: UOp, op: Ops):
|
||||
if ast.op is op: return ast
|
||||
|
||||
@@ -172,7 +172,7 @@ class Tensor(OpMixin):
|
||||
|
||||
def _apply_uop(self, fxn:Callable, *x:Tensor, extra_args=(), **kwargs) -> Tensor:
|
||||
new_uop: UOp = fxn(*[t.uop for t in (self,)+x], *extra_args, **kwargs)
|
||||
if (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = (metadata,)
|
||||
if (metadata:=_METADATA.get()) is not None and TRACEMETA >= 1: all_metadata[new_uop] = (metadata,)
|
||||
needs_input_grad = [t.requires_grad for t in (self,)+x]
|
||||
return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)
|
||||
|
||||
@@ -4178,7 +4178,7 @@ _METADATA: _ContextVar[Metadata|None] = _ContextVar(default=None)
|
||||
|
||||
def _metadata_wrapper(fn: Callable[P, T]) -> Callable[P, T]:
|
||||
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
if _METADATA.get() is not None: return fn(*args, **kwargs)
|
||||
if TRACEMETA < 1 or _METADATA.get() is not None: return fn(*args, **kwargs)
|
||||
|
||||
if TRACEMETA >= 2:
|
||||
caller_frame = sys._getframe(frame := 1)
|
||||
|
||||
Reference in New Issue
Block a user