[FRONTEND] dot now uses tl.float32 by default for out_dtype.

This commit is contained in:
Phil Tillet
2023-03-19 21:57:32 -07:00
parent b4decbe155
commit e650d3708b
2 changed files with 2 additions and 2 deletions

View File

@@ -162,7 +162,7 @@ def test_elementwise(N):
flash_attention_data = {
"a100": {
(4, 48, 4096, 64, 'forward', 'float16'): 0.37,
(4, 48, 4096, 64, 'backward', 'float16'): 0.26,
(4, 48, 4096, 64, 'backward', 'float16'): 0.25,
}
}

View File

@@ -862,7 +862,7 @@ def reshape(input, shape, _builder=None):
@builtin
def dot(input, other, allow_tf32=True, out_dtype=float16, _builder=None):
def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None):
"""
Returns the matrix product of two blocks.