mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fuzz_linearizer: add --ast and --file params to read kernels (#3877)
also fix up ast_str_to_str to support the new tuple of LazyOps
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# stuff needed to unpack a kernel
|
||||
from typing import Tuple
|
||||
from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.dtype import dtypes
|
||||
@@ -9,8 +10,8 @@ inf, nan = float('inf'), float('nan')
|
||||
|
||||
# kernel unpacker
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
def ast_str_to_ast(ast_str:str) -> LazyOp: return eval(ast_str)
|
||||
def ast_str_to_lin(ast_str:str, opts=None): return Linearizer(ast_str_to_ast(ast_str), opts=opts)
|
||||
def ast_str_to_ast(ast_str:str) -> Tuple[LazyOp,...]: return val if isinstance(val:=eval(ast_str), tuple) else (val,)
|
||||
def ast_str_to_lin(ast_str:str, opts=None): return Linearizer(*ast_str_to_ast(ast_str), opts=opts)
|
||||
def kern_str_to_lin(kern_str:str, opts=None):
|
||||
(ast, applied_opts,) = eval(kern_str)
|
||||
k = Linearizer(*ast, opts=opts)
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import argparse
|
||||
from extra.optimization.helpers import ast_str_to_lin
|
||||
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.helpers import BEAM, getenv
|
||||
from tinygrad.device import Device, Compiled
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin
|
||||
from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="Run a search for the optimal opts for a kernel", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
@@ -19,15 +18,14 @@ if __name__ == '__main__':
|
||||
print(f"optimizing for {Device.DEFAULT}")
|
||||
|
||||
if args.ast is not None:
|
||||
asts = [eval(args.ast)]
|
||||
ast_strs = [args.ast]
|
||||
elif args.file is not None:
|
||||
with open(args.file, 'r') as file:
|
||||
lines = file.readlines()
|
||||
asts = [eval(line) for line in lines]
|
||||
ast_strs = file.readlines()
|
||||
|
||||
for ast in asts:
|
||||
print(f"optimizing ast={ast}")
|
||||
lin = Linearizer(ast, device.compiler.linearizer_opts)
|
||||
for ast_str in ast_strs:
|
||||
print(f"optimizing ast={ast_str}")
|
||||
lin = ast_str_to_lin(ast_str, opts=device.compiler.linearizer_opts)
|
||||
rawbufs = bufs_from_lin(lin)
|
||||
lin = beam_search(lin, rawbufs, getenv("BEAM", 8), bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user