fp8 type support (#357)

* add two fp8 data types `tl.float8e4b8` and `tl.float8e5b16` to triton.
* add SW type conversion between `tl.float8e4b8/tl.float8e5b16` and `fp16`
* change flashattention to support fp8 in q/k.
This commit is contained in:
Shucai Xiao
2023-11-02 15:51:23 -05:00
committed by GitHub
parent 38f9136fc8
commit 79bebc4ffe
11 changed files with 445 additions and 131 deletions

View File

@@ -14,7 +14,7 @@ class TritonTypeDef<string name, string _mnemonic>
}
// Floating-point Type
def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2, F16, BF16, F32, F64], "floating-point">;
def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;