mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
also fixed many errors. it was not checking nested dirs. exclude autogen for now. can we use ruff for this?
35 lines
1.4 KiB
Python
35 lines
1.4 KiB
Python
import argparse
|
|
from extra.optimization.helpers import ast_str_to_lin
|
|
|
|
from tinygrad import dtypes
|
|
from tinygrad.helpers import BEAM, getenv
|
|
from tinygrad.device import Device, Compiled
|
|
from tinygrad.codegen.kernel import Kernel
|
|
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description="Run a search for the optimal opts for a kernel", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
parser.add_argument("--ast", type=str, default=None, help="the ast for the kernel to be optimized")
|
|
parser.add_argument("--file", type=str, default=None, help="a file containing asts to be optimized, one per line")
|
|
args = parser.parse_args()
|
|
|
|
device: Compiled = Device[Device.DEFAULT]
|
|
print(f"optimizing for {Device.DEFAULT}")
|
|
|
|
if args.ast is not None:
|
|
ast_strs = [args.ast]
|
|
elif args.file is not None:
|
|
with open(args.file, 'r') as file:
|
|
ast_strs = file.readlines()
|
|
|
|
for i, ast_str in enumerate(ast_strs):
|
|
print(f"optimizing {i}/{len(ast_strs)}\nast={ast_str}")
|
|
lin = ast_str_to_lin(ast_str, opts=device.renderer)
|
|
rawbufs = bufs_from_lin(lin)
|
|
lin = beam_search(lin, rawbufs, getenv("BEAM", 8), bool(getenv("BEAM_ESTIMATE", 1)))
|
|
|
|
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
|
|
print(f"final time {tm*1e6:9.0f} us: {lin.colored_shape()}")
|
|
print(lin.applied_opts)
|