diff --git a/extra/backends/graph_hip.py b/extra/backends/graph_hip.py index a18e6925c9..ddcb3d58b1 100644 --- a/extra/backends/graph_hip.py +++ b/extra/backends/graph_hip.py @@ -12,7 +12,7 @@ class HIPGraph(CUDAGraph): def __del__(self): if hasattr(self, 'graph'): check(hip.hipGraphDestroy(self.graph)) if hasattr(self, 'instance'): check(hip.hipGraphExecDestroy(self.instance)) - def set_device(self): hip_set_device(self.device) + def set_device(self): hip_set_device(self.dev) def encode_args_info(self): return (hip.hipDeviceptr_t, (1,2,3)) def graph_create(self): return init_c_var(hip.hipGraph_t(), lambda x: check(hip.hipGraphCreate(ctypes.byref(x), 0))) def graph_instantiate(self, graph): diff --git a/extra/backends/hsa_graph.py b/extra/backends/hsa_graph.py index 3603bb9421..193a10a7c0 100644 --- a/extra/backends/hsa_graph.py +++ b/extra/backends/hsa_graph.py @@ -32,7 +32,7 @@ class HSAGraph(MultiGraphRunner): # Check all jit items are compatible. compiled_devices = set() for ji in self.jit_cache: - if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.device) + if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.dev) elif isinstance(ji.prg, BufferXfer): for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device]) else: raise GraphException @@ -43,15 +43,15 @@ class HSAGraph(MultiGraphRunner): # Allocate kernel args. kernargs_size: Dict[Compiled, int] = collections.defaultdict(int) for ji in self.jit_cache: - if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16) + if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16) kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions()) for dev,sz in kernargs_size.items()} # Fill initial arguments. self.ji_kargs_structs: Dict[int, ctypes.Structure] = {} for j,ji in enumerate(self.jit_cache): if not isinstance(ji.prg, CompiledRunner): continue - self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device]) - kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16) + self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.dev]) + kernargs_ptrs[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16) for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf) for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]]) @@ -72,13 +72,13 @@ class HSAGraph(MultiGraphRunner): if isinstance(ji.prg, CompiledRunner): wait_signals = self.access_resources(ji.bufs, ji.prg.p.outs, new_dependency=j, sync_with_aql_packets=False) for i in range(0, len(wait_signals), 5): - self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5]) - self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr) + self.virt_aql_queues[ji.prg.dev].submit_barrier(wait_signals[i:i+5]) + self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.dev].write_addr) sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None - self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.p.launch_dims(var_vals), #type:ignore + self.virt_aql_queues[ji.prg.dev].submit_kernel(ji.prg.clprg, *ji.prg.p.launch_dims(var_vals), #type:ignore ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal) - if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False)) + if PROFILE: self.profile_info[ji.prg.dev].append((sync_signal, ji.prg.clprg.name, False)) elif isinstance(ji.prg, BufferXfer): dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]] dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device]) diff --git a/extra/backends/ops_hip.py b/extra/backends/ops_hip.py index ea8139924a..96e56d6c8e 100644 --- a/extra/backends/ops_hip.py +++ b/extra/backends/ops_hip.py @@ -142,22 +142,22 @@ class HIPAllocator(LRUAllocator): class HIPSyncEvent(Runner): def __init__(self, lb): - self.lb, self.device, self.dname = lb, cast(HIPDevice, Device[lb.device]), lb.device + self.lb, self.device, self.device = lb, cast(HIPDevice, Device[lb.device]), lb.device super().__init__() def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False): to_mv(rawbufs[0]._buf, 4).cast("I")[0] = 0 hip_set_device(self.device.device) check(hip.hipStreamWriteValue32(None, rawbufs[0]._buf, 1, 0)) - update_stats(colored("sync", "red"), 0, 0, {}, None, 1, jit, device=self.dname) + update_stats(colored("sync", "red"), 0, 0, {}, None, 1, jit, device=self.device) class HIPWaitEvent(Runner): def __init__(self, device): - self.device, self.dname = cast(HIPDevice, Device[device]), device + self.device, self.device = cast(HIPDevice, Device[device]), device super().__init__() def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False): hip_set_device(self.device.device) check(hip.hipStreamWaitValue32(None, rawbufs[0]._buf, 1, 1, 0xFFFFFFFF)) - update_stats(colored("wait", "RED"), 0, 0, {}, None, 1, jit, device=self.dname) + update_stats(colored("wait", "RED"), 0, 0, {}, None, 1, jit, device=self.device) if getenv("HIPCPU"): rhip = ctypes.CDLL("/usr/local/lib/libremu.so") diff --git a/extra/gemm/triton_nv_matmul.py b/extra/gemm/triton_nv_matmul.py index 003d0e516b..cdae1c3cbe 100644 --- a/extra/gemm/triton_nv_matmul.py +++ b/extra/gemm/triton_nv_matmul.py @@ -85,7 +85,7 @@ if __name__ == "__main__": # remove debug sections src = src.split("\t.file")[0] assert '.extern .shared' not in src - prg = Program("matmul_kernel", src, dname=Device.DEFAULT, + prg = Program("matmul_kernel", src, device=Device.DEFAULT, global_size=[M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1], local_size=[32*compiled.metadata.num_warps, 1, 1], mem_estimate=A.nbytes() + B.nbytes() + C.nbytes()) ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata) diff --git a/test/external/fuzz_uops.py b/test/external/fuzz_uops.py index 476b707bcf..cc28c33790 100644 --- a/test/external/fuzz_uops.py +++ b/test/external/fuzz_uops.py @@ -50,10 +50,10 @@ class UOpsFuzzerRunner(CompiledRunner): # setup prg uops = list(path) if DEBUG >= 5: print_uops(uops) - self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.dname].renderer.render(name, uops), uops=uops) + self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.device].renderer.render(name, uops), uops=uops) if DEBUG >= 4: print(self.p.src) - self.lib = Device[self.p.dname].compiler.compile_cached(self.p.src) - self.clprg = Device[self.p.dname].runtime(name, self.lib) + self.lib = Device[self.p.device].compiler.compile_cached(self.p.src) + self.clprg = Device[self.p.device].runtime(name, self.lib) for x in (rawbufs:=[init_globals[i] for i in self.p.globals]): x.copyin(init_rawbufs[x]) # verify super().__call__(rawbufs, var_vals, wait) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index b107698588..5f8da35830 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -386,7 +386,7 @@ class TestLinearizer(unittest.TestCase): ast = UOp(Ops.SINK, src=(store0, store1)) k = Kernel(ast) - prg = CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT)) + prg = CompiledRunner(replace(k.to_program(), device=Device.DEFAULT)) inbufs = [x.lazydata.base.buffer] outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.arg.st.size, out.arg.dtype).allocate() for out in ast.src] prg.exec(outbufs+inbufs) @@ -1780,7 +1780,7 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:List[Buffer], opts=[] lins: List[Kernel] = [] outbufs = [real_bufs[x.src[0].arg] for x in realized_ast.src] - def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT)) + def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), device=Device.DEFAULT)) def check_opt(opts, create_k, expected_color_size): k = create_k() diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 5ad1bf192c..4c0f5f8d81 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -42,7 +42,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer] for ji in jit_cache: if ji.prg.__class__ in {EmptyOp, ViewOp}: continue ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None. - if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device + if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.dev elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}: ji_graph_dev = Device[ji.bufs[0].device] @@ -101,7 +101,7 @@ class GraphRunner(Runner): # pylint: disable=abstract-method self.w_dependency_map: Dict[int, Any] = {} self.r_dependency_map: Dict[int, List[Any]] = collections.defaultdict(list) - super().__init__(colored(f"", "cyan"), jit_cache[0].prg.dname.split(":")[0], + super().__init__(colored(f"", "cyan"), jit_cache[0].prg.device.split(":")[0], ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate)) def updated_vars(self, var_vals: Dict[Variable, int]): diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 0eff7db22f..bf578b05fb 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -64,11 +64,11 @@ def get_kernel(renderer:Renderer, ast:UOp) -> Kernel: # **************** Runners **************** class Runner: - def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0, lds_estimate:Optional[sint]=None): - self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate, self.lds_estimate = \ - True, display_name, dname, op_estimate, mem_estimate, mem_estimate if lds_estimate is None else lds_estimate + def __init__(self, display_name:str, device:str, op_estimate:sint=0, mem_estimate:sint=0, lds_estimate:Optional[sint]=None): + self.first_run, self.display_name, self.device, self.op_estimate, self.mem_estimate, self.lds_estimate = \ + True, display_name, device, op_estimate, mem_estimate, mem_estimate if lds_estimate is None else lds_estimate @property - def device(self): return Device[self.dname] + def dev(self): return Device[self.device] def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]: return self(rawbufs, {} if var_vals is None else var_vals) def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]: @@ -78,10 +78,10 @@ class CompiledRunner(Runner): def __init__(self, p:Program, precompiled:Optional[bytes]=None): if DEBUG >= 4: print(p.src) self.p:Program = p - self.lib:bytes = precompiled if precompiled is not None else Device[p.dname].compiler.compile_cached(p.src) - if DEBUG >= 6: Device[p.dname].compiler.disassemble(self.lib) - self.clprg = Device[p.dname].runtime(p.function_name, self.lib) - super().__init__(p.name, p.dname, p.op_estimate, p.mem_estimate, p.lds_estimate) + self.lib:bytes = precompiled if precompiled is not None else Device[p.device].compiler.compile_cached(p.src) + if DEBUG >= 6: Device[p.device].compiler.disassemble(self.lib) + self.clprg = Device[p.device].runtime(p.function_name, self.lib) + super().__init__(p.name, p.device, p.op_estimate, p.mem_estimate, p.lds_estimate) def __reduce__(self): return self.__class__, (self.p, self.lib) @@ -141,18 +141,18 @@ class BufferXfer(BufferCopy): # **************** method cache **************** method_cache: Dict[Tuple[str, bytes, int, int, bool], CompiledRunner] = {} -def get_runner(dname:str, ast:UOp) -> CompiledRunner: - ckey = (dname, ast.key, BEAM.value, NOOPT.value, False) +def get_runner(device:str, ast:UOp) -> CompiledRunner: + ckey = (device, ast.key, BEAM.value, NOOPT.value, False) if cret:=method_cache.get(ckey): return cret - bkey = (dname.split(":")[0], ast.key, BEAM.value, NOOPT.value, True) + bkey = (device.split(":")[0], ast.key, BEAM.value, NOOPT.value, True) if bret:=method_cache.get(bkey): - method_cache[ckey] = ret = CompiledRunner(replace(bret.p, dname=dname), bret.lib) + method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib) else: - prg: Program = get_kernel(Device[dname].renderer, ast).to_program() + prg: Program = get_kernel(Device[device].renderer, ast).to_program() if getenv("FUZZ_UOPS"): from test.external.fuzz_uops import UOpsFuzzerRunner - return UOpsFuzzerRunner(replace(prg, dname=dname)) - method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, dname=dname)) + return UOpsFuzzerRunner(replace(prg, device=device)) + method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device)) return ret # **************** lowering functions **************** @@ -175,7 +175,7 @@ class ExecItem: lds_est = sym_infer(self.prg.lds_estimate, var_vals) mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else "" - print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501 + print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501 (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501 f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}")) self.prg.first_run = False diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index d2b62a0cc8..5e70cd3391 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -45,7 +45,7 @@ def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:Li input_bufs = [rawbufs[i] for i in car.p.globals] for _ in range(cnt): if clear_l2: - if hasattr(dev:=Device[p.dname], 'invalidate_caches'): dev.invalidate_caches() + if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches() else: with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False) tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor) diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 2652077dc3..31e5b83f6d 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -26,7 +26,7 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x class Program: name:str src:str - dname:str + device:str uops:Optional[List[UOp]]=None mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 5322a40fdf..7f2e5e602a 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -16,7 +16,7 @@ class HCQGraph(MultiGraphRunner): kernargs_size: Dict[Compiled, int] = collections.defaultdict(int) for ji in self.jit_cache: if not isinstance(ji.prg, CompiledRunner): continue - kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_alloc_size, 16) + kernargs_size[ji.prg.dev] += round_up(ji.prg.clprg.kernargs_alloc_size, 16) self.kernargs_bufs: Dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)) for dev,sz in kernargs_size.items()} # Fill initial arguments. @@ -25,7 +25,7 @@ class HCQGraph(MultiGraphRunner): kargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()} for j,ji in enumerate(self.jit_cache): if not isinstance(ji.prg, CompiledRunner): continue - kargs_ptrs[ji.prg.device] = (kargs_ptr:=kargs_ptrs[ji.prg.device]) + round_up(ji.prg.clprg.kernargs_alloc_size, 16) + kargs_ptrs[ji.prg.dev] = (kargs_ptr:=kargs_ptrs[ji.prg.dev]) + round_up(ji.prg.clprg.kernargs_alloc_size, 16) self.ji_args[j] = ji.prg.clprg.fill_kernargs([cast(Buffer, b)._buf for b in ji.bufs], [var_vals[v] for v in ji.prg.p.vars], kargs_ptr) # Schedule Dependencies. @@ -51,7 +51,7 @@ class HCQGraph(MultiGraphRunner): for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev) for j,ji in enumerate(self.jit_cache): - enqueue_dev = ji.prg.device if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore + enqueue_dev = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t()) out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0)) diff --git a/tinygrad/runtime/graph/metal.py b/tinygrad/runtime/graph/metal.py index 759880850e..ce167f2ca6 100644 --- a/tinygrad/runtime/graph/metal.py +++ b/tinygrad/runtime/graph/metal.py @@ -28,13 +28,13 @@ class MetalGraph(GraphRunner): msg(icb_descriptor, "setInheritPipelineState:", False) msg(icb_descriptor, "setMaxKernelBufferBindCount:", 31) - self.icb = msg(self.device.device, "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:", + self.icb = msg(self.dev.device, "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:", icb_descriptor, len(self.jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache, restype=objc_instance) if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?") icb_label = bytes(msg(msg(self.icb, "description", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode() self.needs_icb_fix = int("AGXG15XFamilyIndirectCommandBuffer" not in icb_label) # not required on M3 - if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize) + if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize) all_resources = [self.int_buf.buf] if len(self.vars) else [] all_pipelines = [] for j,ji in enumerate(self.jit_cache): @@ -55,12 +55,12 @@ class MetalGraph(GraphRunner): self.all_resources = dedup(all_resources) self.all_pipelines = dedup(all_pipelines) self.command_buffer: Any = None - if len(self.vars): self.int_buf_view = self.device.allocator._as_buffer(self.int_buf).cast('i') + if len(self.vars): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i') self.range = to_struct(0, len(self.jit_cache)) def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]: - if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer) + if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer) all_resources = dedup(self.all_resources + [x._buf.buf for x in input_rawbuffers]) for (j,i),input_idx in self.input_replace.items(): @@ -76,7 +76,7 @@ class MetalGraph(GraphRunner): to_struct(*cast(tuple, global_size)), to_struct(*cast(tuple, local_size))) for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var] - command_buffer = msg(self.device.mtl_queue, "commandBuffer", restype=objc_instance) + command_buffer = msg(self.dev.mtl_queue, "commandBuffer", restype=objc_instance) encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance) msg(encoder, "useResources:count:usage:", (objc_id * len(all_resources))(*all_resources), len(all_resources), MTLResourceUsage.MTLResourceUsageRead | MTLResourceUsage.MTLResourceUsageWrite) @@ -99,5 +99,5 @@ class MetalGraph(GraphRunner): if wait: wait_check(command_buffer) return elapsed_time(command_buffer) - self.device.mtl_buffers_in_flight.append(command_buffer) + self.dev.mtl_buffers_in_flight.append(command_buffer) return None diff --git a/tinygrad/runtime/ops_cloud.py b/tinygrad/runtime/ops_cloud.py index d52a6247ec..efb0c29f58 100644 --- a/tinygrad/runtime/ops_cloud.py +++ b/tinygrad/runtime/ops_cloud.py @@ -80,7 +80,7 @@ class CloudSession: class CloudHandler(BaseHTTPRequestHandler): protocol_version = 'HTTP/1.1' - dname: str + device: str sessions: DefaultDict[str, CloudSession] = defaultdict(CloudSession) def setup(self): @@ -99,18 +99,18 @@ class CloudHandler(BaseHTTPRequestHandler): match c: case BufferAlloc(): assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated" - session.buffers[c.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(c.size, c.options), c.size, c.options) + session.buffers[c.buffer_num] = (Device[CloudHandler.device].allocator.alloc(c.size, c.options), c.size, c.options) case BufferFree(): buf,sz,buffer_options = session.buffers[c.buffer_num] - Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options) + Device[CloudHandler.device].allocator.free(buf,sz,buffer_options) del session.buffers[c.buffer_num] - case CopyIn(): Device[CloudHandler.dname].allocator._copyin(session.buffers[c.buffer_num][0], memoryview(bytearray(req._h[c.datahash]))) + case CopyIn(): Device[CloudHandler.device].allocator._copyin(session.buffers[c.buffer_num][0], memoryview(bytearray(req._h[c.datahash]))) case CopyOut(): buf,sz,_ = session.buffers[c.buffer_num] - Device[CloudHandler.dname].allocator._copyout(memoryview(ret:=bytearray(sz)), buf) + Device[CloudHandler.device].allocator._copyout(memoryview(ret:=bytearray(sz)), buf) case ProgramAlloc(): - lib = Device[CloudHandler.dname].compiler.compile_cached(req._h[c.datahash].decode()) - session.programs[(c.name, c.datahash)] = Device[CloudHandler.dname].runtime(c.name, lib) + lib = Device[CloudHandler.device].compiler.compile_cached(req._h[c.datahash].decode()) + session.programs[(c.name, c.datahash)] = Device[CloudHandler.device].runtime(c.name, lib) case ProgramFree(): del session.programs[(c.name, c.datahash)] case ProgramExec(): bufs = [session.buffers[x][0] for x in c.bufs] @@ -118,7 +118,7 @@ class CloudHandler(BaseHTTPRequestHandler): r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args) if r is not None: ret = str(r).encode() elif self.path == "/renderer" and method == "GET": - cls, args = Device[CloudHandler.dname].renderer.__reduce__() + cls, args = Device[CloudHandler.device].renderer.__reduce__() ret = json.dumps((cls.__module__, cls.__name__, args)).encode() else: status_code = 404 self.send_response(status_code) @@ -131,8 +131,8 @@ class CloudHandler(BaseHTTPRequestHandler): def cloud_server(port:int): multiprocessing.current_process().name = "MainProcess" - CloudHandler.dname = getenv("CLOUDDEV", "METAL") if Device.DEFAULT == "CLOUD" else Device.DEFAULT - print(f"start cloud server on {port} with device {CloudHandler.dname}") + CloudHandler.device = getenv("CLOUDDEV", "METAL") if Device.DEFAULT == "CLOUD" else Device.DEFAULT + print(f"start cloud server on {port} with device {CloudHandler.device}") server = HTTPServer(('', port), CloudHandler) server.serve_forever() diff --git a/tinygrad/runtime/support/hcq.py b/tinygrad/runtime/support/hcq.py index 8cf252d0e8..2d8aa154be 100644 --- a/tinygrad/runtime/support/hcq.py +++ b/tinygrad/runtime/support/hcq.py @@ -242,7 +242,7 @@ class HCQSignal: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})") @contextlib.contextmanager -def hcq_profile(dev, enabled, desc, queue_type=None, queue=None): +def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type=None, queue=None): st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None) if enabled and queue is not None: queue.timestamp(st) @@ -257,7 +257,7 @@ def hcq_profile(dev, enabled, desc, queue_type=None, queue=None): queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev) dev.timeline_value += 1 - if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t)) + if enabled and PROFILE: dev.sig_prof_records.append((cast(HCQSignal, st), cast(HCQSignal, en), desc, queue_type is dev.hw_copy_queue_t)) class HCQArgsState(Generic[ProgramType]): def __init__(self, ptr:int, prg:ProgramType, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): self.ptr, self.prg = ptr, prg @@ -393,7 +393,7 @@ class HCQCompiled(Compiled): def _ensure_shared_time_base(self): if not self.gpu2cpu_compute_time_diff.is_nan(): return - def _sync_cpu_queue(d, q_t): + def _sync_cpu_queue(d:HCQCompiled, q_t:Type[HWCommandQueue]): q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d) d.timeline_value += 1 st = time.perf_counter_ns() @@ -411,7 +411,7 @@ class HCQCompiled(Compiled): if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l) if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l) - def _sync_gpu_to_gpu_queue(d1, d2, q1_t, q2_t): + def _sync_gpu_to_gpu_queue(d1:HCQCompiled, d2:HCQCompiled, q1_t:Type[HWCommandQueue], q2_t:Type[HWCommandQueue]): q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \ .timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1) q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \