[FRONTEND] improved error messages (#2363)

this is a combination of #1774 and #2006, which I cannot edit but fail
CI pre-commit hook
This commit is contained in:
Philippe Tillet
2023-09-21 15:05:57 -07:00
committed by GitHub
parent c4bc3fd92f
commit 32c9d2bb8f
2 changed files with 7 additions and 1 deletions

View File

@@ -746,7 +746,7 @@ def test_invalid_pid_axis(device):
def _kernel(dst):
pid = tl.program_id(20)
with pytest.raises(triton.CompilationError, match=r"program_id must be in \[0,3\]"):
with pytest.raises(triton.CompilationError, match=r"program_id axis must be 0, 1, or 2 but got 20"):
_kernel[(1,)](dst)

View File

@@ -25,10 +25,14 @@ class IncompatibleTypeErrorImpl(Exception):
# ===----------------------------------------------------------------------===##
def program_id(axis: int, builder: ir.builder) -> tl.tensor:
if axis not in (0, 1, 2):
raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}")
return tl.tensor(builder.create_get_program_id(axis), tl.int32)
def num_programs(axis: int, builder: ir.builder) -> tl.tensor:
if axis not in (0, 1, 2):
raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}")
return tl.tensor(builder.create_get_num_programs(axis), tl.int32)
# ===----------------------------------------------------------------------===//
@@ -128,6 +132,8 @@ def add(input: tl.tensor,
input, other = binary_op_type_checking_impl(input, other, builder, True, True)
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr():
raise ValueError("cannot add pointers together")
# offset + ptr
# ptr + offset