mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Explicitly forbid dot(.., out_dtype=bfloat16) (#2308)
Fixes: https://github.com/openai/triton/issues/2302
This commit is contained in:
@@ -1311,6 +1311,8 @@ def dot(lhs: tl.tensor,
|
||||
assert lhs.shape[1].value >= 32, "small blocks not supported!"
|
||||
_0 = builder.get_int32(0)
|
||||
ret_scalar_ty = tl.int32
|
||||
elif out_dtype.is_bf16():
|
||||
raise ValueError("out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`")
|
||||
elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
|
||||
_0 = builder.get_fp32(0)
|
||||
ret_scalar_ty = tl.float32
|
||||
|
||||
Reference in New Issue
Block a user