mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
dname -> device [pr] (#7804)
* dname -> device [pr] * a few more * only one left
This commit is contained in:
@@ -12,7 +12,7 @@ class HIPGraph(CUDAGraph):
|
||||
def __del__(self):
|
||||
if hasattr(self, 'graph'): check(hip.hipGraphDestroy(self.graph))
|
||||
if hasattr(self, 'instance'): check(hip.hipGraphExecDestroy(self.instance))
|
||||
def set_device(self): hip_set_device(self.device)
|
||||
def set_device(self): hip_set_device(self.dev)
|
||||
def encode_args_info(self): return (hip.hipDeviceptr_t, (1,2,3))
|
||||
def graph_create(self): return init_c_var(hip.hipGraph_t(), lambda x: check(hip.hipGraphCreate(ctypes.byref(x), 0)))
|
||||
def graph_instantiate(self, graph):
|
||||
|
||||
@@ -32,7 +32,7 @@ class HSAGraph(MultiGraphRunner):
|
||||
# Check all jit items are compatible.
|
||||
compiled_devices = set()
|
||||
for ji in self.jit_cache:
|
||||
if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.device)
|
||||
if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.dev)
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
|
||||
else: raise GraphException
|
||||
@@ -43,15 +43,15 @@ class HSAGraph(MultiGraphRunner):
|
||||
# Allocate kernel args.
|
||||
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
||||
for ji in self.jit_cache:
|
||||
if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
||||
if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
||||
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions()) for dev,sz in kernargs_size.items()}
|
||||
|
||||
# Fill initial arguments.
|
||||
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if not isinstance(ji.prg, CompiledRunner): continue
|
||||
self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device])
|
||||
kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
||||
self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.dev])
|
||||
kernargs_ptrs[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
||||
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf)
|
||||
for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
|
||||
|
||||
@@ -72,13 +72,13 @@ class HSAGraph(MultiGraphRunner):
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
wait_signals = self.access_resources(ji.bufs, ji.prg.p.outs, new_dependency=j, sync_with_aql_packets=False)
|
||||
for i in range(0, len(wait_signals), 5):
|
||||
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5])
|
||||
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
|
||||
self.virt_aql_queues[ji.prg.dev].submit_barrier(wait_signals[i:i+5])
|
||||
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.dev].write_addr)
|
||||
|
||||
sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None
|
||||
self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.p.launch_dims(var_vals), #type:ignore
|
||||
self.virt_aql_queues[ji.prg.dev].submit_kernel(ji.prg.clprg, *ji.prg.p.launch_dims(var_vals), #type:ignore
|
||||
ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal)
|
||||
if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
|
||||
if PROFILE: self.profile_info[ji.prg.dev].append((sync_signal, ji.prg.clprg.name, False))
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device])
|
||||
|
||||
@@ -142,22 +142,22 @@ class HIPAllocator(LRUAllocator):
|
||||
|
||||
class HIPSyncEvent(Runner):
|
||||
def __init__(self, lb):
|
||||
self.lb, self.device, self.dname = lb, cast(HIPDevice, Device[lb.device]), lb.device
|
||||
self.lb, self.device, self.device = lb, cast(HIPDevice, Device[lb.device]), lb.device
|
||||
super().__init__()
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False):
|
||||
to_mv(rawbufs[0]._buf, 4).cast("I")[0] = 0
|
||||
hip_set_device(self.device.device)
|
||||
check(hip.hipStreamWriteValue32(None, rawbufs[0]._buf, 1, 0))
|
||||
update_stats(colored("sync", "red"), 0, 0, {}, None, 1, jit, device=self.dname)
|
||||
update_stats(colored("sync", "red"), 0, 0, {}, None, 1, jit, device=self.device)
|
||||
|
||||
class HIPWaitEvent(Runner):
|
||||
def __init__(self, device):
|
||||
self.device, self.dname = cast(HIPDevice, Device[device]), device
|
||||
self.device, self.device = cast(HIPDevice, Device[device]), device
|
||||
super().__init__()
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False):
|
||||
hip_set_device(self.device.device)
|
||||
check(hip.hipStreamWaitValue32(None, rawbufs[0]._buf, 1, 1, 0xFFFFFFFF))
|
||||
update_stats(colored("wait", "RED"), 0, 0, {}, None, 1, jit, device=self.dname)
|
||||
update_stats(colored("wait", "RED"), 0, 0, {}, None, 1, jit, device=self.device)
|
||||
|
||||
if getenv("HIPCPU"):
|
||||
rhip = ctypes.CDLL("/usr/local/lib/libremu.so")
|
||||
|
||||
@@ -85,7 +85,7 @@ if __name__ == "__main__":
|
||||
# remove debug sections
|
||||
src = src.split("\t.file")[0]
|
||||
assert '.extern .shared' not in src
|
||||
prg = Program("matmul_kernel", src, dname=Device.DEFAULT,
|
||||
prg = Program("matmul_kernel", src, device=Device.DEFAULT,
|
||||
global_size=[M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1], local_size=[32*compiled.metadata.num_warps, 1, 1],
|
||||
mem_estimate=A.nbytes() + B.nbytes() + C.nbytes())
|
||||
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
|
||||
|
||||
6
test/external/fuzz_uops.py
vendored
6
test/external/fuzz_uops.py
vendored
@@ -50,10 +50,10 @@ class UOpsFuzzerRunner(CompiledRunner):
|
||||
# setup prg
|
||||
uops = list(path)
|
||||
if DEBUG >= 5: print_uops(uops)
|
||||
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.dname].renderer.render(name, uops), uops=uops)
|
||||
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.device].renderer.render(name, uops), uops=uops)
|
||||
if DEBUG >= 4: print(self.p.src)
|
||||
self.lib = Device[self.p.dname].compiler.compile_cached(self.p.src)
|
||||
self.clprg = Device[self.p.dname].runtime(name, self.lib)
|
||||
self.lib = Device[self.p.device].compiler.compile_cached(self.p.src)
|
||||
self.clprg = Device[self.p.device].runtime(name, self.lib)
|
||||
for x in (rawbufs:=[init_globals[i] for i in self.p.globals]): x.copyin(init_rawbufs[x])
|
||||
# verify
|
||||
super().__call__(rawbufs, var_vals, wait)
|
||||
|
||||
@@ -386,7 +386,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
ast = UOp(Ops.SINK, src=(store0, store1))
|
||||
k = Kernel(ast)
|
||||
prg = CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
|
||||
prg = CompiledRunner(replace(k.to_program(), device=Device.DEFAULT))
|
||||
inbufs = [x.lazydata.base.buffer]
|
||||
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.arg.st.size, out.arg.dtype).allocate() for out in ast.src]
|
||||
prg.exec(outbufs+inbufs)
|
||||
@@ -1780,7 +1780,7 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:List[Buffer], opts=[]
|
||||
lins: List[Kernel] = []
|
||||
outbufs = [real_bufs[x.src[0].arg] for x in realized_ast.src]
|
||||
|
||||
def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
|
||||
def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), device=Device.DEFAULT))
|
||||
|
||||
def check_opt(opts, create_k, expected_color_size):
|
||||
k = create_k()
|
||||
|
||||
@@ -42,7 +42,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
||||
for ji in jit_cache:
|
||||
if ji.prg.__class__ in {EmptyOp, ViewOp}: continue
|
||||
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
|
||||
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
|
||||
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.dev
|
||||
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
|
||||
ji_graph_dev = Device[ji.bufs[0].device]
|
||||
|
||||
@@ -101,7 +101,7 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
|
||||
self.w_dependency_map: Dict[int, Any] = {}
|
||||
self.r_dependency_map: Dict[int, List[Any]] = collections.defaultdict(list)
|
||||
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0],
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0],
|
||||
ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate))
|
||||
|
||||
def updated_vars(self, var_vals: Dict[Variable, int]):
|
||||
|
||||
@@ -64,11 +64,11 @@ def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
|
||||
# **************** Runners ****************
|
||||
|
||||
class Runner:
|
||||
def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0, lds_estimate:Optional[sint]=None):
|
||||
self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate, self.lds_estimate = \
|
||||
True, display_name, dname, op_estimate, mem_estimate, mem_estimate if lds_estimate is None else lds_estimate
|
||||
def __init__(self, display_name:str, device:str, op_estimate:sint=0, mem_estimate:sint=0, lds_estimate:Optional[sint]=None):
|
||||
self.first_run, self.display_name, self.device, self.op_estimate, self.mem_estimate, self.lds_estimate = \
|
||||
True, display_name, device, op_estimate, mem_estimate, mem_estimate if lds_estimate is None else lds_estimate
|
||||
@property
|
||||
def device(self): return Device[self.dname]
|
||||
def dev(self): return Device[self.device]
|
||||
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
|
||||
return self(rawbufs, {} if var_vals is None else var_vals)
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
@@ -78,10 +78,10 @@ class CompiledRunner(Runner):
|
||||
def __init__(self, p:Program, precompiled:Optional[bytes]=None):
|
||||
if DEBUG >= 4: print(p.src)
|
||||
self.p:Program = p
|
||||
self.lib:bytes = precompiled if precompiled is not None else Device[p.dname].compiler.compile_cached(p.src)
|
||||
if DEBUG >= 6: Device[p.dname].compiler.disassemble(self.lib)
|
||||
self.clprg = Device[p.dname].runtime(p.function_name, self.lib)
|
||||
super().__init__(p.name, p.dname, p.op_estimate, p.mem_estimate, p.lds_estimate)
|
||||
self.lib:bytes = precompiled if precompiled is not None else Device[p.device].compiler.compile_cached(p.src)
|
||||
if DEBUG >= 6: Device[p.device].compiler.disassemble(self.lib)
|
||||
self.clprg = Device[p.device].runtime(p.function_name, self.lib)
|
||||
super().__init__(p.name, p.device, p.op_estimate, p.mem_estimate, p.lds_estimate)
|
||||
|
||||
def __reduce__(self): return self.__class__, (self.p, self.lib)
|
||||
|
||||
@@ -141,18 +141,18 @@ class BufferXfer(BufferCopy):
|
||||
# **************** method cache ****************
|
||||
|
||||
method_cache: Dict[Tuple[str, bytes, int, int, bool], CompiledRunner] = {}
|
||||
def get_runner(dname:str, ast:UOp) -> CompiledRunner:
|
||||
ckey = (dname, ast.key, BEAM.value, NOOPT.value, False)
|
||||
def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
||||
ckey = (device, ast.key, BEAM.value, NOOPT.value, False)
|
||||
if cret:=method_cache.get(ckey): return cret
|
||||
bkey = (dname.split(":")[0], ast.key, BEAM.value, NOOPT.value, True)
|
||||
bkey = (device.split(":")[0], ast.key, BEAM.value, NOOPT.value, True)
|
||||
if bret:=method_cache.get(bkey):
|
||||
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, dname=dname), bret.lib)
|
||||
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
|
||||
else:
|
||||
prg: Program = get_kernel(Device[dname].renderer, ast).to_program()
|
||||
prg: Program = get_kernel(Device[device].renderer, ast).to_program()
|
||||
if getenv("FUZZ_UOPS"):
|
||||
from test.external.fuzz_uops import UOpsFuzzerRunner
|
||||
return UOpsFuzzerRunner(replace(prg, dname=dname))
|
||||
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, dname=dname))
|
||||
return UOpsFuzzerRunner(replace(prg, device=device))
|
||||
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))
|
||||
return ret
|
||||
|
||||
# **************** lowering functions ****************
|
||||
@@ -175,7 +175,7 @@ class ExecItem:
|
||||
lds_est = sym_infer(self.prg.lds_estimate, var_vals)
|
||||
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
|
||||
ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
|
||||
print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
|
||||
print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
|
||||
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501
|
||||
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}"))
|
||||
self.prg.first_run = False
|
||||
|
||||
@@ -45,7 +45,7 @@ def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:Li
|
||||
input_bufs = [rawbufs[i] for i in car.p.globals]
|
||||
for _ in range(cnt):
|
||||
if clear_l2:
|
||||
if hasattr(dev:=Device[p.dname], 'invalidate_caches'): dev.invalidate_caches()
|
||||
if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches()
|
||||
else:
|
||||
with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
|
||||
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)
|
||||
|
||||
@@ -26,7 +26,7 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
|
||||
class Program:
|
||||
name:str
|
||||
src:str
|
||||
dname:str
|
||||
device:str
|
||||
uops:Optional[List[UOp]]=None
|
||||
mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
||||
for ji in self.jit_cache:
|
||||
if not isinstance(ji.prg, CompiledRunner): continue
|
||||
kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
|
||||
kernargs_size[ji.prg.dev] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
|
||||
self.kernargs_bufs: Dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)) for dev,sz in kernargs_size.items()}
|
||||
|
||||
# Fill initial arguments.
|
||||
@@ -25,7 +25,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
kargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if not isinstance(ji.prg, CompiledRunner): continue
|
||||
kargs_ptrs[ji.prg.device] = (kargs_ptr:=kargs_ptrs[ji.prg.device]) + round_up(ji.prg.clprg.kernargs_alloc_size, 16)
|
||||
kargs_ptrs[ji.prg.dev] = (kargs_ptr:=kargs_ptrs[ji.prg.dev]) + round_up(ji.prg.clprg.kernargs_alloc_size, 16)
|
||||
self.ji_args[j] = ji.prg.clprg.fill_kernargs([cast(Buffer, b)._buf for b in ji.bufs], [var_vals[v] for v in ji.prg.p.vars], kargs_ptr)
|
||||
|
||||
# Schedule Dependencies.
|
||||
@@ -51,7 +51,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
enqueue_dev = ji.prg.device if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
|
||||
enqueue_dev = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
|
||||
enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
|
||||
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
|
||||
|
||||
|
||||
@@ -28,13 +28,13 @@ class MetalGraph(GraphRunner):
|
||||
msg(icb_descriptor, "setInheritPipelineState:", False)
|
||||
msg(icb_descriptor, "setMaxKernelBufferBindCount:", 31)
|
||||
|
||||
self.icb = msg(self.device.device, "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:",
|
||||
self.icb = msg(self.dev.device, "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:",
|
||||
icb_descriptor, len(self.jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache, restype=objc_instance)
|
||||
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
||||
icb_label = bytes(msg(msg(self.icb, "description", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode()
|
||||
self.needs_icb_fix = int("AGXG15XFamilyIndirectCommandBuffer" not in icb_label) # not required on M3
|
||||
|
||||
if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
|
||||
if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
|
||||
all_resources = [self.int_buf.buf] if len(self.vars) else []
|
||||
all_pipelines = []
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
@@ -55,12 +55,12 @@ class MetalGraph(GraphRunner):
|
||||
self.all_resources = dedup(all_resources)
|
||||
self.all_pipelines = dedup(all_pipelines)
|
||||
self.command_buffer: Any = None
|
||||
if len(self.vars): self.int_buf_view = self.device.allocator._as_buffer(self.int_buf).cast('i')
|
||||
if len(self.vars): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
|
||||
self.range = to_struct(0, len(self.jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
|
||||
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
||||
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
||||
all_resources = dedup(self.all_resources + [x._buf.buf for x in input_rawbuffers])
|
||||
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
@@ -76,7 +76,7 @@ class MetalGraph(GraphRunner):
|
||||
to_struct(*cast(tuple, global_size)), to_struct(*cast(tuple, local_size)))
|
||||
for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
|
||||
|
||||
command_buffer = msg(self.device.mtl_queue, "commandBuffer", restype=objc_instance)
|
||||
command_buffer = msg(self.dev.mtl_queue, "commandBuffer", restype=objc_instance)
|
||||
encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
|
||||
msg(encoder, "useResources:count:usage:", (objc_id * len(all_resources))(*all_resources), len(all_resources),
|
||||
MTLResourceUsage.MTLResourceUsageRead | MTLResourceUsage.MTLResourceUsageWrite)
|
||||
@@ -99,5 +99,5 @@ class MetalGraph(GraphRunner):
|
||||
if wait:
|
||||
wait_check(command_buffer)
|
||||
return elapsed_time(command_buffer)
|
||||
self.device.mtl_buffers_in_flight.append(command_buffer)
|
||||
self.dev.mtl_buffers_in_flight.append(command_buffer)
|
||||
return None
|
||||
|
||||
@@ -80,7 +80,7 @@ class CloudSession:
|
||||
|
||||
class CloudHandler(BaseHTTPRequestHandler):
|
||||
protocol_version = 'HTTP/1.1'
|
||||
dname: str
|
||||
device: str
|
||||
sessions: DefaultDict[str, CloudSession] = defaultdict(CloudSession)
|
||||
|
||||
def setup(self):
|
||||
@@ -99,18 +99,18 @@ class CloudHandler(BaseHTTPRequestHandler):
|
||||
match c:
|
||||
case BufferAlloc():
|
||||
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
|
||||
session.buffers[c.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(c.size, c.options), c.size, c.options)
|
||||
session.buffers[c.buffer_num] = (Device[CloudHandler.device].allocator.alloc(c.size, c.options), c.size, c.options)
|
||||
case BufferFree():
|
||||
buf,sz,buffer_options = session.buffers[c.buffer_num]
|
||||
Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options)
|
||||
Device[CloudHandler.device].allocator.free(buf,sz,buffer_options)
|
||||
del session.buffers[c.buffer_num]
|
||||
case CopyIn(): Device[CloudHandler.dname].allocator._copyin(session.buffers[c.buffer_num][0], memoryview(bytearray(req._h[c.datahash])))
|
||||
case CopyIn(): Device[CloudHandler.device].allocator._copyin(session.buffers[c.buffer_num][0], memoryview(bytearray(req._h[c.datahash])))
|
||||
case CopyOut():
|
||||
buf,sz,_ = session.buffers[c.buffer_num]
|
||||
Device[CloudHandler.dname].allocator._copyout(memoryview(ret:=bytearray(sz)), buf)
|
||||
Device[CloudHandler.device].allocator._copyout(memoryview(ret:=bytearray(sz)), buf)
|
||||
case ProgramAlloc():
|
||||
lib = Device[CloudHandler.dname].compiler.compile_cached(req._h[c.datahash].decode())
|
||||
session.programs[(c.name, c.datahash)] = Device[CloudHandler.dname].runtime(c.name, lib)
|
||||
lib = Device[CloudHandler.device].compiler.compile_cached(req._h[c.datahash].decode())
|
||||
session.programs[(c.name, c.datahash)] = Device[CloudHandler.device].runtime(c.name, lib)
|
||||
case ProgramFree(): del session.programs[(c.name, c.datahash)]
|
||||
case ProgramExec():
|
||||
bufs = [session.buffers[x][0] for x in c.bufs]
|
||||
@@ -118,7 +118,7 @@ class CloudHandler(BaseHTTPRequestHandler):
|
||||
r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
|
||||
if r is not None: ret = str(r).encode()
|
||||
elif self.path == "/renderer" and method == "GET":
|
||||
cls, args = Device[CloudHandler.dname].renderer.__reduce__()
|
||||
cls, args = Device[CloudHandler.device].renderer.__reduce__()
|
||||
ret = json.dumps((cls.__module__, cls.__name__, args)).encode()
|
||||
else: status_code = 404
|
||||
self.send_response(status_code)
|
||||
@@ -131,8 +131,8 @@ class CloudHandler(BaseHTTPRequestHandler):
|
||||
|
||||
def cloud_server(port:int):
|
||||
multiprocessing.current_process().name = "MainProcess"
|
||||
CloudHandler.dname = getenv("CLOUDDEV", "METAL") if Device.DEFAULT == "CLOUD" else Device.DEFAULT
|
||||
print(f"start cloud server on {port} with device {CloudHandler.dname}")
|
||||
CloudHandler.device = getenv("CLOUDDEV", "METAL") if Device.DEFAULT == "CLOUD" else Device.DEFAULT
|
||||
print(f"start cloud server on {port} with device {CloudHandler.device}")
|
||||
server = HTTPServer(('', port), CloudHandler)
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
@@ -242,7 +242,7 @@ class HCQSignal:
|
||||
raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
|
||||
|
||||
@contextlib.contextmanager
|
||||
def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
|
||||
def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type=None, queue=None):
|
||||
st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None)
|
||||
|
||||
if enabled and queue is not None: queue.timestamp(st)
|
||||
@@ -257,7 +257,7 @@ def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
|
||||
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
|
||||
dev.timeline_value += 1
|
||||
|
||||
if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
|
||||
if enabled and PROFILE: dev.sig_prof_records.append((cast(HCQSignal, st), cast(HCQSignal, en), desc, queue_type is dev.hw_copy_queue_t))
|
||||
|
||||
class HCQArgsState(Generic[ProgramType]):
|
||||
def __init__(self, ptr:int, prg:ProgramType, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): self.ptr, self.prg = ptr, prg
|
||||
@@ -393,7 +393,7 @@ class HCQCompiled(Compiled):
|
||||
def _ensure_shared_time_base(self):
|
||||
if not self.gpu2cpu_compute_time_diff.is_nan(): return
|
||||
|
||||
def _sync_cpu_queue(d, q_t):
|
||||
def _sync_cpu_queue(d:HCQCompiled, q_t:Type[HWCommandQueue]):
|
||||
q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
|
||||
d.timeline_value += 1
|
||||
st = time.perf_counter_ns()
|
||||
@@ -411,7 +411,7 @@ class HCQCompiled(Compiled):
|
||||
if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l)
|
||||
if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l)
|
||||
|
||||
def _sync_gpu_to_gpu_queue(d1, d2, q1_t, q2_t):
|
||||
def _sync_gpu_to_gpu_queue(d1:HCQCompiled, d2:HCQCompiled, q1_t:Type[HWCommandQueue], q2_t:Type[HWCommandQueue]):
|
||||
q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
|
||||
.timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
|
||||
q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \
|
||||
|
||||
Reference in New Issue
Block a user