mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Without this change, a constexpr assignment (ie. `A = B & C`, where `B` and `C` are both constexpr) is getting assigned to a triton tensor, which becomes an issue when `A` is used as the condition of an If statement. Note: I had to add `not isinstance(node.value, ast.Constant)` to the condition because if we are assigning `x = 0` then the assigned value is also a constexpr, but in this case we do want to assign a triton tensor to `x` so that we can do `x.to(tl.int64)` for example, which cannot be done on a constexpr. --------- Co-authored-by: Philippe Tillet <phil@openai.com>
This commit is contained in:
@@ -42,19 +42,17 @@ static const std::string Fp8E5M2_to_Bf16(bool hasNativeFP) {
|
||||
ret =
|
||||
"{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
|
||||
".reg .b32 e112; \n"
|
||||
"mov.u32 e112, 0x77807780; \n" // 2**112 represented as
|
||||
// bf16x2
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n" // a0 = 0xf300f400
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n" // a1 = 0xf100f200
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign)
|
||||
"shr.b32 b0, b0, 3; \n" // b0 >>= 3
|
||||
"shr.b32 b1, b1, 3; \n" // shift into bf16 position
|
||||
"lop3.b32 b0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
|
||||
"lop3.b32 b1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"mul.rn.bf16x2 $0, b0, e112; \n" // b0.exp += 2**7-2**4
|
||||
"mul.rn.bf16x2 $1, b1, e112; \n" // exponent compensate = 112
|
||||
"add.u32 b0, b0, 0x38003800; \n" // b0.exp += 2**7-2**4
|
||||
// exponent compensate = 112
|
||||
"add.u32 b1, b1, 0x38003800; \n" // b1 += 112<<7 | 112<<7<<16
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0)
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign)
|
||||
"}";
|
||||
} else {
|
||||
ret = "{ \n"
|
||||
|
||||
@@ -3073,6 +3073,20 @@ def test_constexpr_scalar_shape(device):
|
||||
kernel[(1,)](x_tri, 32)
|
||||
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def static_assert_func():
|
||||
tl.static_assert(tl.constexpr(False), f"Assert is firing because the constexpr progation did not work properly")
|
||||
|
||||
|
||||
def test_constexpr_propagation():
|
||||
@triton.jit
|
||||
def _kernel(COND: tl.constexpr):
|
||||
NEW_COND = COND
|
||||
if NEW_COND:
|
||||
static_assert_func()
|
||||
_kernel[(1,)](False)
|
||||
|
||||
# -------------
|
||||
# test call
|
||||
# -------------
|
||||
|
||||
@@ -415,6 +415,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple assignment is not supported.")
|
||||
names = _names[0]
|
||||
values = self.visit(node.value)
|
||||
if not isinstance(node.value, ast.Constant) and _is_constexpr(values):
|
||||
self.set_value(names, values)
|
||||
return
|
||||
if not _is_list_like(names):
|
||||
names = [names]
|
||||
if not _is_list_like(values):
|
||||
|
||||
Reference in New Issue
Block a user