fp8 type support (#357)

* add two fp8 data types `tl.float8e4b8` and `tl.float8e5b16` to triton.
* add SW type conversion between `tl.float8e4b8/tl.float8e5b16` and `fp16`
* change flashattention to support fp8 in q/k.
This commit is contained in:
Shucai Xiao
2023-11-02 15:51:23 -05:00
committed by GitHub
parent 38f9136fc8
commit 79bebc4ffe
11 changed files with 445 additions and 131 deletions

View File

@@ -247,7 +247,13 @@ class JITFunction(KernelInterface[T]):
tys = {
"bool": "i1",
"float8e4nv": "fp8e4nv",
"float8_e4m3fn": "fp8e4nv",
"float8e4b8": "fp8e4b8",
"float8_e4m3fnuz": "fp8e4b8",
"float8e5": "fp8e5",
"float8_e5m2": "fp8e5",
"float8e5b16": "fp8e5b16",
"float8_e5m2fnuz": "fp8e5b16",
"float8e4b15": "fp8e4b15",
"float8e4b15x4": "fp8e4b15x4",
"float16": "fp16",