From 99fc275c27aa77be50b74d8507589a2359835be6 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 19 Jun 2024 12:43:20 +0300 Subject: [PATCH] UPat line savings [run_process_replay] (#5053) * line savings * move to new style --- tinygrad/codegen/linearizer.py | 3 --- tinygrad/codegen/uops.py | 21 +++++++++------------ 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index f04cb34773..82f3e7fdc9 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -383,9 +383,6 @@ class Linearizer(Kernel): # save backups sts_backup, gfr_backup, upc_backup = self.sts[:], self.group_for_reduces, self.upcasted - # global uop cache - self.saved_exprs: Dict[Tuple, UOp] = dict() - # uops self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs) self.loop_uops: Dict[str, UOp] = {} diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 74c89ba2f9..1c4010c332 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -69,11 +69,11 @@ class UOp: @staticmethod def load(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(src)+tuple(kwargs.values())) @staticmethod - def store(*src:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.STORE, dtype, tuple(src)+tuple(kwargs.values())) + def store(*src:UOp, **kwargs): return UOp(UOps.STORE, None, tuple(src)+tuple(kwargs.values())) @staticmethod - def var(name: Optional[str]=None, dtype: Optional[DType]=None): return UOp(UOps.VAR, dtype=dtype, arg=name) + def var(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.VAR, dtype=dtype, arg=name) @staticmethod - def cvar(name: Optional[str]=None, dtype: Optional[DType]=None): return UOp(UOps.CONST, dtype=dtype).name(name) + def cvar(name:Optional[str]=None, dtype:Optional[DType]=None): return UOp(UOps.CONST, dtype=dtype).name(name) @functools.cached_property def parents(self) -> Set[UOp]: return set.union(set(self.src), *[x.parents for x in self.src]) @property # parents with self @@ -174,14 +174,12 @@ constant_folder = PatternMatcher([ UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse), # sum collapse to mul (with possible GEP) (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=(UPat(UOps.RANGE, name="loop"),)), - UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse), - (UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", - src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.RANGE, name="loop"),)),)), - UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse), + UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse), + (UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.RANGE, name="loop"),)),)), + UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse), # deal with UNMUL - (UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), - UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]), - lambda c1,c2,v: v if c1.arg == c2.arg else None), + (UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]), + lambda c1,c2,v: v if c1.arg == c2.arg else None), (UOp(UOps.UNMUL, src=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero), (UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.src[0].cast(root.dtype), unmul.src[1]))), # max on special can go away (TODO: special should be variable, same thing applies) @@ -222,8 +220,7 @@ constant_folder = PatternMatcher([ (UOp.var('x') * 0, lambda x: x if isinstance(x.arg, float) and math.isnan(x.arg) else UOp.const(x.dtype, 0)), (UOp.var('x') - UOp.var('x'), lambda x: UOp.const(x.dtype, 0)), # x-x -> 0 # ** load/store folding ** - (UPat(UOps.STORE, src=(UPat(name="buf"), UPat(name="idx"), - UPat(UOps.LOAD, src=(UPat(name="buf"), UPat(name="idx"))))), lambda buf, idx: UOp(UOps.NOOP)), + (UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.load(UOp.var("buf"), UOp.var("idx"))), lambda buf,idx:UOp(UOps.NOOP)), # ** two stage add/sub folding ** ((UOp.var('x') + UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, c2.arg]))), ((UOp.var('x') - UOp.cvar('c1')) + UOp.cvar('c2'), lambda x,c1,c2: x+UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c2.arg, -c1.arg]))),