From 72ddcdb4d17f4b7e0108df0179f40333f66fa175 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 24 Oct 2024 09:38:57 +0700 Subject: [PATCH] move metal tc check to renderer [pr] (#7248) --- tinygrad/renderer/cstyle.py | 4 ++-- tinygrad/runtime/ops_metal.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index e36cad2a7a..174e036892 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast -import os, math +import math from collections import defaultdict, Counter from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp, PatternMatcher, UPat from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX @@ -254,7 +254,7 @@ class MetalRenderer(CStyleLanguage): tensor_cores = [TensorCore(dims=(8,8,8),threads=[(0,2),(1,4),(0,2),(1,2)],expanded_shape=(2,2,2,2),upcast_axes=([(1,2)],[(1,2)],[(1,2)]), st1_pattern=(((1,1),(0,1),(1,0),(0,3)),((0,0),(0,2),(1,3),(1,2))),st2_pattern=(((0,0),(1,1),(1,2),(0,2),(1,0)),((0,1),(0,3),(1,3))), dtype_in=di,dtype_out=do,reduce_axes=[(0,8)]) for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),(dtypes.half,dtypes.half)]] - def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if os.uname().machine == "arm64" else [] + def __init__(self, supports_tensor_cores=False): self.tensor_cores = MetalRenderer.tensor_cores if supports_tensor_cores else [] # language options kernel_prefix = "kernel " diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 4fbfad0321..f5e469315a 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -172,8 +172,8 @@ class MetalDevice(Compiled): self.timeline_value = 0 from tinygrad.runtime.graph.metal import MetalGraph - super().__init__(device, MetalAllocator(self), MetalRenderer(), MetalCompiler(None if getenv("METAL_XCODE") else self), - functools.partial(MetalProgram, self), MetalGraph) + super().__init__(device, MetalAllocator(self), MetalRenderer(os.uname().machine == "arm64"), + MetalCompiler(None if getenv("METAL_XCODE") else self), functools.partial(MetalProgram, self), MetalGraph) def synchronize(self): for cbuf in self.mtl_buffers_in_flight: wait_check(cbuf) self.mv_in_metal.clear()