UOp bounds for max (#5820)

This commit is contained in:
chenyu
2024-07-30 17:54:44 -04:00
committed by GitHub
parent 3630208a01
commit d072e628da

View File

@@ -107,7 +107,7 @@ class UOp:
if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None
if self.op is UOps.CONST: return self, self
if self.op is UOps.ALU:
s0,s1,_ = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(3)]
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
if self.arg is UnaryOps.NEG and self.dtype != dtypes.bool and not dtypes.is_unsigned(cast(DType, self.dtype)):
return self.const(-s0.vmax.arg), self.const(-s0.vmin.arg)
if self.arg is BinaryOps.ADD: return self.const(s0.vmin.arg+s1.vmin.arg), self.const(s0.vmax.arg+s1.vmax.arg)
@@ -117,8 +117,8 @@ class UOp:
Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg)
return self.const(Lmin*Rmin), self.const(Lmax*Rmax)
if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.const(0), self.const(s1.arg-1)
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST and s1.arg > 0:
return self.const(s0.vmin.arg//s1.arg), self.const(s0.vmax.arg//s1.arg)
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST and s1.arg > 0: return self.const(s0.vmin.arg//s1.arg), self.const(s0.vmax.arg//s1.arg)
if self.arg is BinaryOps.MAX: return self.const(max(s0.vmin.arg, s1.vmin.arg)), self.const(max(s0.vmax.arg, s1.vmax.arg))
return None, None
@dataclass(frozen=True, repr=False) # reuse repr from UOp