mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
fp8 type support (#357)
* add two fp8 data types `tl.float8e4b8` and `tl.float8e5b16` to triton. * add SW type conversion between `tl.float8e4b8/tl.float8e5b16` and `fp16` * change flashattention to support fp8 in q/k.
This commit is contained in:
@@ -1070,7 +1070,9 @@ def str_to_ty(name):
|
||||
return language.pointer_type(ty)
|
||||
tys = {
|
||||
"fp8e4nv": language.float8e4nv,
|
||||
"fp8e4b8": language.float8e4b8,
|
||||
"fp8e5": language.float8e5,
|
||||
"fp8e5b16": language.float8e5b16,
|
||||
"fp8e4b15": language.float8e4b15,
|
||||
"fp8e4b15x4": language.float8e4b15x4,
|
||||
"fp16": language.float16,
|
||||
|
||||
@@ -58,7 +58,9 @@ from .core import (
|
||||
float8e4b15,
|
||||
float8e4b15x4,
|
||||
float8e4nv,
|
||||
float8e4b8,
|
||||
float8e5,
|
||||
float8e5b16,
|
||||
function_type,
|
||||
inline_asm_elementwise,
|
||||
int1,
|
||||
|
||||
@@ -75,7 +75,7 @@ def _to_tensor(x, builder):
|
||||
class dtype:
|
||||
SINT_TYPES = ['int8', 'int16', 'int32', 'int64']
|
||||
UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64']
|
||||
FP_TYPES = ['fp8e4b15', 'fp8e4b15x4', 'fp8e4nv', 'fp8e5', 'fp16', 'bf16', 'fp32', 'fp64']
|
||||
FP_TYPES = ['fp8e4b15', 'fp8e4b15x4', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64']
|
||||
STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
|
||||
OTHER_TYPES = ['void']
|
||||
|
||||
@@ -107,10 +107,18 @@ class dtype:
|
||||
self.fp_mantissa_width = 3
|
||||
self.primitive_bitwidth = 8
|
||||
self.exponent_bias = 7
|
||||
elif name == 'fp8e4b8':
|
||||
self.fp_mantissa_width = 3
|
||||
self.primitive_bitwidth = 8
|
||||
self.exponent_bias = 8
|
||||
elif name == 'fp8e5':
|
||||
self.fp_mantissa_width = 2
|
||||
self.primitive_bitwidth = 8
|
||||
self.exponent_bias = 15
|
||||
elif name == 'fp8e5b16':
|
||||
self.fp_mantissa_width = 2
|
||||
self.primitive_bitwidth = 8
|
||||
self.exponent_bias = 16
|
||||
elif name == 'fp16':
|
||||
self.fp_mantissa_width = 10
|
||||
self.primitive_bitwidth = 16
|
||||
@@ -138,6 +146,9 @@ class dtype:
|
||||
def is_fp8e4nv(self):
|
||||
return self.name == 'fp8e4nv'
|
||||
|
||||
def is_fp8e4b8(self):
|
||||
return self.name == 'fp8e4b8'
|
||||
|
||||
def is_fp8e4b15(self):
|
||||
return self.name == 'fp8e4b15'
|
||||
|
||||
@@ -147,6 +158,9 @@ class dtype:
|
||||
def is_fp8e5(self):
|
||||
return self.name == 'fp8e5'
|
||||
|
||||
def is_fp8e5b16(self):
|
||||
return self.name == 'fp8e5b16'
|
||||
|
||||
def is_fp16(self):
|
||||
return self.name == 'fp16'
|
||||
|
||||
@@ -250,8 +264,12 @@ class dtype:
|
||||
return builder.get_int64_ty()
|
||||
elif self.name == 'fp8e5':
|
||||
return builder.get_fp8e5_ty()
|
||||
elif self.name == 'fp8e5b16':
|
||||
return builder.get_fp8e5b16_ty()
|
||||
elif self.name == 'fp8e4nv':
|
||||
return builder.get_fp8e4nv_ty()
|
||||
elif self.name == 'fp8e4b8':
|
||||
return builder.get_fp8e4b8_ty()
|
||||
elif self.name == 'fp8e4b15':
|
||||
return builder.get_fp8e4b15_ty()
|
||||
elif self.name == 'fp8e4b15x4':
|
||||
@@ -388,7 +406,9 @@ uint16 = dtype('uint16')
|
||||
uint32 = dtype('uint32')
|
||||
uint64 = dtype('uint64')
|
||||
float8e5 = dtype('fp8e5')
|
||||
float8e5b16 = dtype('fp8e5b16')
|
||||
float8e4nv = dtype('fp8e4nv')
|
||||
float8e4b8 = dtype('fp8e4b8')
|
||||
float8e4b15 = dtype('fp8e4b15')
|
||||
float8e4b15x4 = dtype('fp8e4b15x4')
|
||||
float16 = dtype('fp16')
|
||||
|
||||
@@ -247,7 +247,13 @@ class JITFunction(KernelInterface[T]):
|
||||
tys = {
|
||||
"bool": "i1",
|
||||
"float8e4nv": "fp8e4nv",
|
||||
"float8_e4m3fn": "fp8e4nv",
|
||||
"float8e4b8": "fp8e4b8",
|
||||
"float8_e4m3fnuz": "fp8e4b8",
|
||||
"float8e5": "fp8e5",
|
||||
"float8_e5m2": "fp8e5",
|
||||
"float8e5b16": "fp8e5b16",
|
||||
"float8_e5m2fnuz": "fp8e5b16",
|
||||
"float8e4b15": "fp8e4b15",
|
||||
"float8e4b15x4": "fp8e4b15x4",
|
||||
"float16": "fp16",
|
||||
|
||||
Reference in New Issue
Block a user