dname -> device [pr] (#7804)

* dname -> device [pr]

* a few more

* only one left
This commit is contained in:
George Hotz
2024-11-20 17:57:14 +08:00
committed by GitHub
parent 0a74acd90e
commit bc977fec53
14 changed files with 62 additions and 62 deletions

View File

@@ -12,7 +12,7 @@ class HIPGraph(CUDAGraph):
def __del__(self):
if hasattr(self, 'graph'): check(hip.hipGraphDestroy(self.graph))
if hasattr(self, 'instance'): check(hip.hipGraphExecDestroy(self.instance))
def set_device(self): hip_set_device(self.device)
def set_device(self): hip_set_device(self.dev)
def encode_args_info(self): return (hip.hipDeviceptr_t, (1,2,3))
def graph_create(self): return init_c_var(hip.hipGraph_t(), lambda x: check(hip.hipGraphCreate(ctypes.byref(x), 0)))
def graph_instantiate(self, graph):

View File

@@ -32,7 +32,7 @@ class HSAGraph(MultiGraphRunner):
# Check all jit items are compatible.
compiled_devices = set()
for ji in self.jit_cache:
if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.device)
if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.dev)
elif isinstance(ji.prg, BufferXfer):
for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
else: raise GraphException
@@ -43,15 +43,15 @@ class HSAGraph(MultiGraphRunner):
# Allocate kernel args.
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
for ji in self.jit_cache:
if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions()) for dev,sz in kernargs_size.items()}
# Fill initial arguments.
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
for j,ji in enumerate(self.jit_cache):
if not isinstance(ji.prg, CompiledRunner): continue
self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device])
kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.dev])
kernargs_ptrs[ji.prg.dev] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf)
for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
@@ -72,13 +72,13 @@ class HSAGraph(MultiGraphRunner):
if isinstance(ji.prg, CompiledRunner):
wait_signals = self.access_resources(ji.bufs, ji.prg.p.outs, new_dependency=j, sync_with_aql_packets=False)
for i in range(0, len(wait_signals), 5):
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5])
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
self.virt_aql_queues[ji.prg.dev].submit_barrier(wait_signals[i:i+5])
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.dev].write_addr)
sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None
self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.p.launch_dims(var_vals), #type:ignore
self.virt_aql_queues[ji.prg.dev].submit_kernel(ji.prg.clprg, *ji.prg.p.launch_dims(var_vals), #type:ignore
ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal)
if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
if PROFILE: self.profile_info[ji.prg.dev].append((sync_signal, ji.prg.clprg.name, False))
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device])

View File

@@ -142,22 +142,22 @@ class HIPAllocator(LRUAllocator):
class HIPSyncEvent(Runner):
def __init__(self, lb):
self.lb, self.device, self.dname = lb, cast(HIPDevice, Device[lb.device]), lb.device
self.lb, self.device, self.device = lb, cast(HIPDevice, Device[lb.device]), lb.device
super().__init__()
def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False):
to_mv(rawbufs[0]._buf, 4).cast("I")[0] = 0
hip_set_device(self.device.device)
check(hip.hipStreamWriteValue32(None, rawbufs[0]._buf, 1, 0))
update_stats(colored("sync", "red"), 0, 0, {}, None, 1, jit, device=self.dname)
update_stats(colored("sync", "red"), 0, 0, {}, None, 1, jit, device=self.device)
class HIPWaitEvent(Runner):
def __init__(self, device):
self.device, self.dname = cast(HIPDevice, Device[device]), device
self.device, self.device = cast(HIPDevice, Device[device]), device
super().__init__()
def __call__(self, rawbufs:List[Buffer], var_vals, wait=False, jit=False):
hip_set_device(self.device.device)
check(hip.hipStreamWaitValue32(None, rawbufs[0]._buf, 1, 1, 0xFFFFFFFF))
update_stats(colored("wait", "RED"), 0, 0, {}, None, 1, jit, device=self.dname)
update_stats(colored("wait", "RED"), 0, 0, {}, None, 1, jit, device=self.device)
if getenv("HIPCPU"):
rhip = ctypes.CDLL("/usr/local/lib/libremu.so")

View File

@@ -85,7 +85,7 @@ if __name__ == "__main__":
# remove debug sections
src = src.split("\t.file")[0]
assert '.extern .shared' not in src
prg = Program("matmul_kernel", src, dname=Device.DEFAULT,
prg = Program("matmul_kernel", src, device=Device.DEFAULT,
global_size=[M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1], local_size=[32*compiled.metadata.num_warps, 1, 1],
mem_estimate=A.nbytes() + B.nbytes() + C.nbytes())
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)

View File

@@ -50,10 +50,10 @@ class UOpsFuzzerRunner(CompiledRunner):
# setup prg
uops = list(path)
if DEBUG >= 5: print_uops(uops)
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.dname].renderer.render(name, uops), uops=uops)
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.device].renderer.render(name, uops), uops=uops)
if DEBUG >= 4: print(self.p.src)
self.lib = Device[self.p.dname].compiler.compile_cached(self.p.src)
self.clprg = Device[self.p.dname].runtime(name, self.lib)
self.lib = Device[self.p.device].compiler.compile_cached(self.p.src)
self.clprg = Device[self.p.device].runtime(name, self.lib)
for x in (rawbufs:=[init_globals[i] for i in self.p.globals]): x.copyin(init_rawbufs[x])
# verify
super().__call__(rawbufs, var_vals, wait)

View File

@@ -386,7 +386,7 @@ class TestLinearizer(unittest.TestCase):
ast = UOp(Ops.SINK, src=(store0, store1))
k = Kernel(ast)
prg = CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
prg = CompiledRunner(replace(k.to_program(), device=Device.DEFAULT))
inbufs = [x.lazydata.base.buffer]
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.arg.st.size, out.arg.dtype).allocate() for out in ast.src]
prg.exec(outbufs+inbufs)
@@ -1780,7 +1780,7 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:List[Buffer], opts=[]
lins: List[Kernel] = []
outbufs = [real_bufs[x.src[0].arg] for x in realized_ast.src]
def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), device=Device.DEFAULT))
def check_opt(opts, create_k, expected_color_size):
k = create_k()

View File

@@ -42,7 +42,7 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
for ji in jit_cache:
if ji.prg.__class__ in {EmptyOp, ViewOp}: continue
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.dev
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"CUDA", "NV", "AMD"}:
ji_graph_dev = Device[ji.bufs[0].device]
@@ -101,7 +101,7 @@ class GraphRunner(Runner): # pylint: disable=abstract-method
self.w_dependency_map: Dict[int, Any] = {}
self.r_dependency_map: Dict[int, List[Any]] = collections.defaultdict(list)
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0],
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0],
ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate))
def updated_vars(self, var_vals: Dict[Variable, int]):

View File

@@ -64,11 +64,11 @@ def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
# **************** Runners ****************
class Runner:
def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0, lds_estimate:Optional[sint]=None):
self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate, self.lds_estimate = \
True, display_name, dname, op_estimate, mem_estimate, mem_estimate if lds_estimate is None else lds_estimate
def __init__(self, display_name:str, device:str, op_estimate:sint=0, mem_estimate:sint=0, lds_estimate:Optional[sint]=None):
self.first_run, self.display_name, self.device, self.op_estimate, self.mem_estimate, self.lds_estimate = \
True, display_name, device, op_estimate, mem_estimate, mem_estimate if lds_estimate is None else lds_estimate
@property
def device(self): return Device[self.dname]
def dev(self): return Device[self.device]
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
return self(rawbufs, {} if var_vals is None else var_vals)
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]:
@@ -78,10 +78,10 @@ class CompiledRunner(Runner):
def __init__(self, p:Program, precompiled:Optional[bytes]=None):
if DEBUG >= 4: print(p.src)
self.p:Program = p
self.lib:bytes = precompiled if precompiled is not None else Device[p.dname].compiler.compile_cached(p.src)
if DEBUG >= 6: Device[p.dname].compiler.disassemble(self.lib)
self.clprg = Device[p.dname].runtime(p.function_name, self.lib)
super().__init__(p.name, p.dname, p.op_estimate, p.mem_estimate, p.lds_estimate)
self.lib:bytes = precompiled if precompiled is not None else Device[p.device].compiler.compile_cached(p.src)
if DEBUG >= 6: Device[p.device].compiler.disassemble(self.lib)
self.clprg = Device[p.device].runtime(p.function_name, self.lib)
super().__init__(p.name, p.device, p.op_estimate, p.mem_estimate, p.lds_estimate)
def __reduce__(self): return self.__class__, (self.p, self.lib)
@@ -141,18 +141,18 @@ class BufferXfer(BufferCopy):
# **************** method cache ****************
method_cache: Dict[Tuple[str, bytes, int, int, bool], CompiledRunner] = {}
def get_runner(dname:str, ast:UOp) -> CompiledRunner:
ckey = (dname, ast.key, BEAM.value, NOOPT.value, False)
def get_runner(device:str, ast:UOp) -> CompiledRunner:
ckey = (device, ast.key, BEAM.value, NOOPT.value, False)
if cret:=method_cache.get(ckey): return cret
bkey = (dname.split(":")[0], ast.key, BEAM.value, NOOPT.value, True)
bkey = (device.split(":")[0], ast.key, BEAM.value, NOOPT.value, True)
if bret:=method_cache.get(bkey):
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, dname=dname), bret.lib)
method_cache[ckey] = ret = CompiledRunner(replace(bret.p, device=device), bret.lib)
else:
prg: Program = get_kernel(Device[dname].renderer, ast).to_program()
prg: Program = get_kernel(Device[device].renderer, ast).to_program()
if getenv("FUZZ_UOPS"):
from test.external.fuzz_uops import UOpsFuzzerRunner
return UOpsFuzzerRunner(replace(prg, dname=dname))
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, dname=dname))
return UOpsFuzzerRunner(replace(prg, device=device))
method_cache[ckey] = method_cache[bkey] = ret = CompiledRunner(replace(prg, device=device))
return ret
# **************** lowering functions ****************
@@ -175,7 +175,7 @@ class ExecItem:
lds_est = sym_infer(self.prg.lds_estimate, var_vals)
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}"))
self.prg.first_run = False

View File

@@ -45,7 +45,7 @@ def _time_program(p:Program, lib:bytes, var_vals:Dict[Variable, int], rawbufs:Li
input_bufs = [rawbufs[i] for i in car.p.globals]
for _ in range(cnt):
if clear_l2:
if hasattr(dev:=Device[p.dname], 'invalidate_caches'): dev.invalidate_caches()
if hasattr(dev:=Device[p.device], 'invalidate_caches'): dev.invalidate_caches()
else:
with Context(DEBUG=0, BEAM=0, CAPTURING=0): Tensor.ones(1024,1024).contiguous().realize(do_update_stats=False)
tms.append(cast(float, car(input_bufs, var_vals, wait=True))*factor)

View File

@@ -26,7 +26,7 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
class Program:
name:str
src:str
dname:str
device:str
uops:Optional[List[UOp]]=None
mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good

View File

@@ -16,7 +16,7 @@ class HCQGraph(MultiGraphRunner):
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
for ji in self.jit_cache:
if not isinstance(ji.prg, CompiledRunner): continue
kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
kernargs_size[ji.prg.dev] += round_up(ji.prg.clprg.kernargs_alloc_size, 16)
self.kernargs_bufs: Dict[Compiled, HCQBuffer] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)) for dev,sz in kernargs_size.items()}
# Fill initial arguments.
@@ -25,7 +25,7 @@ class HCQGraph(MultiGraphRunner):
kargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
for j,ji in enumerate(self.jit_cache):
if not isinstance(ji.prg, CompiledRunner): continue
kargs_ptrs[ji.prg.device] = (kargs_ptr:=kargs_ptrs[ji.prg.device]) + round_up(ji.prg.clprg.kernargs_alloc_size, 16)
kargs_ptrs[ji.prg.dev] = (kargs_ptr:=kargs_ptrs[ji.prg.dev]) + round_up(ji.prg.clprg.kernargs_alloc_size, 16)
self.ji_args[j] = ji.prg.clprg.fill_kernargs([cast(Buffer, b)._buf for b in ji.bufs], [var_vals[v] for v in ji.prg.p.vars], kargs_ptr)
# Schedule Dependencies.
@@ -51,7 +51,7 @@ class HCQGraph(MultiGraphRunner):
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
for j,ji in enumerate(self.jit_cache):
enqueue_dev = ji.prg.device if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
enqueue_dev = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))

View File

@@ -28,13 +28,13 @@ class MetalGraph(GraphRunner):
msg(icb_descriptor, "setInheritPipelineState:", False)
msg(icb_descriptor, "setMaxKernelBufferBindCount:", 31)
self.icb = msg(self.device.device, "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:",
self.icb = msg(self.dev.device, "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:",
icb_descriptor, len(self.jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache, restype=objc_instance)
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
icb_label = bytes(msg(msg(self.icb, "description", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode()
self.needs_icb_fix = int("AGXG15XFamilyIndirectCommandBuffer" not in icb_label) # not required on M3
if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
if len(self.vars): self.int_buf = self.dev.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
all_resources = [self.int_buf.buf] if len(self.vars) else []
all_pipelines = []
for j,ji in enumerate(self.jit_cache):
@@ -55,12 +55,12 @@ class MetalGraph(GraphRunner):
self.all_resources = dedup(all_resources)
self.all_pipelines = dedup(all_pipelines)
self.command_buffer: Any = None
if len(self.vars): self.int_buf_view = self.device.allocator._as_buffer(self.int_buf).cast('i')
if len(self.vars): self.int_buf_view = self.dev.allocator._as_buffer(self.int_buf).cast('i')
self.range = to_struct(0, len(self.jit_cache))
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
if self.command_buffer is not None and self.command_buffer in self.dev.mtl_buffers_in_flight: wait_check(self.command_buffer)
all_resources = dedup(self.all_resources + [x._buf.buf for x in input_rawbuffers])
for (j,i),input_idx in self.input_replace.items():
@@ -76,7 +76,7 @@ class MetalGraph(GraphRunner):
to_struct(*cast(tuple, global_size)), to_struct(*cast(tuple, local_size)))
for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
command_buffer = msg(self.device.mtl_queue, "commandBuffer", restype=objc_instance)
command_buffer = msg(self.dev.mtl_queue, "commandBuffer", restype=objc_instance)
encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
msg(encoder, "useResources:count:usage:", (objc_id * len(all_resources))(*all_resources), len(all_resources),
MTLResourceUsage.MTLResourceUsageRead | MTLResourceUsage.MTLResourceUsageWrite)
@@ -99,5 +99,5 @@ class MetalGraph(GraphRunner):
if wait:
wait_check(command_buffer)
return elapsed_time(command_buffer)
self.device.mtl_buffers_in_flight.append(command_buffer)
self.dev.mtl_buffers_in_flight.append(command_buffer)
return None

View File

@@ -80,7 +80,7 @@ class CloudSession:
class CloudHandler(BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1'
dname: str
device: str
sessions: DefaultDict[str, CloudSession] = defaultdict(CloudSession)
def setup(self):
@@ -99,18 +99,18 @@ class CloudHandler(BaseHTTPRequestHandler):
match c:
case BufferAlloc():
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
session.buffers[c.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(c.size, c.options), c.size, c.options)
session.buffers[c.buffer_num] = (Device[CloudHandler.device].allocator.alloc(c.size, c.options), c.size, c.options)
case BufferFree():
buf,sz,buffer_options = session.buffers[c.buffer_num]
Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options)
Device[CloudHandler.device].allocator.free(buf,sz,buffer_options)
del session.buffers[c.buffer_num]
case CopyIn(): Device[CloudHandler.dname].allocator._copyin(session.buffers[c.buffer_num][0], memoryview(bytearray(req._h[c.datahash])))
case CopyIn(): Device[CloudHandler.device].allocator._copyin(session.buffers[c.buffer_num][0], memoryview(bytearray(req._h[c.datahash])))
case CopyOut():
buf,sz,_ = session.buffers[c.buffer_num]
Device[CloudHandler.dname].allocator._copyout(memoryview(ret:=bytearray(sz)), buf)
Device[CloudHandler.device].allocator._copyout(memoryview(ret:=bytearray(sz)), buf)
case ProgramAlloc():
lib = Device[CloudHandler.dname].compiler.compile_cached(req._h[c.datahash].decode())
session.programs[(c.name, c.datahash)] = Device[CloudHandler.dname].runtime(c.name, lib)
lib = Device[CloudHandler.device].compiler.compile_cached(req._h[c.datahash].decode())
session.programs[(c.name, c.datahash)] = Device[CloudHandler.device].runtime(c.name, lib)
case ProgramFree(): del session.programs[(c.name, c.datahash)]
case ProgramExec():
bufs = [session.buffers[x][0] for x in c.bufs]
@@ -118,7 +118,7 @@ class CloudHandler(BaseHTTPRequestHandler):
r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
if r is not None: ret = str(r).encode()
elif self.path == "/renderer" and method == "GET":
cls, args = Device[CloudHandler.dname].renderer.__reduce__()
cls, args = Device[CloudHandler.device].renderer.__reduce__()
ret = json.dumps((cls.__module__, cls.__name__, args)).encode()
else: status_code = 404
self.send_response(status_code)
@@ -131,8 +131,8 @@ class CloudHandler(BaseHTTPRequestHandler):
def cloud_server(port:int):
multiprocessing.current_process().name = "MainProcess"
CloudHandler.dname = getenv("CLOUDDEV", "METAL") if Device.DEFAULT == "CLOUD" else Device.DEFAULT
print(f"start cloud server on {port} with device {CloudHandler.dname}")
CloudHandler.device = getenv("CLOUDDEV", "METAL") if Device.DEFAULT == "CLOUD" else Device.DEFAULT
print(f"start cloud server on {port} with device {CloudHandler.device}")
server = HTTPServer(('', port), CloudHandler)
server.serve_forever()

View File

@@ -242,7 +242,7 @@ class HCQSignal:
raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
@contextlib.contextmanager
def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type=None, queue=None):
st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None)
if enabled and queue is not None: queue.timestamp(st)
@@ -257,7 +257,7 @@ def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
dev.timeline_value += 1
if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
if enabled and PROFILE: dev.sig_prof_records.append((cast(HCQSignal, st), cast(HCQSignal, en), desc, queue_type is dev.hw_copy_queue_t))
class HCQArgsState(Generic[ProgramType]):
def __init__(self, ptr:int, prg:ProgramType, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): self.ptr, self.prg = ptr, prg
@@ -393,7 +393,7 @@ class HCQCompiled(Compiled):
def _ensure_shared_time_base(self):
if not self.gpu2cpu_compute_time_diff.is_nan(): return
def _sync_cpu_queue(d, q_t):
def _sync_cpu_queue(d:HCQCompiled, q_t:Type[HWCommandQueue]):
q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
d.timeline_value += 1
st = time.perf_counter_ns()
@@ -411,7 +411,7 @@ class HCQCompiled(Compiled):
if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l)
if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l)
def _sync_gpu_to_gpu_queue(d1, d2, q1_t, q2_t):
def _sync_gpu_to_gpu_queue(d1:HCQCompiled, d2:HCQCompiled, q1_t:Type[HWCommandQueue], q2_t:Type[HWCommandQueue]):
q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
.timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \