pre-fetch min/max

This commit is contained in:
George Hotz
2024-09-09 15:54:53 +08:00
parent 69856c8f38
commit cdd71840c5

View File

@@ -423,21 +423,22 @@ class UOp(MathTrait):
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
if self.op is UOps.SPECIAL: return self.const_like(0), self.const_like(self.arg[1]-1) if isinstance(self.arg[1], int) else _max_bound(self.dtype)
if self.op is UOps.CONST: return self, self
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
if self.arg is BinaryOps.ADD: return self.sconst_like(s0.vmin.arg+s1.vmin.arg), self.sconst_like(s0.vmax.arg+s1.vmax.arg)
if self.arg is BinaryOps.MUL and (s0.vmin.arg >= 0 or s1.vmin.arg >= 0):
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1 and self.arg in BinaryOps:
s0_min, s0_max = self.src[0]._min_max
s1_min, s1_max = self.src[1]._min_max
if self.arg is BinaryOps.ADD: return self.sconst_like(s0_min.arg+s1_min.arg), self.sconst_like(s0_max.arg+s1_max.arg)
if self.arg is BinaryOps.MUL and (s0_min.arg >= 0 or s1_min.arg >= 0):
# handle at lease one is non-negative
Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg)
Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg)
Lmin, Lmax = (s0_min.arg, s0_max.arg) if s1_min.arg >= 0 else (s0_max.arg, s0_min.arg)
Rmin, Rmax = (s1_min.arg, s1_max.arg) if s0_min.arg >= 0 else (s1_max.arg, s1_min.arg)
assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}"
return self.sconst_like(Lmin*Rmin), self.sconst_like(Lmax*Rmax)
if self.arg is BinaryOps.MOD and s1.vmin.arg > 0: return self.sconst_like(0), self.sconst_like(s1.vmax.arg-1)
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
if s1.arg > 0: return self.sconst_like(s0.vmin.arg//s1.arg), self.sconst_like(s0.vmax.arg//s1.arg)
if s1.arg < 0: return self.sconst_like(-(s0.vmax.arg//-s1.arg)), self.sconst_like(-(s0.vmin.arg//-s1.arg))
if self.arg is BinaryOps.MAX: return self.sconst_like(max(s0.vmin.arg, s1.vmin.arg)), self.sconst_like(max(s0.vmax.arg, s1.vmax.arg))
if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, s0.vmax.arg<s1.vmin.arg), UOp.const(dtypes.bool, s0.vmin.arg<s1.vmax.arg))
if self.arg is BinaryOps.MOD and s1_min.arg > 0: return self.sconst_like(0), self.sconst_like(s1_max.arg-1)
if self.arg is BinaryOps.IDIV and (s1:=self.src[1]).op is UOps.CONST:
if s1.arg > 0: return self.sconst_like(s0_min.arg//s1.arg), self.sconst_like(s0_max.arg//s1.arg)
if s1.arg < 0: return self.sconst_like(-(s0_max.arg//-s1.arg)), self.sconst_like(-(s0_min.arg//-s1.arg))
if self.arg is BinaryOps.MAX: return self.sconst_like(max(s0_min.arg, s1_min.arg)), self.sconst_like(max(s0_max.arg, s1_max.arg))
if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, s0_max.arg<s1_min.arg), UOp.const(dtypes.bool, s0_min.arg<s1_max.arg))
return _min_bound(self.dtype), _max_bound(self.dtype)
@dataclass(frozen=True)