fp8 type support (#357)

* add two fp8 data types `tl.float8e4b8` and `tl.float8e5b16` to triton.
* add SW type conversion between `tl.float8e4b8/tl.float8e5b16` and `fp16`
* change flashattention to support fp8 in q/k.
This commit is contained in:
Shucai Xiao
2023-11-02 15:51:23 -05:00
committed by GitHub
parent 38f9136fc8
commit 79bebc4ffe
11 changed files with 445 additions and 131 deletions

View File

@@ -1070,7 +1070,9 @@ def str_to_ty(name):
return language.pointer_type(ty)
tys = {
"fp8e4nv": language.float8e4nv,
"fp8e4b8": language.float8e4b8,
"fp8e5": language.float8e5,
"fp8e5b16": language.float8e5b16,
"fp8e4b15": language.float8e4b15,
"fp8e4b15x4": language.float8e4b15x4,
"fp16": language.float16,

View File

@@ -58,7 +58,9 @@ from .core import (
float8e4b15,
float8e4b15x4,
float8e4nv,
float8e4b8,
float8e5,
float8e5b16,
function_type,
inline_asm_elementwise,
int1,

View File

@@ -75,7 +75,7 @@ def _to_tensor(x, builder):
class dtype:
SINT_TYPES = ['int8', 'int16', 'int32', 'int64']
UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64']
FP_TYPES = ['fp8e4b15', 'fp8e4b15x4', 'fp8e4nv', 'fp8e5', 'fp16', 'bf16', 'fp32', 'fp64']
FP_TYPES = ['fp8e4b15', 'fp8e4b15x4', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64']
STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
OTHER_TYPES = ['void']
@@ -107,10 +107,18 @@ class dtype:
self.fp_mantissa_width = 3
self.primitive_bitwidth = 8
self.exponent_bias = 7
elif name == 'fp8e4b8':
self.fp_mantissa_width = 3
self.primitive_bitwidth = 8
self.exponent_bias = 8
elif name == 'fp8e5':
self.fp_mantissa_width = 2
self.primitive_bitwidth = 8
self.exponent_bias = 15
elif name == 'fp8e5b16':
self.fp_mantissa_width = 2
self.primitive_bitwidth = 8
self.exponent_bias = 16
elif name == 'fp16':
self.fp_mantissa_width = 10
self.primitive_bitwidth = 16
@@ -138,6 +146,9 @@ class dtype:
def is_fp8e4nv(self):
return self.name == 'fp8e4nv'
def is_fp8e4b8(self):
return self.name == 'fp8e4b8'
def is_fp8e4b15(self):
return self.name == 'fp8e4b15'
@@ -147,6 +158,9 @@ class dtype:
def is_fp8e5(self):
return self.name == 'fp8e5'
def is_fp8e5b16(self):
return self.name == 'fp8e5b16'
def is_fp16(self):
return self.name == 'fp16'
@@ -250,8 +264,12 @@ class dtype:
return builder.get_int64_ty()
elif self.name == 'fp8e5':
return builder.get_fp8e5_ty()
elif self.name == 'fp8e5b16':
return builder.get_fp8e5b16_ty()
elif self.name == 'fp8e4nv':
return builder.get_fp8e4nv_ty()
elif self.name == 'fp8e4b8':
return builder.get_fp8e4b8_ty()
elif self.name == 'fp8e4b15':
return builder.get_fp8e4b15_ty()
elif self.name == 'fp8e4b15x4':
@@ -388,7 +406,9 @@ uint16 = dtype('uint16')
uint32 = dtype('uint32')
uint64 = dtype('uint64')
float8e5 = dtype('fp8e5')
float8e5b16 = dtype('fp8e5b16')
float8e4nv = dtype('fp8e4nv')
float8e4b8 = dtype('fp8e4b8')
float8e4b15 = dtype('fp8e4b15')
float8e4b15x4 = dtype('fp8e4b15x4')
float16 = dtype('fp16')

View File

@@ -247,7 +247,13 @@ class JITFunction(KernelInterface[T]):
tys = {
"bool": "i1",
"float8e4nv": "fp8e4nv",
"float8_e4m3fn": "fp8e4nv",
"float8e4b8": "fp8e4b8",
"float8_e4m3fnuz": "fp8e4b8",
"float8e5": "fp8e5",
"float8_e5m2": "fp8e5",
"float8e5b16": "fp8e5b16",
"float8_e5m2fnuz": "fp8e5b16",
"float8e4b15": "fp8e4b15",
"float8e4b15x4": "fp8e4b15x4",
"float16": "fp16",