use functools.cache instead of lru_cache(None) [pr] (#9714)

* use functools.cache instead of lru_cache(None) [pr]

* more cache
This commit is contained in:
George Hotz
2025-04-03 11:47:13 +08:00
committed by GitHub
parent bbd13191f4
commit 5c7b549eab
16 changed files with 54 additions and 54 deletions

View File

@@ -5,7 +5,7 @@ from PIL import Image
import functools, pathlib
from tinygrad.helpers import diskcache, getenv
@functools.lru_cache(None)
@functools.cache
def get_imagenet_categories():
ci = json.load(open(BASEDIR / "imagenet_class_index.json"))
return {v[0]: int(k) for k,v in ci.items()}
@@ -13,7 +13,7 @@ def get_imagenet_categories():
if getenv("MNISTMOCK"):
BASEDIR = pathlib.Path(__file__).parent / "mnist"
@functools.lru_cache(None)
@functools.cache
def get_train_files():
if not BASEDIR.exists():
from extra.datasets.fake_imagenet_from_mnist import create_fake_mnist_imagenet
@@ -29,7 +29,7 @@ else:
if not (files:=glob.glob(p:=str(BASEDIR / "train/*/*"))): raise FileNotFoundError(f"No training files in {p}")
return files
@functools.lru_cache(None)
@functools.cache
def get_val_files():
if not (files:=glob.glob(p:=str(BASEDIR / "val/*/*"))): raise FileNotFoundError(f"No validation files in {p}")
return files

View File

@@ -15,11 +15,11 @@ BASEDIR = Path(__file__).parent / "kits19" / "data"
TRAIN_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "train"
VAL_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "val"
@functools.lru_cache(None)
@functools.cache
def get_train_files():
return sorted([x for x in BASEDIR.iterdir() if x.stem.startswith("case") and int(x.stem.split("_")[-1]) < 210 and x not in get_val_files()])
@functools.lru_cache(None)
@functools.cache
def get_val_files():
data = fetch("https://raw.githubusercontent.com/mlcommons/training/master/retired_benchmarks/unet3d/pytorch/evaluation_cases.txt").read_text()
return sorted([x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.split("\n")])

View File

@@ -89,7 +89,7 @@ required_input_python_consts: dict[str, tuple[int, ...]] = {
}
cache_misses = 0
@functools.lru_cache(None)
@functools.cache
def _cached_to_python_const(t:Tensor):
if t.dtype is dtypes.uint8: return t.data().tobytes()
if 0 in t.shape: return []

View File

@@ -150,7 +150,7 @@ def fill_scalar(x, y):
@torch.library.impl("aten::_local_scalar_dense", "privateuseone")
def _local_scalar_dense(tensor): return unwrap(tensor).item()
@functools.lru_cache(None)
@functools.cache
def cached_to_movement_ops(shape, st) -> list:
mops = to_movement_ops(st)
if mops[0] == (MovementOps.RESHAPE, shape): mops = mops[1:]

View File

@@ -123,7 +123,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
# ***** optional patterns *****
powers_of_two = {2**i:i for i in range(64)}
@functools.lru_cache(None)
@functools.cache
def get_late_rewrite_patterns(ops, force_transcendental=False):
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]

View File

@@ -15,7 +15,7 @@ def _expand_arg_to_idx(args:tuple[tuple[int, int], ...], rpk:dict[int, int]) ->
def _choices_from_args(args:tuple[tuple[int, int], ...]) -> list[dict[int, int]]:
return [dict(x) for x in itertools.product(*[zip(itertools.repeat(axis), range(m)) for axis,m in args])]
@functools.lru_cache(None)
@functools.cache
def _swizzle_args(cargs:tuple[tuple[int, int], ...], eargs:tuple[tuple[int, int], ...], exclude_args:tuple[int, ...]) -> list[int]:
return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
@@ -101,7 +101,7 @@ expander = PatternMatcher([
])
def create_gate(root:UOp) -> UOp|None:
@functools.lru_cache(None)
@functools.cache
def _gate_srcs(u:UOp, gate:UOp) -> UOp:
if u.op is Ops.BARRIER: return u
if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER:

View File

@@ -575,7 +575,7 @@ class Kernel:
return name + colored(num, 'BLACK')
def get_optimized_ast(self, name_override:Optional[str]=None) -> UOp:
@functools.lru_cache(None)
@functools.cache
def fixup_ast(op:UOp) -> UOp:
ret = op.replace(src=tuple(fixup_ast(x) for x in op.src))
if op.op in GroupOp.Buffer and op in self.bufs:

View File

@@ -15,12 +15,12 @@ class _Device:
def __init__(self) -> None:
self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
self._opened_devices:set[str] = set()
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
@functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none
def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):])
# NOTE: you can't cache canonicalize in case Device.DEFAULT changes
def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
@functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none
def __get_canonicalized_item(self, ix:str) -> Compiled:
cpn = multiprocessing.current_process().name
assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"can only open device {ix} from parent, not {cpn}"

View File

@@ -33,7 +33,7 @@ class DType(metaclass=DTypeMetaClass):
def base(self): return self
@property
def vcount(self): return self.count
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
@functools.cache # pylint: disable=method-cache-max-size-none
def vec(self, sz:int) -> DType:
assert self.count == 1, f"can't vectorize {self} with size {sz}"
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
@@ -50,7 +50,7 @@ class PtrDType(DType):
size: int = -1 # -1 is unlimited size
@property
def base(self): return self._base
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
@functools.cache # pylint: disable=method-cache-max-size-none
def vec(self, sz:int) -> DType:
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
if sz == 1: return self # sz=1 is a scalar
@@ -73,13 +73,13 @@ class ImageDType(PtrDType):
class dtypes:
@staticmethod
@functools.lru_cache(None)
@functools.cache
def is_float(x: DType) -> bool: return x.scalar() in dtypes.floats or isinstance(x, ImageDType)
@staticmethod # static methods on top, or bool in the type info will refer to dtypes.bool
@functools.lru_cache(None)
@functools.cache
def is_int(x: DType) -> bool: return x.scalar() in dtypes.ints
@staticmethod
@functools.lru_cache(None)
@functools.cache
def is_unsigned(x: DType) -> bool: return x.scalar() in dtypes.uints
@staticmethod
def is_bool(x: DType) -> bool: return x.scalar() == dtypes.bool
@@ -99,12 +99,12 @@ class dtypes:
# TODO: should truncate here
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
@staticmethod
@functools.lru_cache(None)
@functools.cache
def min(dtype:DType):
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
return -float("inf") if dtypes.is_float(dtype) else False
@staticmethod
@functools.lru_cache(None)
@functools.cache
def max(dtype:DType):
if dtypes.is_int(dtype): return 2**(dtype.itemsize*8)-1+dtypes.min(dtype)
return float("inf") if dtypes.is_float(dtype) else True
@@ -165,10 +165,10 @@ promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16],
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }
@functools.lru_cache(None)
@functools.cache
def _get_recursive_parents(dtype:DType) -> set[DType]:
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
@functools.lru_cache(None)
@functools.cache
def least_upper_dtype(*ds:DType) -> DType:
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float)
@@ -210,12 +210,12 @@ def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] #
import numpy as np
return dtypes.fields()[np.dtype(npdtype).name]
@functools.lru_cache(None)
@functools.cache
def _to_torch_dtype(dtype:DType) -> Optional['torch.dtype']: # type: ignore [name-defined] # noqa: F821
import numpy as np, torch
# NOTE: torch doesn't expose this mapping with a stable API
try: return torch.from_numpy(np.array([], dtype=_to_np_dtype(dtype))).dtype
except TypeError: return None
@functools.lru_cache(None)
@functools.cache
def _from_torch_dtype(torchdtype:'torch.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
return {v:k for k in dtypes.all if (v:=_to_torch_dtype(k)) is not None}[torchdtype]

View File

@@ -47,7 +47,7 @@ pm_gradient = PatternMatcher([
# copied from tensor.py, get relevant toposort of gradients
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
@functools.lru_cache(None)
@functools.cache
def is_in_target_path(x:UOp) -> bool: return any(u in targets or is_in_target_path(u) for u in x.src)
def _walk(node:UOp, visited:set[UOp]) -> Iterator[UOp]:
visited.add(node)

View File

@@ -78,9 +78,9 @@ def pluralize(st:str, cnt:int): return f"{cnt} {st}"+('' if cnt == 1 else 's')
# for length N coefficients `p`, returns p[0] * x**(N-1) + p[1] * x**(N-2) + ... + p[-2] * x + p[-1]
def polyN(x:T, p:list[float]) -> T: return functools.reduce(lambda acc,c: acc*x+c, p, 0.0) # type: ignore
@functools.lru_cache(maxsize=None)
@functools.cache
def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
@functools.lru_cache(maxsize=None)
@functools.cache
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
def temp(x:str, append_user:bool=False) -> str:
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
@@ -295,7 +295,7 @@ def to_mv(ptr:int, sz:int) -> memoryview: return memoryview(ctypes.cast(ptr, cty
def mv_address(mv): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
def to_char_p_p(options: list[bytes], to_type=ctypes.c_char):
return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options])
@functools.lru_cache(maxsize=None)
@functools.cache
def init_c_struct_t(fields: tuple[tuple[str, ctypes._SimpleCData], ...]):
class CStruct(ctypes.Structure):
_pack_, _fields_ = 1, fields

View File

@@ -707,7 +707,7 @@ def get_location() -> tuple[str, int]:
frm = frm.f_back
return frm.f_code.co_filename, frm.f_lineno
@functools.lru_cache(None)
@functools.cache
def lines(fn) -> list[str]:
with open(fn) as f: return f.readlines()
@@ -745,10 +745,10 @@ class UPat(MathTrait):
def or_casted(self, name:str|None=None): return UPat.any(self if name is None else self.named(name), UPat(Ops.CAST, name=name, src=(self,)))
@staticmethod
@functools.lru_cache(None)
@functools.cache
def var(name:Optional[str]=None, dtype:Optional[Union[DType, tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name)
@staticmethod
@functools.lru_cache(None)
@functools.cache
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True): return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name)
@staticmethod
def const(dtype:Optional[Union[DType, tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
@@ -823,7 +823,7 @@ class PatternMatcher:
def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "<lambda>" else fxn) for x,fxn in self.patterns],)
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
@functools.cache # pylint: disable=method-cache-max-size-none
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
def rewrite(self, uop:UOp, ctx=None) -> UOp|None:

View File

@@ -35,7 +35,7 @@ libobjc.sel_registerName.restype = objc_id
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
libdispatch.dispatch_data_create.restype = objc_instance
@functools.lru_cache(None)
@functools.cache
def msg(selector: str, restype: type[T] = objc_id): # type: ignore [assignment]
resname = libobjc.sel_registerName(selector.encode())
sender = libobjc["objc_msgSend"] # Using attribute access returns a new reference so setting restype is safe
@@ -43,7 +43,7 @@ def msg(selector: str, restype: type[T] = objc_id): # type: ignore [assignment]
def _msg(ptr: objc_id, *args: Any) -> T: return sender(ptr, resname, *args)
return _msg
@functools.lru_cache(None)
@functools.cache
def to_ns_str(s: str): return msg("stringWithUTF8String:", objc_instance)(libobjc.objc_getClass(b"NSString"), s.encode())
def from_ns_str(s): return bytes(msg("UTF8String", ctypes.c_char_p)(s)).decode()

View File

@@ -29,7 +29,7 @@ def folded_upcast(u: UOp):
with Context(TRACK_MATCH_STATS=0):
return upcast(graph_rewrite(u, sym, {}))
@functools.lru_cache(None)
@functools.cache
def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...]]=None) -> tuple[UOp, UOp]:
idx, valid = views[-1].to_indexed_uops(_idxs)
for view in reversed(views[0:-1]):
@@ -37,7 +37,7 @@ def views_to_indexed_uops(views: tuple[View, ...], _idxs:Optional[tuple[UOp, ...
idx, valid = view.to_indexed_uops([sint_to_uop(i) for i in unravel(view.shape, idx)], valid)
return idx, valid
@functools.lru_cache(None)
@functools.cache
def views_to_real_strides(views: tuple[View, ...], ignore_valid=False) -> tuple[Optional[sint], ...]:
# NOTE: if a stride is not always valid, it will be None
if len(views) == 1 and views[-1].mask is None: return views[-1].strides

View File

@@ -6,17 +6,17 @@ from tinygrad.dtype import dtypes
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop
from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
@functools.lru_cache(maxsize=None)
@functools.cache
def canonicalize_strides(shape:tuple[sint, ...], strides:tuple[sint, ...]) -> tuple[sint, ...]:
return tuple(0 if s == 1 else st for s, st in zip(shape, strides))
@functools.lru_cache(maxsize=None)
@functools.cache
def strides_for_shape(shape:tuple[sint, ...]) -> tuple[sint, ...]:
if not shape: return ()
strides = tuple(itertools.accumulate(reversed(shape[1:]), operator.mul, initial=1))[::-1]
return canonicalize_strides(shape, strides)
@functools.lru_cache(maxsize=None)
@functools.cache
def merge_dims(shape:tuple[int, ...], strides:tuple[int, ...], mask:Optional[tuple[tuple[int, int], ...]]=None) -> tuple[tuple[int, int, int], ...]:
# merge contiguous sub-parts or zero strided dims
# any stride 0, masked from dim=1, or contiguous part is merged into next dim.
@@ -38,7 +38,7 @@ def merge_dims(shape:tuple[int, ...], strides:tuple[int, ...], mask:Optional[tup
merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1
return tuple(ret)
@functools.lru_cache(maxsize=None)
@functools.cache
def _reshape_mask(_mask:Optional[tuple[tuple[sint, sint], ...]], old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) \
-> Optional[tuple[tuple[sint, sint], ...]]:
"""Returns the new mask if reshape is possible, and None if not possible."""
@@ -98,14 +98,14 @@ class View:
if resolve(m[1] != sh): vexpr = vexpr * (idx < m[1])
return iexpr, vexpr
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@functools.cache # pylint: disable=method-cache-max-size-none
def size(self) -> int:
ret = prod([x.vmax if isinstance(x, UOp) else x for x in self.shape])
assert isinstance(ret, int), f"{ret=} is not int"
return ret
@staticmethod
@functools.lru_cache(maxsize=None)
@functools.cache
def create(shape:tuple[sint, ...], strides:Optional[tuple[sint, ...]]=None, offset:sint=0, mask:Optional[tuple[tuple[sint, sint], ...]]=None):
if not all(s >= 0 for s in shape): raise ValueError(f"Trying to create View with negative dimension: {shape=}")
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
@@ -131,12 +131,12 @@ class View:
contiguous = offset == 0 and mask is None and strides == strides_for_shape(shape)
return View(shape, strides, offset, mask, contiguous)
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
@functools.cache # pylint: disable=method-cache-max-size-none
def vars(self) -> set[Variable]:
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, UOp)], set())
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
@functools.cache # pylint: disable=method-cache-max-size-none
def unbind(self) -> tuple[View, dict[Variable, int]]:
var_unboundvar_val = [(v, v.unbind()) for v in self.vars()]
unbound_vars = {v:uv for v,(uv,_) in var_unboundvar_val}
@@ -147,7 +147,7 @@ class View:
new_mask = tuple((substitute(x[0]), substitute(x[1])) for x in self.mask) if self.mask is not None else None
return View.create(new_shape, new_strides, new_offset, new_mask), dict(x[1] for x in var_unboundvar_val)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@functools.cache # pylint: disable=method-cache-max-size-none
def __add__(self, vm1:View) -> Optional[View]:
vm2 = self
if vm2.contiguous: return vm1
@@ -207,14 +207,14 @@ class View:
return View.create(vm1.shape, tuple(strides), sum(o * s for o, s in zip(origin, vm2.strides)) + vm2.offset)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@functools.cache # pylint: disable=method-cache-max-size-none
def invert(self, out_shape:tuple[sint, ...]) -> Optional[View]:
ret = View.create(self.shape)
if self.mask: ret = ret.shrink(self.mask)
ret = ret.flip(tuple(x < 0 for x in self.strides)).permute(argsort(tuple(-x if x > 0 else x for x in self.strides)))
return ret if prod(ret.shape) == prod(out_shape) else None # don't support shrink, expand, or stride != (-1, 1)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@functools.cache # pylint: disable=method-cache-max-size-none
def minify(self):
min_shape = tuple(x[0] for x in merge_dims(self.shape, self.strides, self.mask))
return nv if (nv := self.reshape(min_shape)) else self
@@ -228,7 +228,7 @@ class View:
mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask
return View.create(tuple([y-x for x,y in arg]), self.strides, self.offset+offset, mask)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@functools.cache # pylint: disable=method-cache-max-size-none
def pad(self, arg: tuple[tuple[sint, sint], ...]) -> View:
assert len(arg) == len(self.shape), f"invalid pad {arg} for {self.shape}"
# NOTE: not checking for symbolic arg
@@ -239,14 +239,14 @@ class View:
return self.__unsafe_resize(zvarg, mask=mask)
return self
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@functools.cache # pylint: disable=method-cache-max-size-none
def shrink(self, arg: tuple[tuple[sint, sint], ...]) -> View:
assert len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
# NOTE: not checking for symbolic arg
for s,(b,e) in zip(self.shape,arg): assert not all_int([s,b,e]) or (0<=b<=e<=s), f"invalid shrink {arg} for {self.shape}"
return self.__unsafe_resize(arg)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@functools.cache # pylint: disable=method-cache-max-size-none
def expand(self, new_shape: tuple[sint, ...]) -> View:
if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}")
# NOTE: does not check multiple of symbolic shape
@@ -257,19 +257,19 @@ class View:
for m,s,ns in zip(self.mask, self.shape, new_shape)]) if self.mask else None
return View.create(new_shape, self.strides, self.offset, mask)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@functools.cache # pylint: disable=method-cache-max-size-none
def permute(self, axis: tuple[int, ...]) -> View:
assert sorted(axis) == list(range(len(self.shape))), f"invalid permutation {axis} of len {len(self.shape)}"
return View.create(tuple(self.shape[a] for a in axis), tuple(self.strides[a] for a in axis), self.offset,
tuple(self.mask[a] for a in axis) if self.mask is not None else None)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@functools.cache # pylint: disable=method-cache-max-size-none
def flip(self, arg: tuple[bool, ...]) -> View:
offset = sum((s-1)*z for s,z,f in zip(self.shape, self.strides, arg) if f)
mask = tuple((s-my,s-mx) if f else (mx,my) for (mx,my),s,f in zip(self.mask, self.shape, arg)) if self.mask is not None else None
return View.create(self.shape, tuple(-z if f else z for z,f in zip(self.strides, arg)), self.offset+offset, mask)
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
@functools.cache # pylint: disable=method-cache-max-size-none
def reshape(self, new_shape: tuple[sint, ...]) -> Optional[View]:
if self.shape == new_shape: return self

View File

@@ -32,7 +32,7 @@ class GraphRewriteMetadata(TypedDict):
kernel_code: str|None # optionally render the final kernel code
name: str|None # optional name of the rewrite
@functools.lru_cache(None)
@functools.cache
def render_program(k:Kernel): return k.opts.render(k.uops)
def to_metadata(k:Any, v:TrackedGraphRewrite) -> GraphRewriteMetadata:
return {"loc":v.loc, "match_count":len(v.matches), "code_line":lines(v.loc[0])[v.loc[1]-1].strip(),