mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Raise exception if MTLCommandBuffer fails (#3465)
This commit is contained in:
@@ -5,7 +5,7 @@ from tinygrad.helpers import dedup, unwrap2, GraphException
|
||||
from tinygrad.device import Buffer, CompiledASTRunner, update_stats
|
||||
from tinygrad.features.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.ops_metal import MetalDevice
|
||||
from tinygrad.runtime.ops_metal import MetalDevice, wait_check
|
||||
|
||||
class MetalGraph:
|
||||
def __init__(self, device:MetalDevice, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
@@ -53,7 +53,7 @@ class MetalGraph:
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
# NOTE: you at least can't update the ints if this is running
|
||||
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted()
|
||||
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
||||
all_resources = self.all_resources + [x._buf for x in input_rawbuffers]
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
|
||||
@@ -69,7 +69,7 @@ class MetalGraph:
|
||||
command_buffer.commit()
|
||||
self.command_buffer = command_buffer
|
||||
if wait:
|
||||
command_buffer.waitUntilCompleted()
|
||||
wait_check(command_buffer)
|
||||
et = command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
else:
|
||||
self.device.mtl_buffers_in_flight.append(command_buffer)
|
||||
|
||||
@@ -7,6 +7,11 @@ from tinygrad.helpers import prod, getenv, DEBUG, unwrap2
|
||||
from tinygrad.device import Compiled, LRUAllocator, Compiler
|
||||
from tinygrad.renderer.cstyle import MetalRenderer
|
||||
|
||||
def wait_check(cbuf: Any):
|
||||
cbuf.waitUntilCompleted()
|
||||
if (error := cbuf.error()) is not None:
|
||||
raise RuntimeError(error)
|
||||
|
||||
class MetalCompiler(Compiler):
|
||||
linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=os.uname().machine == "arm64")
|
||||
def __init__(self, device:Optional[MetalDevice]):
|
||||
@@ -48,7 +53,7 @@ class MetalProgram:
|
||||
encoder.endEncoding()
|
||||
command_buffer.commit()
|
||||
if wait:
|
||||
command_buffer.waitUntilCompleted()
|
||||
wait_check(command_buffer)
|
||||
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
self.device.mtl_buffers_in_flight.append(command_buffer)
|
||||
|
||||
@@ -88,6 +93,6 @@ class MetalDevice(Compiled):
|
||||
super().__init__(device, MetalAllocator(self), MetalCompiler(None if getenv("METAL_XCODE") else self),
|
||||
functools.partial(MetalProgram, self), functools.partial(MetalGraph, self))
|
||||
def synchronize(self):
|
||||
for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted()
|
||||
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
|
||||
self.mv_in_metal.clear()
|
||||
self.mtl_buffers_in_flight.clear()
|
||||
|
||||
Reference in New Issue
Block a user