leftover lru_cache on UPat [pr] (#7257)

* leftover lru_cache on UPat [pr]

* fix mypy
This commit is contained in:
George Hotz
2024-10-24 15:11:24 +07:00
committed by GitHub
parent 532b7b018c
commit a7be9dfd71

View File

@@ -531,15 +531,12 @@ class UPat(MathTrait):
def any(*src): return UPatAny(src=src)
@staticmethod
@functools.lru_cache(None)
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UPat(dtype=dtype, name=name)
def var(name:Optional[str]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name)
@staticmethod
@functools.lru_cache(None)
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True):
return UPat((UOps.CONST, UOps.VCONST) if vec else UOps.CONST, dtype=dtype, name=name)
@staticmethod
@functools.lru_cache(None)
def const(dtype:Optional[DType], b:ConstType): return UPat(UOps.CONST, dtype=dtype, arg=b)
def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(UOps.CONST, dtype=dtype, arg=b)
# copied from UOp
def cast(self, dtype=None): return UPat(UOps.CAST, dtype, (self,))
@@ -550,7 +547,7 @@ class UPat(MathTrait):
@staticmethod
def store(*src:UPat): return UPat(UOps.STORE, dtypes.void, src)
def const_like(self, b:ConstType|Variable|Tuple[ConstType]): return UPat.const(self.dtype, b)
def const_like(self, b:ConstType|Variable|Tuple[ConstType]): return UPat.const(self.dtype, cast(ConstType, b))
def alu(self, arg, *src:UPat):
asrc = (self,)+src
return UPat(UOps.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype, list(asrc) if arg in COMMUTATIVE else asrc, arg)