fuzz_linearizer same api for interpreted and compiled (#2320)

This commit is contained in:
chenyu
2023-11-15 17:40:22 -05:00
committed by GitHub
parent 294e71de15
commit a98511561c

View File

@@ -6,46 +6,35 @@ from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin, tuplize_uops
from tinygrad.graph import print_tree
from tinygrad.helpers import getenv
from tinygrad.ops import Device, Compiled, Interpreted
from tinygrad.ops import Device, Compiled, Interpreted, get_interpreted_fxn
from tinygrad.lazy import vars_from_ast
device = Device[Device.DEFAULT]
class LB:
# placeholder LazyBuffer
def __init__(self, rawbuf, dtype):
self.realized = rawbuf
self.output_buffer = rawbuf
self.dtype = dtype
def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
if rawbufs is None: rawbufs = bufs_from_lin(lin)
if var_vals is None: var_vals = {v: v.min for v in vars_from_ast(lin.ast)}
# TODO: images needs required_optimization
if isinstance(device, Compiled):
try:
try:
if isinstance(device, Compiled):
prg = device.to_program(lin)
except:
print(lin.ast)
traceback.print_exc()
print("COMPILE FAILED!!")
return "COMPILE_ERROR"
try:
prg.exec(rawbufs, var_vals, force_wait=True)
except:
print(lin.ast)
traceback.print_exc()
print("EXEC FAILED!!")
return "EXEC_ERROR"
else:
try:
device.exec_ast(lin.ast, output=LB(rawbufs[0], rawbufs[0].dtype), inputs=[LB(buf, buf.dtype) for buf in rawbufs[1:]])
except Exception as e:
print(lin.ast)
traceback.print_exc()
return str(type(e))
else:
prg = get_interpreted_fxn(device.fxn_for_op, device.from_underlying, lin.ast)
except:
print(lin.ast)
traceback.print_exc()
print("COMPILE FAILED!!")
return "COMPILE_ERROR"
try:
prg.exec(rawbufs, var_vals, force_wait=True)
except:
print(lin.ast)
traceback.print_exc()
print("EXEC FAILED!!")
return "EXEC_ERROR"
return "PASS"