metal profiler + cpu_profile (#8291)

* metal + cpu_profile

* gpt example

* linter + revert gpt2 for now

* a bit of readme

* linter

* unrelated

* tests

* linter

* b
This commit is contained in:
nimlgen
2024-12-18 00:06:56 +03:00
committed by GitHub
parent 32c995a5da
commit 777d2aec05
5 changed files with 56 additions and 38 deletions

View File

@@ -7,7 +7,7 @@ from tinygrad.engine.realize import ExecItem, CompiledRunner
from tinygrad.engine.jit import GraphRunner, GraphException
from tinygrad.ops import Variable
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
MTLResourceOptions, elapsed_time, objc_id
MTLResourceOptions, cmdbuf_st_time, cmdbuf_en_time, objc_id, to_ns_str
class MTLIndirectCommandType:
MTLIndirectCommandTypeConcurrentDispatch = (1 << 5)
@@ -90,11 +90,12 @@ class MetalGraph(GraphRunner):
msg(encoder, "executeCommandsInBuffer:withRange:", self.icb, self.range)
msg(encoder, "endEncoding")
msg(command_buffer, "setLabel:", to_ns_str(f"batched {len(self.jit_cache)}"))
msg(command_buffer, "commit")
self.command_buffer = command_buffer
self.dev.mtl_buffers_in_flight.append(command_buffer)
if wait:
wait_check(command_buffer)
return elapsed_time(command_buffer)
self.dev.mtl_buffers_in_flight.append(command_buffer)
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
return None

View File

@@ -1,8 +1,8 @@
from __future__ import annotations
import os, pathlib, struct, ctypes, tempfile, functools
import os, pathlib, struct, ctypes, tempfile, functools, decimal
from typing import List, Any, Union, Tuple, cast
from tinygrad.helpers import prod, to_mv, getenv, round_up, cache_dir, T, init_c_struct_t
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
from tinygrad.helpers import prod, to_mv, getenv, round_up, cache_dir, T, init_c_struct_t, PROFILE
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator, cpu_profile, ProfileDeviceEvent, ProfileRangeEvent
from tinygrad.renderer.cstyle import MetalRenderer
class objc_id(ctypes.c_void_p): # This prevents ctypes from converting response to plain int, and dict.fromkeys() can use it to dedup
@@ -44,6 +44,7 @@ def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id)
return sender(ptr, sel(selector), *args)
def to_ns_str(s: str): return msg(libobjc.objc_getClass(b"NSString"), "stringWithUTF8String:", s.encode(), restype=objc_instance)
def from_ns_str(s): return bytes(msg(s, "UTF8String", restype=ctypes.c_char_p)).decode()
def to_struct(*t: int, _type: type = ctypes.c_ulong): return init_c_struct_t(tuple([(f"field{i}", _type) for i in range(len(t))]))(*t)
@@ -51,12 +52,13 @@ def wait_check(cbuf: Any):
msg(cbuf, "waitUntilCompleted")
error_check(msg(cbuf, "error", restype=objc_instance))
def elapsed_time(cbuf: objc_id):
return cast(float, msg(cbuf, "GPUEndTime", restype=ctypes.c_double)) - cast(float, msg(cbuf, "GPUStartTime", restype=ctypes.c_double))
def cmdbuf_label(cbuf: objc_id) -> str: return from_ns_str(msg(cbuf, "label", restype=objc_id))
def cmdbuf_st_time(cbuf: objc_id) -> float: return cast(float, msg(cbuf, "GPUStartTime", restype=ctypes.c_double))
def cmdbuf_en_time(cbuf: objc_id) -> float: return cast(float, msg(cbuf, "GPUEndTime", restype=ctypes.c_double))
def error_check(error: objc_instance, error_constructor: type[Exception] = RuntimeError):
if error.value is None: return None
raise error_constructor(bytes(msg(msg(error, "localizedDescription", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode())
raise error_constructor(from_ns_str(msg(error, "localizedDescription", restype=objc_instance)))
def metal_src_to_library(device:MetalDevice, src:str) -> objc_instance:
options = msg(libobjc.objc_getClass(b"MTLCompileOptions"), "new", restype=objc_instance)
@@ -136,11 +138,12 @@ class MetalProgram:
for i,a in enumerate(vals, start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
msg(encoder, "endEncoding")
msg(command_buffer, "setLabel:", to_ns_str(self.name))
msg(command_buffer, "commit")
self.dev.mtl_buffers_in_flight.append(command_buffer)
if wait:
wait_check(command_buffer)
return elapsed_time(command_buffer)
self.dev.mtl_buffers_in_flight.append(command_buffer)
return cmdbuf_en_time(command_buffer) - cmdbuf_st_time(command_buffer)
class MetalBuffer:
def __init__(self, buf:Any, size:int, offset=0): self.buf, self.size, self.offset = buf, size, offset
@@ -172,11 +175,13 @@ class MetalAllocator(LRUAllocator):
src_dev.timeline_value += 1
msg(src_command_buffer, "commit")
src_dev.mtl_buffers_in_flight.append(src_command_buffer)
def _cp_mv(self, dst, src, prof_desc):
with cpu_profile(prof_desc, self.dev.device, is_copy=True): dst[:] = src
def _as_buffer(self, src:MetalBuffer) -> memoryview:
self.dev.synchronize()
return to_mv(cast(int, msg(src.buf, "contents", restype=objc_id).value), src.size + src.offset)[src.offset:]
def _copyin(self, dest:MetalBuffer, src:memoryview): self._as_buffer(dest)[:] = src
def _copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self._as_buffer(src)
def _copyin(self, dest:MetalBuffer, src:memoryview): self._cp_mv(self._as_buffer(dest), src, "CPU -> METAL")
def _copyout(self, dest:memoryview, src:MetalBuffer): self._cp_mv(dest, self._as_buffer(src), "METAL -> CPU")
def _offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
class MetalDevice(Compiled):
@@ -185,14 +190,18 @@ class MetalDevice(Compiled):
self.mtl_queue = msg(self.sysdevice, "newCommandQueueWithMaxCommandBufferCount:", 1024, restype=objc_instance)
if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
self.mtl_buffers_in_flight: List[Any] = []
self.mv_in_metal: List[memoryview] = []
self.timeline_signal = msg(self.sysdevice, "newSharedEvent", restype=objc_instance)
self.timeline_value = 0
Compiled.profile_events += [ProfileDeviceEvent(device)]
from tinygrad.runtime.graph.metal import MetalGraph
super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler() if getenv("METAL_DIRECT", 1) else Compiler(),
functools.partial(MetalProgram, self), MetalGraph)
def synchronize(self):
for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf)
self.mv_in_metal.clear()
for cbuf in self.mtl_buffers_in_flight:
wait_check(cbuf)
st, en = decimal.Decimal(cmdbuf_st_time(cbuf)) * 1000000, decimal.Decimal(cmdbuf_en_time(cbuf)) * 1000000
if PROFILE: Compiled.profile_events += [ProfileRangeEvent(self.device, cmdbuf_label(cbuf), st, en, is_copy=False)]
self.mtl_buffers_in_flight.clear()