new style metal compiler (#14632)

This commit is contained in:
Christopher Milan
2026-02-08 18:58:25 -08:00
committed by GitHub
parent 9eef9f38ad
commit 0ebb508b85
2 changed files with 7 additions and 3 deletions

View File

@@ -3,9 +3,10 @@ import os, math, sys, struct
from collections import defaultdict, Counter
from tinygrad.codegen.opt import tc
from tinygrad.uop.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, range_str, axis_letters
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX, CPU_COUNT
from tinygrad.helpers import strip_parens, getenv, prod, dedup, select_first_inited, AMX, CPU_COUNT
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, AddrSpace, truncate, float_to_bf16
from tinygrad.renderer import Renderer
from tinygrad.device import Compiler
from tinygrad.codegen.late.devectorizer import no_vectorized_alu
@@ -340,7 +341,10 @@ class IntelRenderer(OpenCLRenderer):
class MetalRenderer(CStyleLanguage):
device = "METAL"
shared_max = 32768
def __init__(self): self.tensor_cores = tc.metal if hasattr(os, 'uname') and os.uname().machine == "arm64" else []
def __init__(self):
from tinygrad.runtime.ops_metal import MetalCompiler
self.compiler = select_first_inited([MetalCompiler, Compiler], "No compiler for METAL is available")
self.tensor_cores = tc.metal if hasattr(os, 'uname') and os.uname().machine == "arm64" else []
# language options
kernel_typedef = "kernel void"

View File

@@ -42,7 +42,7 @@ class MetalDevice(Compiled):
from tinygrad.runtime.graph.metal import MetalGraph
# NOTE: GitHub CI macOS runners use paravirtualized metal which is broken with graph.
# This can be reproduced locally with any virtualization software (like utm) that can create macOS VMs with apple's own virtualization framework.
super().__init__(device, MetalAllocator(self), CompilerSet([CompilerPair(MetalRenderer, MetalCompiler), CompilerPair(MetalRenderer, Compiler)]),
super().__init__(device, MetalAllocator(self), CompilerSet([CompilerPair(MetalRenderer, None)]),
functools.partial(MetalProgram, self), MetalGraph if 'virtual' not in from_ns_str(self.sysdevice.name()).lower() else None)
def synchronize(self):