From e196640d715c5cd8a1a2d2e153eecd22a3af290d Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 23 Jul 2024 19:50:59 -0400 Subject: [PATCH] more generic lt folding (#5665) --- tinygrad/codegen/lowerer.py | 3 ++- tinygrad/codegen/uopgraph.py | 6 +++--- tinygrad/codegen/uops.py | 6 ++++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 5cbb738070..4a1f564895 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -17,7 +17,8 @@ render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.bigint, sel DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx), ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx), LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)), - Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else UOp(UOps.DEFINE_VAR, dtypes.int32, (), self), + Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else \ + UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max)), self), SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)), AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) } diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 5478ee8cf2..74f3653619 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -213,7 +213,7 @@ constant_folder = PatternMatcher([ (UOp(UOps.GEP, src=(UOp.cvar("x"),)).name("root"), lambda root,x: root.const(x.arg)), # max -2147483648 (UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x), - # bool < False is always false, True < bool is always false + # bool < False is always false, True < bool is always false # TODO: replace these with generic cmp (UOp.var().lt(UOp.const(dtypes.bool, False)), lambda: UOp.const(dtypes.bool, False)), (UOp.const(dtypes.bool, True).lt(UOp.var()), lambda: UOp.const(dtypes.bool, False)), # a conditional with the same results either way is a noop, also fold const conditionals @@ -239,8 +239,8 @@ constant_folder = PatternMatcher([ (UOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), (UOp.var('x') - UOp.var('x'), lambda x: x.const(0)), # x-x -> 0 # lt folding - (UOp.var('x').lt(UOp.cvar('c')), - lambda x,c: UOp.const(dtypes.bool, True) if x.vmax.arg < c.arg else UOp.const(dtypes.bool, False) if x.vmin.arg >= c.arg else None), + (UOp.var('x').lt(UOp.var('y')), + lambda x,y: UOp.const(dtypes.bool, True) if x.vmax.arg < y.vmin.arg else UOp.const(dtypes.bool, False) if x.vmin.arg >= y.vmax.arg else None), # ** load/store folding ** (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.load(UOp.var("buf"), UOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)), # ** two stage add/sub folding ** diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index a0690c7881..58d7c3692d 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -73,8 +73,8 @@ class UOp: @staticmethod @functools.lru_cache(maxsize=None) def _const(dtype:Optional[DType], b:ConstType|Variable): - # TODO: min/max for const Variable? - if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (), b) + # TODO: fix dtype of b.max after Variable is just an UOp + if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (UOp.const(dtypes.int, b.min), UOp.const(dtypes.int, cast(int,b.max))), b) return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) @staticmethod def alu(arg, *src:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else src[-1].dtype, src, arg) @@ -101,10 +101,12 @@ class UOp: @functools.cached_property def vmax(self) -> UOp: if self.op is UOps.DEFINE_VAR: return self.src[1] + if self.op is UOps.CONST: return self return UOp.const(dtypes.float, math.inf) @functools.cached_property def vmin(self) -> UOp: if self.op is UOps.DEFINE_VAR: return self.src[0] + if self.op is UOps.CONST: return self return UOp.const(dtypes.float, -math.inf) class UPat: