mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-03 19:25:06 -05:00
fuzz_linearizer same api for interpreted and compiled (#2320)
This commit is contained in:
47
test/external/fuzz_linearizer.py
vendored
47
test/external/fuzz_linearizer.py
vendored
@@ -6,46 +6,35 @@ from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin, tuplize_uops
|
||||
from tinygrad.graph import print_tree
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.ops import Device, Compiled, Interpreted
|
||||
from tinygrad.ops import Device, Compiled, Interpreted, get_interpreted_fxn
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
|
||||
device = Device[Device.DEFAULT]
|
||||
|
||||
class LB:
|
||||
# placeholder LazyBuffer
|
||||
def __init__(self, rawbuf, dtype):
|
||||
self.realized = rawbuf
|
||||
self.output_buffer = rawbuf
|
||||
self.dtype = dtype
|
||||
|
||||
|
||||
def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
|
||||
if rawbufs is None: rawbufs = bufs_from_lin(lin)
|
||||
if var_vals is None: var_vals = {v: v.min for v in vars_from_ast(lin.ast)}
|
||||
|
||||
# TODO: images needs required_optimization
|
||||
if isinstance(device, Compiled):
|
||||
try:
|
||||
try:
|
||||
if isinstance(device, Compiled):
|
||||
prg = device.to_program(lin)
|
||||
except:
|
||||
print(lin.ast)
|
||||
traceback.print_exc()
|
||||
print("COMPILE FAILED!!")
|
||||
return "COMPILE_ERROR"
|
||||
try:
|
||||
prg.exec(rawbufs, var_vals, force_wait=True)
|
||||
except:
|
||||
print(lin.ast)
|
||||
traceback.print_exc()
|
||||
print("EXEC FAILED!!")
|
||||
return "EXEC_ERROR"
|
||||
else:
|
||||
try:
|
||||
device.exec_ast(lin.ast, output=LB(rawbufs[0], rawbufs[0].dtype), inputs=[LB(buf, buf.dtype) for buf in rawbufs[1:]])
|
||||
except Exception as e:
|
||||
print(lin.ast)
|
||||
traceback.print_exc()
|
||||
return str(type(e))
|
||||
else:
|
||||
prg = get_interpreted_fxn(device.fxn_for_op, device.from_underlying, lin.ast)
|
||||
except:
|
||||
print(lin.ast)
|
||||
traceback.print_exc()
|
||||
print("COMPILE FAILED!!")
|
||||
return "COMPILE_ERROR"
|
||||
|
||||
try:
|
||||
prg.exec(rawbufs, var_vals, force_wait=True)
|
||||
except:
|
||||
print(lin.ast)
|
||||
traceback.print_exc()
|
||||
print("EXEC FAILED!!")
|
||||
return "EXEC_ERROR"
|
||||
|
||||
return "PASS"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user