[FRONTEND] Explicitly forbid dot(.., out_dtype=bfloat16) (#2308)

Fixes: https://github.com/openai/triton/issues/2302
This commit is contained in:
jon-chuang
2023-09-17 05:15:06 -04:00
committed by GitHub
parent 073aa16379
commit 4f2d995fad

View File

@@ -1311,6 +1311,8 @@ def dot(lhs: tl.tensor,
assert lhs.shape[1].value >= 32, "small blocks not supported!"
_0 = builder.get_int32(0)
ret_scalar_ty = tl.int32
elif out_dtype.is_bf16():
raise ValueError("out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`")
elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
_0 = builder.get_fp32(0)
ret_scalar_ty = tl.float32