From daa4034e80345b13130e2e5e97459291fb2eb846 Mon Sep 17 00:00:00 2001 From: uuuvn <83587632+uuuvn@users.noreply.github.com> Date: Fri, 8 Mar 2024 18:54:44 +0200 Subject: [PATCH] No more metal flakiness (#3643) --- tinygrad/runtime/ops_metal.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 4c9600da3a..f9154403c2 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -1,7 +1,7 @@ from __future__ import annotations import os, subprocess, pathlib, ctypes, tempfile, functools import Metal, libdispatch -from typing import List, Any, Tuple, Optional +from typing import List, Set, Any, Tuple, Optional from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.helpers import prod, getenv, DEBUG, unwrap2 from tinygrad.device import Compiled, LRUAllocator, Compiler @@ -60,7 +60,13 @@ class MetalProgram: class MetalAllocator(LRUAllocator): def __init__(self, device:MetalDevice): self.device:MetalDevice = device + self.track_cross_device: Set[MetalDevice] = set() super().__init__() + def free_cache(self): + self.device.synchronize() + for x in self.track_cross_device: x.synchronize() + self.track_cross_device.clear() + return super().free_cache() def _alloc(self, size:int) -> Any: ret = self.device.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared) if ret is None: raise MemoryError(f"Metal OOM while allocating {size=}") @@ -89,6 +95,7 @@ class MetalDevice(Compiled): self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024) self.mtl_buffers_in_flight: List[Any] = [] self.mv_in_metal: List[memoryview] = [] + self.track_cross_buffer: List[Any] = [] from tinygrad.runtime.graph.metal import MetalGraph super().__init__(device, MetalAllocator(self), MetalCompiler(None if getenv("METAL_XCODE") else self), functools.partial(MetalProgram, self), functools.partial(MetalGraph, self)) @@ -96,3 +103,4 @@ class MetalDevice(Compiled): for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf) self.mv_in_metal.clear() self.mtl_buffers_in_flight.clear() + self.track_cross_buffer.clear()