mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
update tests get_runner (#4522)
This commit is contained in:
4
test/external/speed_compare_cuda_nv.py
vendored
4
test/external/speed_compare_cuda_nv.py
vendored
@@ -25,12 +25,12 @@ if __name__ == "__main__":
|
|||||||
# cuda compile
|
# cuda compile
|
||||||
culin = ast_str_to_lin(ast, opts=cudev.compiler.compiler_opts)
|
culin = ast_str_to_lin(ast, opts=cudev.compiler.compiler_opts)
|
||||||
culin.hand_coded_optimizations()
|
culin.hand_coded_optimizations()
|
||||||
cuda_prg = cudev.to_program(culin)
|
cuda_prg = cudev.to_runner(culin)
|
||||||
cubufs = bufs_from_lin(culin)
|
cubufs = bufs_from_lin(culin)
|
||||||
|
|
||||||
nvlin = ast_str_to_lin(ast, opts=nvdev.compiler.compiler_opts)
|
nvlin = ast_str_to_lin(ast, opts=nvdev.compiler.compiler_opts)
|
||||||
nvlin.hand_coded_optimizations()
|
nvlin.hand_coded_optimizations()
|
||||||
nv_prg = nvdev.to_program(nvlin)
|
nv_prg = nvdev.to_runner(nvlin)
|
||||||
nvbufs = bufs_from_lin(nvlin)
|
nvbufs = bufs_from_lin(nvlin)
|
||||||
|
|
||||||
# warmup
|
# warmup
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad import Tensor, dtypes, Device
|
from tinygrad import Tensor, dtypes
|
||||||
from tinygrad.engine.schedule import create_schedule
|
from tinygrad.engine.schedule import create_schedule
|
||||||
from tinygrad.engine.realize import run_schedule
|
from tinygrad.engine.realize import lower_schedule_item, run_schedule
|
||||||
|
|
||||||
class TestFusionOp(unittest.TestCase):
|
class TestFusionOp(unittest.TestCase):
|
||||||
def test_contiguous_add(self):
|
def test_contiguous_add(self):
|
||||||
@@ -27,9 +27,9 @@ class TestFusionOp(unittest.TestCase):
|
|||||||
a = Tensor([1,2,3,4])
|
a = Tensor([1,2,3,4])
|
||||||
for _ in range(24): a = a + a
|
for _ in range(24): a = a + a
|
||||||
sched = create_schedule([a.lazydata], None)
|
sched = create_schedule([a.lazydata], None)
|
||||||
ji = Device[Device.DEFAULT].get_runner(*sched[-1].ast)
|
ei = lower_schedule_item(sched[-1])
|
||||||
self.assertLess(time.perf_counter()-st, 1.0)
|
self.assertLess(time.perf_counter()-st, 1.0)
|
||||||
assert len(ji.p.src.splitlines()) < 250
|
assert len(ei.prg.p.src.splitlines()) < 250
|
||||||
|
|
||||||
def test_recursive_add_cmp(self):
|
def test_recursive_add_cmp(self):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from tinygrad import Tensor, Device
|
from tinygrad import Tensor
|
||||||
from tinygrad.engine.schedule import create_schedule
|
from tinygrad.engine.schedule import create_schedule
|
||||||
|
from tinygrad.engine.realize import lower_schedule_item
|
||||||
|
|
||||||
# TODO: can copy this in here when we remove it
|
# TODO: can copy this in here when we remove it
|
||||||
#from tinygrad.ops import get_lazyop_info
|
#from tinygrad.ops import get_lazyop_info
|
||||||
@@ -12,8 +13,8 @@ from tinygrad.engine.schedule import create_schedule
|
|||||||
|
|
||||||
def get_stats(x:Tensor):
|
def get_stats(x:Tensor):
|
||||||
si = create_schedule([x.lazydata])[-1]
|
si = create_schedule([x.lazydata])[-1]
|
||||||
runner = Device[Device.DEFAULT].get_runner(*si.ast)
|
ei = lower_schedule_item(si)
|
||||||
return runner.op_estimate, runner.mem_estimate
|
return ei.prg.p.op_estimate, ei.prg.p.mem_estimate
|
||||||
|
|
||||||
class TestUOpsStats(unittest.TestCase):
|
class TestUOpsStats(unittest.TestCase):
|
||||||
def test_simple_add(self):
|
def test_simple_add(self):
|
||||||
|
|||||||
@@ -181,43 +181,6 @@ class Runner:
|
|||||||
|
|
||||||
# **************** for Compiled Devices ****************
|
# **************** for Compiled Devices ****************
|
||||||
|
|
||||||
def fake_renderer(name, uops): raise NotImplementedError("needs a renderer")
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class CompilerOptions:
|
|
||||||
device: str = ""
|
|
||||||
suffix: str = ""
|
|
||||||
# TODO: make this generic with a list of supported types
|
|
||||||
supports_float4: bool = True
|
|
||||||
has_local: bool = True
|
|
||||||
has_shared: bool = True
|
|
||||||
has_tensor_cores: bool = False
|
|
||||||
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
|
|
||||||
global_max: Optional[List[int]] = None
|
|
||||||
local_max: Optional[List[int]] = None
|
|
||||||
shared_max: int = 32768
|
|
||||||
renderer: Callable = fake_renderer
|
|
||||||
|
|
||||||
class Compiler:
|
|
||||||
compiler_opts: ClassVar[CompilerOptions]
|
|
||||||
def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
|
|
||||||
def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
|
|
||||||
def compile_cached(self, src:str) -> bytes:
|
|
||||||
if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
|
|
||||||
lib = self.compile(src)
|
|
||||||
if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
|
|
||||||
return lib
|
|
||||||
|
|
||||||
def to_program(self, k:Linearizer, override_device:Optional[str]=None) -> Program:
|
|
||||||
k.linearize()
|
|
||||||
info = get_lazyop_info(k.ast[0])
|
|
||||||
ops, mem = k.uops.flops_mem()
|
|
||||||
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
|
|
||||||
# NOTE: we use min here to ignore the indexing FLOPS
|
|
||||||
return Program(k.name, self.compiler_opts.renderer(to_function_name(k.name), k.uops),
|
|
||||||
override_device if override_device else self.compiler_opts.device,
|
|
||||||
k.global_size, k.local_size, k.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Program:
|
class Program:
|
||||||
name:str
|
name:str
|
||||||
@@ -246,6 +209,43 @@ class Program:
|
|||||||
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
|
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else None
|
||||||
return global_size, local_size
|
return global_size, local_size
|
||||||
|
|
||||||
|
def fake_renderer(name, uops): raise NotImplementedError("needs a renderer")
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CompilerOptions:
|
||||||
|
device: str = ""
|
||||||
|
suffix: str = ""
|
||||||
|
# TODO: make this generic with a list of supported types
|
||||||
|
supports_float4: bool = True
|
||||||
|
has_local: bool = True
|
||||||
|
has_shared: bool = True
|
||||||
|
has_tensor_cores: bool = False
|
||||||
|
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
|
||||||
|
global_max: Optional[List[int]] = None
|
||||||
|
local_max: Optional[List[int]] = None
|
||||||
|
shared_max: int = 32768
|
||||||
|
renderer: Callable = fake_renderer
|
||||||
|
|
||||||
|
def to_program(self, k:Linearizer, override_device:Optional[str]=None) -> Program:
|
||||||
|
k.linearize()
|
||||||
|
info = get_lazyop_info(k.ast[0])
|
||||||
|
ops, mem = k.uops.flops_mem()
|
||||||
|
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
|
||||||
|
# NOTE: we use min here to ignore the indexing FLOPS
|
||||||
|
return Program(k.name, self.renderer(to_function_name(k.name), k.uops),
|
||||||
|
override_device if override_device else self.device,
|
||||||
|
k.global_size, k.local_size, k.uops, min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|
||||||
|
|
||||||
|
class Compiler:
|
||||||
|
compiler_opts: ClassVar[CompilerOptions]
|
||||||
|
def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
|
||||||
|
def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function")
|
||||||
|
def compile_cached(self, src:str) -> bytes:
|
||||||
|
if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
|
||||||
|
lib = self.compile(src)
|
||||||
|
if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
|
||||||
|
return lib
|
||||||
|
|
||||||
class CompiledRunner(Runner):
|
class CompiledRunner(Runner):
|
||||||
def __init__(self, p:Program, precompiled:Optional[bytes]=None):
|
def __init__(self, p:Program, precompiled:Optional[bytes]=None):
|
||||||
if DEBUG >= 4: print(p.src)
|
if DEBUG >= 4: print(p.src)
|
||||||
@@ -280,7 +280,7 @@ class Compiled:
|
|||||||
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler if compiler else Compiler(), runtime, graph
|
self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler if compiler else Compiler(), runtime, graph
|
||||||
def synchronize(self): pass # override this in your device
|
def synchronize(self): pass # override this in your device
|
||||||
|
|
||||||
def to_runner(self, k:Linearizer) -> CompiledRunner: return CompiledRunner(self.compiler.to_program(k, override_device=self.dname))
|
def to_runner(self, k:Linearizer) -> CompiledRunner: return CompiledRunner(self.compiler.compiler_opts.to_program(k, override_device=self.dname))
|
||||||
|
|
||||||
def get_linearizer(self, *ast:LazyOp) -> Linearizer:
|
def get_linearizer(self, *ast:LazyOp) -> Linearizer:
|
||||||
if DEBUG >= 3:
|
if DEBUG >= 3:
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler) ->
|
|||||||
try:
|
try:
|
||||||
x[1].linearize()
|
x[1].linearize()
|
||||||
if len(x[1].uops.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
|
if len(x[1].uops.uops) >= getenv("BEAM_UOPS_MAX", 3000) > 0: raise RuntimeError("too many uops")
|
||||||
p = compiler.to_program(x[1])
|
p = compiler.compiler_opts.to_program(x[1])
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
prog = compiler.compile(p.src)
|
prog = compiler.compile(p.src)
|
||||||
et = time.perf_counter() - st
|
et = time.perf_counter() - st
|
||||||
@@ -174,7 +174,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
|
|||||||
|
|
||||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
|
var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()}
|
||||||
p = dev.compiler.to_program(lin)
|
p = dev.compiler.compiler_opts.to_program(lin)
|
||||||
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
|
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
|
||||||
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
|
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user