mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
fix fp8 vectorization (#12977)
* fix fp8 vectorization * add fp8 tc to benchmark
This commit is contained in:
1
.github/workflows/benchmark.yml
vendored
1
.github/workflows/benchmark.yml
vendored
@@ -211,6 +211,7 @@ jobs:
|
||||
CUDA=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
|
||||
CUDA=1 SHOULD_USE_TC=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
|
||||
CUDA=1 SHOULD_USE_TC=1 ALLOW_TF32=1 DEBUG=2 ATOL=2e-2 python3 extra/gemm/simple_matmul.py | tee matmul_tf32.txt
|
||||
CUDA=1 SHOULD_USE_TC=1 FP8E4M3=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_fp8.txt
|
||||
- name: Run Tensor Core GEMM (PTX)
|
||||
run: NV=1 NV_PTX=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt
|
||||
- name: Run Tensor Core GEMM (NV)
|
||||
|
||||
@@ -148,7 +148,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
if ctx is not None and ctx.device == "DSP":
|
||||
lengths = [128,64,32,16,8,4]
|
||||
must_divide = False
|
||||
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
|
||||
elif buf.dtype.base not in (dtypes.float, dtypes.half, *dtypes.fp8s) and not isinstance(buf.dtype, ImageDType):
|
||||
pass
|
||||
elif buf.ptrdtype.addrspace == AddrSpace.REG:
|
||||
pass
|
||||
|
||||
@@ -388,7 +388,7 @@ class CUDARenderer(CStyleLanguage):
|
||||
if any(dt.scalar() == dtypes.half for dt in used_dtypes): prefix.append("#include <cuda_fp16.h>")
|
||||
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("#include <cuda_bf16.h>")
|
||||
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if (dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16})
|
||||
or (dt.count in (8,16) and dt.scalar() in dtypes.fp8s)]
|
||||
or (dt.count in (2,4,8,16) and dt.scalar() in dtypes.fp8s)]
|
||||
dt_map_in = { dtypes.float: "tf32", dtypes.half: "f16", dtypes.bfloat16: "bf16", dtypes.fp8e4m3: "e4m3", dtypes.fp8e5m2: "e5m2" }
|
||||
dt_map_out = { dtypes.float: "f32", dtypes.half: "f16" }
|
||||
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in wmma_args(uops):
|
||||
|
||||
Reference in New Issue
Block a user