From 4dc9436d63c4d3f893295eaccd2d16cbb8e993b0 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 12 Sep 2024 13:52:41 +0800 Subject: [PATCH] use more UPat.var and UPat.cvar [run_process_replay] (#6491) --- tinygrad/codegen/uopgraph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 2e042eed62..07e08bd355 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -361,8 +361,8 @@ constant_folder = PatternMatcher([ (UPat(UOps.SINK, name="root"), lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None), # ** move add consts to end (NOTE: this is still happening before constant folding) ** - (UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat(UOps.CONST, name='c1'), UPat(name='x'))), lambda c1,x: x+c1 if x.op is not UOps.CONST else None), - (UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat(name='x'), UPat(UOps.CONST, name='c1'))), UPat(name='y')]), + (UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.cvar('c1'), UPat.var('x'))), lambda c1,x: x+c1 if x.op is not UOps.CONST else None), + (UPat(UOps.ALU, arg=BinaryOps.ADD, src=[UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.var('x'), UPat.cvar('c1'))), UPat.var('y')]), lambda x,c1,y: (x+y)+c1), ])