mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
use functools.cache instead of lru_cache(None) [pr] (#9714)
* use functools.cache instead of lru_cache(None) [pr] * more cache
This commit is contained in:
@@ -5,7 +5,7 @@ from PIL import Image
|
||||
import functools, pathlib
|
||||
from tinygrad.helpers import diskcache, getenv
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_imagenet_categories():
|
||||
ci = json.load(open(BASEDIR / "imagenet_class_index.json"))
|
||||
return {v[0]: int(k) for k,v in ci.items()}
|
||||
@@ -13,7 +13,7 @@ def get_imagenet_categories():
|
||||
if getenv("MNISTMOCK"):
|
||||
BASEDIR = pathlib.Path(__file__).parent / "mnist"
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_train_files():
|
||||
if not BASEDIR.exists():
|
||||
from extra.datasets.fake_imagenet_from_mnist import create_fake_mnist_imagenet
|
||||
@@ -29,7 +29,7 @@ else:
|
||||
if not (files:=glob.glob(p:=str(BASEDIR / "train/*/*"))): raise FileNotFoundError(f"No training files in {p}")
|
||||
return files
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_val_files():
|
||||
if not (files:=glob.glob(p:=str(BASEDIR / "val/*/*"))): raise FileNotFoundError(f"No validation files in {p}")
|
||||
return files
|
||||
|
||||
@@ -15,11 +15,11 @@ BASEDIR = Path(__file__).parent / "kits19" / "data"
|
||||
TRAIN_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "train"
|
||||
VAL_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "val"
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_train_files():
|
||||
return sorted([x for x in BASEDIR.iterdir() if x.stem.startswith("case") and int(x.stem.split("_")[-1]) < 210 and x not in get_val_files()])
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_val_files():
|
||||
data = fetch("https://raw.githubusercontent.com/mlcommons/training/master/retired_benchmarks/unet3d/pytorch/evaluation_cases.txt").read_text()
|
||||
return sorted([x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.split("\n")])
|
||||
|
||||
@@ -89,7 +89,7 @@ required_input_python_consts: dict[str, tuple[int, ...]] = {
|
||||
}
|
||||
|
||||
cache_misses = 0
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _cached_to_python_const(t:Tensor):
|
||||
if t.dtype is dtypes.uint8: return t.data().tobytes()
|
||||
if 0 in t.shape: return []
|
||||
|
||||
@@ -150,7 +150,7 @@ def fill_scalar(x, y):
|
||||
@torch.library.impl("aten::_local_scalar_dense", "privateuseone")
|
||||
def _local_scalar_dense(tensor): return unwrap(tensor).item()
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def cached_to_movement_ops(shape, st) -> list:
|
||||
mops = to_movement_ops(st)
|
||||
if mops[0] == (MovementOps.RESHAPE, shape): mops = mops[1:]
|
||||
|
||||
@@ -123,7 +123,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
# ***** optional patterns *****
|
||||
|
||||
powers_of_two = {2**i:i for i in range(64)}
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def get_late_rewrite_patterns(ops, force_transcendental=False):
|
||||
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
|
||||
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
|
||||
|
||||
@@ -15,7 +15,7 @@ def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) ->
|
||||
def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]:
|
||||
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]:
|
||||
return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
|
||||
|
||||
@@ -101,7 +101,7 @@ expander = PatternMatcher([
|
||||
])
|
||||
|
||||
def create_gate(root:UOp) -> UOp|None:
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _gate_srcs(u:UOp, gate:UOp) -> UOp:
|
||||
if u.op is Ops.BARRIER: return u
|
||||
if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER:
|
||||
|
||||
@@ -575,7 +575,7 @@ class Kernel:
|
||||
return name + colored(num, 'BLACK')
|
||||
|
||||
def get_optimized_ast(self, name_override:Optional[str]=None) -> UOp:
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def fixup_ast(op:UOp) -> UOp:
|
||||
ret = op.replace(src=tuple(fixup_ast(x) for x in op.src))
|
||||
if op.op in GroupOp.Buffer and op in self.bufs:
|
||||
|
||||
@@ -15,12 +15,12 @@ class _Device:
|
||||
def __init__(self) -> None:
|
||||
self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
|
||||
self._opened_devices:set[str] = set()
|
||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):])
|
||||
# NOTE: you can't cache canonicalize in case Device.DEFAULT changes
|
||||
def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
|
||||
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
|
||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def __get_canonicalized_item(self, ix:str) -> Compiled:
|
||||
cpn = multiprocessing.current_process().name
|
||||
assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"can only open device {ix} from parent, not {cpn}"
|
||||
|
||||
@@ -33,7 +33,7 @@ class DType(metaclass=DTypeMetaClass):
|
||||
def base(self): return self
|
||||
@property
|
||||
def vcount(self): return self.count
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def vec(self, sz:int) -> DType:
|
||||
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
||||
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
|
||||
@@ -50,7 +50,7 @@ class PtrDType(DType):
|
||||
size: int = -1 # -1 is unlimited size
|
||||
@property
|
||||
def base(self): return self._base
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def vec(self, sz:int) -> DType:
|
||||
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
|
||||
if sz == 1: return self # sz=1 is a scalar
|
||||
@@ -73,13 +73,13 @@ class ImageDType(PtrDType):
|
||||
|
||||
class dtypes:
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType)
|
||||
@staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
|
||||
@staticmethod
|
||||
def is_bool(x: DType) -> bool: return x.scalar() == dtypes.bool
|
||||
@@ -99,12 +99,12 @@ class dtypes:
|
||||
# TODO: should truncate here
|
||||
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def min(dtype:DType):
|
||||
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
|
||||
return -float("inf") if dtypes.is_float(dtype) else False
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def max(dtype:DType):
|
||||
if dtypes.is_int(dtype): return 2**(dtype.itemsize*8)-1+dtypes.min(dtype)
|
||||
return float("inf") if dtypes.is_float(dtype) else True
|
||||
@@ -165,10 +165,10 @@ promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes
|
||||
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16],
|
||||
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _get_recursive_parents(dtype:DType) -> set[DType]:
|
||||
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def least_upper_dtype(*ds:DType) -> DType:
|
||||
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
|
||||
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float)
|
||||
@@ -210,12 +210,12 @@ def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] #
|
||||
import numpy as np
|
||||
return dtypes.fields()[np.dtype(npdtype).name]
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _to_torch_dtype(dtype:DType) -> Optional['torch.dtype']: # type: ignore [name-defined] # noqa: F821
|
||||
import numpy as np, torch
|
||||
# NOTE: torch doesn't expose this mapping with a stable API
|
||||
try: return torch.from_numpy(np.array([], dtype=_to_np_dtype(dtype))).dtype
|
||||
except TypeError: return None
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
||||
return {v:k for k in dtypes.all if (v:=_to_torch_dtype(k)) is not None}[torchdtype]
|
||||
@@ -47,7 +47,7 @@ pm_gradient = PatternMatcher([
|
||||
|
||||
# copied from tensor.py, get relevant toposort of gradients
|
||||
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def is_in_target_path(x:UOp) -> bool: return any(u in targets or is_in_target_path(u) for u in x.src)
|
||||
def _walk(node:UOp, visited:set[UOp]) -> Iterator[UOp]:
|
||||
visited.add(node)
|
||||
|
||||
@@ -78,9 +78,9 @@ def pluralize(st:str, cnt:int): return f"{cnt} {st}"+('' if cnt == 1 else 's')
|
||||
# for length N coefficients `p`, returns p[0] * x**(N-1) + p[1] * x**(N-2) + ... + p[-2] * x + p[-1]
|
||||
def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
|
||||
def temp(x:str, append_user:bool=False) -> str:
|
||||
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
|
||||
@@ -295,7 +295,7 @@ def to_mv(ptr:int, sz:int) -> memoryview: return memoryview(ctypes.cast(ptr, cty
|
||||
def mv_address(mv): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
|
||||
def to_char_p_p(options: list[bytes], to_type=ctypes.c_char):
|
||||
return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options])
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def init_c_struct_t(fields: tuple[tuple[str, ctypes._SimpleCData], ...]):
|
||||
class CStruct(ctypes.Structure):
|
||||
_pack_, _fields_ = 1, fields
|
||||
|
||||
@@ -707,7 +707,7 @@ def get_location() -> tuple[str, int]:
|
||||
frm = frm.f_back
|
||||
return frm.f_code.co_filename, frm.f_lineno
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def lines(fn) -> list[str]:
|
||||
with open(fn) as f: return f.readlines()
|
||||
|
||||
@@ -745,10 +745,10 @@ class UPat(MathTrait):
|
||||
def or_casted(self, name:str|None=None): return UPat.any(self if name is None else self.named(name), UPat(Ops.CAST, name=name, src=(self,)))
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def var(name:Optional[str]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name)
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True): return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name)
|
||||
@staticmethod
|
||||
def const(dtype:Optional[Union[DType, tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
|
||||
@@ -823,7 +823,7 @@ class PatternMatcher:
|
||||
|
||||
def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "<lambda>" else fxn) for x,fxn in self.patterns],)
|
||||
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
||||
|
||||
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:
|
||||
|
||||
@@ -35,7 +35,7 @@ libobjc.sel_registerName.restype = objc_id
|
||||
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
|
||||
libdispatch.dispatch_data_create.restype = objc_instance
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def msg(selector: str, restype: type[T] = objc_id): # type: ignore [assignment]
|
||||
resname = libobjc.sel_registerName(selector.encode())
|
||||
sender = libobjc["objc_msgSend"] # Using attribute access returns a new reference so setting restype is safe
|
||||
@@ -43,7 +43,7 @@ def msg(selector: str, restype: type[T] = objc_id): # type: ignore [assignment]
|
||||
def _msg(ptr: objc_id, *args: Any) -> T: return sender(ptr, resname, *args)
|
||||
return _msg
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def to_ns_str(s: str): return msg("stringWithUTF8String:", objc_instance)(libobjc.objc_getClass(b"NSString"), s.encode())
|
||||
def from_ns_str(s): return bytes(msg("UTF8String", ctypes.c_char_p)(s)).decode()
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ def folded_upcast(u: UOp):
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
return upcast(graph_rewrite(u, sym, {}))
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
|
||||
idx, valid = views[-1].to_indexed_uops(_idxs)
|
||||
for view in reversed(views[0:-1]):
|
||||
@@ -37,7 +37,7 @@ def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...
|
||||
idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
|
||||
return idx, valid
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[Optional[sint], ...]:
|
||||
# NOTE: if a stride is not always valid, it will be None
|
||||
if len(views) == 1 and views[-1].mask is None: return views[-1].strides
|
||||
|
||||
@@ -6,17 +6,17 @@ from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop
|
||||
from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def canonicalize_strides(shape:tuple[sint, ...], strides:tuple[sint, ...]) -> tuple[sint, ...]:
|
||||
return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def strides_for_shape(shape:tuple[sint, ...]) -> tuple[sint, ...]:
|
||||
if not shape: return ()
|
||||
strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))[::-1]
|
||||
return canonicalize_strides(shape, strides)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def merge_dims(shape:tuple[int, ...], strides:tuple[int, ...], mask:Optional[tuple[tuple[int, int], ...]]=None) -> tuple[tuple[int, int, int], ...]:
|
||||
# merge contiguous sub-parts or zero strided dims
|
||||
# any stride 0, masked from dim=1, or contiguous part is merged into next dim.
|
||||
@@ -38,7 +38,7 @@ def merge_dims(shape:tuple[int, ...], strides:tuple[int, ...], mask:Optional[tup
|
||||
merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1
|
||||
return tuple(ret)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) \
|
||||
-> Optional[tuple[tuple[sint, sint], ...]]:
|
||||
"""Returns the new mask if reshape is possible, and None if not possible."""
|
||||
@@ -98,14 +98,14 @@ class View:
|
||||
if resolve(m[1] != sh): vexpr = vexpr * (idx < m[1])
|
||||
return iexpr, vexpr
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def size(self) -> int:
|
||||
ret = prod([x.vmax if isinstance(x, UOp) else x for x in self.shape])
|
||||
assert isinstance(ret, int), f"{ret=} is not int"
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def create(shape:tuple[sint, ...], strides:Optional[tuple[sint, ...]]=None, offset:sint=0, mask:Optional[tuple[tuple[sint, sint], ...]]=None):
|
||||
if not all(s >= 0 for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
|
||||
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
|
||||
@@ -131,12 +131,12 @@ class View:
|
||||
contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
|
||||
return View(shape, strides, offset, mask, contiguous)
|
||||
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def vars(self) -> set[Variable]:
|
||||
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
|
||||
return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, UOp)], set())
|
||||
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def unbind(self) -> tuple[View, dict[Variable, int]]:
|
||||
var_unboundvar_val = [(v, v.unbind()) for v in self.vars()]
|
||||
unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
|
||||
@@ -147,7 +147,7 @@ class View:
|
||||
new_mask = tuple((substitute(x[0]), substitute(x[1])) for x in self.mask) if self.mask is not None else None
|
||||
return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def __add__(self, vm1:View) -> Optional[View]:
|
||||
vm2 = self
|
||||
if vm2.contiguous: return vm1
|
||||
@@ -207,14 +207,14 @@ class View:
|
||||
|
||||
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def invert(self, out_shape:tuple[sint, ...]) -> Optional[View]:
|
||||
ret = View.create(self.shape)
|
||||
if self.mask: ret = ret.shrink(self.mask)
|
||||
ret = ret.flip(tuple(x < 0 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
|
||||
return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def minify(self):
|
||||
min_shape = tuple(x[0] for x in merge_dims(self.shape, self.strides, self.mask))
|
||||
return nv if (nv := self.reshape(min_shape)) else self
|
||||
@@ -228,7 +228,7 @@ class View:
|
||||
mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
|
||||
return View.create(tuple([y-x for x,y in arg]), self.strides, self.offset+offset, mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def pad(self, arg: tuple[tuple[sint, sint], ...]) -> View:
|
||||
assert len(arg) == len(self.shape), f"invalid pad {arg} for {self.shape}"
|
||||
# NOTE: not checking for symbolic arg
|
||||
@@ -239,14 +239,14 @@ class View:
|
||||
return self.__unsafe_resize(zvarg, mask=mask)
|
||||
return self
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> View:
|
||||
assert len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
|
||||
# NOTE: not checking for symbolic arg
|
||||
for s,(b,e) in zip(self.shape,arg): assert not all_int([s,b,e]) or (0<=b<=e<=s), f"invalid shrink {arg} for {self.shape}"
|
||||
return self.__unsafe_resize(arg)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def expand(self, new_shape: tuple[sint, ...]) -> View:
|
||||
if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}")
|
||||
# NOTE: does not check multiple of symbolic shape
|
||||
@@ -257,19 +257,19 @@ class View:
|
||||
for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
|
||||
return View.create(new_shape, self.strides, self.offset, mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def permute(self, axis: tuple[int, ...]) -> View:
|
||||
assert sorted(axis) == list(range(len(self.shape))), f"invalid permutation {axis} of len {len(self.shape)}"
|
||||
return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
|
||||
tuple(self.mask[a] for a in axis) if self.mask is not None else None)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def flip(self, arg: tuple[bool, ...]) -> View:
|
||||
offset = sum((s-1)*z for s,z,f in zip(self.shape, self.strides, arg) if f)
|
||||
mask = tuple((s-my,s-mx) if f else (mx,my) for (mx,my),s,f in zip(self.mask, self.shape, arg)) if self.mask is not None else None
|
||||
return View.create(self.shape, tuple(-z if f else z for z,f in zip(self.strides, arg)), self.offset+offset, mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
@functools.cache # pylint: disable=method-cache-max-size-none
|
||||
def reshape(self, new_shape: tuple[sint, ...]) -> Optional[View]:
|
||||
if self.shape == new_shape: return self
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ class GraphRewriteMetadata(TypedDict):
|
||||
kernel_code: str|None # optionally render the final kernel code
|
||||
name: str|None # optional name of the rewrite
|
||||
|
||||
@functools.lru_cache(None)
|
||||
@functools.cache
|
||||
def render_program(k:Kernel): return k.opts.render(k.uops)
|
||||
def to_metadata(k:Any, v:TrackedGraphRewrite) -> GraphRewriteMetadata:
|
||||
return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(),
|
||||
|
||||
Reference in New Issue
Block a user