fixup ast in kernel to be MetaOps.SINK [run_process_replay] (#5424)

* fixup ast in kernel to be MetaOps.SINK [run_process_replay]

* fix tests

* fix more tests
This commit is contained in:
George Hotz
2024-07-12 14:01:03 -07:00
committed by GitHub
parent b055ece550
commit 94599c0637
8 changed files with 35 additions and 28 deletions

View File

@@ -87,10 +87,10 @@ def compare_linearizer(lin: Linearizer, rawbufs=None, var_vals=None, ground_trut
if var_vals is None:
# TODO: handle symbolic max case
var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast[0].vars()}
var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast.vars()}
if ground_truth is None and not has_bf16:
unoptimized = Linearizer(*lin.ast)
unoptimized = Linearizer(lin.ast)
unoptimized.required_optimizations()
if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
return ("BASELINE_ERROR", rawbufs, var_vals, ground_truth,)
@@ -121,7 +121,7 @@ def fuzz_linearizer(lin: Linearizer, rtol=1e-2, atol=1e-2):
SEED = getenv("SEED", 42)
random.seed(SEED)
np.random.seed(SEED)
for op in lin.ast: print_tree(op)
print_tree(lin.ast)
print(lin.colored_shape())
seen_uops = {}
last_lins = [lin]
@@ -178,8 +178,8 @@ def fuzz_linearizer(lin: Linearizer, rtol=1e-2, atol=1e-2):
return failures
def _is_simple(lin: Linearizer) -> bool:
if len(lin.ast) > 1: return False
ast:LazyOp = lin.ast[0]
if len(lin.ast.src) > 1: return False
ast:LazyOp = lin.ast.src[0]
if ast.src[0] and ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0] and ast.src[0].src[0].op is BufferOps.LOAD: return True
return False

View File

@@ -17,7 +17,7 @@ for offset in tqdm(range(0, row_count, page_size)):
with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache}):
# try linearize
try:
k = Linearizer(*ast, opts=opts)
k = Linearizer(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)
good_src = k.opts.render(name, k.linearize().uops)
except Exception as e: