diff --git a/test/external/external_hip_compiler_bug.py b/test/external/external_hip_compiler_bug.py index 2fbc6b370d..9dfcf01264 100644 --- a/test/external/external_hip_compiler_bug.py +++ b/test/external/external_hip_compiler_bug.py @@ -1,6 +1,6 @@ # [, , ] from tinygrad import Device, dtypes -from tinygrad.device import Buffer, CompiledASTRunner +from tinygrad.device import Buffer, CompiledRunner import ctypes import gpuctypes.hip as hip @@ -216,8 +216,8 @@ b2 = Buffer(dev, 9408, dtypes.float) print(hex(b0._buf.value), hex(b0._buf.value+1605632*4)) print(hex(b1._buf.value)) print(hex(b2._buf.value)) -#prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [7, 1, 1], [8, 4, 1], precompiled=lib) -prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [49, 8, 2], [8, 4, 1], precompiled=lib) +#prg = CompiledRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [7, 1, 1], [8, 4, 1], precompiled=lib) +prg = CompiledRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [49, 8, 2], [8, 4, 1], precompiled=lib) print("compiled") prg([b0, b1, b2], {}) print("ran") diff --git a/test/external/speed_compare_cuda_ptx.py b/test/external/speed_compare_cuda_ptx.py index a977f5b2a0..b24dbc9793 100644 --- a/test/external/speed_compare_cuda_ptx.py +++ b/test/external/speed_compare_cuda_ptx.py @@ -1,6 +1,6 @@ import itertools from tinygrad import Device -from tinygrad.device import CompiledASTRunner +from tinygrad.device import CompiledRunner from tinygrad.helpers import to_function_name, getenv, colored from extra.optimization.helpers import load_worlds, ast_str_to_lin from tinygrad.features.search import bufs_from_lin @@ -38,7 +38,7 @@ if __name__ == "__main__": lin.linearize() ptx_src = ptx.render(to_function_name(lin.name), lin.uops) try: - ptx_prg = CompiledASTRunner(lin.name, ptx_src, "CUDA", lin.global_size, lin.local_size, lin.uops.vars(), precompiled=ptx.compile(ptx_src)) + ptx_prg = CompiledRunner(lin.name, ptx_src, "CUDA", lin.global_size, lin.local_size, lin.uops.vars(), precompiled=ptx.compile(ptx_src)) except RuntimeError: print("PTX FAIL") continue diff --git a/test/test_custom_function.py b/test/test_custom_function.py index 76d8c0a038..e69d76c62a 100644 --- a/test/test_custom_function.py +++ b/test/test_custom_function.py @@ -10,7 +10,7 @@ from tinygrad.dtype import dtypes # *** first, we implement the atan2 op at the lowest level *** # `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers from tinygrad.lazy import Buffer, create_lazybuffer -from tinygrad.device import CompiledASTRunner, Device +from tinygrad.device import CompiledRunner, Device from tinygrad.shape.shapetracker import ShapeTracker # we don't always have GPU support, so the type signature is the abstract CompiledBuffer instead of GPUBuffer @@ -21,7 +21,7 @@ def atan2_gpu(ret:Buffer, a:Buffer, b:Buffer): int idx = get_global_id(0); c[idx] = atan2(a[idx], b[idx]); }""" - CompiledASTRunner("atan2_gpu", src, ret.device, global_size=[ret.size]).exec([ret, a, b]) + CompiledRunner("atan2_gpu", src, ret.device, global_size=[ret.size]).exec([ret, a, b]) def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(a._buf, b._buf), requirements='C').data) diff --git a/test/test_uops.py b/test/test_uops.py index ea9d96aa59..569e3dc3d9 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -4,7 +4,7 @@ import numpy as np from tinygrad.tensor import Tensor from tinygrad.helpers import getenv from tinygrad.dtype import dtypes, DType, PtrDType -from tinygrad.device import Buffer, Device, CompiledASTRunner +from tinygrad.device import Buffer, Device, CompiledRunner from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.engine.schedule import create_schedule from tinygrad.codegen.linearizer import UOps, UOp @@ -14,7 +14,7 @@ from test.helpers import is_dtype_supported def _uops_to_prg(uops): src = Device[Device.DEFAULT].compiler.render("test", uops) has_local = Device[Device.DEFAULT].compiler.compiler_opts.has_local - return CompiledASTRunner("test", src, Device.DEFAULT, [1] if has_local else None, [1] if has_local else None) + return CompiledRunner("test", src, Device.DEFAULT, [1] if has_local else None, [1] if has_local else None) def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp: uops.append(UOp(uop, dtype, tuple(vin), arg)) diff --git a/tinygrad/device.py b/tinygrad/device.py index 41f3a87ce3..01508a6f2c 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -158,7 +158,7 @@ class Compiler: if self.cachekey is not None: diskcache_put(self.cachekey, src, lib) return lib -class CompiledASTRunner(Runner): +class CompiledRunner(Runner): def __init__(self, name:str, prg:str, dname:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, variables:Optional[List[Variable]]=None, op_estimate:sint=0, mem_estimate:sint=0, precompiled:Optional[bytes]=None, outcount:int=1): super().__init__() @@ -211,14 +211,14 @@ class Compiled: self.dname, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler, runtime, graph def synchronize(self): pass # override this in your device - def to_program(self, k:Linearizer) -> CompiledASTRunner: + def to_program(self, k:Linearizer) -> CompiledRunner: assert self.compiler is not None, "compiler is required to run AST" k.linearize() info = get_lazyop_info(k.ast[0]) ops, mem = k.uops.flops_mem() run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else [])) # NOTE: we use min here to ignore the indexing FLOPS - ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self.dname, k.global_size, k.local_size, + ret = CompiledRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self.dname, k.global_size, k.local_size, k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count), outcount=len(k.outbufs)) return ret @@ -254,4 +254,4 @@ class Compiled: return k @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none - def get_runner(self, *ast:LazyOp) -> CompiledASTRunner: return self.to_program(self.get_linearizer(*ast)) + def get_runner(self, *ast:LazyOp) -> CompiledRunner: return self.to_program(self.get_linearizer(*ast)) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 6cb4508e6b..126e01dc3e 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -4,7 +4,7 @@ import functools, itertools, operator from tinygrad.nn.state import get_parameters from tinygrad.dtype import DType from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH, flatten, GraphException -from tinygrad.device import Compiled, Runner, CompiledASTRunner, Buffer, BufferXfer, MultiDeviceJITGraph, Device +from tinygrad.device import Compiled, Runner, CompiledRunner, Buffer, BufferXfer, MultiDeviceJITGraph, Device from tinygrad.tensor import Tensor from tinygrad.lazy import LazyBuffer from tinygrad.features.multi import MultiLazyBuffer @@ -14,8 +14,8 @@ from tinygrad.engine.realize import ExecItem from weakref import ref, WeakKeyDictionary def get_jit_stats(jit_cache: List[ExecItem]) -> Tuple[sint, int]: - return functools.reduce(operator.add, [ji.prg.op_estimate for ji in jit_cache if isinstance(ji.prg, CompiledASTRunner)], 0), \ - functools.reduce(operator.add, [ji.prg.mem_estimate for ji in jit_cache if isinstance(ji.prg, CompiledASTRunner)], 0) + return functools.reduce(operator.add, [ji.prg.op_estimate for ji in jit_cache if isinstance(ji.prg, CompiledRunner)], 0), \ + functools.reduce(operator.add, [ji.prg.mem_estimate for ji in jit_cache if isinstance(ji.prg, CompiledRunner)], 0) def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]: input_replace: Dict[Tuple[int, int], int] = {} for j,ji in enumerate(jit_cache): @@ -24,9 +24,9 @@ def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) input_replace[(j,i)] = input_rawbuffers.index(a) return input_replace def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[ExecItem]) -> List[int]: - return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)))] # noqa: E501 + return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledRunner) and ((ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)))] # noqa: E501 def get_jc_idxs_with_updatable_var_vals(jit_cache: List[ExecItem]) -> List[int]: - return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars] + return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledRunner) and ji.prg.vars] def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[ExecItem]: # Split JIT cache into batches for faster graph execution. @@ -51,7 +51,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer] for ji in jit_cache: ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None. - if isinstance(ji.prg, CompiledASTRunner): ji_graph_dev = ji.prg.device + if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device elif isinstance(ji.prg, BufferXfer) and ji.rawbufs[0] and ji.rawbufs[0].device.split(":", 1)[0] in {"HSA", "CUDA"}: ji_graph_dev = Device[ji.rawbufs[0].device] @@ -169,7 +169,7 @@ class _CacheCollector: for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}" # Buffer optimization is allowed only for kernel operations. Avoids for copies (prevents parallelism) and syncs (incorrect buffer reuse). - if isinstance(prg, CompiledASTRunner): + if isinstance(prg, CompiledRunner): for i in range(prg.outcount): self.placeholders[rawbufs[i]] = PlaceHolder(rawbufs[i]) self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, Buffer) else x for x in rawbufs])) diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 04b945fa43..acbe70f1bd 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -1,7 +1,7 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable import itertools, functools, random, math, time, multiprocessing, traceback, signal from collections import defaultdict -from tinygrad.device import Device, Compiled, Buffer, CompiledASTRunner, Compiler +from tinygrad.device import Device, Compiled, Buffer, CompiledRunner, Compiler from tinygrad.ops import MemBuffer from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name from tinygrad.dtype import ImageDType @@ -35,7 +35,7 @@ def _time_program(variables:List[Variable], outcount:int, rdev:Compiled, lib:byt factor = 1 if global_size is not None and max_global_size is not None: global_size, factor = _get_test_global_size(global_size, max_global_size, var_vals) - try: car = CompiledASTRunner(name, "", rdev.dname, global_size, local_size, variables=variables, precompiled=lib, outcount=outcount) + try: car = CompiledRunner(name, "", rdev.dname, global_size, local_size, variables=variables, precompiled=lib, outcount=outcount) except AssertionError: return [math.inf] * cnt tms = [] for _ in range(cnt): diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index 50bb388d0a..a9d43bc5aa 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -2,7 +2,7 @@ import ctypes, collections from typing import Any, Optional, Tuple, Dict, List, cast import tinygrad.runtime.autogen.cuda as cuda from tinygrad.helpers import init_c_var, GraphException, getenv -from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions +from tinygrad.device import CompiledRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution from tinygrad.shape.symbolic import Variable from tinygrad.engine.realize import ExecItem @@ -11,7 +11,7 @@ from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_wi class CUDAGraph(MultiDeviceJITGraph): def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): # Check all jit items are compatible. - if not all(isinstance(ji.prg, CompiledASTRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException + if not all(isinstance(ji.prg, CompiledRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException self.jit_cache = jit_cache self.input_replace = get_input_replace(jit_cache, input_rawbuffers) @@ -27,7 +27,7 @@ class CUDAGraph(MultiDeviceJITGraph): self.cpu_buffers = [] for j,ji in enumerate(self.jit_cache): - if isinstance(ji.prg, CompiledASTRunner): + if isinstance(ji.prg, CompiledRunner): global_size, local_size = ji.prg.launch_dims(var_vals) new_node = cuda.CUgraphNode() @@ -81,12 +81,12 @@ class CUDAGraph(MultiDeviceJITGraph): # Update var_vals in the c_args struct. for j in self.jc_idxs_with_updatable_var_vals: - for i,v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars): + for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).vars): setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v]) # Update launch dims in the kern_params struct. for j in self.jc_idxs_with_updatable_launch_dims: - self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)) + self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledRunner, self.jit_cache[j].prg).launch_dims(var_vals)) # Update graph nodes with the updated structs. for node, c_node_params, c_args, is_copy in self.updatable_nodes.values(): diff --git a/tinygrad/runtime/graph/hsa.py b/tinygrad/runtime/graph/hsa.py index 76bd1820ab..7cefd82c49 100644 --- a/tinygrad/runtime/graph/hsa.py +++ b/tinygrad/runtime/graph/hsa.py @@ -2,7 +2,7 @@ import ctypes, collections, time, itertools from typing import List, Any, Dict, cast, Optional, Union, Tuple from tinygrad.helpers import GraphException, init_c_var, round_up from tinygrad.buffer import Buffer, BufferOptions -from tinygrad.device import Compiled, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats, Device +from tinygrad.device import Compiled, CompiledRunner, BufferXfer, MultiDeviceJITGraph, update_stats, Device from tinygrad.shape.symbolic import Variable from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler from tinygrad.engine.realize import ExecItem @@ -36,7 +36,7 @@ class HSAGraph(MultiDeviceJITGraph): # Check all jit items are compatible. compiled_devices = set() for ji in self.jit_cache: - if isinstance(ji.prg, CompiledASTRunner): compiled_devices.add(ji.prg.device) + if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.device) elif isinstance(ji.prg, BufferXfer): for x in ji.rawbufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device]) else: raise GraphException @@ -47,13 +47,13 @@ class HSAGraph(MultiDeviceJITGraph): # Allocate kernel args. kernargs_size: Dict[Compiled, int] = collections.defaultdict(int) for ji in self.jit_cache: - if isinstance(ji.prg, CompiledASTRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16) + if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16) kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions()) for dev,sz in kernargs_size.items()} # Fill initial arguments. self.ji_kargs_structs: Dict[int, ctypes.Structure] = {} for j,ji in enumerate(self.jit_cache): - if not isinstance(ji.prg, CompiledASTRunner): continue + if not isinstance(ji.prg, CompiledRunner): continue self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device]) kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16) for i in range(len(ji.rawbufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.rawbufs[i])._buf) @@ -75,7 +75,7 @@ class HSAGraph(MultiDeviceJITGraph): for dev in self.devices: self.virt_aql_queues[dev].submit_barrier([], self.kickoff_signals[dev]) for j,ji in enumerate(self.jit_cache): - if isinstance(ji.prg, CompiledASTRunner): + if isinstance(ji.prg, CompiledRunner): wait_signals = self.access_resources(ji.rawbufs[(outs:=ji.prg.outcount):], ji.rawbufs[:outs], new_dependency=j, sync_with_aql_packets=False) for i in range(0, len(wait_signals), 5): self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5]) @@ -131,12 +131,12 @@ class HSAGraph(MultiDeviceJITGraph): # Update var_vals for j in self.jc_idxs_with_updatable_var_vals: - for i,v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars): + for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).vars): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v]) # Update launch dims for j in self.jc_idxs_with_updatable_launch_dims: - gl, lc = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals) + gl, lc = cast(CompiledRunner, self.jit_cache[j].prg).launch_dims(var_vals) self.packets[j].workgroup_size_x = lc[0] self.packets[j].workgroup_size_y = lc[1] self.packets[j].workgroup_size_z = lc[2] diff --git a/tinygrad/runtime/graph/metal.py b/tinygrad/runtime/graph/metal.py index 6c5a9243e1..192fa12abf 100644 --- a/tinygrad/runtime/graph/metal.py +++ b/tinygrad/runtime/graph/metal.py @@ -2,7 +2,7 @@ from typing import List, Any, Dict, cast, Optional import Metal from tinygrad.dtype import dtypes from tinygrad.helpers import dedup, unwrap2, GraphException -from tinygrad.device import Buffer, CompiledASTRunner, update_stats +from tinygrad.device import Buffer, CompiledRunner, update_stats from tinygrad.engine.realize import ExecItem from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims from tinygrad.shape.symbolic import Variable @@ -10,7 +10,7 @@ from tinygrad.runtime.ops_metal import MetalDevice, wait_check class MetalGraph: def __init__(self, device:MetalDevice, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): - if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException + if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException self.jit_cache = jit_cache self.input_replace = get_input_replace(jit_cache, input_rawbuffers) @@ -30,7 +30,7 @@ class MetalGraph: if len(var_vals): self.int_buf = self.device.allocator.alloc(len(var_vals)*dtypes.int32.itemsize) all_resources = [self.int_buf] if len(var_vals) else [] for j,ji in enumerate(self.jit_cache): - prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg) + prg: CompiledRunner = cast(CompiledRunner, ji.prg) descriptor = Metal.MTLComputePipelineDescriptor.new() descriptor.setComputeFunction_(prg.clprg.fxn) descriptor.setSupportIndirectCommandBuffers_(True) @@ -62,7 +62,7 @@ class MetalGraph: for (j,i),input_idx in self.input_replace.items(): self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i) for j in self.jc_idx_with_updatable_launch_dims: - global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals) + global_size, local_size = cast(CompiledRunner, self.jit_cache[j].prg).launch_dims(var_vals) self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) # noqa: E501 for j, value in enumerate(var_vals.values()): self.int_buf_view[j] = value command_buffer = self.device.mtl_queue.commandBuffer()