assign early folding (#8093)

* assign early folding [pr]

* move to to_si

* -

* fix generate_dataset

* diff too big

* no recreation, no diff

* gzip

* new sops from tiny10

* final try
This commit is contained in:
qazal
2024-12-07 11:02:55 +02:00
committed by GitHub
parent 00ac0db9d4
commit 07b6d5cf63
6 changed files with 8 additions and 7 deletions

Binary file not shown.

View File

@@ -5,9 +5,9 @@ from test.external.process_replay.process_replay import _pmap
LOGOPS = os.getenv("LOGOPS", "/tmp/sops")
def extract_ast(*args) -> bool:
def extract_ast(*args) -> None:
open(LOGOPS, "a").write(str(args[0]).replace("\n", "").replace(" ", "")+"\n")
return args[-1]
return None
if __name__ == "__main__":
_pmap("kernel", extract_ast)

View File

@@ -1,5 +1,6 @@
#!/bin/bash
export PAGE_SIZE=1
export PYTHONPATH=.
export LOGOPS=/tmp/ops
export RUN_PROCESS_REPLAY=1
rm $LOGOPS
@@ -24,5 +25,5 @@ JIT=2 BIG=1 MPS=1 python -m pytest test/test_speed_v_torch.py
# extract, sort and uniq
extra/optimization/extract_dataset.py
sort -u /tmp/ops > /tmp/sops
sort -u /tmp/ops > /tmp/sops
ls -lh /tmp/ops /tmp/sops

View File

@@ -54,6 +54,7 @@ def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]:
# try recreate
try:
with Context(**{k:v for k,v in args[-2].items() if k in ContextVar._cache and k != "DEBUG"}): good = fxn(*args[:-2])
if good is None: continue
except Exception as e:
logging.warning(f"FAILED TO RECREATE KERNEL {e}")
for x in args[:-1]: logging.info(x)

View File

@@ -239,8 +239,6 @@ arange_m = ((arange_augrng<UPat.cvar("compval"))!=UPat(Ops.CONST, name="ne", arg
sym = symbolic_flat+PatternMatcher([
# self ASSIGN is just self
(UPat(Ops.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
# ASSIGN to global is just self
(UPat(Ops.ASSIGN, src=(UPat(Ops.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x),
# VECTORIZE/CONST, VECTORIZE/GEP
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),

View File

@@ -178,6 +178,9 @@ check_preload = PatternMatcher([(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()),
to_si = PatternMatcher([
(UPat(Ops.VIEW, name="x"), _append_st_vars),
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))),
# don't need contiguous or assign anymore
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda ctx,x: x),
])
# ** fusion
@@ -185,8 +188,6 @@ to_si = PatternMatcher([
lazy = PatternMatcher([
# gather the metadata for this kernel
(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.metadata.add(m) if (m:=ctx.ops_metadata.get(x)) is not None else None),
# don't need contiguous anymore
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
])
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.sinked.get(b)),])