mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
late gate for ALLOW_TF32 (#13527)
* remove ALLOW_TF32 * the right place to put that gate
This commit is contained in:
@@ -7,6 +7,7 @@ from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
|
||||
from tinygrad.helpers import ALLOW_TF32
|
||||
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -228,6 +229,7 @@ class Scheduler:
|
||||
except IndexError:
|
||||
raise KernelOptError(f"invalid tensor core choice {tc_select}")
|
||||
for tc in tensor_cores:
|
||||
if self.ren.device in ("CUDA", "NV") and tc.dtype_in == dtypes.float and not ALLOW_TF32: continue
|
||||
if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar():
|
||||
# tensor cores have three ranges. X, Y, and REDUCE
|
||||
in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: x.arg[0], reverse=True)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import math, functools
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
|
||||
@@ -92,8 +91,7 @@ cuda_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2
|
||||
swizzle=((('r0', 'r1', 'l2', 'l3', 'l4'), ('u1', 'r2'), ('l0', 'l1', 'u0')),
|
||||
(('r0', 'r1', 'u0', 'l0', 'l1'), ('u1', 'r2'), ('l2', 'l3', 'l4'))))]
|
||||
cuda_sm75: list[TensorCore] = cuda_8168_f16
|
||||
cuda_sm80: list[TensorCore] = cuda_81616 + cuda_8168_f16
|
||||
if getenv("ALLOW_TF32", 0): cuda_sm80 += cuda_8168_tf32
|
||||
cuda_sm80: list[TensorCore] = cuda_81616 + cuda_8168_f16 + cuda_8168_tf32
|
||||
cuda_sm89: list[TensorCore] = cuda_sm80 + cuda_81632_f8
|
||||
|
||||
# ***** AMD *****
|
||||
|
||||
@@ -200,6 +200,8 @@ DEBUG_RANGEIFY = ContextVar("DEBUG_RANGEIFY", 0)
|
||||
TUPLE_ORDER = ContextVar("TUPLE_ORDER", 1)
|
||||
# set to 0 to disable the compiler cache
|
||||
CCACHE = ContextVar("CCACHE", 1)
|
||||
# allow tf32 to be used on NVIDIA GPUs
|
||||
ALLOW_TF32 = ContextVar("ALLOW_TF32", 0)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
||||
Reference in New Issue
Block a user