metal globalcounters

This commit is contained in:
George Hotz
2023-02-17 12:02:54 -08:00
parent 67d1df80ba
commit 121bd03cbd
2 changed files with 9 additions and 2 deletions

View File

@@ -12,8 +12,13 @@ from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
from tinygrad.helpers import colored, getenv
from extra.jit import TinyJit
METAL = getenv("METAL")
try:
from tinygrad.runtime.opencl import CL
if METAL:
from tinygrad.runtime.metal import sync
else:
def sync(): CL.cl_queue.finish()
except ImportError:
CL = None
@@ -44,7 +49,7 @@ def helper_test_speed(f1, *args):
st = time.monotonic()
ret = f1(*args)
if isinstance(ret, Tensor) and CL is not None and ret.device in ["GPU"]:
CL.cl_queue.finish()
sync()
if not isinstance(ret, Tensor) and torch_device != "cpu":
# TODO: better way to sync?
torch.zeros(1, device=torch_device).cpu()

View File

@@ -2,7 +2,7 @@
import Metal # type: ignore
import numpy as np
from typing import List, Any
from tinygrad.ops import DEBUG
from tinygrad.ops import DEBUG, GlobalCounters
device = Metal.MTLCreateSystemDefaultDevice()
mtl_queue = device.newCommandQueue()
@@ -49,6 +49,7 @@ class CLProgram:
global_size += [1] * (3-len(global_size))
if local_size is None: local_size = []
local_size += [1] * (3-len(local_size))
if DEBUG >= 2: print("METAL launch", global_size, local_size)
pipeline_state = device.newComputePipelineStateWithFunction_error_(self.fxn, None)
assert pipeline_state[0] is not None, str(pipeline_state)
command_buffer = mtl_queue.commandBuffer()
@@ -60,4 +61,5 @@ class CLProgram:
encoder.endEncoding()
command_buffer.commit()
mtl_buffers_in_flight.append(command_buffer)
GlobalCounters.log_kernel(self.op_estimate, self.mem_estimate)
return command_buffer