diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 66f317d39a..e7c86718f9 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -379,11 +379,11 @@ class UOp(MathTrait): def sink(self, *srcs): return UOp(UOps.SINK, dtypes.void, (self,)+srcs) def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st) def const_like(self, b:ConstType|Variable): return type(self).const(self.dtype, b) - def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,)) - def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,)) + def cast(self, dtype:DType): return type(self)(UOps.CAST, dtype, (self,)) + def bitcast(self, dtype:DType): return type(self)(UOps.BITCAST, dtype, (self,)) def gep(self, i:int): return type(self)(UOps.GEP, self.dtype.scalar(), (self,), i) @classmethod - def load(cls, *src:UOp, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtype, src) # type: ignore + def load(cls, *src:UOp, dtype:DType): return cls(UOps.LOAD, dtype, src) @classmethod def store(cls, *src:UOp): return cls(UOps.STORE, dtypes.void, src) def alu(self, arg, *src:UOp): @@ -393,14 +393,14 @@ class UOp(MathTrait): return type(self)(UOps.ALU, out_dtype, (self,)+src, arg) @classmethod @functools.lru_cache(None) - def const(cls, dtype:Optional[DType], b:ConstType|Variable): return cls._const(dtype, b) + def const(cls, dtype:DType, b:ConstType|Variable): return cls._const(dtype, b) @classmethod - def _const(cls, dtype:Optional[DType], b:ConstType|Variable): + def _const(cls, dtype:DType, b:ConstType|Variable): # TODO: fix dtype of b.max after Variable is just an UOp - if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b.expr, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max)))) # type: ignore + if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b.expr, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max)))) if dtype is not None and dtype != (sdtype := dtype.scalar()): return cls(UOps.VECTORIZE, dtype, src=tuple(cls(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count))) - return cls(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore + return cls(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) @functools.cached_property def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}} @property # parents with self @@ -646,7 +646,7 @@ class UPat(MathTrait): def bitcast(self, dtype=None): return type(self)(UOps.BITCAST, dtype, (self,)) def gep(self, i:int): return type(self)(UOps.GEP, None, (self,), i) @classmethod - def load(cls, *src:UPat, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtype, src) # type: ignore + def load(cls, *src:UPat, dtype:Optional[DType]=None): return cls(UOps.LOAD, dtype, src) @classmethod def store(cls, *src:UPat): return cls(UOps.STORE, dtypes.void, src)