From 69856c8f38e799b12f50ecdafd3eaaba00abc590 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 9 Sep 2024 15:49:44 +0800 Subject: [PATCH] non-optional bounds (faster) [run_process_replay] --- tinygrad/ops.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 966658b0e6..b99235bdcb 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -324,6 +324,11 @@ BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST} END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)} +@functools.lru_cache(None) +def _min_bound(dtype:DType): return UOp.const(dtype.scalar(), dtypes.min(dtype)) +@functools.lru_cache(None) +def _max_bound(dtype:DType): return UOp.const(dtype.scalar(), dtypes.max(dtype)) + @dataclass(frozen=True, eq=False) class UOp(MathTrait): op: UOps @@ -407,18 +412,16 @@ class UOp(MathTrait): if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1 return None # generic None if we aren't sure @property - def vmin(self) -> UOp: - return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst_like(dtypes.min(cast(DType, self.dtype))) + def vmin(self) -> UOp: return self._min_max[0] @property - def vmax(self) -> UOp: - return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst_like(dtypes.max(cast(DType, self.dtype))) + def vmax(self) -> UOp: return self._min_max[1] @functools.cached_property - def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]: + def _min_max(self) -> Tuple[UOp, UOp]: # NOTE: returned UOp is assumed to be CONST - if self.op is UOps.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2] if isinstance(self.arg[2].arg, int) else None + if self.op is UOps.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2] if isinstance(self.arg[2].arg, int) else _max_bound(self.dtype) if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax # TODO: UOps.SPECIAL is UOps.DEFINE_VAR - if self.op is UOps.SPECIAL: return self.const_like(0), self.const_like(self.arg[1]-1) if isinstance(self.arg[1], int) else None + if self.op is UOps.SPECIAL: return self.const_like(0), self.const_like(self.arg[1]-1) if isinstance(self.arg[1], int) else _max_bound(self.dtype) if self.op is UOps.CONST: return self, self if self.op is UOps.ALU and cast(DType, self.dtype).count == 1: s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)] @@ -435,7 +438,7 @@ class UOp(MathTrait): if s1.arg < 0: return self.sconst_like(-(s0.vmax.arg//-s1.arg)), self.sconst_like(-(s0.vmin.arg//-s1.arg)) if self.arg is BinaryOps.MAX: return self.sconst_like(max(s0.vmin.arg, s1.vmin.arg)), self.sconst_like(max(s0.vmax.arg, s1.vmax.arg)) if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, s0.vmax.arg