mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] dot now uses tl.float32 by default for out_dtype.
This commit is contained in:
@@ -162,7 +162,7 @@ def test_elementwise(N):
|
||||
flash_attention_data = {
|
||||
"a100": {
|
||||
(4, 48, 4096, 64, 'forward', 'float16'): 0.37,
|
||||
(4, 48, 4096, 64, 'backward', 'float16'): 0.26,
|
||||
(4, 48, 4096, 64, 'backward', 'float16'): 0.25,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -862,7 +862,7 @@ def reshape(input, shape, _builder=None):
|
||||
|
||||
|
||||
@builtin
|
||||
def dot(input, other, allow_tf32=True, out_dtype=float16, _builder=None):
|
||||
def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None):
|
||||
"""
|
||||
Returns the matrix product of two blocks.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user