mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
metal globalcounters
This commit is contained in:
@@ -12,8 +12,13 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.helpers import colored, getenv
|
||||
from extra.jit import TinyJit
|
||||
METAL = getenv("METAL")
|
||||
try:
|
||||
from tinygrad.runtime.opencl import CL
|
||||
if METAL:
|
||||
from tinygrad.runtime.metal import sync
|
||||
else:
|
||||
def sync(): CL.cl_queue.finish()
|
||||
except ImportError:
|
||||
CL = None
|
||||
|
||||
@@ -44,7 +49,7 @@ def helper_test_speed(f1, *args):
|
||||
st = time.monotonic()
|
||||
ret = f1(*args)
|
||||
if isinstance(ret, Tensor) and CL is not None and ret.device in ["GPU"]:
|
||||
CL.cl_queue.finish()
|
||||
sync()
|
||||
if not isinstance(ret, Tensor) and torch_device != "cpu":
|
||||
# TODO: better way to sync?
|
||||
torch.zeros(1, device=torch_device).cpu()
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import Metal # type: ignore
|
||||
import numpy as np
|
||||
from typing import List, Any
|
||||
from tinygrad.ops import DEBUG
|
||||
from tinygrad.ops import DEBUG, GlobalCounters
|
||||
|
||||
device = Metal.MTLCreateSystemDefaultDevice()
|
||||
mtl_queue = device.newCommandQueue()
|
||||
@@ -49,6 +49,7 @@ class CLProgram:
|
||||
global_size += [1] * (3-len(global_size))
|
||||
if local_size is None: local_size = []
|
||||
local_size += [1] * (3-len(local_size))
|
||||
if DEBUG >= 2: print("METAL launch", global_size, local_size)
|
||||
pipeline_state = device.newComputePipelineStateWithFunction_error_(self.fxn, None)
|
||||
assert pipeline_state[0] is not None, str(pipeline_state)
|
||||
command_buffer = mtl_queue.commandBuffer()
|
||||
@@ -60,4 +61,5 @@ class CLProgram:
|
||||
encoder.endEncoding()
|
||||
command_buffer.commit()
|
||||
mtl_buffers_in_flight.append(command_buffer)
|
||||
GlobalCounters.log_kernel(self.op_estimate, self.mem_estimate)
|
||||
return command_buffer
|
||||
|
||||
Reference in New Issue
Block a user