mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 23:38:58 -05:00
non-optional bounds (faster) [run_process_replay]
This commit is contained in:
@@ -324,6 +324,11 @@ BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST}
|
||||
|
||||
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _min_bound(dtype:DType): return UOp.const(dtype.scalar(), dtypes.min(dtype))
|
||||
@functools.lru_cache(None)
|
||||
def _max_bound(dtype:DType): return UOp.const(dtype.scalar(), dtypes.max(dtype))
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class UOp(MathTrait):
|
||||
op: UOps
|
||||
@@ -407,18 +412,16 @@ class UOp(MathTrait):
|
||||
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
||||
return None # generic None if we aren't sure
|
||||
@property
|
||||
def vmin(self) -> UOp:
|
||||
return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst_like(dtypes.min(cast(DType, self.dtype)))
|
||||
def vmin(self) -> UOp: return self._min_max[0]
|
||||
@property
|
||||
def vmax(self) -> UOp:
|
||||
return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst_like(dtypes.max(cast(DType, self.dtype)))
|
||||
def vmax(self) -> UOp: return self._min_max[1]
|
||||
@functools.cached_property
|
||||
def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
|
||||
def _min_max(self) -> Tuple[UOp, UOp]:
|
||||
# NOTE: returned UOp is assumed to be CONST
|
||||
if self.op is UOps.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2] if isinstance(self.arg[2].arg, int) else None
|
||||
if self.op is UOps.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2] if isinstance(self.arg[2].arg, int) else _max_bound(self.dtype)
|
||||
if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
|
||||
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
|
||||
if self.op is UOps.SPECIAL: return self.const_like(0), self.const_like(self.arg[1]-1) if isinstance(self.arg[1], int) else None
|
||||
if self.op is UOps.SPECIAL: return self.const_like(0), self.const_like(self.arg[1]-1) if isinstance(self.arg[1], int) else _max_bound(self.dtype)
|
||||
if self.op is UOps.CONST: return self, self
|
||||
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
|
||||
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
|
||||
@@ -435,7 +438,7 @@ class UOp(MathTrait):
|
||||
if s1.arg < 0: return self.sconst_like(-(s0.vmax.arg//-s1.arg)), self.sconst_like(-(s0.vmin.arg//-s1.arg))
|
||||
if self.arg is BinaryOps.MAX: return self.sconst_like(max(s0.vmin.arg, s1.vmin.arg)), self.sconst_like(max(s0.vmax.arg, s1.vmax.arg))
|
||||
if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, s0.vmax.arg<s1.vmin.arg), UOp.const(dtypes.bool, s0.vmin.arg<s1.vmax.arg))
|
||||
return None, None
|
||||
return _min_bound(self.dtype), _max_bound(self.dtype)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KernelInfo:
|
||||
|
||||
Reference in New Issue
Block a user