From a98511561c88e8101cbc825198ec333fec899ebd Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 15 Nov 2023 17:40:22 -0500 Subject: [PATCH] fuzz_linearizer same api for interpreted and compiled (#2320) --- test/external/fuzz_linearizer.py | 47 ++++++++++++-------------------- 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index d65ba01ba2..66f91d6827 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -6,46 +6,35 @@ from tinygrad.codegen.linearizer import Linearizer from tinygrad.features.search import get_linearizer_actions, bufs_from_lin, tuplize_uops from tinygrad.graph import print_tree from tinygrad.helpers import getenv -from tinygrad.ops import Device, Compiled, Interpreted +from tinygrad.ops import Device, Compiled, Interpreted, get_interpreted_fxn from tinygrad.lazy import vars_from_ast device = Device[Device.DEFAULT] -class LB: - # placeholder LazyBuffer - def __init__(self, rawbuf, dtype): - self.realized = rawbuf - self.output_buffer = rawbuf - self.dtype = dtype - def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None): if rawbufs is None: rawbufs = bufs_from_lin(lin) if var_vals is None: var_vals = {v: v.min for v in vars_from_ast(lin.ast)} # TODO: images needs required_optimization - if isinstance(device, Compiled): - try: + try: + if isinstance(device, Compiled): prg = device.to_program(lin) - except: - print(lin.ast) - traceback.print_exc() - print("COMPILE FAILED!!") - return "COMPILE_ERROR" - try: - prg.exec(rawbufs, var_vals, force_wait=True) - except: - print(lin.ast) - traceback.print_exc() - print("EXEC FAILED!!") - return "EXEC_ERROR" - else: - try: - device.exec_ast(lin.ast, output=LB(rawbufs[0], rawbufs[0].dtype), inputs=[LB(buf, buf.dtype) for buf in rawbufs[1:]]) - except Exception as e: - print(lin.ast) - traceback.print_exc() - return str(type(e)) + else: + prg = get_interpreted_fxn(device.fxn_for_op, device.from_underlying, lin.ast) + except: + print(lin.ast) + traceback.print_exc() + print("COMPILE FAILED!!") + return "COMPILE_ERROR" + + try: + prg.exec(rawbufs, var_vals, force_wait=True) + except: + print(lin.ast) + traceback.print_exc() + print("EXEC FAILED!!") + return "EXEC_ERROR" return "PASS"