From 67b036bdfda1edd158fc0e52e87dc818ae391b2d Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 24 Jul 2024 01:30:32 -0400 Subject: [PATCH] generic UOp max folding (#5675) --- tinygrad/codegen/uopgraph.py | 5 ++--- tinygrad/codegen/uops.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 4b663d930a..80ffc5a9d7 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -193,8 +193,9 @@ constant_folder = PatternMatcher([ arg=ReduceOps.SUM).name("reduce_allow_any_len"), index_collapse), # other arange folders (UOp.cvar("c1") - (UOp.var("x") + UOp.cvar("c2")), lambda c1, c2, x: (c1-c2)-x), # c1 - (x + c2) -> (c1-c2) - x + # max folding + (UOp.max(UOp.var('x'), UOp.var('y')), lambda x,y: x if x.vmin.arg >= y.vmax.arg else y if x.vmax.arg <= y.vmin.arg else None), # max on special can go away (TODO: special should be variable, same thing applies) - (UOp.max(UOp.cvar('c'), UOp(UOps.SPECIAL).name('s')), lambda c,s: c if (s.arg[1]-1) <= c.arg else None), (UOp.max(UOp.cvar('c'), UOp(UOps.SPECIAL).name('s')+UOp.cvar('c2')), lambda c,s,c2: (s+c2) if 0 >= c.arg else None), # TODO: generic (UOp.max(UOp.cvar('c'), -(UOp(UOps.SPECIAL).name('s')+UOp.cvar('c2'))), lambda c,s,c2: -(s+c2) if -(s.arg[1]-1+c2.arg) >= c.arg else None), # max on range can go away (ugh: copy of SPECIAL, and with/without const) @@ -213,8 +214,6 @@ constant_folder = PatternMatcher([ # a DEFINE_ACC without inputs is a const + GEP on a const is the const (UOp(UOps.DEFINE_ACC, src=(UOp.cvar(),)).name("root"), lambda root: UOp.cast(root.src[0], root.dtype)), (UOp(UOps.GEP, src=(UOp.cvar("x"),)).name("root"), lambda root,x: root.const(x.arg)), - # max -2147483648 - (UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x), # a conditional with the same results either way is a noop, also fold const conditionals (UOp.var().where(UOp.var("val"), UOp.var("val")), lambda val: val), (UOp.cvar('gate').where(UOp.var('c0'), UOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1), diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index f02a6bb05b..f26b7e6775 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -103,11 +103,11 @@ class UOp: @functools.cached_property def vmax(self) -> UOp: return self._min_max[1] @functools.cached_property - def _min_max(self): + def _min_max(self) -> Tuple[UOp, UOp]: # TODO: UOps.SPECIAL is UOps.DEFINE_VAR if self.op is UOps.DEFINE_VAR: return self.src[0], self.src[1] if self.op is UOps.SPECIAL: - return self.const(0), self.const(self.arg[1]) if isinstance(self.arg[1], int) else self.const(dtypes.max(cast(DType, self.dtype))) + return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else self.const(dtypes.max(cast(DType, self.dtype))) if self.op is UOps.CONST: return self, self return self.const(dtypes.min(cast(DType, self.dtype))), self.const(dtypes.max(cast(DType, self.dtype)))