UOp compute min and max in one call [run_process_replay] (#5674)

easier to handle cases like *-1 that flip the bounds
This commit is contained in:
chenyu
2024-07-24 00:51:23 -04:00
committed by GitHub
parent 4e85761d40
commit d1d81b359f

View File

@@ -99,18 +99,17 @@ class UOp:
if self.arg is BinaryOps.MUL: return any(x.divides(v) for x in self.src)
return False # generic false if we aren't sure
@functools.cached_property
def vmax(self) -> UOp:
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
if self.op is UOps.DEFINE_VAR: return self.src[1]
if self.op is UOps.SPECIAL and isinstance(self.arg[1], int): return self.const(self.arg[1])
if self.op is UOps.CONST: return self
return self.const(dtypes.max(cast(DType, self.dtype)))
def vmin(self) -> UOp: return self._min_max[0]
@functools.cached_property
def vmin(self) -> UOp:
if self.op is UOps.DEFINE_VAR: return self.src[0]
if self.op is UOps.SPECIAL: return self.const(0)
if self.op is UOps.CONST: return self
return self.const(dtypes.min(cast(DType, self.dtype)))
def vmax(self) -> UOp: return self._min_max[1]
@functools.cached_property
def _min_max(self):
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
if self.op is UOps.DEFINE_VAR: return self.src[0], self.src[1]
if self.op is UOps.SPECIAL:
return self.const(0), self.const(self.arg[1]) if isinstance(self.arg[1], int) else self.const(dtypes.max(cast(DType, self.dtype)))
if self.op is UOps.CONST: return self, self
return self.const(dtypes.min(cast(DType, self.dtype))), self.const(dtypes.max(cast(DType, self.dtype)))
class UPat:
def __init__(self, op:Optional[Union[UOps, Set[UOps]]]=None, arg:Any=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None,