update tests get_runner (#4522)

This commit is contained in:
George Hotz
2024-05-10 20:09:22 -07:00
committed by GitHub
parent a0448ff595
commit 827058f030
5 changed files with 50 additions and 49 deletions

View File

@@ -25,12 +25,12 @@ if __name__ == "__main__":
# cuda compile # cuda compile
culin = ast_str_to_lin(ast, opts=cudev.compiler.compiler_opts) culin = ast_str_to_lin(ast, opts=cudev.compiler.compiler_opts)
culin.hand_coded_optimizations() culin.hand_coded_optimizations()
cuda_prg = cudev.to_program(culin) cuda_prg = cudev.to_runner(culin)
cubufs = bufs_from_lin(culin) cubufs = bufs_from_lin(culin)
nvlin = ast_str_to_lin(ast, opts=nvdev.compiler.compiler_opts) nvlin = ast_str_to_lin(ast, opts=nvdev.compiler.compiler_opts)
nvlin.hand_coded_optimizations() nvlin.hand_coded_optimizations()
nv_prg = nvdev.to_program(nvlin) nv_prg = nvdev.to_runner(nvlin)
nvbufs = bufs_from_lin(nvlin) nvbufs = bufs_from_lin(nvlin)
# warmup # warmup

View File

@@ -1,9 +1,9 @@
import unittest import unittest
import time import time
import numpy as np import numpy as np
from tinygrad import Tensor, dtypes, Device from tinygrad import Tensor, dtypes
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import lower_schedule_item, run_schedule
class TestFusionOp(unittest.TestCase): class TestFusionOp(unittest.TestCase):
def test_contiguous_add(self): def test_contiguous_add(self):
@@ -27,9 +27,9 @@ class TestFusionOp(unittest.TestCase):
a = Tensor([1,2,3,4]) a = Tensor([1,2,3,4])
for _ in range(24): a = a + a for _ in range(24): a = a + a
sched = create_schedule([a.lazydata], None) sched = create_schedule([a.lazydata], None)
ji = Device[Device.DEFAULT].get_runner(*sched[-1].ast) ei = lower_schedule_item(sched[-1])
self.assertLess(time.perf_counter()-st, 1.0) self.assertLess(time.perf_counter()-st, 1.0)
assert len(ji.p.src.splitlines()) < 250 assert len(ei.prg.p.src.splitlines()) < 250
def test_recursive_add_cmp(self): def test_recursive_add_cmp(self):
st = time.perf_counter() st = time.perf_counter()

View File

@@ -1,6 +1,7 @@
import unittest import unittest
from tinygrad import Tensor, Device from tinygrad import Tensor
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule_item
# TODO: can copy this in here when we remove it # TODO: can copy this in here when we remove it
#from tinygrad.ops import get_lazyop_info #from tinygrad.ops import get_lazyop_info
@@ -12,8 +13,8 @@ from tinygrad.engine.schedule import create_schedule
def get_stats(x:Tensor): def get_stats(x:Tensor):
si = create_schedule([x.lazydata])[-1] si = create_schedule([x.lazydata])[-1]
runner = Device[Device.DEFAULT].get_runner(*si.ast) ei = lower_schedule_item(si)
return runner.op_estimate, runner.mem_estimate return ei.prg.p.op_estimate, ei.prg.p.mem_estimate
class TestUOpsStats(unittest.TestCase): class TestUOpsStats(unittest.TestCase):
def test_simple_add(self): def test_simple_add(self):

View File

@@ -181,43 +181,6 @@ class Runner:
# **************** for Compiled Devices **************** # **************** for Compiled Devices ****************
def fake_renderer(name, uops): raise NotImplementedError("needs a renderer")
@dataclass(frozen=True)
class CompilerOptions:
device: str = ""
suffix: str = ""
# TODO: make this generic with a list of supported types
supports_float4: bool = True
has_local: bool = True
has_shared: bool = True
has_tensor_cores: bool = False
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
global_max: Optional[List[int]] = None
local_max: Optional[List[int]] = None
shared_max: int = 32768
renderer: Callable = fake_renderer
class Compiler:
compiler_opts: ClassVar[CompilerOptions]
def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
def compile_cached(self, src:str) -> bytes:
if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
lib = self.compile(src)
if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
return lib
def to_program(self, k:Linearizer, override_device:Optional[str]=None) -> Program:
k.linearize()
info = get_lazyop_info(k.ast[0])
ops, mem = k.uops.flops_mem()
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
# NOTE: we use min here to ignore the indexing FLOPS
return Program(k.name, self.compiler_opts.renderer(to_function_name(k.name), k.uops),
override_device if override_device else self.compiler_opts.device,
k.global_size, k.local_size, k.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
@dataclass(frozen=True) @dataclass(frozen=True)
class Program: class Program:
name:str name:str
@@ -246,6 +209,43 @@ class Program:
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
return global_size, local_size return global_size, local_size
def fake_renderer(name, uops): raise NotImplementedError("needs a renderer")
@dataclass(frozen=True)
class CompilerOptions:
device: str = ""
suffix: str = ""
# TODO: make this generic with a list of supported types
supports_float4: bool = True
has_local: bool = True
has_shared: bool = True
has_tensor_cores: bool = False
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
global_max: Optional[List[int]] = None
local_max: Optional[List[int]] = None
shared_max: int = 32768
renderer: Callable = fake_renderer
def to_program(self, k:Linearizer, override_device:Optional[str]=None) -> Program:
k.linearize()
info = get_lazyop_info(k.ast[0])
ops, mem = k.uops.flops_mem()
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
# NOTE: we use min here to ignore the indexing FLOPS
return Program(k.name, self.renderer(to_function_name(k.name), k.uops),
override_device if override_device else self.device,
k.global_size, k.local_size, k.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
class Compiler:
compiler_opts: ClassVar[CompilerOptions]
def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
def compile_cached(self, src:str) -> bytes:
if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
lib = self.compile(src)
if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
return lib
class CompiledRunner(Runner): class CompiledRunner(Runner):
def __init__(self, p:Program, precompiled:Optional[bytes]=None): def __init__(self, p:Program, precompiled:Optional[bytes]=None):
if DEBUG >= 4: print(p.src) if DEBUG >= 4: print(p.src)
@@ -280,7 +280,7 @@ class Compiled:
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler if compiler else Compiler(), runtime, graph self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler if compiler else Compiler(), runtime, graph
def synchronize(self): pass # override this in your device def synchronize(self): pass # override this in your device
def to_runner(self, k:Linearizer) -> CompiledRunner: return CompiledRunner(self.compiler.to_program(k, override_device=self.dname)) def to_runner(self, k:Linearizer) -> CompiledRunner: return CompiledRunner(self.compiler.compiler_opts.to_program(k, override_device=self.dname))
def get_linearizer(self, *ast:LazyOp) -> Linearizer: def get_linearizer(self, *ast:LazyOp) -> Linearizer:
if DEBUG >= 3: if DEBUG >= 3:

View File

@@ -51,7 +51,7 @@ def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler) ->
try: try:
x[1].linearize() x[1].linearize()
if len(x[1].uops.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops") if len(x[1].uops.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
p = compiler.to_program(x[1]) p = compiler.compiler_opts.to_program(x[1])
st = time.perf_counter() st = time.perf_counter()
prog = compiler.compile(p.src) prog = compiler.compile(p.src)
et = time.perf_counter() - st et = time.perf_counter() - st
@@ -174,7 +174,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
rawbufs = _ensure_buffer_alloc(rawbufs) rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()} var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
p = dev.compiler.to_program(lin) p = dev.compiler.compiler_opts.to_program(lin)
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs, tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))