mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Fix speed compare script (#4581)
* Fix speed compare script * Update speed_compare_cuda_ptx.py * Update speed_compare_cuda_ptx.py * Remove unused function
This commit is contained in:
32
test/external/speed_compare_cuda_ptx.py
vendored
32
test/external/speed_compare_cuda_ptx.py
vendored
@@ -1,10 +1,10 @@
|
||||
import itertools
|
||||
from tinygrad import Device
|
||||
from tinygrad.device import CompiledRunner
|
||||
from tinygrad.helpers import to_function_name, getenv, colored
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import getenv, colored
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin
|
||||
from tinygrad.features.search import bufs_from_lin
|
||||
from tinygrad.runtime.ops_cuda import PTXCompiler
|
||||
from tinygrad.runtime.ops_cuda import PTXCompiler, PTXRenderer, CUDACompiler
|
||||
|
||||
# move to helpers?
|
||||
def colorize_float(x):
|
||||
@@ -15,8 +15,10 @@ def colorize_float(x):
|
||||
|
||||
if __name__ == "__main__":
|
||||
ast_strs = load_worlds(filter_reduce=False, filter_novariable=True)
|
||||
# no bfloat16 for ptx at the moment
|
||||
ast_strs = [x for x in ast_strs if "dtypes.bfloat16" not in x]
|
||||
dev = Device["CUDA"]
|
||||
ptx = PTXCompiler(dev.arch)
|
||||
ptx = PTXRenderer(dev.arch)
|
||||
|
||||
# NUM=112 python3 test/external/speed_compare_cuda_ptx.py
|
||||
|
||||
@@ -26,24 +28,26 @@ if __name__ == "__main__":
|
||||
average_tm_cuda, average_tm_ptx = 0, 0
|
||||
for num,ast in enumerate(ast_strs):
|
||||
# cuda compile
|
||||
dev.compiler = CUDACompiler(dev.arch)
|
||||
lin = ast_str_to_lin(ast, opts=dev.renderer)
|
||||
lin.hand_coded_optimizations()
|
||||
cuda_prg = dev.to_runner(lin)
|
||||
cuda_prg = CompiledRunner(lin.to_program())
|
||||
|
||||
bufs = bufs_from_lin(lin)
|
||||
|
||||
# ptx compile
|
||||
lin = ast_str_to_lin(ast, opts=ptx.compiler_opts)
|
||||
dev.compiler = PTXCompiler(dev.arch)
|
||||
lin = ast_str_to_lin(ast, opts=ptx)
|
||||
lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
ptx_src = ptx.render(to_function_name(lin.name), lin.uops)
|
||||
try:
|
||||
ptx_prg = CompiledRunner(lin.name, ptx_src, "CUDA", lin.global_size, lin.local_size, lin.uops.vars(), precompiled=ptx.compile(ptx_src))
|
||||
except RuntimeError:
|
||||
print("PTX FAIL")
|
||||
continue
|
||||
ptx_prg = CompiledRunner(lin.to_program())
|
||||
|
||||
# warmup
|
||||
cuda_prg(bufs, {}, wait=True)
|
||||
try:
|
||||
cuda_prg(bufs, {}, wait=True)
|
||||
except RuntimeError:
|
||||
print("cuda failed ast:", num)
|
||||
continue
|
||||
ptx_prg(bufs, {}, wait=True)
|
||||
|
||||
tm_cuda, tm_ptx = [], []
|
||||
@@ -56,7 +60,7 @@ if __name__ == "__main__":
|
||||
print(f"{average_tm_ptx/average_tm_cuda:5.2f}x -- {num:4d} {colorize_float(ratio)} {min(tm_ptx)*1e6:7.2f} us", lin.name)
|
||||
if ratio > 1.5:
|
||||
def fix(x): return x.replace('\t', ' ').strip()
|
||||
ll1, ll2 = cuda_prg.lib.decode().split('\n'), ptx_src.split('\n')
|
||||
ll1, ll2 = cuda_prg.lib.decode().split('\n'), ptx_prg.lib.decode().split('\n')
|
||||
if single != -1:
|
||||
for ln, (l1, l2) in enumerate(itertools.zip_longest(ll1, ll2, fillvalue='')):
|
||||
print(f"{ln:5d} | {fix(l1):80s} | {fix(l2):80s}")
|
||||
|
||||
Reference in New Issue
Block a user