all UOp methods need dtype [run_process_replay] (#6490)

* all UOp methods need dtype [run_process_replay]

* delete all type: ignores yay
This commit is contained in:
qazal
2024-09-12 13:38:14 +08:00
committed by GitHub
parent 76487a3533
commit e5e14fc4ef

View File

@@ -379,11 +379,11 @@ class UOp(MathTrait):
def sink(self, *srcs): return UOp(UOps.SINK, dtypes.void, (self,)+srcs)
def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st)
def const_like(self, b:ConstType|Variable): return type(self).const(self.dtype, b)
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
def cast(self, dtype:DType): return type(self)(UOps.CAST, dtype, (self,))
def bitcast(self, dtype:DType): return type(self)(UOps.BITCAST, dtype, (self,))
def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar(), (self,), i)
@classmethod
def load(cls, *src:UOp, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtype, src) # type: ignore
def load(cls, *src:UOp, dtype:DType): return cls(UOps.LOAD, dtype, src)
@classmethod
def store(cls, *src:UOp): return cls(UOps.STORE, dtypes.void, src)
def alu(self, arg, *src:UOp):
@@ -393,14 +393,14 @@ class UOp(MathTrait):
return type(self)(UOps.ALU, out_dtype, (self,)+src, arg)
@classmethod
@functools.lru_cache(None)
def const(cls, dtype:Optional[DType], b:ConstType|Variable): return cls._const(dtype, b)
def const(cls, dtype:DType, b:ConstType|Variable): return cls._const(dtype, b)
@classmethod
def _const(cls, dtype:Optional[DType], b:ConstType|Variable):
def _const(cls, dtype:DType, b:ConstType|Variable):
# TODO: fix dtype of b.max after Variable is just an UOp
if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b.expr, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max)))) # type: ignore
if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b.expr, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max))))
if dtype is not None and dtype != (sdtype := dtype.scalar()):
return cls(UOps.VECTORIZE, dtype, src=tuple(cls(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
return cls(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
return cls(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
@functools.cached_property
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}}
@property # parents with self
@@ -646,7 +646,7 @@ class UPat(MathTrait):
def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,))
def gep(self, i:int): return type(self)(UOps.GEP, None, (self,), i)
@classmethod
def load(cls, *src:UPat, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtype, src) # type: ignore
def load(cls, *src:UPat, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtype, src)
@classmethod
def store(cls, *src:UPat): return cls(UOps.STORE, dtypes.void, src)