[FRONTEND] Fix tl.full with unsigned dtypes (#1919)

Calling `tl.full` with an unsigned dtype currently fails with the error:
```
AttributeError("'triton._C.libtriton.triton.ir.builder' object has no attribute
'get_uint8'")
```

This PR defines those functions rather than changing the calls to the
signed versions so that we can use an unsigned argument type in C++ and
avoid overflow for large uint64 values.
This commit is contained in:
peterbell10
2023-07-10 17:36:22 +01:00
committed by GitHub
parent 5a722b5f74
commit ef947dac31
2 changed files with 26 additions and 2 deletions

View File

@@ -643,6 +643,26 @@ void init_triton_ir(py::module &&m) {
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
v, self.getBuilder().getI64Type()));
})
.def("get_uint8",
[](TritonOpBuilder &self, uint64_t v) -> mlir::Value {
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
v, self.getBuilder().getI8Type()));
})
.def("get_uint16",
[](TritonOpBuilder &self, uint64_t v) -> mlir::Value {
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
v, self.getBuilder().getI16Type()));
})
.def("get_uint32",
[](TritonOpBuilder &self, uint64_t v) -> mlir::Value {
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
v, self.getBuilder().getI32Type()));
})
.def("get_uint64",
[](TritonOpBuilder &self, uint64_t v) -> mlir::Value {
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
v, self.getBuilder().getI64Type()));
})
.def("get_bf16",
[](TritonOpBuilder &self, float v) -> mlir::Value {
auto type = self.getBuilder().getBF16Type();

View File

@@ -2171,9 +2171,13 @@ def test_dot_mulbroadcastred(in_dtype, device):
assert "triton_gpu.async_wait {num = 2 : i32}" in h.asm['ttgir']
@pytest.mark.parametrize("dtype_str", int_dtypes + float_dtypes + ['bfloat16'])
@pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16'])
def test_full(dtype_str, device):
dtype = getattr(torch, dtype_str)
if dtype_str in uint_dtypes and not hasattr(torch, dtype_str):
# PyTorch only has unsigned 8, but not 16, 32, or 64
dtype = getattr(torch, dtype_str[1:]) # uintx -> intx
else:
dtype = getattr(torch, dtype_str)
check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested
@triton.jit