non-optional bounds (faster) [run_process_replay]

This commit is contained in:
George Hotz
2024-09-09 15:49:44 +08:00
parent c5bae55ec8
commit 69856c8f38

View File

@@ -324,6 +324,11 @@ BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST}
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
@functools.lru_cache(None)
def _min_bound(dtype:DType): return UOp.const(dtype.scalar(), dtypes.min(dtype))
@functools.lru_cache(None)
def _max_bound(dtype:DType): return UOp.const(dtype.scalar(), dtypes.max(dtype))
@dataclass(frozen=True, eq=False)
class UOp(MathTrait):
op: UOps
@@ -407,18 +412,16 @@ class UOp(MathTrait):
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
return None # generic None if we aren't sure
@property
def vmin(self) -> UOp:
return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst_like(dtypes.min(cast(DType, self.dtype)))
def vmin(self) -> UOp: return self._min_max[0]
@property
def vmax(self) -> UOp:
return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst_like(dtypes.max(cast(DType, self.dtype)))
def vmax(self) -> UOp: return self._min_max[1]
@functools.cached_property
def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
def _min_max(self) -> Tuple[UOp, UOp]:
# NOTE: returned UOp is assumed to be CONST
if self.op is UOps.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2] if isinstance(self.arg[2].arg, int) else None
if self.op is UOps.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2] if isinstance(self.arg[2].arg, int) else _max_bound(self.dtype)
if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
if self.op is UOps.SPECIAL: return self.const_like(0), self.const_like(self.arg[1]-1) if isinstance(self.arg[1], int) else None
if self.op is UOps.SPECIAL: return self.const_like(0), self.const_like(self.arg[1]-1) if isinstance(self.arg[1], int) else _max_bound(self.dtype)
if self.op is UOps.CONST: return self, self
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
@@ -435,7 +438,7 @@ class UOp(MathTrait):
if s1.arg < 0: return self.sconst_like(-(s0.vmax.arg//-s1.arg)), self.sconst_like(-(s0.vmin.arg//-s1.arg))
if self.arg is BinaryOps.MAX: return self.sconst_like(max(s0.vmin.arg, s1.vmin.arg)), self.sconst_like(max(s0.vmax.arg, s1.vmax.arg))
if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, s0.vmax.arg<s1.vmin.arg), UOp.const(dtypes.bool, s0.vmin.arg<s1.vmax.arg))
return None, None
return _min_bound(self.dtype), _max_bound(self.dtype)
@dataclass(frozen=True)
class KernelInfo: