mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
move metadata set to tensor [pr] (#9976)
* move metadata set to tensor [pr] * only track that in tensor.py
This commit is contained in:
@@ -3,7 +3,8 @@ import numpy as np
|
||||
import torch
|
||||
import unittest, copy, mmap, random, math, array
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.helpers import getenv, temp, _METADATA, mv_address
|
||||
from tinygrad.tensor import _METADATA
|
||||
from tinygrad.helpers import getenv, temp, mv_address
|
||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass
|
||||
import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib
|
||||
import urllib.request, subprocess, shutil, math, types, copyreg, inspect, importlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Union, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic
|
||||
|
||||
@@ -126,7 +126,6 @@ class Metadata:
|
||||
def __hash__(self): return hash(self.name)
|
||||
def __repr__(self): return str(self) + (f" - {self.caller}" if self.caller else "")
|
||||
def __str__(self): return self.name + (" bw" if self.backward else "")
|
||||
_METADATA: contextvars.ContextVar[Optional[Metadata]] = contextvars.ContextVar("_METADATA", default=None)
|
||||
|
||||
# **************** global state Counters ****************
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import sys, time, functools, itertools, math, operator, hashlib, os, types, pick
|
||||
from enum import auto, IntEnum, Enum
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, _METADATA, flatten
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten
|
||||
from tinygrad.helpers import PICKLE_BUFFERS, dedup, cdiv, cmod
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -228,8 +228,6 @@ class UOpMetaClass(type):
|
||||
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
|
||||
UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
|
||||
for s in src: s.children.add(ref)
|
||||
# NOTE: this will soon be set by Tensor once we remove function.py
|
||||
if (metadata:=_METADATA.get()) is not None: all_metadata[created] = metadata
|
||||
# NOTE: this value is set by pickle when pickling a realized tensor
|
||||
if _buffer is not None:
|
||||
assert op is Ops.BUFFER, f"trying to set Buffer {_buffer} for {op}"
|
||||
@@ -238,7 +236,7 @@ class UOpMetaClass(type):
|
||||
|
||||
# some uops map to other stuff
|
||||
buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # this maps BUFFER uops to their device Buffers
|
||||
all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary()
|
||||
all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary() # TODO: should this be here?
|
||||
|
||||
def _toposort(u:UOp, cache:set[UOp]):
|
||||
if u in cache: return {}
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref, contextvars
|
||||
from contextlib import ContextDecorator
|
||||
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar
|
||||
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Optional
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
|
||||
from tinygrad.engine.multi import get_multi_map
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
|
||||
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element, all_metadata
|
||||
from tinygrad.spec import tensor_uop_spec, type_verify
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
@@ -47,6 +47,9 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str|None=None) -> Non
|
||||
|
||||
# **** Tensor helper functions ****
|
||||
|
||||
# this tracks the tensor.py METADATA
|
||||
_METADATA: contextvars.ContextVar[Optional[Metadata]] = contextvars.ContextVar("_METADATA", default=None)
|
||||
|
||||
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:str|tuple[str, ...], arg=None) -> UOp:
|
||||
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
|
||||
return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
|
||||
@@ -177,6 +180,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
|
||||
new_uop: UOp = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
|
||||
if (metadata:=_METADATA.get()) is not None: all_metadata[new_uop] = metadata
|
||||
needs_input_grad = [t.requires_grad for t in (self,)+x]
|
||||
return Tensor(new_uop, device=new_uop.device, requires_grad=True if any(needs_input_grad) else None if None in needs_input_grad else False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user