mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] dd memory synchronization scope parameter to atomic ops. (#2562)
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
#include <mutex>
|
||||
#include <mutex>
|
||||
#include <stack>
|
||||
#include <unordered_map>
|
||||
|
||||
@@ -228,6 +228,12 @@ void init_triton_ir(py::module &&m) {
|
||||
.value("RELAXED", mlir::triton::MemSemantic::RELAXED)
|
||||
.export_values();
|
||||
|
||||
py::enum_<mlir::triton::MemSyncScope>(m, "MEM_SYNC_SCOPE", py::module_local())
|
||||
.value("GPU", mlir::triton::MemSyncScope::GPU)
|
||||
.value("CTA", mlir::triton::MemSyncScope::CTA)
|
||||
.value("SYSTEM", mlir::triton::MemSyncScope::SYSTEM)
|
||||
.export_values();
|
||||
|
||||
py::enum_<mlir::triton::EvictionPolicy>(m, "EVICTION_POLICY",
|
||||
py::module_local())
|
||||
.value("NORMAL", mlir::triton::EvictionPolicy::NORMAL)
|
||||
@@ -1418,7 +1424,8 @@ void init_triton_ir(py::module &&m) {
|
||||
// // atomic
|
||||
.def("create_atomic_cas",
|
||||
[](TritonOpBuilder &self, mlir::Value &ptr, mlir::Value &cmp,
|
||||
mlir::Value &val, mlir::triton::MemSemantic sem) -> mlir::Value {
|
||||
mlir::Value &val, mlir::triton::MemSemantic sem,
|
||||
mlir::triton::MemSyncScope scope) -> mlir::Value {
|
||||
mlir::Type dstType;
|
||||
if (auto srcTensorType =
|
||||
ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
|
||||
@@ -1433,12 +1440,13 @@ void init_triton_ir(py::module &&m) {
|
||||
dstType = ptrType.getPointeeType();
|
||||
}
|
||||
return self.create<mlir::triton::AtomicCASOp>(dstType, ptr, cmp,
|
||||
val, sem);
|
||||
val, sem, scope);
|
||||
})
|
||||
.def("create_atomic_rmw",
|
||||
[](TritonOpBuilder &self, mlir::triton::RMWOp rmwOp,
|
||||
mlir::Value &ptr, mlir::Value &val, mlir::Value &mask,
|
||||
mlir::triton::MemSemantic sem) -> mlir::Value {
|
||||
mlir::triton::MemSemantic sem,
|
||||
mlir::triton::MemSyncScope scope) -> mlir::Value {
|
||||
mlir::Type dstType;
|
||||
if (auto srcTensorType =
|
||||
ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
|
||||
@@ -1452,8 +1460,8 @@ void init_triton_ir(py::module &&m) {
|
||||
.cast<mlir::triton::PointerType>();
|
||||
dstType = ptrType.getPointeeType();
|
||||
}
|
||||
return self.create<mlir::triton::AtomicRMWOp>(dstType, rmwOp, ptr,
|
||||
val, mask, sem);
|
||||
return self.create<mlir::triton::AtomicRMWOp>(
|
||||
dstType, rmwOp, ptr, val, mask, sem, scope);
|
||||
})
|
||||
// External
|
||||
.def("create_extern_elementwise",
|
||||
|
||||
@@ -1166,6 +1166,9 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
|
||||
:param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default),
|
||||
"ACQUIRE", "RELEASE", or "RELAXED")
|
||||
:type sem: str
|
||||
:param scope: Scope of threads that observe synchronizing effect of the
|
||||
atomic operation ("GPU" (default), "CTA", or "SYSTEM")
|
||||
:type scope: str
|
||||
"""
|
||||
func.__doc__ = docstr
|
||||
return func
|
||||
@@ -1175,67 +1178,75 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("compare-and-swap", has_cmp=True)
|
||||
def atomic_cas(pointer, cmp, val, sem=None, _builder=None):
|
||||
def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None):
|
||||
cmp = _to_tensor(cmp, _builder)
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_cas(pointer, cmp, val, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("exchange")
|
||||
def atomic_xchg(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_xchg(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("add")
|
||||
def atomic_add(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_add(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_add(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("max")
|
||||
def atomic_max(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_max(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_max(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("min")
|
||||
def atomic_min(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_min(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_min(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("logical and")
|
||||
def atomic_and(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_and(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_and(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("logical or")
|
||||
def atomic_or(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_or(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_or(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
@_add_atomic_docstr("logical xor")
|
||||
def atomic_xor(pointer, val, mask=None, sem=None, _builder=None):
|
||||
def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
|
||||
val = _to_tensor(val, _builder)
|
||||
sem = _constexpr_to_value(sem)
|
||||
return semantic.atomic_xor(pointer, val, mask, sem, _builder)
|
||||
scope = _constexpr_to_value(scope)
|
||||
return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
|
||||
@@ -873,6 +873,20 @@ def _str_to_sem(sem_option):
|
||||
return sem
|
||||
|
||||
|
||||
def _str_to_scope(scope_option):
|
||||
scope = ir.MEM_SYNC_SCOPE.GPU
|
||||
if scope_option:
|
||||
if scope_option == "gpu":
|
||||
scope = ir.MEM_SYNC_SCOPE.GPU
|
||||
elif scope_option == "cta":
|
||||
scope = ir.MEM_SYNC_SCOPE.CTA
|
||||
elif scope_option == "sys":
|
||||
scope = ir.MEM_SYNC_SCOPE.SYSTEM
|
||||
else:
|
||||
raise ValueError(f"Memory semantic {scope_option} not supported")
|
||||
return scope
|
||||
|
||||
|
||||
def _canonicalize_boundary_check(boundary_check, block_shape):
|
||||
if boundary_check:
|
||||
if not hasattr(boundary_check, "__iter__"):
|
||||
@@ -1084,12 +1098,14 @@ def atomic_cas(ptr: tl.tensor,
|
||||
cmp: tl.tensor,
|
||||
val: tl.tensor,
|
||||
sem: str,
|
||||
scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
sem = _str_to_sem(sem)
|
||||
scope = _str_to_scope(scope)
|
||||
element_ty = ptr.type.scalar.element_ty
|
||||
if element_ty.primitive_bitwidth not in [16, 32, 64]:
|
||||
raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
|
||||
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem), val.type)
|
||||
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type)
|
||||
|
||||
|
||||
def atom_red_typechecking_impl(ptr: tl.tensor,
|
||||
@@ -1124,9 +1140,11 @@ def atomic_max(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
scope = _str_to_scope(scope)
|
||||
sca_ty = val.type.scalar
|
||||
# direct call to atomic_max for integers
|
||||
if sca_ty.is_int():
|
||||
@@ -1135,14 +1153,16 @@ def atomic_max(ptr: tl.tensor,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask.handle,
|
||||
sem),
|
||||
sem,
|
||||
scope),
|
||||
val.type)
|
||||
else:
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask.handle,
|
||||
sem),
|
||||
sem,
|
||||
scope),
|
||||
val.type)
|
||||
# for float
|
||||
# return atomic_smax(i_ptr, i_val) if val >= 0
|
||||
@@ -1157,8 +1177,8 @@ def atomic_max(ptr: tl.tensor,
|
||||
i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder)
|
||||
pos = greater_equal(val, zero, builder)
|
||||
neg = less_than(val, zero, builder)
|
||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle, sem), i_val.type)
|
||||
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle, sem), i_val.type)
|
||||
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle, sem, scope), i_val.type)
|
||||
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle, sem, scope), i_val.type)
|
||||
ret = where(pos, pos_ret, neg_ret, builder)
|
||||
return bitcast(ret, sca_ty, builder)
|
||||
|
||||
@@ -1167,9 +1187,11 @@ def atomic_min(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
scope = _str_to_scope(scope)
|
||||
sca_ty = val.type.scalar
|
||||
# direct call to atomic_min for integers
|
||||
if sca_ty.is_int():
|
||||
@@ -1178,14 +1200,16 @@ def atomic_min(ptr: tl.tensor,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask.handle,
|
||||
sem),
|
||||
sem,
|
||||
scope),
|
||||
val.type)
|
||||
else:
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN,
|
||||
ptr.handle,
|
||||
val.handle,
|
||||
mask.handle,
|
||||
sem),
|
||||
sem,
|
||||
scope),
|
||||
val.type)
|
||||
# for float
|
||||
# return atomic_smin(i_ptr, i_val) if val >= 0
|
||||
@@ -1204,13 +1228,15 @@ def atomic_min(ptr: tl.tensor,
|
||||
i_ptr.handle,
|
||||
i_val.handle,
|
||||
and_(mask, pos, builder).handle,
|
||||
sem),
|
||||
sem,
|
||||
scope),
|
||||
i_val.type)
|
||||
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
|
||||
i_ptr.handle,
|
||||
i_val.handle,
|
||||
and_(mask, neg, builder).handle,
|
||||
sem),
|
||||
sem,
|
||||
scope),
|
||||
i_val.type)
|
||||
ret = where(pos, pos_ret, neg_ret, builder)
|
||||
return bitcast(ret, sca_ty, builder)
|
||||
@@ -1220,52 +1246,62 @@ def atomic_add(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
scope = _str_to_scope(scope)
|
||||
sca_ty = val.type.scalar
|
||||
op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
|
||||
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem), val.type)
|
||||
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
|
||||
|
||||
def atomic_and(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem), val.type)
|
||||
scope = _str_to_scope(scope)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
|
||||
|
||||
def atomic_or(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem), val.type)
|
||||
scope = _str_to_scope(scope)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
|
||||
|
||||
def atomic_xor(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem), val.type)
|
||||
scope = _str_to_scope(scope)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
|
||||
|
||||
def atomic_xchg(ptr: tl.tensor,
|
||||
val: tl.tensor,
|
||||
mask: tl.tensor,
|
||||
sem: str,
|
||||
scope: str,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
|
||||
sem = _str_to_sem(sem)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem), val.type)
|
||||
scope = _str_to_scope(scope)
|
||||
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Linear Algebra
|
||||
|
||||
Reference in New Issue
Block a user