From cdd71840c535e53b670653998405eb28ee5ee931 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 9 Sep 2024 15:54:53 +0800 Subject: [PATCH] pre-fetch min/max --- tinygrad/ops.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b99235bdcb..4dc3f7628b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -423,21 +423,22 @@ class UOp(MathTrait): # TODO: UOps.SPECIAL is UOps.DEFINE_VAR if self.op is UOps.SPECIAL: return self.const_like(0), self.const_like(self.arg[1]-1) if isinstance(self.arg[1], int) else _max_bound(self.dtype) if self.op is UOps.CONST: return self, self - if self.op is UOps.ALU and cast(DType, self.dtype).count == 1: - s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)] - if self.arg is BinaryOps.ADD: return self.sconst_like(s0.vmin.arg+s1.vmin.arg), self.sconst_like(s0.vmax.arg+s1.vmax.arg) - if self.arg is BinaryOps.MUL and (s0.vmin.arg >= 0 or s1.vmin.arg >= 0): + if self.op is UOps.ALU and cast(DType, self.dtype).count == 1 and self.arg in BinaryOps: + s0_min, s0_max = self.src[0]._min_max + s1_min, s1_max = self.src[1]._min_max + if self.arg is BinaryOps.ADD: return self.sconst_like(s0_min.arg+s1_min.arg), self.sconst_like(s0_max.arg+s1_max.arg) + if self.arg is BinaryOps.MUL and (s0_min.arg >= 0 or s1_min.arg >= 0): # handle at lease one is non-negative - Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg) - Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg) + Lmin, Lmax = (s0_min.arg, s0_max.arg) if s1_min.arg >= 0 else (s0_max.arg, s0_min.arg) + Rmin, Rmax = (s1_min.arg, s1_max.arg) if s0_min.arg >= 0 else (s1_max.arg, s1_min.arg) assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}" return self.sconst_like(Lmin*Rmin), self.sconst_like(Lmax*Rmax) - if self.arg is BinaryOps.MOD and s1.vmin.arg > 0: return self.sconst_like(0), self.sconst_like(s1.vmax.arg-1) - if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST: - if s1.arg > 0: return self.sconst_like(s0.vmin.arg//s1.arg), self.sconst_like(s0.vmax.arg//s1.arg) - if s1.arg < 0: return self.sconst_like(-(s0.vmax.arg//-s1.arg)), self.sconst_like(-(s0.vmin.arg//-s1.arg)) - if self.arg is BinaryOps.MAX: return self.sconst_like(max(s0.vmin.arg, s1.vmin.arg)), self.sconst_like(max(s0.vmax.arg, s1.vmax.arg)) - if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, s0.vmax.arg 0: return self.sconst_like(0), self.sconst_like(s1_max.arg-1) + if self.arg is BinaryOps.IDIV and (s1:=self.src[1]).op is UOps.CONST: + if s1.arg > 0: return self.sconst_like(s0_min.arg//s1.arg), self.sconst_like(s0_max.arg//s1.arg) + if s1.arg < 0: return self.sconst_like(-(s0_max.arg//-s1.arg)), self.sconst_like(-(s0_min.arg//-s1.arg)) + if self.arg is BinaryOps.MAX: return self.sconst_like(max(s0_min.arg, s1_min.arg)), self.sconst_like(max(s0_max.arg, s1_max.arg)) + if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, s0_max.arg