From 722dd4276c16e8aee086272c4f6312c7fa2d5d63 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 17 Mar 2024 16:52:20 +0200 Subject: [PATCH] add outbufs info to CompiledASTRunner (#3781) * add outbufs * Revert "add outbufs" This reverts commit 5f4c0668f58bf0727e9ebb2231e55e0537165411. * simplify --- tinygrad/device.py | 6 +++--- tinygrad/features/jit.py | 7 +++---- tinygrad/features/search.py | 16 ++++++++-------- tinygrad/runtime/graph/hsa.py | 2 +- 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/tinygrad/device.py b/tinygrad/device.py index 1ead1bdc64..455d4d052b 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -188,7 +188,7 @@ class Compiler: class CompiledASTRunner(JITRunner): def __init__(self, name:str, prg:str, device:Compiled, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, - variables:Optional[List[Variable]]=None, op_estimate:sint=0, mem_estimate:sint=0, precompiled:Optional[bytes]=None): + variables:Optional[List[Variable]]=None, op_estimate:sint=0, mem_estimate:sint=0, precompiled:Optional[bytes]=None, outcount:int=1): super().__init__() if DEBUG >= 4: print(prg) if global_size is not None: global_size = global_size + [1]*(3-len(global_size)) @@ -197,7 +197,7 @@ class CompiledASTRunner(JITRunner): to_function_name(name), name, prg, device, global_size, local_size, True assert self.device.compiler is not None, "compiler is reuired to make an AST kernel" lib:bytes = precompiled if precompiled is not None else self.device.compiler.compile_cached(prg) - self.lib, self.clprg = lib, self.device.runtime(self.name, lib) + self.lib, self.clprg, self.outcount = lib, self.device.runtime(self.name, lib), outcount self.vars: List[Variable] = [] if variables is None else variables self.op_estimate, self.mem_estimate = op_estimate, mem_estimate @@ -239,7 +239,7 @@ class Compiled: run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else [])) # NOTE: we use min here to ignore the indexing FLOPS ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size, - k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count)) + k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count), outcount=len(k.outbufs)) return ret def get_linearizer(self, *ast:LazyOp) -> Linearizer: diff --git a/tinygrad/features/jit.py b/tinygrad/features/jit.py index 35834fb362..a5417fa8f8 100644 --- a/tinygrad/features/jit.py +++ b/tinygrad/features/jit.py @@ -167,11 +167,10 @@ class _CacheCollector: for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}" # Buffer optimization is allowed only for kernel operations. Avoids for copies (prevents parallelism) and syncs (incorrect buffer reuse). - allow_buffer_optimization = isinstance(prg, CompiledASTRunner) + if isinstance(prg, CompiledASTRunner): + for i in range(prg.outcount): self.placeholders[rawbufs[i]] = PlaceHolder(rawbufs[i]) - # NOTE: this is making an assumption that 0 is special - if len(rawbufs): self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0]) - self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, Buffer) and allow_buffer_optimization else x for x in rawbufs])) + self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, Buffer) else x for x in rawbufs])) def finish(self) -> List[JitItem]: if self.cache is None: return [] diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index e029253ac2..df5d2af72a 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -31,12 +31,12 @@ def _get_test_global_size(global_size, max_global_size, var_vals): break return test_global_size, factor -def _time_program(variables:List[Variable], rdev:Compiled, lib:bytes, global_size, local_size, var_vals, rawbufs, +def _time_program(variables:List[Variable], outcount:int, rdev:Compiled, lib:bytes, global_size, local_size, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"): factor = 1 if global_size is not None and max_global_size is not None: global_size, factor = _get_test_global_size(global_size, max_global_size, var_vals) - try: car = CompiledASTRunner(name, "", rdev, global_size, local_size, variables=variables, precompiled=lib) + try: car = CompiledASTRunner(name, "", rdev, global_size, local_size, variables=variables, precompiled=lib, outcount=outcount) except AssertionError: return [math.inf] * cnt tms = [] for _ in range(cnt): @@ -47,11 +47,11 @@ def _time_program(variables:List[Variable], rdev:Compiled, lib:bytes, global_siz return tms def _compile_linearizer(compiler:Compiler, lin:Linearizer, name:Optional[str]=None) -> Tuple[bytes, Optional[List[int]], Optional[List[int]], - List[Variable]]: + List[Variable], int]: lin.linearize() src = compiler.render(name if name is not None else to_function_name(lin.name), lin.uops.uops) # NOTE: these all have the same name for deduping if DEBUG >= 5: print(src) - return compiler.compile(src), lin.global_size, lin.local_size, lin.uops.vars() + return compiler.compile(src), lin.global_size, lin.local_size, lin.uops.vars(), len(lin.outbufs) def _try_compile_linearized_w_idx(x, compiler:Compiler): try: return (x[0], _compile_linearizer(compiler, x[1], "test")) @@ -121,11 +121,11 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea _compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler) for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))): if proc is None: continue - lib, global_size, local_size, vars = proc + lib, global_size, local_size, vars, outcount = proc if lib in seen_libs: continue #print(acted_lins[i].colored_shape(), acted_lins[i].applied_opts) # for debugging BEAMs that segfault seen_libs.add(lib) - try: tms = _time_program(vars, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0) + try: tms = _time_program(vars, outcount, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0) except RuntimeError: continue # for runtime issues timed_lins.append((acted_lins[i], min(tms))) if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501 @@ -165,8 +165,8 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, assert isinstance(dev, Compiled) and dev.compiler is not None var_vals = {k:(k.max+k.min)//2 for k in lin.ast[0].vars()} - lib, global_size, local_size, vars = _compile_linearizer(dev.compiler, lin) - tms = _time_program(vars, dev, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501 + lib, global_size, local_size, vars, outcount = _compile_linearizer(dev.compiler, lin) + tms = _time_program(vars, outcount, dev, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501 if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms) return min(tms) diff --git a/tinygrad/runtime/graph/hsa.py b/tinygrad/runtime/graph/hsa.py index ca369456d5..5de5f6e3a2 100644 --- a/tinygrad/runtime/graph/hsa.py +++ b/tinygrad/runtime/graph/hsa.py @@ -76,7 +76,7 @@ class HSAGraph(MultiDeviceJITGraph): for j,ji in enumerate(self.jit_cache): if isinstance(ji.prg, CompiledASTRunner): - wait_signals = self.access_resources(read=ji.rawbufs[1:], write=ji.rawbufs[0:1], new_dependency=j, sync_with_aql_packets=False) + wait_signals = self.access_resources(ji.rawbufs[(outs:=ji.prg.outcount):], ji.rawbufs[:outs], new_dependency=j, sync_with_aql_packets=False) for i in range(0, len(wait_signals), 5): self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals=wait_signals[i:i+5]) self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)