* cuda fp8

* tensor core

* tc test

* clean

* clean pm
This commit is contained in:
b1tg
2025-10-22 03:05:25 +08:00
committed by GitHub
parent 587ccc0e5c
commit 60d7e232f2
7 changed files with 43 additions and 15 deletions

View File

@@ -5,8 +5,10 @@ from tinygrad.dtype import _to_np_dtype
from tinygrad.codegen.opt import OptOps
from tinygrad.engine.realize import lower_schedule
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
dtype_in = (dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else
dtypes.fp8e4m3 if getenv("FP8E4M3") else dtypes.fp8e5m2 if getenv("FP8E5M2") else dtypes.float)
acc_dtype = (dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else
dtypes.fp8e4m3 if getenv("ACC_FP8E4M3") else dtypes.fp8e5m2 if getenv("ACC_FP8E5M2") else None)
if getenv("INT"): dtype_in, acc_dtype = dtypes.int8, dtypes.int32
if getenv("UINT"): dtype_in, acc_dtype = dtypes.uint8, dtypes.int32
@@ -14,8 +16,10 @@ N = getenv("N", 4096)
M = getenv("M", N)
K = getenv("K", N)
CNT = getenv("CNT", 10)
ATOL = getenv("ATOL", 1e-4)
RTOL = getenv("RTOL", 3e-2)
atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype_in, (1e-4, 3e-2))
ATOL, RTOL = getenv("ATOL", atol), getenv("RTOL", rtol)
INT_LOW = getenv("INT_LOW", 0)
INT_HIGH = getenv("INT_HIGH", 10)