mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
No more metal flakiness (#3643)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import os, subprocess, pathlib, ctypes, tempfile, functools
|
||||
import Metal, libdispatch
|
||||
from typing import List, Any, Tuple, Optional
|
||||
from typing import List, Set, Any, Tuple, Optional
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, unwrap2
|
||||
from tinygrad.device import Compiled, LRUAllocator, Compiler
|
||||
@@ -60,7 +60,13 @@ class MetalProgram:
|
||||
class MetalAllocator(LRUAllocator):
|
||||
def __init__(self, device:MetalDevice):
|
||||
self.device:MetalDevice = device
|
||||
self.track_cross_device: Set[MetalDevice] = set()
|
||||
super().__init__()
|
||||
def free_cache(self):
|
||||
self.device.synchronize()
|
||||
for x in self.track_cross_device: x.synchronize()
|
||||
self.track_cross_device.clear()
|
||||
return super().free_cache()
|
||||
def _alloc(self, size:int) -> Any:
|
||||
ret = self.device.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
|
||||
if ret is None: raise MemoryError(f"Metal OOM while allocating {size=}")
|
||||
@@ -89,6 +95,7 @@ class MetalDevice(Compiled):
|
||||
self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
|
||||
self.mtl_buffers_in_flight: List[Any] = []
|
||||
self.mv_in_metal: List[memoryview] = []
|
||||
self.track_cross_buffer: List[Any] = []
|
||||
from tinygrad.runtime.graph.metal import MetalGraph
|
||||
super().__init__(device, MetalAllocator(self), MetalCompiler(None if getenv("METAL_XCODE") else self),
|
||||
functools.partial(MetalProgram, self), functools.partial(MetalGraph, self))
|
||||
@@ -96,3 +103,4 @@ class MetalDevice(Compiled):
|
||||
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
|
||||
self.mv_in_metal.clear()
|
||||
self.mtl_buffers_in_flight.clear()
|
||||
self.track_cross_buffer.clear()
|
||||
|
||||
Reference in New Issue
Block a user