mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Fix tl.full with unsigned dtypes (#1919)
Calling `tl.full` with an unsigned dtype currently fails with the error:
```
AttributeError("'triton._C.libtriton.triton.ir.builder' object has no attribute
'get_uint8'")
```
This PR defines those functions rather than changing the calls to the
signed versions so that we can use an unsigned argument type in C++ and
avoid overflow for large uint64 values.
This commit is contained in:
@@ -643,6 +643,26 @@ void init_triton_ir(py::module &&m) {
|
||||
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
|
||||
v, self.getBuilder().getI64Type()));
|
||||
})
|
||||
.def("get_uint8",
|
||||
[](TritonOpBuilder &self, uint64_t v) -> mlir::Value {
|
||||
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
|
||||
v, self.getBuilder().getI8Type()));
|
||||
})
|
||||
.def("get_uint16",
|
||||
[](TritonOpBuilder &self, uint64_t v) -> mlir::Value {
|
||||
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
|
||||
v, self.getBuilder().getI16Type()));
|
||||
})
|
||||
.def("get_uint32",
|
||||
[](TritonOpBuilder &self, uint64_t v) -> mlir::Value {
|
||||
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
|
||||
v, self.getBuilder().getI32Type()));
|
||||
})
|
||||
.def("get_uint64",
|
||||
[](TritonOpBuilder &self, uint64_t v) -> mlir::Value {
|
||||
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
|
||||
v, self.getBuilder().getI64Type()));
|
||||
})
|
||||
.def("get_bf16",
|
||||
[](TritonOpBuilder &self, float v) -> mlir::Value {
|
||||
auto type = self.getBuilder().getBF16Type();
|
||||
|
||||
@@ -2171,9 +2171,13 @@ def test_dot_mulbroadcastred(in_dtype, device):
|
||||
assert "triton_gpu.async_wait {num = 2 : i32}" in h.asm['ttgir']
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", int_dtypes + float_dtypes + ['bfloat16'])
|
||||
@pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16'])
|
||||
def test_full(dtype_str, device):
|
||||
dtype = getattr(torch, dtype_str)
|
||||
if dtype_str in uint_dtypes and not hasattr(torch, dtype_str):
|
||||
# PyTorch only has unsigned 8, but not 16, 32, or 64
|
||||
dtype = getattr(torch, dtype_str[1:]) # uintx -> intx
|
||||
else:
|
||||
dtype = getattr(torch, dtype_str)
|
||||
check_type_supported(dtype, device) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
@triton.jit
|
||||
|
||||
Reference in New Issue
Block a user