[FRONTEND][BACKEND] Fix constexpr assignment ; revert #2430 (#2496)

Without this change, a constexpr assignment (ie. `A = B & C`, where `B`
and `C` are both constexpr) is getting assigned to a triton tensor,
which becomes an issue when `A` is used as the condition of an If
statement.
Note: I had to add `not isinstance(node.value, ast.Constant)` to the
condition because if we are assigning `x = 0` then the assigned value is
also a constexpr, but in this case we do want to assign a triton tensor
to `x` so that we can do `x.to(tl.int64)` for example, which cannot be
done on a constexpr.

---------

Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
Zahi Moudallal
2023-10-16 12:35:19 -07:00
committed by GitHub
parent 29828fe491
commit 726bdb984f
3 changed files with 22 additions and 7 deletions

View File

@@ -42,19 +42,17 @@ static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) {
ret =
"{ \n"
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
".reg .b32 e112; \n"
"mov.u32 e112, 0x77807780; \n" // 2**112 represented as
// bf16x2
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
"lop3.b32 b0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
"lop3.b32 b1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"mul.rn.bf16x2 $0, b0, e112; \n" // b0.exp += 2**7-2**4
"mul.rn.bf16x2 $1, b1, e112; \n" // exponent compensate = 112
"add.u32 b0, b0, 0x38003800; \n" // b0.exp += 2**7-2**4
// exponent compensate = 112
"add.u32 b1, b1, 0x38003800; \n" // b1 += 112<<7 | 112<<7<<16
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
"}";
} else {
ret = "{ \n"

View File

@@ -3073,6 +3073,20 @@ def test_constexpr_scalar_shape(device):
kernel[(1,)](x_tri, 32)
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)
@triton.jit
def static_assert_func():
tl.static_assert(tl.constexpr(False), f"Assert is firing because the constexpr progation did not work properly")
def test_constexpr_propagation():
@triton.jit
def _kernel(COND: tl.constexpr):
NEW_COND = COND
if NEW_COND:
static_assert_func()
_kernel[(1,)](False)
# -------------
# test call
# -------------

View File

@@ -415,6 +415,9 @@ class CodeGenerator(ast.NodeVisitor):
raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple assignment is not supported.")
names = _names[0]
values = self.visit(node.value)
if not isinstance(node.value, ast.Constant) and _is_constexpr(values):
self.set_value(names, values)
return
if not _is_list_like(names):
names = [names]
if not _is_list_like(values):