From 38f97aa0fed000735033260d41d68bed94148f75 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 24 Apr 2024 07:27:27 +0400 Subject: [PATCH] rename rawbufs to bufs in ExecItem (#4274) --- extra/export_model.py | 4 ++-- test/test_linearizer.py | 2 +- tinygrad/engine/jit.py | 16 ++++++++-------- tinygrad/engine/realize.py | 6 +++--- tinygrad/runtime/graph/cuda.py | 8 ++++---- tinygrad/runtime/graph/hsa.py | 10 +++++----- tinygrad/runtime/graph/metal.py | 6 +++--- 7 files changed, 26 insertions(+), 26 deletions(-) diff --git a/extra/export_model.py b/extra/export_model.py index fd7cba5ae3..20b1b78096 100644 --- a/extra/export_model.py +++ b/extra/export_model.py @@ -25,7 +25,7 @@ def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str] fxn = ji.prg functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same cargs = [] - for i,arg in enumerate(ji.rawbufs): + for i,arg in enumerate(ji.bufs): key = id(arg) if key not in bufs: if key in special_names: @@ -55,7 +55,7 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]: # hack to put the inputs back for (j,i),idx in run.input_replace.items(): realized_input = args[idx].lazydata.base.realized - run.jit_cache[j].rawbufs[i] = realized_input + run.jit_cache[j].bufs[i] = realized_input special_names[id(realized_input)] = f'input{idx}' # TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index f34e9da3ba..f9561f8520 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -58,7 +58,7 @@ class TestLinearizer(unittest.TestCase): c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))) lowered = list(lower_schedule(create_schedule([c.lazydata]))) for ei in lowered: ei.run() - rawbufs = lowered[-1].rawbufs + rawbufs = lowered[-1].bufs assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized} np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:]) np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 2d5b924414..b6d1ddd7d9 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -46,8 +46,8 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer] for ji in jit_cache: ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None. if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device - elif isinstance(ji.prg, BufferXfer) and ji.rawbufs[0] and ji.rawbufs[0].device.split(":", 1)[0] in {"HSA", "CUDA"}: - ji_graph_dev = Device[ji.rawbufs[0].device] + elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA"}: + ji_graph_dev = Device[ji.bufs[0].device] can_be_graphed = ji_graph_dev and ji_graph_dev.graph can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and (ji_graph_dev == current_device or @@ -65,7 +65,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer] def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]: input_replace: Dict[Tuple[int, int], int] = {} for j,ji in enumerate(jit_cache): - for i,a in enumerate(ji.rawbufs): + for i,a in enumerate(ji.bufs): if a in input_rawbuffers: input_replace[(j,i)] = input_rawbuffers.index(a) return input_replace @@ -83,7 +83,7 @@ class TinyJit(Generic[ReturnType]): return ret def add(self, ei:ExecItem): - self.jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.rawbufs if buf is not None])) + self.jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None])) def reset(self): self.jit_cache: List[ExecItem] = [] @@ -124,8 +124,8 @@ class TinyJit(Generic[ReturnType]): if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs") # memory planning (optional) - assigned = _internal_memory_planner([cast(List[Buffer], x.rawbufs) for x in self.jit_cache], debug_prefix="JIT ") - self.jit_cache = [ExecItem(ei.prg, [assigned.get(x,x).ensure_allocated() for x in ei.rawbufs if x is not None]) for ei in self.jit_cache] + assigned = _internal_memory_planner([cast(List[Buffer], x.bufs) for x in self.jit_cache], debug_prefix="JIT ") + self.jit_cache = [ExecItem(ei.prg, [assigned.get(x,x).ensure_allocated() for x in ei.bufs if x is not None]) for ei in self.jit_cache] # Condense the items into a graph executor. if getenv("JIT") != 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_rawbuffers, var_vals) @@ -135,12 +135,12 @@ class TinyJit(Generic[ReturnType]): elif self.cnt >= 2: # jit exec assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT" - for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx] + for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_rawbuffers[input_idx] if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels") for ei in self.jit_cache: ei.run(var_vals, jit=True) # clear jit inputs - for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None + for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None self.cnt += 1 return self.ret diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index a42b39d8a2..9aa3eb0804 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -11,9 +11,9 @@ from tinygrad.shape.symbolic import Variable, sym_infer @dataclass(frozen=True) class ExecItem: prg: Runner - rawbufs: List[Optional[Buffer]] + bufs: List[Optional[Buffer]] def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]: - et = self.prg([cast(Buffer, x).ensure_allocated() for x in self.rawbufs], var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2) + et = self.prg([cast(Buffer, x).ensure_allocated() for x in self.bufs], var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2) if do_update_stats: GlobalCounters.kernel_count += 1 GlobalCounters.global_ops += (op_estimate:=sym_infer(self.prg.op_estimate, var_vals)) @@ -21,7 +21,7 @@ class ExecItem: if et is not None: GlobalCounters.time_sum_s += et if DEBUG >= 2: ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else "" - print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(38-ansilen(self.prg.display_name))} arg {len(self.rawbufs):3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501 + print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(38-ansilen(self.prg.display_name))} arg {len(self.bufs):3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501 (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # noqa: E501 self.prg.first_run = False return et diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index 60e9d2f117..25341f16d3 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -30,17 +30,17 @@ class CUDAGraph(MultiDeviceJITGraph): global_size, local_size = ji.prg.launch_dims(var_vals) new_node = cuda.CUgraphNode() - deps = self.access_resources(ji.rawbufs[(outs:=ji.prg.outcount):], ji.rawbufs[:outs], new_dependency=new_node) + deps = self.access_resources(ji.bufs[(outs:=ji.prg.outcount):], ji.bufs[:outs], new_dependency=new_node) c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None - c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in ji.prg.vars]) + c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.vars]) kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs) check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params))) if j in self.jc_idxs_with_updatable_launch_dims or j in self.jc_idxs_with_updatable_var_vals or j in self.jc_idxs_with_updatable_rawbufs: self.updatable_nodes[j] = (new_node, kern_params, c_args, False) elif isinstance(ji.prg, BufferXfer): - dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]] + dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]] src_dev, dest_dev = cast(CUDADevice, Device[src.device]), cast(CUDADevice, Device[dest.device]) node_from = cuda.CUgraphNode() deps = self.access_resources(read=[src], write=[dest], new_dependency=node_from) @@ -68,7 +68,7 @@ class CUDAGraph(MultiDeviceJITGraph): self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0))) # clear jit inputs to allow their memory to be freed/reused - for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None + for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None super().__init__(colored(f"", "cyan"), "CUDA", *get_jit_stats(jit_cache)) def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]: diff --git a/tinygrad/runtime/graph/hsa.py b/tinygrad/runtime/graph/hsa.py index 8d5de5442f..4179915b45 100644 --- a/tinygrad/runtime/graph/hsa.py +++ b/tinygrad/runtime/graph/hsa.py @@ -37,7 +37,7 @@ class HSAGraph(MultiDeviceJITGraph): for ji in self.jit_cache: if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.device) elif isinstance(ji.prg, BufferXfer): - for x in ji.rawbufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device]) + for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device]) else: raise GraphException if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException @@ -55,7 +55,7 @@ class HSAGraph(MultiDeviceJITGraph): if not isinstance(ji.prg, CompiledRunner): continue self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device]) kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16) - for i in range(len(ji.rawbufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.rawbufs[i])._buf) + for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf) for i in range(len(ji.prg.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.vars[i]]) # Build queues. @@ -75,7 +75,7 @@ class HSAGraph(MultiDeviceJITGraph): for j,ji in enumerate(self.jit_cache): if isinstance(ji.prg, CompiledRunner): - wait_signals = self.access_resources(ji.rawbufs[(outs:=ji.prg.outcount):], ji.rawbufs[:outs], new_dependency=j, sync_with_aql_packets=False) + wait_signals = self.access_resources(ji.bufs[(outs:=ji.prg.outcount):], ji.bufs[:outs], new_dependency=j, sync_with_aql_packets=False) for i in range(0, len(wait_signals), 5): self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5]) self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr) @@ -85,7 +85,7 @@ class HSAGraph(MultiDeviceJITGraph): ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal) if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False)) elif isinstance(ji.prg, BufferXfer): - dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]] + dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]] dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device]) sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev]) @@ -112,7 +112,7 @@ class HSAGraph(MultiDeviceJITGraph): hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0) # clear jit inputs to allow their memory to be freed/reused - for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None + for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None super().__init__(colored(f"", "cyan"), "HSA", *get_jit_stats(jit_cache)) def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]: diff --git a/tinygrad/runtime/graph/metal.py b/tinygrad/runtime/graph/metal.py index eeba44c632..fe541797ea 100644 --- a/tinygrad/runtime/graph/metal.py +++ b/tinygrad/runtime/graph/metal.py @@ -36,13 +36,13 @@ class MetalGraph(Runner): pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)) # noqa: E501 icb_command = self.icb.indirectComputeCommandAtIndex_(j) icb_command.setComputePipelineState_(pipeline_state) - for i,b in enumerate(ji.rawbufs): + for i,b in enumerate(ji.bufs): if b is not None: icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i) all_resources.append(b._buf) var_vals_keys = list(var_vals.keys()) for i,v in enumerate(prg.vars): - icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i) + icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, var_vals_keys.index(v)*4, len(ji.bufs)+i) if j not in self.jc_idx_with_updatable_launch_dims: global_size, local_size = prg.launch_dims(var_vals) icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) @@ -52,7 +52,7 @@ class MetalGraph(Runner): if len(var_vals): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i') # clear jit inputs to allow their memory to be freed/reused - for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None + for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None super().__init__(colored(f"", "cyan"), device.dname, *get_jit_stats(jit_cache)) def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]: