[wgsl-in/spv-out] Add support for WGSL's atomicCompareExchangeWeak (#2165)

* Add support for WGSL's `atomicCompareExchangeWeak` with the `__atomic_compare_exchange_result` struct, and add SPIR-V codegen for it.

Partially addresses https://github.com/gpuweb/gpuweb/pull/2113, #1755.

* Add tests for `atomicCompareExchangeWeak`, and support both u32 and i32 atomics with it.

* More thorough typechecking of the struct returned by `atomicCompareExchangeWeak`.
This commit is contained in:
Avi Weinstock
2022-12-13 04:47:28 -05:00
committed by GitHub
parent 8f1d82f0d2
commit 5d8fc3fdcf
11 changed files with 472 additions and 44 deletions

View File

@@ -2078,8 +2078,50 @@ impl<'w> BlockContext<'w> {
value_id,
)
}
crate::AtomicFunction::Exchange { compare: Some(_) } => {
return Err(Error::FeatureNotImplemented("atomic CompareExchange"));
crate::AtomicFunction::Exchange { compare: Some(cmp) } => {
let scalar_type_id = match *value_inner {
crate::TypeInner::Scalar { kind, width } => {
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
kind,
width,
pointer_space: None,
}))
}
_ => unimplemented!(),
};
let bool_type_id =
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
pointer_space: None,
}));
let cas_result_id = self.gen_id();
let equality_result_id = self.gen_id();
let mut cas_instr = Instruction::new(spirv::Op::AtomicCompareExchange);
cas_instr.set_type(scalar_type_id);
cas_instr.set_result(cas_result_id);
cas_instr.add_operand(pointer_id);
cas_instr.add_operand(scope_constant_id);
cas_instr.add_operand(semantics_id); // semantics if equal
cas_instr.add_operand(semantics_id); // semantics if not equal
cas_instr.add_operand(value_id);
cas_instr.add_operand(self.cached[cmp]);
block.body.push(cas_instr);
block.body.push(Instruction::binary(
spirv::Op::IEqual,
bool_type_id,
equality_result_id,
cas_result_id,
self.cached[cmp],
));
Instruction::composite_construct(
result_type_id,
id,
&[cas_result_id, equality_result_id],
)
}
};

View File

@@ -1630,8 +1630,13 @@ impl Parser {
let expression = match *ctx.resolve_type(value)? {
crate::TypeInner::Scalar { kind, width } => crate::Expression::AtomicResult {
kind,
width,
ty: ctx.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar { kind, width },
},
NagaSpan::UNDEFINED,
),
comparison: false,
},
_ => return Err(Error::InvalidAtomicOperandType(value_span)),
@@ -1861,9 +1866,48 @@ impl Parser {
let expression = match *ctx.resolve_type(value)? {
crate::TypeInner::Scalar { kind, width } => {
let bool_ty = ctx.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
},
},
NagaSpan::UNDEFINED,
);
let scalar_ty = ctx.types.insert(
crate::Type {
name: None,
inner: crate::TypeInner::Scalar { kind, width },
},
NagaSpan::UNDEFINED,
);
let struct_ty = ctx.types.insert(
crate::Type {
name: Some("__atomic_compare_exchange_result".to_string()),
inner: crate::TypeInner::Struct {
members: vec![
crate::StructMember {
name: Some("old_value".to_string()),
ty: scalar_ty,
binding: None,
offset: 0,
},
crate::StructMember {
name: Some("exchanged".to_string()),
ty: bool_ty,
binding: None,
offset: 4,
},
],
span: 8,
},
},
NagaSpan::UNDEFINED,
);
crate::Expression::AtomicResult {
kind,
width,
ty: struct_ty,
comparison: true,
}
}

View File

@@ -1401,11 +1401,7 @@ pub enum Expression {
/// Result of calling another function.
CallResult(Handle<Function>),
/// Result of an atomic operation.
AtomicResult {
kind: ScalarKind,
width: Bytes,
comparison: bool,
},
AtomicResult { ty: Handle<Type>, comparison: bool },
/// Get the length of an array.
/// The expression must resolve to a pointer to an array with a dynamic size.
///

View File

@@ -644,21 +644,7 @@ impl<'a> ResolveContext<'a> {
| crate::BinaryOperator::ShiftLeft
| crate::BinaryOperator::ShiftRight => past(left)?.clone(),
},
crate::Expression::AtomicResult {
kind,
width,
comparison,
} => {
if comparison {
TypeResolution::Value(Ti::Vector {
size: crate::VectorSize::Bi,
kind,
width,
})
} else {
TypeResolution::Value(Ti::Scalar { kind, width })
}
}
crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
crate::Expression::Select { accept, .. } => past(accept)?.clone(),
crate::Expression::Derivative { axis: _, expr } => past(expr)?.clone(),
crate::Expression::Relational { fun, argument } => match fun {

View File

@@ -1,5 +1,8 @@
#[cfg(feature = "validate")]
use super::{compose::validate_compose, FunctionInfo, ShaderStages, TypeFlags};
use super::{
compose::validate_compose, validate_atomic_compare_exchange_struct, FunctionInfo, ShaderStages,
TypeFlags,
};
#[cfg(feature = "validate")]
use crate::arena::UniqueArena;
@@ -115,8 +118,8 @@ pub enum ExpressionError {
WrongArgumentCount(crate::MathFunction),
#[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
#[error("Atomic result type can't be {0:?} of {1} bytes")]
InvalidAtomicResultType(crate::ScalarKind, crate::Bytes),
#[error("Atomic result type can't be {0:?}")]
InvalidAtomicResultType(Handle<crate::Type>),
#[error("Shader requires capability {0:?}")]
MissingCapabilities(super::Capabilities),
}
@@ -1389,19 +1392,27 @@ impl super::Validator {
ShaderStages::all()
}
E::CallResult(function) => other_infos[function.index()].available_stages,
E::AtomicResult {
kind,
width,
comparison: _,
} => {
let good = match kind {
crate::ScalarKind::Uint | crate::ScalarKind::Sint => {
self.check_width(kind, width)
E::AtomicResult { ty, comparison } => {
let scalar_predicate = |ty: &crate::TypeInner| match ty {
&crate::TypeInner::Scalar {
kind: kind @ (crate::ScalarKind::Uint | crate::ScalarKind::Sint),
width,
} => self.check_width(kind, width),
_ => false,
};
let good = match &module.types[ty].inner {
ty if !comparison => scalar_predicate(ty),
&crate::TypeInner::Struct { ref members, .. } if comparison => {
validate_atomic_compare_exchange_struct(
&module.types,
members,
scalar_predicate,
)
}
_ => false,
};
if !good {
return Err(ExpressionError::InvalidAtomicResultType(kind, width));
return Err(ExpressionError::InvalidAtomicResultType(ty));
}
ShaderStages::all()
}

View File

@@ -2,6 +2,9 @@
use crate::arena::{Arena, UniqueArena};
use crate::arena::{BadHandle, Handle};
#[cfg(feature = "validate")]
use super::validate_atomic_compare_exchange_struct;
use super::{
analyzer::{UniformityDisruptor, UniformityRequirements},
ExpressionError, FunctionInfo, ModuleInfo,
@@ -363,12 +366,26 @@ impl super::Validator {
.into_other());
}
match context.expressions[result] {
//TODO: support atomic result with comparison
crate::Expression::AtomicResult {
kind,
width,
comparison: false,
} if kind == ptr_kind && width == ptr_width => {}
crate::Expression::AtomicResult { ty, comparison }
if {
let scalar_predicate = |ty: &crate::TypeInner| {
*ty == crate::TypeInner::Scalar {
kind: ptr_kind,
width: ptr_width,
}
};
match &context.types[ty].inner {
ty if !comparison => scalar_predicate(ty),
&crate::TypeInner::Struct { ref members, .. } if comparison => {
validate_atomic_compare_exchange_struct(
context.types,
members,
scalar_predicate,
)
}
_ => false,
}
} => {}
_ => {
return Err(AtomicError::ResultTypeMismatch(result)
.with_span_handle(result, context.expressions)

View File

@@ -412,3 +412,20 @@ impl Validator {
Ok(mod_info)
}
}
#[cfg(feature = "validate")]
fn validate_atomic_compare_exchange_struct(
types: &UniqueArena<crate::Type>,
members: &[crate::StructMember],
scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
) -> bool {
members.len() == 2
&& members[0].name.as_deref() == Some("old_value")
&& scalar_predicate(&types[members[0].ty].inner)
&& members[1].name.as_deref() == Some("exchanged")
&& types[members[1].ty].inner
== crate::TypeInner::Scalar {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
}
}

View File

@@ -0,0 +1,34 @@
let SIZE: u32 = 128u;
@group(0) @binding(0)
var<storage,read_write> arr_i32: array<atomic<i32>, SIZE>;
@group(0) @binding(1)
var<storage,read_write> arr_u32: array<atomic<u32>, SIZE>;
@compute @workgroup_size(1)
fn test_atomic_compare_exchange_i32() {
for(var i = 0u; i < SIZE; i++) {
var old = atomicLoad(&arr_i32[i]);
var exchanged = false;
while(!exchanged) {
let new_ = bitcast<i32>(bitcast<f32>(old) + 1.0);
let result = atomicCompareExchangeWeak(&arr_i32[i], old, new_);
old = result.old_value;
exchanged = result.exchanged;
}
}
}
@compute @workgroup_size(1)
fn test_atomic_compare_exchange_u32() {
for(var i = 0u; i < SIZE; i++) {
var old = atomicLoad(&arr_u32[i]);
var exchanged = false;
while(!exchanged) {
let new_ = bitcast<u32>(bitcast<f32>(old) + 1.0);
let result = atomicCompareExchangeWeak(&arr_u32[i], old, new_);
old = result.old_value;
exchanged = result.exchanged;
}
}
}

View File

@@ -0,0 +1,188 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 116
OpCapability Shader
OpExtension "SPV_KHR_storage_buffer_storage_class"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %31 "test_atomic_compare_exchange_i32"
OpEntryPoint GLCompute %79 "test_atomic_compare_exchange_u32"
OpExecutionMode %31 LocalSize 1 1 1
OpExecutionMode %79 LocalSize 1 1 1
OpDecorate %12 ArrayStride 4
OpDecorate %13 ArrayStride 4
OpMemberDecorate %14 0 Offset 0
OpMemberDecorate %14 1 Offset 4
OpMemberDecorate %15 0 Offset 0
OpMemberDecorate %15 1 Offset 4
OpDecorate %16 DescriptorSet 0
OpDecorate %16 Binding 0
OpDecorate %17 Block
OpMemberDecorate %17 0 Offset 0
OpDecorate %19 DescriptorSet 0
OpDecorate %19 Binding 1
OpDecorate %20 Block
OpMemberDecorate %20 0 Offset 0
%2 = OpTypeVoid
%4 = OpTypeInt 32 0
%3 = OpConstant %4 128
%5 = OpConstant %4 0
%6 = OpConstant %4 1
%8 = OpTypeBool
%7 = OpConstantFalse %8
%10 = OpTypeFloat 32
%9 = OpConstant %10 1.0
%11 = OpTypeInt 32 1
%12 = OpTypeArray %11 %3
%13 = OpTypeArray %4 %3
%14 = OpTypeStruct %11 %8
%15 = OpTypeStruct %4 %8
%17 = OpTypeStruct %12
%18 = OpTypePointer StorageBuffer %17
%16 = OpVariable %18 StorageBuffer
%20 = OpTypeStruct %13
%21 = OpTypePointer StorageBuffer %20
%19 = OpVariable %21 StorageBuffer
%23 = OpTypePointer Function %4
%25 = OpTypePointer Function %11
%26 = OpConstantNull %11
%28 = OpTypePointer Function %8
%29 = OpConstantNull %8
%32 = OpTypeFunction %2
%33 = OpTypePointer StorageBuffer %12
%35 = OpTypePointer StorageBuffer %13
%46 = OpTypePointer StorageBuffer %11
%49 = OpConstant %11 1
%50 = OpConstant %4 64
%75 = OpConstantNull %4
%77 = OpConstantNull %8
%91 = OpTypePointer StorageBuffer %4
%31 = OpFunction %2 None %32
%30 = OpLabel
%22 = OpVariable %23 Function %5
%24 = OpVariable %25 Function %26
%27 = OpVariable %28 Function %29
%34 = OpAccessChain %33 %16 %5
OpBranch %36
%36 = OpLabel
OpBranch %37
%37 = OpLabel
OpLoopMerge %38 %40 None
OpBranch %39
%39 = OpLabel
%41 = OpLoad %4 %22
%42 = OpULessThan %8 %41 %3
OpSelectionMerge %43 None
OpBranchConditional %42 %43 %44
%44 = OpLabel
OpBranch %38
%43 = OpLabel
%45 = OpLoad %4 %22
%47 = OpAccessChain %46 %34 %45
%48 = OpAtomicLoad %11 %47 %49 %50
OpStore %24 %48
OpStore %27 %7
OpBranch %51
%51 = OpLabel
OpLoopMerge %52 %54 None
OpBranch %53
%53 = OpLabel
%55 = OpLoad %8 %27
%56 = OpLogicalNot %8 %55
OpSelectionMerge %57 None
OpBranchConditional %56 %57 %58
%58 = OpLabel
OpBranch %52
%57 = OpLabel
%59 = OpLoad %11 %24
%60 = OpBitcast %10 %59
%61 = OpFAdd %10 %60 %9
%62 = OpBitcast %11 %61
%63 = OpLoad %4 %22
%64 = OpLoad %11 %24
%66 = OpAccessChain %46 %34 %63
%67 = OpAtomicCompareExchange %11 %66 %49 %50 %50 %62 %64
%68 = OpIEqual %8 %67 %64
%65 = OpCompositeConstruct %14 %67 %68
%69 = OpCompositeExtract %11 %65 0
OpStore %24 %69
%70 = OpCompositeExtract %8 %65 1
OpStore %27 %70
OpBranch %54
%54 = OpLabel
OpBranch %51
%52 = OpLabel
OpBranch %40
%40 = OpLabel
%71 = OpLoad %4 %22
%72 = OpIAdd %4 %71 %6
OpStore %22 %72
OpBranch %37
%38 = OpLabel
OpReturn
OpFunctionEnd
%79 = OpFunction %2 None %32
%78 = OpLabel
%73 = OpVariable %23 Function %5
%74 = OpVariable %23 Function %75
%76 = OpVariable %28 Function %77
%80 = OpAccessChain %35 %19 %5
OpBranch %81
%81 = OpLabel
OpBranch %82
%82 = OpLabel
OpLoopMerge %83 %85 None
OpBranch %84
%84 = OpLabel
%86 = OpLoad %4 %73
%87 = OpULessThan %8 %86 %3
OpSelectionMerge %88 None
OpBranchConditional %87 %88 %89
%89 = OpLabel
OpBranch %83
%88 = OpLabel
%90 = OpLoad %4 %73
%92 = OpAccessChain %91 %80 %90
%93 = OpAtomicLoad %4 %92 %49 %50
OpStore %74 %93
OpStore %76 %7
OpBranch %94
%94 = OpLabel
OpLoopMerge %95 %97 None
OpBranch %96
%96 = OpLabel
%98 = OpLoad %8 %76
%99 = OpLogicalNot %8 %98
OpSelectionMerge %100 None
OpBranchConditional %99 %100 %101
%101 = OpLabel
OpBranch %95
%100 = OpLabel
%102 = OpLoad %4 %74
%103 = OpBitcast %10 %102
%104 = OpFAdd %10 %103 %9
%105 = OpBitcast %4 %104
%106 = OpLoad %4 %73
%107 = OpLoad %4 %74
%109 = OpAccessChain %91 %80 %106
%110 = OpAtomicCompareExchange %4 %109 %49 %50 %50 %105 %107
%111 = OpIEqual %8 %110 %107
%108 = OpCompositeConstruct %15 %110 %111
%112 = OpCompositeExtract %4 %108 0
OpStore %74 %112
%113 = OpCompositeExtract %8 %108 1
OpStore %76 %113
OpBranch %97
%97 = OpLabel
OpBranch %94
%95 = OpLabel
OpBranch %85
%85 = OpLabel
%114 = OpLoad %4 %73
%115 = OpIAdd %4 %114 %6
OpStore %73 %115
OpBranch %82
%83 = OpLabel
OpReturn
OpFunctionEnd

View File

@@ -0,0 +1,92 @@
struct gen___atomic_compare_exchange_result {
old_value: i32,
exchanged: bool,
}
struct gen___atomic_compare_exchange_result_1 {
old_value: u32,
exchanged: bool,
}
let SIZE: u32 = 128u;
@group(0) @binding(0)
var<storage, read_write> arr_i32_: array<atomic<i32>,SIZE>;
@group(0) @binding(1)
var<storage, read_write> arr_u32_: array<atomic<u32>,SIZE>;
@compute @workgroup_size(1, 1, 1)
fn test_atomic_compare_exchange_i32_() {
var i: u32 = 0u;
var old: i32;
var exchanged: bool;
loop {
let _e5 = i;
if (_e5 < SIZE) {
} else {
break;
}
let _e10 = i;
let _e12 = atomicLoad((&arr_i32_[_e10]));
old = _e12;
exchanged = false;
loop {
let _e16 = exchanged;
if !(_e16) {
} else {
break;
}
let _e18 = old;
let new_ = bitcast<i32>((bitcast<f32>(_e18) + 1.0));
let _e23 = i;
let _e25 = old;
let _e26 = atomicCompareExchangeWeak((&arr_i32_[_e23]), _e25, new_);
old = _e26.old_value;
exchanged = _e26.exchanged;
}
continuing {
let _e7 = i;
i = (_e7 + 1u);
}
}
return;
}
@compute @workgroup_size(1, 1, 1)
fn test_atomic_compare_exchange_u32_() {
var i_1: u32 = 0u;
var old_1: u32;
var exchanged_1: bool;
loop {
let _e5 = i_1;
if (_e5 < SIZE) {
} else {
break;
}
let _e10 = i_1;
let _e12 = atomicLoad((&arr_u32_[_e10]));
old_1 = _e12;
exchanged_1 = false;
loop {
let _e16 = exchanged_1;
if !(_e16) {
} else {
break;
}
let _e18 = old_1;
let new_1 = bitcast<u32>((bitcast<f32>(_e18) + 1.0));
let _e23 = i_1;
let _e25 = old_1;
let _e26 = atomicCompareExchangeWeak((&arr_u32_[_e23]), _e25, new_1);
old_1 = _e26.old_value;
exchanged_1 = _e26.exchanged;
}
continuing {
let _e7 = i_1;
i_1 = (_e7 + 1u);
}
}
return;
}

View File

@@ -492,6 +492,7 @@ fn convert_wgsl() {
"access",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
("atomicCompareExchange", Targets::SPIRV | Targets::WGSL),
(
"padding",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,