From 88c1c9037dfba40e6634533de0801ab1fc34d5fe Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 5 Aug 2021 15:13:25 -0400 Subject: [PATCH] Add atomic exchange function --- CHANGELOG.md | 1 + src/back/dot/mod.rs | 9 ++++++ src/back/glsl/mod.rs | 13 +++++++++ src/back/msl/writer.rs | 25 ++++++++++++++--- src/back/spv/block.rs | 15 ++++++++++ src/back/wgsl/writer.rs | 32 ++++++++++++++++++++-- src/front/wgsl/mod.rs | 24 ++++++++++++++++ src/lib.rs | 5 ++++ src/proc/typifier.rs | 19 ++++++++++++- src/valid/analyzer.rs | 6 +++- src/valid/expression.rs | 13 ++++++++- tests/in/access.wgsl | 15 ++++++---- tests/out/glsl/access.atomics.Compute.glsl | 13 +++++---- tests/out/msl/access.msl | 13 +++++---- tests/out/spv/access.spvasm | 23 +++++++++------- tests/out/wgsl/access.wgsl | 13 +++++---- 16 files changed, 195 insertions(+), 44 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d095734e8..03f7d72bbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## TBD - API: + - atomic types and functions - WGSL `select()` order of true/false is swapped ## v0.5 (2021-06-18) diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index 0e8f475236..b824933289 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -300,6 +300,15 @@ fn write_fun( edges.insert("value", value); Cow::Borrowed("Max") } + crate::AtomicFunction::Exchange(value) => { + edges.insert("value", value); + Cow::Borrowed("Exchange") + } + crate::AtomicFunction::CompareExchange { cmp, value } => { + edges.insert("cmp", cmp); + edges.insert("value", value); + Cow::Borrowed("CompareExchange") + } }; (description, 3) } diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 4026178ebe..aabe04a3ef 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2138,6 +2138,19 @@ impl<'a, W: Write> Writer<'a, W> { self.write_expr(value, ctx)?; write!(self.out, ")")?; } + crate::AtomicFunction::Exchange(value) => { + write!(self.out, "atomicExchange(")?; + self.write_expr(pointer, ctx)?; + write!(self.out, ", ")?; + self.write_expr(value, ctx)?; + write!(self.out, ")")?; + } + crate::AtomicFunction::CompareExchange { .. } => { + //TODO: write a wrapper function to return vec2 + return Err(Error::Custom( + "atomic CompareExchange is not implemented".to_string(), + )); + } }, // `Select` is written as `condition ? accept : reject` // We wrap everything in parentheses to avoid precedence issues diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index eeb6fc01e7..21ad996352 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1066,6 +1066,23 @@ impl Writer { crate::AtomicFunction::Max(value) => { self.put_atomic_fetch(pointer, "max", value, context)?; } + crate::AtomicFunction::Exchange(value) => { + write!( + self.out, + "{}::atomic_exchange_explicit({}", + NAMESPACE, ATOMIC_REFERENCE, + )?; + self.put_expression(pointer, context, true)?; + write!(self.out, ", ")?; + self.put_expression(value, context, true)?; + write!(self.out, ", {}::memory_order_relaxed)", NAMESPACE)?; + } + crate::AtomicFunction::CompareExchange { .. } => { + //TODO: add a wrapper function to return the vector + return Err(Error::FeatureNotImplemented( + "atomic CompareExchange".to_string(), + )); + } }, crate::Expression::Select { condition, @@ -2570,8 +2587,8 @@ fn test_stack_size() { } let stack_size = addresses.end - addresses.start; // check the size (in debug only) - // last observed macOS value: 17664 - if stack_size < 14000 || stack_size > 19000 { + // last observed macOS value: 20752 (from CI) + if stack_size < 16000 || stack_size > 21000 { panic!("`put_expression` stack size {} has changed!", stack_size); } } @@ -2585,8 +2602,8 @@ fn test_stack_size() { } let stack_size = addresses.end - addresses.start; // check the size (in debug only) - // last observed macOS value: 13600 - if stack_size < 11000 || stack_size > 16000 { + // last observed macOS value: 16160 (on CI) + if stack_size < 12000 || stack_size > 17000 { panic!("`put_block` stack size {} has changed!", stack_size); } } diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 9f93d3efd3..6ce6890844 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -567,6 +567,21 @@ impl<'w> BlockContext<'w> { value_id, ) } + crate::AtomicFunction::Exchange(value) => { + let value_id = self.cached[value]; + Instruction::atomic_binary( + spirv::Op::AtomicExchange, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ) + } + crate::AtomicFunction::CompareExchange { .. } => { + return Err(Error::FeatureNotImplemented("atomic CompareExchange")); + } }; block.body.push(instruction); diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 334bcd1334..44e8fa95f7 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -6,6 +6,11 @@ use crate::{ }; use std::fmt::Write; +// This is a hack: we need to pass a pointer to an atomic, +// but generally the backend isn't putting "&" in front of every pointer. +// Some more general handling of pointers is needed to be implemented here. +const ATOMIC_REFERENCE: &str = "&"; + /// Shorthand result used internally by the backend type BackendResult = Result<(), Error>; @@ -924,26 +929,47 @@ impl Writer { } Expression::Atomic { pointer, fun } => match fun { crate::AtomicFunction::Binary { op, value } => { - write!(self.out, "atomic{}(", op.to_wgsl_atomic_suffix())?; + write!( + self.out, + "atomic{}({}", + op.to_wgsl_atomic_suffix(), + ATOMIC_REFERENCE + )?; self.write_expr(module, pointer, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, value, func_ctx)?; write!(self.out, ")")?; } crate::AtomicFunction::Min(value) => { - write!(self.out, "atomicMin(")?; + write!(self.out, "atomicMin({}", ATOMIC_REFERENCE)?; self.write_expr(module, pointer, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, value, func_ctx)?; write!(self.out, ")")?; } crate::AtomicFunction::Max(value) => { - write!(self.out, "atomicMax(")?; + write!(self.out, "atomicMax({}", ATOMIC_REFERENCE)?; self.write_expr(module, pointer, func_ctx)?; write!(self.out, ", ")?; self.write_expr(module, value, func_ctx)?; write!(self.out, ")")?; } + crate::AtomicFunction::Exchange(value) => { + write!(self.out, "atomicExchange({}", ATOMIC_REFERENCE)?; + self.write_expr(module, pointer, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + write!(self.out, ")")?; + } + crate::AtomicFunction::CompareExchange { cmp, value } => { + write!(self.out, "atomicCompareExchangeWeak({}", ATOMIC_REFERENCE)?; + self.write_expr(module, pointer, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, cmp, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + write!(self.out, ")")?; + } }, // TODO: copy-paste from glsl-out Expression::Access { base, index } => { diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 7fbda0bf86..469581cf0e 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -1211,6 +1211,30 @@ impl Parser { fun: crate::AtomicFunction::Max(value), } } + "atomicExchange" => { + lexer.open_arguments()?; + let pointer = self.parse_singular_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let value = self.parse_singular_expression(lexer, ctx.reborrow())?; + lexer.close_arguments()?; + crate::Expression::Atomic { + pointer, + fun: crate::AtomicFunction::Exchange(value), + } + } + "atomicCompareExchangeWeak" => { + lexer.open_arguments()?; + let pointer = self.parse_singular_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let cmp = self.parse_singular_expression(lexer, ctx.reborrow())?; + lexer.expect(Token::Separator(','))?; + let value = self.parse_singular_expression(lexer, ctx.reborrow())?; + lexer.close_arguments()?; + crate::Expression::Atomic { + pointer, + fun: crate::AtomicFunction::CompareExchange { cmp, value }, + } + } // texture sampling "textureSample" => { lexer.open_arguments()?; diff --git a/src/lib.rs b/src/lib.rs index 7ca561c5ad..21a14ed6f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -751,6 +751,11 @@ pub enum AtomicFunction { }, Min(Handle), Max(Handle), + Exchange(Handle), + CompareExchange { + cmp: Handle, + value: Handle, + }, } /// Axis on which to compute a derivative. diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 649a61039a..c29a4c8129 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -508,7 +508,24 @@ impl<'a> ResolveContext<'a> { crate::Expression::Atomic { pointer: _, fun } => match fun { crate::AtomicFunction::Binary { op: _, value } | crate::AtomicFunction::Min(value) - | crate::AtomicFunction::Max(value) => past(value).clone(), + | crate::AtomicFunction::Max(value) + | crate::AtomicFunction::Exchange(value) => past(value).clone(), + crate::AtomicFunction::CompareExchange { cmp: _, value } => { + let (kind, width) = match *past(value).inner_with(types) { + Ti::Scalar { kind, width } => (kind, width), + ref other => { + return Err(ResolveError::IncompatibleOperands(format!( + "atomic ptr {:?}", + other + ))) + } + }; + TypeResolution::Value(Ti::Vector { + size: crate::VectorSize::Bi, + kind, + width, + }) + } }, crate::Expression::Select { accept, .. } => past(accept).clone(), crate::Expression::Derivative { axis: _, expr } => past(expr).clone(), diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 8cfbe48e3f..99c13ca397 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -536,7 +536,11 @@ impl FunctionInfo { let non_uniform_result = match fun { crate::AtomicFunction::Binary { op: _, value } | crate::AtomicFunction::Min(value) - | crate::AtomicFunction::Max(value) => self.add_ref(value), + | crate::AtomicFunction::Max(value) + | crate::AtomicFunction::Exchange(value) => self.add_ref(value), + crate::AtomicFunction::CompareExchange { cmp, value } => { + self.add_ref(value).or(self.add_ref(cmp)) + } }; Uniformity { non_uniform_result: self.add_ref(pointer).or(non_uniform_result), diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 28e5a88971..9a1c907238 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -806,7 +806,18 @@ impl super::Validator { } value } - crate::AtomicFunction::Min(value) | crate::AtomicFunction::Max(value) => value, + crate::AtomicFunction::Min(value) + | crate::AtomicFunction::Max(value) + | crate::AtomicFunction::Exchange(value) => value, + crate::AtomicFunction::CompareExchange { cmp, value } => { + if resolver.resolve(cmp)? != resolver.resolve(value)? { + log::error!( + "Atomic exchange comparison has a different type from the value" + ); + return Err(ExpressionError::InvalidAtomicOperand(cmp)); + } + value + } }; match *resolver.resolve(value)? { Ti::Scalar { width, kind } if kind == ptr_kind && width == ptr_width => {} diff --git a/tests/in/access.wgsl b/tests/in/access.wgsl index 64409c3d69..ba54363ed1 100644 --- a/tests/in/access.wgsl +++ b/tests/in/access.wgsl @@ -42,11 +42,14 @@ fn foo([[builtin(vertex_index)]] vi: u32) -> [[builtin(position)]] vec4 { fn atomics() { var tmp: i32; let value = atomicLoad(&bar.atom); - tmp = atomicAdd(&bar.atom, 1); - tmp = atomicAnd(&bar.atom, 1); - tmp = atomicOr(&bar.atom, 1); - tmp = atomicXor(&bar.atom, 1); - tmp = atomicMin(&bar.atom, 1); - tmp = atomicMax(&bar.atom, 1); + tmp = atomicAdd(&bar.atom, 5); + tmp = atomicAnd(&bar.atom, 5); + tmp = atomicOr(&bar.atom, 5); + tmp = atomicXor(&bar.atom, 5); + tmp = atomicMin(&bar.atom, 5); + tmp = atomicMax(&bar.atom, 5); + tmp = atomicExchange(&bar.atom, 5); + // https://github.com/gpuweb/gpuweb/issues/2021 + // tmp = atomicCompareExchangeWeak(&bar.atom, 5, 5); atomicStore(&bar.atom, value); } diff --git a/tests/out/glsl/access.atomics.Compute.glsl b/tests/out/glsl/access.atomics.Compute.glsl index 1dd0daac2d..2e0859d2ed 100644 --- a/tests/out/glsl/access.atomics.Compute.glsl +++ b/tests/out/glsl/access.atomics.Compute.glsl @@ -16,12 +16,13 @@ buffer Bar_block_0Cs { void main() { int tmp = 0; int value = _group_0_binding_0.atom; - tmp = atomicAdd(_group_0_binding_0.atom, 1); - tmp = atomicAnd(_group_0_binding_0.atom, 1); - tmp = atomicOr(_group_0_binding_0.atom, 1); - tmp = atomicXor(_group_0_binding_0.atom, 1); - tmp = atomicMin(_group_0_binding_0.atom, 1); - tmp = atomicMax(_group_0_binding_0.atom, 1); + tmp = atomicAdd(_group_0_binding_0.atom, 5); + tmp = atomicAnd(_group_0_binding_0.atom, 5); + tmp = atomicOr(_group_0_binding_0.atom, 5); + tmp = atomicXor(_group_0_binding_0.atom, 5); + tmp = atomicMin(_group_0_binding_0.atom, 5); + tmp = atomicMax(_group_0_binding_0.atom, 5); + tmp = atomicExchange(_group_0_binding_0.atom, 5); _group_0_binding_0.atom = value; return; } diff --git a/tests/out/msl/access.msl b/tests/out/msl/access.msl index 58ea44f6c5..abdec9dcfe 100644 --- a/tests/out/msl/access.msl +++ b/tests/out/msl/access.msl @@ -56,12 +56,13 @@ kernel void atomics( ) { int tmp; int value = metal::atomic_load_explicit(&bar.atom, metal::memory_order_relaxed); - tmp = metal::atomic_fetch_add_explicit(&bar.atom, 1, metal::memory_order_relaxed); - tmp = metal::atomic_fetch_and_explicit(&bar.atom, 1, metal::memory_order_relaxed); - tmp = metal::atomic_fetch_or_explicit(&bar.atom, 1, metal::memory_order_relaxed); - tmp = metal::atomic_fetch_xor_explicit(&bar.atom, 1, metal::memory_order_relaxed); - tmp = metal::atomic_fetch_min_explicit(&bar.atom, 1, metal::memory_order_relaxed); - tmp = metal::atomic_fetch_max_explicit(&bar.atom, 1, metal::memory_order_relaxed); + tmp = metal::atomic_fetch_add_explicit(&bar.atom, 5, metal::memory_order_relaxed); + tmp = metal::atomic_fetch_and_explicit(&bar.atom, 5, metal::memory_order_relaxed); + tmp = metal::atomic_fetch_or_explicit(&bar.atom, 5, metal::memory_order_relaxed); + tmp = metal::atomic_fetch_xor_explicit(&bar.atom, 5, metal::memory_order_relaxed); + tmp = metal::atomic_fetch_min_explicit(&bar.atom, 5, metal::memory_order_relaxed); + tmp = metal::atomic_fetch_max_explicit(&bar.atom, 5, metal::memory_order_relaxed); + tmp = metal::atomic_exchange_explicit(&bar.atom, 5, metal::memory_order_relaxed); metal::atomic_store_explicit(&bar.atom, value, metal::memory_order_relaxed); return; } diff --git a/tests/out/spv/access.spvasm b/tests/out/spv/access.spvasm index 57bf3363f3..33920bc062 100644 --- a/tests/out/spv/access.spvasm +++ b/tests/out/spv/access.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 105 +; Bound: 107 OpCapability Shader OpExtension "SPV_KHR_storage_buffer_storage_class" %1 = OpExtInstImport "GLSL.std.450" @@ -81,7 +81,7 @@ OpDecorate %37 BuiltIn Position %86 = OpTypePointer StorageBuffer %4 %89 = OpConstant %9 66 %92 = OpConstant %9 72 -%104 = OpConstant %9 68 +%106 = OpConstant %9 68 %39 = OpFunction %2 None %40 %33 = OpLabel %29 = OpVariable %30 Function %5 @@ -138,24 +138,27 @@ OpBranch %85 %87 = OpAccessChain %86 %27 %15 %88 = OpAtomicLoad %4 %87 %11 %89 %91 = OpAccessChain %86 %27 %15 -%90 = OpAtomicIAdd %4 %91 %11 %92 %11 +%90 = OpAtomicIAdd %4 %91 %11 %92 %16 OpStore %82 %90 %94 = OpAccessChain %86 %27 %15 -%93 = OpAtomicAnd %4 %94 %11 %92 %11 +%93 = OpAtomicAnd %4 %94 %11 %92 %16 OpStore %82 %93 %96 = OpAccessChain %86 %27 %15 -%95 = OpAtomicOr %4 %96 %11 %92 %11 +%95 = OpAtomicOr %4 %96 %11 %92 %16 OpStore %82 %95 %98 = OpAccessChain %86 %27 %15 -%97 = OpAtomicXor %4 %98 %11 %92 %11 +%97 = OpAtomicXor %4 %98 %11 %92 %16 OpStore %82 %97 %100 = OpAccessChain %86 %27 %15 -%99 = OpAtomicSMin %4 %100 %11 %92 %11 +%99 = OpAtomicSMin %4 %100 %11 %92 %16 OpStore %82 %99 %102 = OpAccessChain %86 %27 %15 -%101 = OpAtomicSMax %4 %102 %11 %92 %11 +%101 = OpAtomicSMax %4 %102 %11 %92 %16 OpStore %82 %101 -%103 = OpAccessChain %86 %27 %15 -OpAtomicStore %103 %11 %104 %88 +%104 = OpAccessChain %86 %27 %15 +%103 = OpAtomicExchange %4 %104 %11 %92 %16 +OpStore %82 %103 +%105 = OpAccessChain %86 %27 %15 +OpAtomicStore %105 %11 %106 %88 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/access.wgsl b/tests/out/wgsl/access.wgsl index 854f033edc..967d6b0d33 100644 --- a/tests/out/wgsl/access.wgsl +++ b/tests/out/wgsl/access.wgsl @@ -35,12 +35,13 @@ fn atomics() { var tmp: i32; let value: i32 = bar.atom; - tmp = atomicAdd(bar.atom, 1); - tmp = atomicAnd(bar.atom, 1); - tmp = atomicOr(bar.atom, 1); - tmp = atomicXor(bar.atom, 1); - tmp = atomicMin(bar.atom, 1); - tmp = atomicMax(bar.atom, 1); + tmp = atomicAdd(&bar.atom, 5); + tmp = atomicAnd(&bar.atom, 5); + tmp = atomicOr(&bar.atom, 5); + tmp = atomicXor(&bar.atom, 5); + tmp = atomicMin(&bar.atom, 5); + tmp = atomicMax(&bar.atom, 5); + tmp = atomicExchange(&bar.atom, 5); bar.atom = value; return; }