[FRONTEND][BACKEND] dd memory synchronization scope parameter to atomic ops. (#2562)

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
Chris Jones
2023-10-31 02:18:27 +00:00
committed by GitHub
parent 70fca00b67
commit 2398b82f18
9 changed files with 128 additions and 46 deletions

View File

@@ -1,4 +1,4 @@
#include <mutex>
#include <mutex>
#include <stack>
#include <unordered_map>
@@ -228,6 +228,12 @@ void init_triton_ir(py::module &&m) {
.value("RELAXED", mlir::triton::MemSemantic::RELAXED)
.export_values();
py::enum_<mlir::triton::MemSyncScope>(m, "MEM_SYNC_SCOPE", py::module_local())
.value("GPU", mlir::triton::MemSyncScope::GPU)
.value("CTA", mlir::triton::MemSyncScope::CTA)
.value("SYSTEM", mlir::triton::MemSyncScope::SYSTEM)
.export_values();
py::enum_<mlir::triton::EvictionPolicy>(m, "EVICTION_POLICY",
py::module_local())
.value("NORMAL", mlir::triton::EvictionPolicy::NORMAL)
@@ -1418,7 +1424,8 @@ void init_triton_ir(py::module &&m) {
// // atomic
.def("create_atomic_cas",
[](TritonOpBuilder &self, mlir::Value &ptr, mlir::Value &cmp,
mlir::Value &val, mlir::triton::MemSemantic sem) -> mlir::Value {
mlir::Value &val, mlir::triton::MemSemantic sem,
mlir::triton::MemSyncScope scope) -> mlir::Value {
mlir::Type dstType;
if (auto srcTensorType =
ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
@@ -1433,12 +1440,13 @@ void init_triton_ir(py::module &&m) {
dstType = ptrType.getPointeeType();
}
return self.create<mlir::triton::AtomicCASOp>(dstType, ptr, cmp,
val, sem);
val, sem, scope);
})
.def("create_atomic_rmw",
[](TritonOpBuilder &self, mlir::triton::RMWOp rmwOp,
mlir::Value &ptr, mlir::Value &val, mlir::Value &mask,
mlir::triton::MemSemantic sem) -> mlir::Value {
mlir::triton::MemSemantic sem,
mlir::triton::MemSyncScope scope) -> mlir::Value {
mlir::Type dstType;
if (auto srcTensorType =
ptr.getType().dyn_cast<mlir::RankedTensorType>()) {
@@ -1452,8 +1460,8 @@ void init_triton_ir(py::module &&m) {
.cast<mlir::triton::PointerType>();
dstType = ptrType.getPointeeType();
}
return self.create<mlir::triton::AtomicRMWOp>(dstType, rmwOp, ptr,
val, mask, sem);
return self.create<mlir::triton::AtomicRMWOp>(
dstType, rmwOp, ptr, val, mask, sem, scope);
})
// External
.def("create_extern_elementwise",

View File

@@ -1166,6 +1166,9 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
:param sem: Memory semantics to use ("ACQUIRE_RELEASE" (default),
"ACQUIRE", "RELEASE", or "RELAXED")
:type sem: str
:param scope: Scope of threads that observe synchronizing effect of the
atomic operation ("GPU" (default), "CTA", or "SYSTEM")
:type scope: str
"""
func.__doc__ = docstr
return func
@@ -1175,67 +1178,75 @@ def _add_atomic_docstr(name: str, has_cmp: bool = False) -> Callable[[T], T]:
@builtin
@_add_atomic_docstr("compare-and-swap", has_cmp=True)
def atomic_cas(pointer, cmp, val, sem=None, _builder=None):
def atomic_cas(pointer, cmp, val, sem=None, scope=None, _builder=None):
cmp = _to_tensor(cmp, _builder)
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_cas(pointer, cmp, val, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_cas(pointer, cmp, val, sem, scope, _builder)
@builtin
@_add_atomic_docstr("exchange")
def atomic_xchg(pointer, val, mask=None, sem=None, _builder=None):
def atomic_xchg(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_xchg(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_xchg(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("add")
def atomic_add(pointer, val, mask=None, sem=None, _builder=None):
def atomic_add(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_add(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_add(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("max")
def atomic_max(pointer, val, mask=None, sem=None, _builder=None):
def atomic_max(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_max(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_max(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("min")
def atomic_min(pointer, val, mask=None, sem=None, _builder=None):
def atomic_min(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_min(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_min(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("logical and")
def atomic_and(pointer, val, mask=None, sem=None, _builder=None):
def atomic_and(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_and(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_and(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("logical or")
def atomic_or(pointer, val, mask=None, sem=None, _builder=None):
def atomic_or(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_or(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_or(pointer, val, mask, sem, scope, _builder)
@builtin
@_add_atomic_docstr("logical xor")
def atomic_xor(pointer, val, mask=None, sem=None, _builder=None):
def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None):
val = _to_tensor(val, _builder)
sem = _constexpr_to_value(sem)
return semantic.atomic_xor(pointer, val, mask, sem, _builder)
scope = _constexpr_to_value(scope)
return semantic.atomic_xor(pointer, val, mask, sem, scope, _builder)
# -----------------------

View File

@@ -873,6 +873,20 @@ def _str_to_sem(sem_option):
return sem
def _str_to_scope(scope_option):
scope = ir.MEM_SYNC_SCOPE.GPU
if scope_option:
if scope_option == "gpu":
scope = ir.MEM_SYNC_SCOPE.GPU
elif scope_option == "cta":
scope = ir.MEM_SYNC_SCOPE.CTA
elif scope_option == "sys":
scope = ir.MEM_SYNC_SCOPE.SYSTEM
else:
raise ValueError(f"Memory semantic {scope_option} not supported")
return scope
def _canonicalize_boundary_check(boundary_check, block_shape):
if boundary_check:
if not hasattr(boundary_check, "__iter__"):
@@ -1084,12 +1098,14 @@ def atomic_cas(ptr: tl.tensor,
cmp: tl.tensor,
val: tl.tensor,
sem: str,
scope: str,
builder: ir.builder) -> tl.tensor:
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
element_ty = ptr.type.scalar.element_ty
if element_ty.primitive_bitwidth not in [16, 32, 64]:
raise ValueError("atomic_cas only supports elements with width {16, 32, 64}")
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem), val.type)
return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type)
def atom_red_typechecking_impl(ptr: tl.tensor,
@@ -1124,9 +1140,11 @@ def atomic_max(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
scope: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder)
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
sca_ty = val.type.scalar
# direct call to atomic_max for integers
if sca_ty.is_int():
@@ -1135,14 +1153,16 @@ def atomic_max(ptr: tl.tensor,
ptr.handle,
val.handle,
mask.handle,
sem),
sem,
scope),
val.type)
else:
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
ptr.handle,
val.handle,
mask.handle,
sem),
sem,
scope),
val.type)
# for float
# return atomic_smax(i_ptr, i_val) if val >= 0
@@ -1157,8 +1177,8 @@ def atomic_max(ptr: tl.tensor,
i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder)
pos = greater_equal(val, zero, builder)
neg = less_than(val, zero, builder)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle, sem), i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle, sem), i_val.type)
pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle, sem, scope), i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle, sem, scope), i_val.type)
ret = where(pos, pos_ret, neg_ret, builder)
return bitcast(ret, sca_ty, builder)
@@ -1167,9 +1187,11 @@ def atomic_min(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
scope: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder)
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
sca_ty = val.type.scalar
# direct call to atomic_min for integers
if sca_ty.is_int():
@@ -1178,14 +1200,16 @@ def atomic_min(ptr: tl.tensor,
ptr.handle,
val.handle,
mask.handle,
sem),
sem,
scope),
val.type)
else:
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN,
ptr.handle,
val.handle,
mask.handle,
sem),
sem,
scope),
val.type)
# for float
# return atomic_smin(i_ptr, i_val) if val >= 0
@@ -1204,13 +1228,15 @@ def atomic_min(ptr: tl.tensor,
i_ptr.handle,
i_val.handle,
and_(mask, pos, builder).handle,
sem),
sem,
scope),
i_val.type)
neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX,
i_ptr.handle,
i_val.handle,
and_(mask, neg, builder).handle,
sem),
sem,
scope),
i_val.type)
ret = where(pos, pos_ret, neg_ret, builder)
return bitcast(ret, sca_ty, builder)
@@ -1220,52 +1246,62 @@ def atomic_add(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
scope: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder)
sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
sca_ty = val.type.scalar
op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem), val.type)
return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
def atomic_and(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
scope: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder)
sem = _str_to_sem(sem)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem), val.type)
scope = _str_to_scope(scope)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
def atomic_or(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
scope: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder)
sem = _str_to_sem(sem)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem), val.type)
scope = _str_to_scope(scope)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
def atomic_xor(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
scope: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder)
sem = _str_to_sem(sem)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem), val.type)
scope = _str_to_scope(scope)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
def atomic_xchg(ptr: tl.tensor,
val: tl.tensor,
mask: tl.tensor,
sem: str,
scope: str,
builder: ir.builder) -> tl.tensor:
ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder)
sem = _str_to_sem(sem)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem), val.type)
scope = _str_to_scope(scope)
return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), val.type)
# ===----------------------------------------------------------------------===//
# Linear Algebra