No more metal flakiness (#3643)

This commit is contained in:
uuuvn
2024-03-08 18:54:44 +02:00
committed by GitHub
parent e25879d50e
commit daa4034e80

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import os, subprocess, pathlib, ctypes, tempfile, functools
import Metal, libdispatch
from typing import List, Any, Tuple, Optional
from typing import List, Set, Any, Tuple, Optional
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.helpers import prod, getenv, DEBUG, unwrap2
from tinygrad.device import Compiled, LRUAllocator, Compiler
@@ -60,7 +60,13 @@ class MetalProgram:
class MetalAllocator(LRUAllocator):
def __init__(self, device:MetalDevice):
self.device:MetalDevice = device
self.track_cross_device: Set[MetalDevice] = set()
super().__init__()
def free_cache(self):
self.device.synchronize()
for x in self.track_cross_device: x.synchronize()
self.track_cross_device.clear()
return super().free_cache()
def _alloc(self, size:int) -> Any:
ret = self.device.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
if ret is None: raise MemoryError(f"Metal OOM while allocating {size=}")
@@ -89,6 +95,7 @@ class MetalDevice(Compiled):
self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
self.mtl_buffers_in_flight: List[Any] = []
self.mv_in_metal: List[memoryview] = []
self.track_cross_buffer: List[Any] = []
from tinygrad.runtime.graph.metal import MetalGraph
super().__init__(device, MetalAllocator(self), MetalCompiler(None if getenv("METAL_XCODE") else self),
functools.partial(MetalProgram, self), functools.partial(MetalGraph, self))
@@ -96,3 +103,4 @@ class MetalDevice(Compiled):
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
self.mv_in_metal.clear()
self.mtl_buffers_in_flight.clear()
self.track_cross_buffer.clear()