fuzz_linearizer: add --ast and --file params to read kernels (#3877)

also fix up ast_str_to_str to support the new tuple of LazyOps
This commit is contained in:
Francis Lam
2024-03-22 11:27:40 -07:00
committed by GitHub
parent c5467e5bd6
commit 5587594a00
4 changed files with 33 additions and 15 deletions

View File

@@ -1,4 +1,5 @@
# stuff needed to unpack a kernel
from typing import Tuple
from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.dtype import dtypes
@@ -9,8 +10,8 @@ inf, nan = float('inf'), float('nan')
# kernel unpacker
from tinygrad.codegen.linearizer import Linearizer
def ast_str_to_ast(ast_str:str) -> LazyOp: return eval(ast_str)
def ast_str_to_lin(ast_str:str, opts=None): return Linearizer(ast_str_to_ast(ast_str), opts=opts)
def ast_str_to_ast(ast_str:str) -> Tuple[LazyOp,...]: return val if isinstance(val:=eval(ast_str), tuple) else (val,)
def ast_str_to_lin(ast_str:str, opts=None): return Linearizer(*ast_str_to_ast(ast_str), opts=opts)
def kern_str_to_lin(kern_str:str, opts=None):
(ast, applied_opts,) = eval(kern_str)
k = Linearizer(*ast, opts=opts)

View File

@@ -1,13 +1,12 @@
import argparse
from extra.optimization.helpers import ast_str_to_lin
from tinygrad import dtypes
from tinygrad.helpers import BEAM, getenv
from tinygrad.device import Device, Compiled
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin
from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Run a search for the optimal opts for a kernel", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@@ -19,15 +18,14 @@ if __name__ == '__main__':
print(f"optimizing for {Device.DEFAULT}")
if args.ast is not None:
asts = [eval(args.ast)]
ast_strs = [args.ast]
elif args.file is not None:
with open(args.file, 'r') as file:
lines = file.readlines()
asts = [eval(line) for line in lines]
ast_strs = file.readlines()
for ast in asts:
print(f"optimizing ast={ast}")
lin = Linearizer(ast, device.compiler.linearizer_opts)
for ast_str in ast_strs:
print(f"optimizing ast={ast_str}")
lin = ast_str_to_lin(ast_str, opts=device.compiler.linearizer_opts)
rawbufs = bufs_from_lin(lin)
lin = beam_search(lin, rawbufs, getenv("BEAM", 8), bool(getenv("BEAM_ESTIMATE", 1)))