search: catch RuntimeError when timing acted_lins (#3664)

when compilation succeeds, but runtime fails due to thread limits
on METAL, this allows a beam search to proceed, treating this
the same way as a compile failure.
This commit is contained in:
Francis Lam
2024-03-11 13:14:03 -07:00
committed by GitHub
parent 490c5a3ec3
commit 9f13960f72
3 changed files with 5 additions and 3 deletions

View File

@@ -40,7 +40,8 @@ class HSAProgram:
if not hasattr(self, "args_struct_t"):
self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] +
[(f'v{i}', ctypes.c_int) for i in range(len(vals))]))
assert ctypes.sizeof(self.args_struct_t) == self.kernargs_segment_size, f"{ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}"
if ctypes.sizeof(self.args_struct_t) != self.kernargs_segment_size:
raise RuntimeError(f"HSAProgram.__call__: incorrect args struct size {ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}")
kernargs = None
if self.kernargs_segment_size > 0:

View File

@@ -43,7 +43,7 @@ class MetalProgram:
self.pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(),f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}" # noqa: E501
if prod(local_size) > self.pipeline_state.maxTotalThreadsPerThreadgroup(): raise RuntimeError(f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}") # noqa: E501
command_buffer = self.device.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.setComputePipelineState_(self.pipeline_state)