mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-27 07:48:07 -05:00
replace RANGE max fold with generic max fold (#5676)
This commit is contained in:
@@ -199,9 +199,7 @@ constant_folder = PatternMatcher([
|
||||
(UOp.max(UOp.cvar('c'), UOp(UOps.SPECIAL).name('s')+UOp.cvar('c2')), lambda c,s,c2: (s+c2) if 0 >= c.arg else None), # TODO: generic
|
||||
(UOp.max(UOp.cvar('c'), -(UOp(UOps.SPECIAL).name('s')+UOp.cvar('c2'))), lambda c,s,c2: -(s+c2) if -(s.arg[1]-1+c2.arg) >= c.arg else None),
|
||||
# max on range can go away (ugh: copy of SPECIAL, and with/without const)
|
||||
(UOp.max(UOp.cvar('c'), UOp(UOps.RANGE).name('s')), lambda c,s: s if s.src[0].arg >= c.arg else None), # TODO: generic
|
||||
(UOp.max(UOp.cvar('c'), UOp(UOps.RANGE).name('s')+UOp.cvar('c2')), lambda c,s,c2: (s+c2) if s.src[0].arg >= c.arg else None), # TODO: generic
|
||||
(UOp.max(UOp.cvar('c'), -(UOp(UOps.RANGE).name('s'))), lambda c,s: -s if -(s.src[1].arg-1) >= c.arg else None),
|
||||
(UOp.max(UOp.cvar('c'), -(UOp(UOps.RANGE).name('s')+UOp.cvar('c2'))), lambda c,s,c2: -(s+c2) if -(s.src[1].arg-1+c2.arg) >= c.arg else None),
|
||||
# const rules
|
||||
(UOp(UOps.GEP, src=(UOp.cvar("c"),)).name("root"), lambda root, c: root.const(c.arg)),
|
||||
|
||||
@@ -105,10 +105,14 @@ class UOp:
|
||||
@functools.cached_property
|
||||
def _min_max(self) -> Tuple[UOp, UOp]:
|
||||
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
|
||||
if self.op is UOps.DEFINE_VAR: return self.src[0], self.src[1]
|
||||
if self.op in (UOps.DEFINE_VAR, UOps.RANGE):
|
||||
return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else self.const(dtypes.max(cast(DType, self.dtype)))
|
||||
if self.op is UOps.SPECIAL:
|
||||
return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else self.const(dtypes.max(cast(DType, self.dtype)))
|
||||
if self.op is UOps.CONST: return self, self
|
||||
if self.op is UOps.ALU and self.arg is UnaryOps.NEG and self.dtype != dtypes.bool:
|
||||
nmin, nmax = self.src[0]._min_max
|
||||
return self.const(-nmax.arg), self.const(-nmin.arg)
|
||||
return self.const(dtypes.min(cast(DType, self.dtype))), self.const(dtypes.max(cast(DType, self.dtype)))
|
||||
|
||||
class UPat:
|
||||
|
||||
Reference in New Issue
Block a user