diff --git a/tinygrad/ops.py b/tinygrad/ops.py index c99df8cce2..8ce2eb0869 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -531,15 +531,12 @@ class UPat(MathTrait): def any(*src): return UPatAny(src=src) @staticmethod - @functools.lru_cache(None) - def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UPat(dtype=dtype, name=name) + def var(name:Optional[str]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None): return UPat(dtype=dtype, name=name) @staticmethod - @functools.lru_cache(None) def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True): return UPat((UOps.CONST, UOps.VCONST) if vec else UOps.CONST, dtype=dtype, name=name) @staticmethod - @functools.lru_cache(None) - def const(dtype:Optional[DType], b:ConstType): return UPat(UOps.CONST, dtype=dtype, arg=b) + def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(UOps.CONST, dtype=dtype, arg=b) # copied from UOp def cast(self, dtype=None): return UPat(UOps.CAST, dtype, (self,)) @@ -550,7 +547,7 @@ class UPat(MathTrait): @staticmethod def store(*src:UPat): return UPat(UOps.STORE, dtypes.void, src) - def const_like(self, b:ConstType|Variable|Tuple[ConstType]): return UPat.const(self.dtype, b) + def const_like(self, b:ConstType|Variable|Tuple[ConstType]): return UPat.const(self.dtype, cast(ConstType, b)) def alu(self, arg, *src:UPat): asrc = (self,)+src return UPat(UOps.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype, list(asrc) if arg in COMMUTATIVE else asrc, arg)