mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
cuda fp8 (#12782)
* cuda fp8 * tensor core * tc test * clean * clean pm
This commit is contained in:
@@ -5,8 +5,10 @@ from tinygrad.dtype import _to_np_dtype
|
||||
from tinygrad.codegen.opt import OptOps
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
|
||||
acc_dtype = dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else None
|
||||
dtype_in = (dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else
|
||||
dtypes.fp8e4m3 if getenv("FP8E4M3") else dtypes.fp8e5m2 if getenv("FP8E5M2") else dtypes.float)
|
||||
acc_dtype = (dtypes.half if getenv("ACC_HALF") else dtypes.bfloat16 if getenv("ACC_BFLOAT16") else
|
||||
dtypes.fp8e4m3 if getenv("ACC_FP8E4M3") else dtypes.fp8e5m2 if getenv("ACC_FP8E5M2") else None)
|
||||
if getenv("INT"): dtype_in, acc_dtype = dtypes.int8, dtypes.int32
|
||||
if getenv("UINT"): dtype_in, acc_dtype = dtypes.uint8, dtypes.int32
|
||||
|
||||
@@ -14,8 +16,10 @@ N = getenv("N", 4096)
|
||||
M = getenv("M", N)
|
||||
K = getenv("K", N)
|
||||
CNT = getenv("CNT", 10)
|
||||
ATOL = getenv("ATOL", 1e-4)
|
||||
RTOL = getenv("RTOL", 3e-2)
|
||||
|
||||
atol, rtol = {dtypes.bfloat16:(1e-3, 1e-2), dtypes.fp8e4m3:(1e-1, 1e-1), dtypes.fp8e5m2:(1.0, 5e-1)}.get(dtype_in, (1e-4, 3e-2))
|
||||
ATOL, RTOL = getenv("ATOL", atol), getenv("RTOL", rtol)
|
||||
|
||||
INT_LOW = getenv("INT_LOW", 0)
|
||||
INT_HIGH = getenv("INT_HIGH", 10)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user