diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index 1732f1c414..f7913b1007 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -7,6 +7,7 @@ from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos from tinygrad.device import Buffer from tinygrad.dtype import dtypes, ImageDType from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten +from tinygrad.helpers import ALLOW_TF32 from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check from tinygrad.codegen.simplify import pm_flatten_range from tinygrad.renderer import Renderer @@ -228,6 +229,7 @@ class Scheduler: except IndexError: raise KernelOptError(f"invalid tensor core choice {tc_select}") for tc in tensor_cores: + if self.ren.device in ("CUDA", "NV") and tc.dtype_in == dtypes.float and not ALLOW_TF32: continue if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar(): # tensor cores have three ranges. X, Y, and REDUCE in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: x.arg[0], reverse=True) diff --git a/tinygrad/codegen/opt/tc.py b/tinygrad/codegen/opt/tc.py index 7dbdf4b071..6a05f0bd16 100644 --- a/tinygrad/codegen/opt/tc.py +++ b/tinygrad/codegen/opt/tc.py @@ -1,7 +1,6 @@ import math, functools from dataclasses import dataclass from tinygrad.dtype import DType, dtypes -from tinygrad.helpers import getenv @dataclass(frozen=True) class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N) @@ -92,8 +91,7 @@ cuda_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2 swizzle=((('r0', 'r1', 'l2', 'l3', 'l4'), ('u1', 'r2'), ('l0', 'l1', 'u0')), (('r0', 'r1', 'u0', 'l0', 'l1'), ('u1', 'r2'), ('l2', 'l3', 'l4'))))] cuda_sm75: list[TensorCore] = cuda_8168_f16 -cuda_sm80: list[TensorCore] = cuda_81616 + cuda_8168_f16 -if getenv("ALLOW_TF32", 0): cuda_sm80 += cuda_8168_tf32 +cuda_sm80: list[TensorCore] = cuda_81616 + cuda_8168_f16 + cuda_8168_tf32 cuda_sm89: list[TensorCore] = cuda_sm80 + cuda_81632_f8 # ***** AMD ***** diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 5a48d20c93..8fbf11d074 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -200,6 +200,8 @@ DEBUG_RANGEIFY = ContextVar("DEBUG_RANGEIFY", 0) TUPLE_ORDER = ContextVar("TUPLE_ORDER", 1) # set to 0 to disable the compiler cache CCACHE = ContextVar("CCACHE", 1) +# allow tf32 to be used on NVIDIA GPUs +ALLOW_TF32 = ContextVar("ALLOW_TF32", 0) @dataclass(frozen=True) class Metadata: