rename rawbufs to bufs in ExecItem (#4274)

This commit is contained in:
George Hotz
2024-04-24 07:27:27 +04:00
committed by GitHub
parent 60e3aa5cb1
commit 38f97aa0fe
7 changed files with 26 additions and 26 deletions

View File

@@ -25,7 +25,7 @@ def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str]
fxn = ji.prg
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
cargs = []
for i,arg in enumerate(ji.rawbufs):
for i,arg in enumerate(ji.bufs):
key = id(arg)
if key not in bufs:
if key in special_names:
@@ -55,7 +55,7 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
# hack to put the inputs back
for (j,i),idx in run.input_replace.items():
realized_input = args[idx].lazydata.base.realized
run.jit_cache[j].rawbufs[i] = realized_input
run.jit_cache[j].bufs[i] = realized_input
special_names[id(realized_input)] = f'input{idx}'
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)

View File

@@ -58,7 +58,7 @@ class TestLinearizer(unittest.TestCase):
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
lowered = list(lower_schedule(create_schedule([c.lazydata])))
for ei in lowered: ei.run()
rawbufs = lowered[-1].rawbufs
rawbufs = lowered[-1].bufs
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized}
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)

View File

@@ -46,8 +46,8 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
for ji in jit_cache:
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
elif isinstance(ji.prg, BufferXfer) and ji.rawbufs[0] and ji.rawbufs[0].device.split(":", 1)[0] in {"HSA", "CUDA"}:
ji_graph_dev = Device[ji.rawbufs[0].device]
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA"}:
ji_graph_dev = Device[ji.bufs[0].device]
can_be_graphed = ji_graph_dev and ji_graph_dev.graph
can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and (ji_graph_dev == current_device or
@@ -65,7 +65,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
input_replace: Dict[Tuple[int, int], int] = {}
for j,ji in enumerate(jit_cache):
for i,a in enumerate(ji.rawbufs):
for i,a in enumerate(ji.bufs):
if a in input_rawbuffers:
input_replace[(j,i)] = input_rawbuffers.index(a)
return input_replace
@@ -83,7 +83,7 @@ class TinyJit(Generic[ReturnType]):
return ret
def add(self, ei:ExecItem):
self.jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.rawbufs if buf is not None]))
self.jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
def reset(self):
self.jit_cache: List[ExecItem] = []
@@ -124,8 +124,8 @@ class TinyJit(Generic[ReturnType]):
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
# memory planning (optional)
assigned = _internal_memory_planner([cast(List[Buffer], x.rawbufs) for x in self.jit_cache], debug_prefix="JIT ")
self.jit_cache = [ExecItem(ei.prg, [assigned.get(x,x).ensure_allocated() for x in ei.rawbufs if x is not None]) for ei in self.jit_cache]
assigned = _internal_memory_planner([cast(List[Buffer], x.bufs) for x in self.jit_cache], debug_prefix="JIT ")
self.jit_cache = [ExecItem(ei.prg, [assigned.get(x,x).ensure_allocated() for x in ei.bufs if x is not None]) for ei in self.jit_cache]
# Condense the items into a graph executor.
if getenv("JIT") != 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_rawbuffers, var_vals)
@@ -135,12 +135,12 @@ class TinyJit(Generic[ReturnType]):
elif self.cnt >= 2:
# jit exec
assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT"
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_rawbuffers[input_idx]
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
for ei in self.jit_cache: ei.run(var_vals, jit=True)
# clear jit inputs
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
self.cnt += 1
return self.ret

View File

@@ -11,9 +11,9 @@ from tinygrad.shape.symbolic import Variable, sym_infer
@dataclass(frozen=True)
class ExecItem:
prg: Runner
rawbufs: List[Optional[Buffer]]
bufs: List[Optional[Buffer]]
def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
et = self.prg([cast(Buffer, x).ensure_allocated() for x in self.rawbufs], var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
et = self.prg([cast(Buffer, x).ensure_allocated() for x in self.bufs], var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
if do_update_stats:
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += (op_estimate:=sym_infer(self.prg.op_estimate, var_vals))
@@ -21,7 +21,7 @@ class ExecItem:
if et is not None: GlobalCounters.time_sum_s += et
if DEBUG >= 2:
ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(38-ansilen(self.prg.display_name))} arg {len(self.rawbufs):3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(38-ansilen(self.prg.display_name))} arg {len(self.bufs):3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # noqa: E501
self.prg.first_run = False
return et

View File

@@ -30,17 +30,17 @@ class CUDAGraph(MultiDeviceJITGraph):
global_size, local_size = ji.prg.launch_dims(var_vals)
new_node = cuda.CUgraphNode()
deps = self.access_resources(ji.rawbufs[(outs:=ji.prg.outcount):], ji.rawbufs[:outs], new_dependency=new_node)
deps = self.access_resources(ji.bufs[(outs:=ji.prg.outcount):], ji.bufs[:outs], new_dependency=new_node)
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in ji.prg.vars])
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.vars])
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs)
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
if j in self.jc_idxs_with_updatable_launch_dims or j in self.jc_idxs_with_updatable_var_vals or j in self.jc_idxs_with_updatable_rawbufs:
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]]
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
src_dev, dest_dev = cast(CUDADevice, Device[src.device]), cast(CUDADevice, Device[dest.device])
node_from = cuda.CUgraphNode()
deps = self.access_resources(read=[src], write=[dest], new_dependency=node_from)
@@ -68,7 +68,7 @@ class CUDAGraph(MultiDeviceJITGraph):
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
# clear jit inputs to allow their memory to be freed/reused
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), "CUDA", *get_jit_stats(jit_cache))
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:

View File

@@ -37,7 +37,7 @@ class HSAGraph(MultiDeviceJITGraph):
for ji in self.jit_cache:
if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.device)
elif isinstance(ji.prg, BufferXfer):
for x in ji.rawbufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
else: raise GraphException
if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException
@@ -55,7 +55,7 @@ class HSAGraph(MultiDeviceJITGraph):
if not isinstance(ji.prg, CompiledRunner): continue
self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device])
kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
for i in range(len(ji.rawbufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.rawbufs[i])._buf)
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf)
for i in range(len(ji.prg.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.vars[i]])
# Build queues.
@@ -75,7 +75,7 @@ class HSAGraph(MultiDeviceJITGraph):
for j,ji in enumerate(self.jit_cache):
if isinstance(ji.prg, CompiledRunner):
wait_signals = self.access_resources(ji.rawbufs[(outs:=ji.prg.outcount):], ji.rawbufs[:outs], new_dependency=j, sync_with_aql_packets=False)
wait_signals = self.access_resources(ji.bufs[(outs:=ji.prg.outcount):], ji.bufs[:outs], new_dependency=j, sync_with_aql_packets=False)
for i in range(0, len(wait_signals), 5):
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5])
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
@@ -85,7 +85,7 @@ class HSAGraph(MultiDeviceJITGraph):
ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal)
if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]]
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device])
sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev])
@@ -112,7 +112,7 @@ class HSAGraph(MultiDeviceJITGraph):
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
# clear jit inputs to allow their memory to be freed/reused
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), "HSA", *get_jit_stats(jit_cache))
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:

View File

@@ -36,13 +36,13 @@ class MetalGraph(Runner):
pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)) # noqa: E501
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
icb_command.setComputePipelineState_(pipeline_state)
for i,b in enumerate(ji.rawbufs):
for i,b in enumerate(ji.bufs):
if b is not None:
icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
all_resources.append(b._buf)
var_vals_keys = list(var_vals.keys())
for i,v in enumerate(prg.vars):
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, var_vals_keys.index(v)*4, len(ji.bufs)+i)
if j not in self.jc_idx_with_updatable_launch_dims:
global_size, local_size = prg.launch_dims(var_vals)
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
@@ -52,7 +52,7 @@ class MetalGraph(Runner):
if len(var_vals): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
# clear jit inputs to allow their memory to be freed/reused
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), device.dname, *get_jit_stats(jit_cache))
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]: