late gate for ALLOW_TF32 (#13527)

* remove ALLOW_TF32

* the right place to put that gate
This commit is contained in:
George Hotz
2025-12-02 07:51:58 -08:00
committed by GitHub
parent 6a7c58abf1
commit 037edc151c
3 changed files with 5 additions and 3 deletions

View File

@@ -7,6 +7,7 @@ from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
from tinygrad.device import Buffer
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
from tinygrad.helpers import ALLOW_TF32
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.renderer import Renderer
@@ -228,6 +229,7 @@ class Scheduler:
except IndexError:
raise KernelOptError(f"invalid tensor core choice {tc_select}")
for tc in tensor_cores:
if self.ren.device in ("CUDA", "NV") and tc.dtype_in == dtypes.float and not ALLOW_TF32: continue
if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar():
# tensor cores have three ranges. X, Y, and REDUCE
in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: x.arg[0], reverse=True)

View File

@@ -1,7 +1,6 @@
import math, functools
from dataclasses import dataclass
from tinygrad.dtype import DType, dtypes
from tinygrad.helpers import getenv
@dataclass(frozen=True)
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
@@ -92,8 +91,7 @@ cuda_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2
swizzle=((('r0', 'r1', 'l2', 'l3', 'l4'), ('u1', 'r2'), ('l0', 'l1', 'u0')),
(('r0', 'r1', 'u0', 'l0', 'l1'), ('u1', 'r2'), ('l2', 'l3', 'l4'))))]
cuda_sm75: list[TensorCore] = cuda_8168_f16
cuda_sm80: list[TensorCore] = cuda_81616 + cuda_8168_f16
if getenv("ALLOW_TF32", 0): cuda_sm80 += cuda_8168_tf32
cuda_sm80: list[TensorCore] = cuda_81616 + cuda_8168_f16 + cuda_8168_tf32
cuda_sm89: list[TensorCore] = cuda_sm80 + cuda_81632_f8
# ***** AMD *****

View File

@@ -200,6 +200,8 @@ DEBUG_RANGEIFY = ContextVar("DEBUG_RANGEIFY", 0)
TUPLE_ORDER = ContextVar("TUPLE_ORDER", 1)
# set to 0 to disable the compiler cache
CCACHE = ContextVar("CCACHE", 1)
# allow tf32 to be used on NVIDIA GPUs
ALLOW_TF32 = ContextVar("ALLOW_TF32", 0)
@dataclass(frozen=True)
class Metadata: