move metadata set to tensor [pr] (#9976)

* move metadata set to tensor [pr]

* only track that in tensor.py
This commit is contained in:
George Hotz
2025-04-22 12:30:35 +01:00
committed by GitHub
parent f6271515fe
commit e358e0a0c6
4 changed files with 13 additions and 11 deletions

View File

@@ -3,7 +3,8 @@ import numpy as np
import torch
import unittest, copy, mmap, random, math, array
from tinygrad import Tensor, Device, dtypes
from tinygrad.helpers import getenv, temp, _METADATA, mv_address
from tinygrad.tensor import _METADATA
from tinygrad.helpers import getenv, temp, mv_address
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat
from tinygrad.device import is_dtype_supported

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass
import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib
import urllib.request, subprocess, shutil, math, types, copyreg, inspect, importlib
from dataclasses import dataclass
from typing import Union, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic
@@ -126,7 +126,6 @@ class Metadata:
def __hash__(self): return hash(self.name)
def __repr__(self): return str(self) + (f" - {self.caller}" if self.caller else "")
def __str__(self): return self.name + (" bw" if self.backward else "")
_METADATA: contextvars.ContextVar[Optional[Metadata]] = contextvars.ContextVar("_METADATA", default=None)
# **************** global state Counters ****************

View File

@@ -4,7 +4,7 @@ import sys, time, functools, itertools, math, operator, hashlib, os, types, pick
from enum import auto, IntEnum, Enum
from dataclasses import dataclass, field
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, _METADATA, flatten
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten
from tinygrad.helpers import PICKLE_BUFFERS, dedup, cdiv, cmod
if TYPE_CHECKING:
from tinygrad.shape.shapetracker import ShapeTracker
@@ -228,8 +228,6 @@ class UOpMetaClass(type):
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
for s in src: s.children.add(ref)
# NOTE: this will soon be set by Tensor once we remove function.py
if (metadata:=_METADATA.get()) is not None: all_metadata[created] = metadata
# NOTE: this value is set by pickle when pickling a realized tensor
if _buffer is not None:
assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
@@ -238,7 +236,7 @@ class UOpMetaClass(type):
# some uops map to other stuff
buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers
all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary()
all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary() # TODO: should this be here?
def _toposort(u:UOp, cache:set[UOp]):
if u in cache: return {}

View File

@@ -1,15 +1,15 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref, contextvars
from contextlib import ContextDecorator
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Optional
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
from tinygrad.engine.multi import get_multi_map
from tinygrad.gradient import compute_gradient
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element, all_metadata
from tinygrad.spec import tensor_uop_spec, type_verify
from tinygrad.device import Device, Buffer
from tinygrad.engine.realize import run_schedule
@@ -47,6 +47,9 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> Non
# **** Tensor helper functions ****
# this tracks the tensor.py METADATA
_METADATA: contextvars.ContextVar[Optional[Metadata]] = contextvars.ContextVar("_METADATA", default=None)
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:str|tuple[str, ...], arg=None) -> UOp:
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
@@ -177,6 +180,7 @@ class Tensor(SimpleMathTrait):
def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
new_uop: UOp = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
if (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = metadata
needs_input_grad = [t.requires_grad for t in (self,)+x]
return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)