mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
rename rawbufs to bufs in ExecItem (#4274)
This commit is contained in:
@@ -25,7 +25,7 @@ def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str]
|
||||
fxn = ji.prg
|
||||
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
|
||||
cargs = []
|
||||
for i,arg in enumerate(ji.rawbufs):
|
||||
for i,arg in enumerate(ji.bufs):
|
||||
key = id(arg)
|
||||
if key not in bufs:
|
||||
if key in special_names:
|
||||
@@ -55,7 +55,7 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
|
||||
# hack to put the inputs back
|
||||
for (j,i),idx in run.input_replace.items():
|
||||
realized_input = args[idx].lazydata.base.realized
|
||||
run.jit_cache[j].rawbufs[i] = realized_input
|
||||
run.jit_cache[j].bufs[i] = realized_input
|
||||
special_names[id(realized_input)] = f'input{idx}'
|
||||
|
||||
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
|
||||
|
||||
@@ -58,7 +58,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
|
||||
lowered = list(lower_schedule(create_schedule([c.lazydata])))
|
||||
for ei in lowered: ei.run()
|
||||
rawbufs = lowered[-1].rawbufs
|
||||
rawbufs = lowered[-1].bufs
|
||||
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized}
|
||||
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
|
||||
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
@@ -46,8 +46,8 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
||||
for ji in jit_cache:
|
||||
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
|
||||
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
|
||||
elif isinstance(ji.prg, BufferXfer) and ji.rawbufs[0] and ji.rawbufs[0].device.split(":", 1)[0] in {"HSA", "CUDA"}:
|
||||
ji_graph_dev = Device[ji.rawbufs[0].device]
|
||||
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA"}:
|
||||
ji_graph_dev = Device[ji.bufs[0].device]
|
||||
|
||||
can_be_graphed = ji_graph_dev and ji_graph_dev.graph
|
||||
can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and (ji_graph_dev == current_device or
|
||||
@@ -65,7 +65,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
||||
def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
|
||||
input_replace: Dict[Tuple[int, int], int] = {}
|
||||
for j,ji in enumerate(jit_cache):
|
||||
for i,a in enumerate(ji.rawbufs):
|
||||
for i,a in enumerate(ji.bufs):
|
||||
if a in input_rawbuffers:
|
||||
input_replace[(j,i)] = input_rawbuffers.index(a)
|
||||
return input_replace
|
||||
@@ -83,7 +83,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
return ret
|
||||
|
||||
def add(self, ei:ExecItem):
|
||||
self.jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.rawbufs if buf is not None]))
|
||||
self.jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
|
||||
|
||||
def reset(self):
|
||||
self.jit_cache: List[ExecItem] = []
|
||||
@@ -124,8 +124,8 @@ class TinyJit(Generic[ReturnType]):
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
||||
|
||||
# memory planning (optional)
|
||||
assigned = _internal_memory_planner([cast(List[Buffer], x.rawbufs) for x in self.jit_cache], debug_prefix="JIT ")
|
||||
self.jit_cache = [ExecItem(ei.prg, [assigned.get(x,x).ensure_allocated() for x in ei.rawbufs if x is not None]) for ei in self.jit_cache]
|
||||
assigned = _internal_memory_planner([cast(List[Buffer], x.bufs) for x in self.jit_cache], debug_prefix="JIT ")
|
||||
self.jit_cache = [ExecItem(ei.prg, [assigned.get(x,x).ensure_allocated() for x in ei.bufs if x is not None]) for ei in self.jit_cache]
|
||||
|
||||
# Condense the items into a graph executor.
|
||||
if getenv("JIT") != 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_rawbuffers, var_vals)
|
||||
@@ -135,12 +135,12 @@ class TinyJit(Generic[ReturnType]):
|
||||
elif self.cnt >= 2:
|
||||
# jit exec
|
||||
assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT"
|
||||
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
|
||||
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_rawbuffers[input_idx]
|
||||
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
|
||||
for ei in self.jit_cache: ei.run(var_vals, jit=True)
|
||||
|
||||
# clear jit inputs
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
|
||||
|
||||
self.cnt += 1
|
||||
return self.ret
|
||||
|
||||
@@ -11,9 +11,9 @@ from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
@dataclass(frozen=True)
|
||||
class ExecItem:
|
||||
prg: Runner
|
||||
rawbufs: List[Optional[Buffer]]
|
||||
bufs: List[Optional[Buffer]]
|
||||
def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
|
||||
et = self.prg([cast(Buffer, x).ensure_allocated() for x in self.rawbufs], var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
|
||||
et = self.prg([cast(Buffer, x).ensure_allocated() for x in self.bufs], var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
|
||||
if do_update_stats:
|
||||
GlobalCounters.kernel_count += 1
|
||||
GlobalCounters.global_ops += (op_estimate:=sym_infer(self.prg.op_estimate, var_vals))
|
||||
@@ -21,7 +21,7 @@ class ExecItem:
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
if DEBUG >= 2:
|
||||
ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
|
||||
print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(38-ansilen(self.prg.display_name))} arg {len(self.rawbufs):3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
|
||||
print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(38-ansilen(self.prg.display_name))} arg {len(self.bufs):3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
|
||||
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # noqa: E501
|
||||
self.prg.first_run = False
|
||||
return et
|
||||
|
||||
@@ -30,17 +30,17 @@ class CUDAGraph(MultiDeviceJITGraph):
|
||||
global_size, local_size = ji.prg.launch_dims(var_vals)
|
||||
|
||||
new_node = cuda.CUgraphNode()
|
||||
deps = self.access_resources(ji.rawbufs[(outs:=ji.prg.outcount):], ji.rawbufs[:outs], new_dependency=new_node)
|
||||
deps = self.access_resources(ji.bufs[(outs:=ji.prg.outcount):], ji.bufs[:outs], new_dependency=new_node)
|
||||
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
||||
|
||||
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in ji.prg.vars])
|
||||
c_args, vargs = encode_args([cast(Buffer, x)._buf for x in ji.bufs], [var_vals[x] for x in ji.prg.vars])
|
||||
kern_params = cuda.CUDA_KERNEL_NODE_PARAMS(ji.prg.clprg.prg, *global_size, *local_size, 0, None, vargs)
|
||||
check(cuda.cuGraphAddKernelNode(ctypes.byref(new_node), self.graph, c_deps, len(deps), ctypes.byref(kern_params)))
|
||||
|
||||
if j in self.jc_idxs_with_updatable_launch_dims or j in self.jc_idxs_with_updatable_var_vals or j in self.jc_idxs_with_updatable_rawbufs:
|
||||
self.updatable_nodes[j] = (new_node, kern_params, c_args, False)
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]]
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
src_dev, dest_dev = cast(CUDADevice, Device[src.device]), cast(CUDADevice, Device[dest.device])
|
||||
node_from = cuda.CUgraphNode()
|
||||
deps = self.access_resources(read=[src], write=[dest], new_dependency=node_from)
|
||||
@@ -68,7 +68,7 @@ class CUDAGraph(MultiDeviceJITGraph):
|
||||
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
|
||||
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), "CUDA", *get_jit_stats(jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
|
||||
@@ -37,7 +37,7 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
for ji in self.jit_cache:
|
||||
if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.device)
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
for x in ji.rawbufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
|
||||
for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
|
||||
else: raise GraphException
|
||||
if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException
|
||||
|
||||
@@ -55,7 +55,7 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
if not isinstance(ji.prg, CompiledRunner): continue
|
||||
self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device])
|
||||
kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
||||
for i in range(len(ji.rawbufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.rawbufs[i])._buf)
|
||||
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf)
|
||||
for i in range(len(ji.prg.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.vars[i]])
|
||||
|
||||
# Build queues.
|
||||
@@ -75,7 +75,7 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
wait_signals = self.access_resources(ji.rawbufs[(outs:=ji.prg.outcount):], ji.rawbufs[:outs], new_dependency=j, sync_with_aql_packets=False)
|
||||
wait_signals = self.access_resources(ji.bufs[(outs:=ji.prg.outcount):], ji.bufs[:outs], new_dependency=j, sync_with_aql_packets=False)
|
||||
for i in range(0, len(wait_signals), 5):
|
||||
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5])
|
||||
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
|
||||
@@ -85,7 +85,7 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal)
|
||||
if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]]
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device])
|
||||
sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev])
|
||||
|
||||
@@ -112,7 +112,7 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
|
||||
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), "HSA", *get_jit_stats(jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
|
||||
@@ -36,13 +36,13 @@ class MetalGraph(Runner):
|
||||
pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None)) # noqa: E501
|
||||
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
|
||||
icb_command.setComputePipelineState_(pipeline_state)
|
||||
for i,b in enumerate(ji.rawbufs):
|
||||
for i,b in enumerate(ji.bufs):
|
||||
if b is not None:
|
||||
icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
|
||||
all_resources.append(b._buf)
|
||||
var_vals_keys = list(var_vals.keys())
|
||||
for i,v in enumerate(prg.vars):
|
||||
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
|
||||
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, var_vals_keys.index(v)*4, len(ji.bufs)+i)
|
||||
if j not in self.jc_idx_with_updatable_launch_dims:
|
||||
global_size, local_size = prg.launch_dims(var_vals)
|
||||
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
@@ -52,7 +52,7 @@ class MetalGraph(Runner):
|
||||
if len(var_vals): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
|
||||
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), device.dname, *get_jit_stats(jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
|
||||
Reference in New Issue
Block a user