mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
search: catch RuntimeError when timing acted_lins (#3664)
when compilation succeeds, but runtime fails due to thread limits on METAL, this allows a beam search to proceed, treating this the same way as a compile failure.
This commit is contained in:
@@ -40,7 +40,8 @@ class HSAProgram:
|
||||
if not hasattr(self, "args_struct_t"):
|
||||
self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] +
|
||||
[(f'v{i}', ctypes.c_int) for i in range(len(vals))]))
|
||||
assert ctypes.sizeof(self.args_struct_t) == self.kernargs_segment_size, f"{ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}"
|
||||
if ctypes.sizeof(self.args_struct_t) != self.kernargs_segment_size:
|
||||
raise RuntimeError(f"HSAProgram.__call__: incorrect args struct size {ctypes.sizeof(self.args_struct_t)} != {self.kernargs_segment_size}")
|
||||
|
||||
kernargs = None
|
||||
if self.kernargs_segment_size > 0:
|
||||
|
||||
@@ -43,7 +43,7 @@ class MetalProgram:
|
||||
self.pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
|
||||
|
||||
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(),f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}" # noqa: E501
|
||||
if prod(local_size) > self.pipeline_state.maxTotalThreadsPerThreadgroup(): raise RuntimeError(f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}") # noqa: E501
|
||||
command_buffer = self.device.mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
encoder.setComputePipelineState_(self.pipeline_state)
|
||||
|
||||
Reference in New Issue
Block a user