diff --git a/test/test_custom_function.py b/test/test_custom_function.py index 78c365e178..6ab4ea39bf 100644 --- a/test/test_custom_function.py +++ b/test/test_custom_function.py @@ -21,7 +21,7 @@ def atan2_gpu(ret:Buffer, a:Buffer, b:Buffer): int idx = get_global_id(0); c[idx] = atan2(a[idx], b[idx]); }""" - CompiledASTRunner(None, "atan2_gpu", src, Device[ret.device], global_size=[ret.size]).exec([ret, a, b]) + CompiledASTRunner("atan2_gpu", src, Device[ret.device], global_size=[ret.size]).exec([ret, a, b]) def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(a._buf, b._buf), requirements='C').data) diff --git a/test/test_uops.py b/test/test_uops.py index 0794e59f52..8856d6d7de 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -13,7 +13,7 @@ from test.test_dtype import is_dtype_supported def _uops_to_prg(uops): src = Device[Device.DEFAULT].compiler.render("test", uops) has_local = Device[Device.DEFAULT].compiler.linearizer_opts.has_local - return CompiledASTRunner(None, "test", src, Device[Device.DEFAULT], [1] if has_local else None, [1] if has_local else None) + return CompiledASTRunner("test", src, Device[Device.DEFAULT], [1] if has_local else None, [1] if has_local else None) def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp: uops.append(UOp(uop, dtype, tuple(vin), arg)) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index fff74403d9..69f0962e6d 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -33,9 +33,8 @@ def get_recursive_children(uops:List[UOp], x:UOp) -> Set[UOp]: deps.add(u) return deps -UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_VAR} +UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER} def remove_childless_uops(uops:List[UOp]) -> List[UOp]: - # NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that while 1: has_child: Set[UOp] = set() for ru in uops: diff --git a/tinygrad/device.py b/tinygrad/device.py index a8cdd3b7fa..6a929a458c 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -3,6 +3,7 @@ from collections import defaultdict from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar import importlib, inspect, functools, pathlib, time, ctypes from tinygrad.dtype import DType, ImageDType +from tinygrad.codegen.uops import UOps from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put from tinygrad.helpers import prod from tinygrad.shape.symbolic import Variable, sym_infer, sint @@ -40,7 +41,9 @@ Device = _Device() # **************** base Runner + helpers **************** class JITRunner: - def __init__(self): self.op_estimate, self.mem_estimate = 0, 0 + def __init__(self): + self.op_estimate:sint = 0 + self.mem_estimate:sint = 0 def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]: var_vals = var_vals if var_vals is not None else {} from tinygrad.features.jit import CacheCollector @@ -185,7 +188,8 @@ class Compiler: return lib class CompiledASTRunner(JITRunner): - def __init__(self, ast:Optional[LazyOp], name:str, prg:str, device:Compiled, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, precompiled:Optional[bytes]=None): # noqa: E501 + def __init__(self, name:str, prg:str, device:Compiled, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, + variables:Optional[List[Variable]]=None, op_estimate:sint=0, mem_estimate:sint=0, precompiled:Optional[bytes]=None): super().__init__() if DEBUG >= 4: print(prg) if global_size is not None: global_size = global_size + [1]*(3-len(global_size)) @@ -195,12 +199,8 @@ class CompiledASTRunner(JITRunner): assert self.device.compiler is not None, "compiler is reuired to make an AST kernel" lib:bytes = precompiled if precompiled is not None else self.device.compiler.compile_cached(prg) self.lib, self.clprg = lib, self.device.runtime(self.name, lib) - self.vars: List[Variable] = [] - if ast: - info = get_lazyop_info(ast) - self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate - self.vars = ast.vars() - assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}" + self.vars: List[Variable] = [] if variables is None else variables + self.op_estimate, self.mem_estimate = op_estimate, mem_estimate def launch_dims(self, var_vals): global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size @@ -235,13 +235,14 @@ class Compiled: def to_program(self, k:Linearizer) -> CompiledASTRunner: assert self.compiler is not None, "compiler is required to run AST" k.linearize() - ret = CompiledASTRunner(k.ast, k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size) + info = get_lazyop_info(k.ast) from tinygrad.codegen.uops import uops_flops_mem - run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else [])) ops, mem = uops_flops_mem(k.uops) + run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else [])) # NOTE: we use min here to ignore the indexing FLOPS - ret.op_estimate = min(ret.op_estimate, ops * run_count) - ret.mem_estimate = min(ret.mem_estimate, mem * run_count) + ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size, + [x.arg for x in k.uops if x.uop is UOps.DEFINE_VAR], + min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count)) return ret def get_linearizer(self, ast:LazyOp) -> Linearizer: diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index a22f0217f8..4e8bd0bc24 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -1,13 +1,14 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable import itertools, functools, random, math, time, multiprocessing, traceback, signal from tinygrad.device import Device, Compiled, Buffer, CompiledASTRunner, Compiler -from tinygrad.ops import MemBuffer, LazyOp +from tinygrad.ops import MemBuffer +from tinygrad.codegen.uops import UOps from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name from tinygrad.dtype import ImageDType from tinygrad.codegen.linearizer import Linearizer from collections import defaultdict from tinygrad.tensor import Tensor -from tinygrad.shape.symbolic import sym_infer +from tinygrad.shape.symbolic import sym_infer, Variable from tinygrad.codegen.kernel import Opt, OptOps actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)] @@ -30,11 +31,12 @@ def _get_test_global_size(global_size, max_global_size, var_vals): break return test_global_size, factor -def _time_program(ast:LazyOp, rdev:Compiled, lib:bytes, global_size, local_size, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"): # noqa: E501 +def _time_program(variables:List[Variable], rdev:Compiled, lib:bytes, global_size, local_size, var_vals, rawbufs, + early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"): factor = 1 if global_size is not None and max_global_size is not None: global_size, factor = _get_test_global_size(global_size, max_global_size, var_vals) - try: car = CompiledASTRunner(ast, name, "", rdev, global_size, local_size, precompiled=lib) + try: car = CompiledASTRunner(name, "", rdev, global_size, local_size, variables=variables, precompiled=lib) except AssertionError: return [math.inf] * cnt tms = [] for _ in range(cnt): @@ -44,10 +46,11 @@ def _time_program(ast:LazyOp, rdev:Compiled, lib:bytes, global_size, local_size, if early_stop is not None and early_stop < tms[-1]: break return tms -def _compile_linearizer(compiler:Compiler, lin:Linearizer, name:Optional[str]=None) -> Tuple[bytes, Optional[List[int]], Optional[List[int]]]: +def _compile_linearizer(compiler:Compiler, lin:Linearizer, name:Optional[str]=None) -> Tuple[bytes, Optional[List[int]], Optional[List[int]], + List[Variable]]: lin.linearize() src = compiler.render(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping - return compiler.compile(src), lin.global_size, lin.local_size + return compiler.compile(src), lin.global_size, lin.local_size, [x.arg for x in lin.uops if x.uop is UOps.DEFINE_VAR] def _try_compile_linearized_w_idx(x, compiler:Compiler): try: return (x[0], _compile_linearizer(compiler, x[1], "test")) @@ -116,10 +119,10 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea _compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=Device[lin.opts.device].compiler) for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))): if proc is None: continue - lib, global_size, local_size = proc + lib, global_size, local_size, vars = proc if lib in seen_libs: continue seen_libs.add(lib) - tms = _time_program(lin.ast, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0) + tms = _time_program(vars, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0) timed_lins.append((acted_lins[i], min(tms))) if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501 @@ -157,8 +160,8 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, assert isinstance(dev, Compiled) and dev.compiler is not None var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()} - lib, global_size, local_size = _compile_linearizer(dev.compiler, lin) - tms = _time_program(lin.ast, dev, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501 + lib, global_size, local_size, vars = _compile_linearizer(dev.compiler, lin) + tms = _time_program(vars, dev, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501 if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms) return min(tms)