Revert "fix TF32 tensor core dropped in tc_sm89 (#9798)"

This reverts commit 7c9a96824f.
This commit is contained in:
George Hotz
2025-04-09 12:27:39 +08:00
parent 1ed4eae510
commit 14928fecff
3 changed files with 15 additions and 24 deletions

View File

@@ -67,11 +67,11 @@ jobs:
DEBUG=2 CPU=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
DEBUG=2 LLVM=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
- name: Run Tensor Core GEMM (float)
run: DEBUG=2 SHOULD_USE_TC=1 python3.11 extra/gemm/simple_matmul.py | tee matmul.txt
run: DEBUG=2 python3.11 extra/gemm/simple_matmul.py | tee matmul.txt
- name: Run Tensor Core GEMM (half)
run: DEBUG=2 SHOULD_USE_TC=1 HALF=1 python3.11 extra/gemm/simple_matmul.py | tee matmul_half.txt
run: DEBUG=2 HALF=1 python3.11 extra/gemm/simple_matmul.py | tee matmul_half.txt
- name: Run Tensor Core GEMM (bfloat16)
run: DEBUG=2 SHOULD_USE_TC=1 BFLOAT16=1 python3.11 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
run: DEBUG=2 BFLOAT16=1 python3.11 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
- name: Fuzz Padded Tensor Core GEMM
run: METAL=1 M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3.11 ./extra/gemm/fuzz_matmul.py
- name: Run LLaMA
@@ -178,13 +178,13 @@ jobs:
PTX=1 ALLOW_TF32=1 NV=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
- name: Run Tensor Core GEMM (CUDA)
run: |
CUDA=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
CUDA=1 SHOULD_USE_TC=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
CUDA=1 SHOULD_USE_TC=1 ALLOW_TF32=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_tf32.txt
CUDA=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
CUDA=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
CUDA=1 ALLOW_TF32=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_tf32.txt
- name: Run Tensor Core GEMM (PTX)
run: NV=1 PTX=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt
run: NV=1 PTX=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt
- name: Run Tensor Core GEMM (NV)
run: NV=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_nv.txt
run: NV=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_nv.txt
- name: Test NV=1
run: DEBUG=2 NV=1 python -m pytest -rA test/test_tiny.py
- name: Test CUDA=1
@@ -371,9 +371,9 @@ jobs:
- name: Test tensor cores
run: |
AMD=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
AMD=1 SHOULD_USE_TC=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
AMD=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
- name: Run Tensor Core GEMM (AMD)
run: AMD=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_amd.txt
run: AMD=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_amd.txt
- name: Test AMD=1
run: DEBUG=2 AMD=1 python -m pytest -rA test/test_tiny.py
- name: Test HIP=1

View File

@@ -2,8 +2,6 @@ import numpy as np
from tinygrad.helpers import getenv
from tinygrad.dtype import _to_np_dtype
from tinygrad import dtypes, Tensor
from tinygrad.codegen.kernel import OptOps
from tinygrad.engine.realize import lower_schedule
dtype_in = (dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else
dtypes.fp8e4m3 if getenv("FP8E4M3") else dtypes.fp8e5m2 if getenv("FP8E5M2") else dtypes.float)
@@ -35,12 +33,6 @@ if __name__ == "__main__":
if i > 0 and getenv("RAND", 0) != 0:
a, b = init_matrix(M, K), init_matrix(K, N)
c = a.matmul(b, dtype=acc_dtype).realize()
if getenv("SHOULD_USE_TC"):
sched = a.matmul(b, dtype=acc_dtype).schedule()
lowered = list(lower_schedule(sched))
assert len(lowered) == 1
ei = lowered[0][1]
assert any(opt.op is OptOps.TC for opt in ei.prg.p.applied_opts), f"TC not triggered, {ei.prg.p.applied_opts}"
ref = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
res = c.numpy()

View File

@@ -338,15 +338,14 @@ class CUDARenderer(CStyleLanguage):
tc_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))]
tc_sm75 = tc_8168_f16
tc_sm89 = tc_81616 + tc_8168_f16 + tc_81632_f8
tc_sm80 = tc_81616 + tc_8168_f16
if getenv("ALLOW_TF32"): tc_sm80 += tc_8168_tf32
tc_sm89 = tc_sm80 + tc_81632_f8
if getenv("ALLOW_TF32", 0): tc_sm80 += tc_8168_tf32
tc_sm75 = tc_8168_f16
def __init__(self, arch: str):
self.arch = arch
tensor_cores_versions = [(89, CUDARenderer.tc_sm89), (80, CUDARenderer.tc_sm80), (75, CUDARenderer.tc_sm75)]
self.tensor_cores = next((tc for version, tc in tensor_cores_versions if int(arch[3:]) >= version), [])
tensor_cores_map = {89: CUDARenderer.tc_sm89, 80: CUDARenderer.tc_sm80, 75: CUDARenderer.tc_sm75}
self.tensor_cores = next((tc for version, tc in sorted(tensor_cores_map.items(), reverse=True) if int(arch[3:]) >= version), [])
def __reduce__(self): return self.__class__, (self.arch,)
# language options