diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index aa7af75c8d..e9fc133485 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -125,22 +125,15 @@ impl StatementGraph { } S::Atomic { pointer, - fun, + ref fun, + value, result, } => { self.emits.push((id, result)); self.dependencies.push((id, pointer, "pointer")); - match fun { - crate::AtomicFunction::Binary { op: _, value } - | crate::AtomicFunction::Min(value) - | crate::AtomicFunction::Max(value) - | crate::AtomicFunction::Exchange(value) => { - self.dependencies.push((id, value, "value")); - } - crate::AtomicFunction::CompareExchange { cmp, value } => { - self.dependencies.push((id, cmp, "cmp")); - self.dependencies.push((id, value, "value")); - } + self.dependencies.push((id, value, "value")); + if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { + self.dependencies.push((id, cmp, "cmp")); } "Atomic" } diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index c14c64694f..e47e89688c 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -68,6 +68,21 @@ pub const SUPPORTED_ES_VERSIONS: &[u16] = &[300, 310, 320]; pub type BindingMap = std::collections::BTreeMap; +impl crate::AtomicFunction { + fn to_glsl(self) -> &'static str { + match self { + Self::Add => "Add", + Self::And => "And", + Self::InclusiveOr => "Or", + Self::ExclusiveOr => "Xor", + Self::Min => "Min", + Self::Max => "Max", + Self::Exchange { compare: None } => "Exchange", + Self::Exchange { compare: Some(_) } => "", //TODO + } + } +} + /// glsl version #[derive(Debug, Copy, Clone, PartialEq)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] @@ -1682,7 +1697,8 @@ impl<'a, W: Write> Writer<'a, W> { } Statement::Atomic { pointer, - fun, + ref fun, + value, result, } => { write!(self.out, "{}", INDENT.repeat(indent))?; @@ -1691,44 +1707,17 @@ impl<'a, W: Write> Writer<'a, W> { self.write_value_type(res_ty)?; write!(self.out, " {} = ", res_name)?; self.named_expressions.insert(result, res_name); - match fun { - crate::AtomicFunction::Binary { op, value } => { - let fun_str = match op { - crate::BinaryOperator::Add => "Add", - crate::BinaryOperator::And => "And", - crate::BinaryOperator::InclusiveOr => "Or", - crate::BinaryOperator::ExclusiveOr => "Xor", - _ => unreachable!(), - }; - write!(self.out, "atomic{}(", fun_str)?; - self.write_expr(pointer, ctx)?; - write!(self.out, ", ")?; - self.write_expr(value, ctx)?; - } - crate::AtomicFunction::Min(value) => { - write!(self.out, "atomicMin(")?; - self.write_expr(pointer, ctx)?; - write!(self.out, ", ")?; - self.write_expr(value, ctx)?; - } - crate::AtomicFunction::Max(value) => { - write!(self.out, "atomicMax(")?; - self.write_expr(pointer, ctx)?; - write!(self.out, ", ")?; - self.write_expr(value, ctx)?; - } - crate::AtomicFunction::Exchange(value) => { - write!(self.out, "atomicExchange(")?; - self.write_expr(pointer, ctx)?; - write!(self.out, ", ")?; - self.write_expr(value, ctx)?; - } - crate::AtomicFunction::CompareExchange { .. } => { - return Err(Error::Custom( - "atomic CompareExchange is not implemented".to_string(), - )); - } + + let fun_str = fun.to_glsl(); + write!(self.out, "atomic{}(", fun_str)?; + self.write_expr(pointer, ctx)?; + if let crate::AtomicFunction::Exchange { compare: Some(_) } = *fun { + return Err(Error::Custom( + "atomic CompareExchange is not implemented".to_string(), + )); } + write!(self.out, ", ")?; + self.write_expr(value, ctx)?; writeln!(self.out, ");")?; } } @@ -2371,6 +2360,7 @@ impl<'a, W: Write> Writer<'a, W> { } } } + // These expressions never show up in `Emit`. Expression::CallResult(_) | Expression::AtomicResult { .. } => unreachable!(), // `ArrayLength` is written as `expr.length()` and we convert it to a uint Expression::ArrayLength(expr) => { diff --git a/src/back/hlsl/conv.rs b/src/back/hlsl/conv.rs index 958858c412..501cbbcd2e 100644 --- a/src/back/hlsl/conv.rs +++ b/src/back/hlsl/conv.rs @@ -123,15 +123,18 @@ impl crate::Sampling { } } -impl crate::BinaryOperator { +impl crate::AtomicFunction { /// Return the HLSL suffix for the `InterlockedXxx` method. - pub(super) fn to_hlsl_atomic_suffix(self) -> &'static str { + pub(super) fn to_hlsl_suffix(self) -> &'static str { match self { Self::Add => "Add", Self::And => "And", Self::InclusiveOr => "Or", Self::ExclusiveOr => "Xor", - _ => "", + Self::Min => "Min", + Self::Max => "Max", + Self::Exchange { compare: None } => "Exchange", + Self::Exchange { .. } => "", //TODO } } } diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 5d651fe41d..f46e4ee9ac 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1108,7 +1108,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } Statement::Atomic { pointer, - fun, + ref fun, + value, result, } => { write!(self.out, "{}", INDENT.repeat(indent))?; @@ -1125,37 +1126,18 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { let chain = mem::take(&mut self.temp_access_chain); let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; - write!(self.out, " {}; {}.Interlocked", res_name, var_name)?; - match fun { - crate::AtomicFunction::Binary { op, value } => { - let suffix = op.to_hlsl_atomic_suffix(); - write!(self.out, "{}(", suffix)?; - self.write_storage_address(module, &chain, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, value, func_ctx)?; - } - crate::AtomicFunction::Min(value) => { - write!(self.out, "Min(")?; - self.write_storage_address(module, &chain, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, value, func_ctx)?; - } - crate::AtomicFunction::Max(value) => { - write!(self.out, "Max(")?; - self.write_storage_address(module, &chain, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, value, func_ctx)?; - } - crate::AtomicFunction::Exchange(value) => { - write!(self.out, "Exchange(")?; - self.write_storage_address(module, &chain, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, value, func_ctx)?; - } - crate::AtomicFunction::CompareExchange { .. } => { - return Err(Error::Unimplemented("atomic CompareExchange".to_string())); - } + let fun_str = fun.to_hlsl_suffix(); + write!( + self.out, + " {}; {}.Interlocked{}(", + res_name, var_name, fun_str + )?; + self.write_storage_address(module, &chain, func_ctx)?; + if let crate::AtomicFunction::Exchange { compare: Some(_) } = *fun { + return Err(Error::Unimplemented("atomic CompareExchange".to_string())); } + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; writeln!(self.out, ", {});", res_name)?; self.temp_access_chain = chain; self.named_expressions.insert(result, res_name); diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 6520d6aed1..fa9ec02bf7 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -33,18 +33,6 @@ impl Display for Level { } } -impl crate::BinaryOperator { - fn to_msl_atomic_str(self) -> &'static str { - match self { - Self::Add => "add", - Self::And => "and", - Self::InclusiveOr => "or", - Self::ExclusiveOr => "xor", - _ => unreachable!(), - } - } -} - struct TypeContext<'a> { handle: Handle, arena: &'a crate::Arena, @@ -1629,29 +1617,34 @@ impl Writer { } crate::Statement::Atomic { pointer, - fun, + ref fun, + value, result, } => { write!(self.out, "{}", level)?; let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); self.start_baking_expression(result, &context.expression, &res_name)?; self.named_expressions.insert(result, res_name); - match fun { - crate::AtomicFunction::Binary { op, value } => { - self.put_atomic_fetch( - pointer, - op.to_msl_atomic_str(), - value, - &context.expression, - )?; + match *fun { + crate::AtomicFunction::Add => { + self.put_atomic_fetch(pointer, "add", value, &context.expression)?; } - crate::AtomicFunction::Min(value) => { + crate::AtomicFunction::And => { + self.put_atomic_fetch(pointer, "and", value, &context.expression)?; + } + crate::AtomicFunction::InclusiveOr => { + self.put_atomic_fetch(pointer, "or", value, &context.expression)?; + } + crate::AtomicFunction::ExclusiveOr => { + self.put_atomic_fetch(pointer, "xor", value, &context.expression)?; + } + crate::AtomicFunction::Min => { self.put_atomic_fetch(pointer, "min", value, &context.expression)?; } - crate::AtomicFunction::Max(value) => { + crate::AtomicFunction::Max => { self.put_atomic_fetch(pointer, "max", value, &context.expression)?; } - crate::AtomicFunction::Exchange(value) => { + crate::AtomicFunction::Exchange { compare: None } => { write!( self.out, "{}::atomic_exchange_explicit({}", @@ -1662,7 +1655,7 @@ impl Writer { self.put_expression(value, &context.expression, true)?; write!(self.out, ", {}::memory_order_relaxed)", NAMESPACE)?; } - crate::AtomicFunction::CompareExchange { .. } => { + crate::AtomicFunction::Exchange { .. } => { return Err(Error::FeatureNotImplemented( "atomic CompareExchange".to_string(), )); diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 3707127411..0142e10cef 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -8,9 +8,6 @@ use super::{ use crate::{arena::Handle, proc::TypeResolution}; use spirv::Word; -//TODO: should this ever be `Workgroup`? -const ATOMIC_SCOPE: spirv::Scope = spirv::Scope::Device; - fn get_dimension(type_inner: &crate::TypeInner) -> Dimension { match *type_inner { crate::TypeInner::Scalar { .. } => Dimension::Scalar, @@ -677,9 +674,8 @@ impl<'w> BlockContext<'w> { _ => None, }; let instruction = if let Some(class) = atomic_class { - let semantics = - spirv::MemorySemantics::ACQUIRE | class.to_spirv_semantics(); - let scope_constant_id = self.get_scope_constant(ATOMIC_SCOPE as u32)?; + let (semantics, scope) = class.to_spirv_semantics_and_scope(); + let scope_constant_id = self.get_scope_constant(scope as u32)?; let semantics_id = self.get_index_constant(semantics.bits())?; Instruction::atomic_load( result_type_id, @@ -719,7 +715,7 @@ impl<'w> BlockContext<'w> { } crate::Expression::FunctionArgument(index) => self.function.parameter_id(index), crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } => { - self.writer.lookup_statement_result[&expr_handle] + self.cached[expr_handle] } crate::Expression::As { expr, @@ -1619,10 +1615,8 @@ impl<'w> BlockContext<'w> { _ => None, }; let instruction = if let Some(class) = atomic_class { - let semantics = - spirv::MemorySemantics::RELEASE | class.to_spirv_semantics(); - let scope_constant_id = - self.get_scope_constant(ATOMIC_SCOPE as u32)?; + let (semantics, scope) = class.to_spirv_semantics_and_scope(); + let scope_constant_id = self.get_scope_constant(scope as u32)?; let semantics_id = self.get_index_constant(semantics.bits())?; Instruction::atomic_store( pointer_id, @@ -1698,7 +1692,6 @@ impl<'w> BlockContext<'w> { let type_id = match result { Some(expr) => { self.cached[expr] = id; - self.writer.lookup_statement_result.insert(expr, id); self.get_expression_type_id(&self.fun_info[expr].ty)? } None => self.writer.void_type, @@ -1713,14 +1706,14 @@ impl<'w> BlockContext<'w> { } crate::Statement::Atomic { pointer, - fun, + ref fun, + value, result, } => { let id = self.gen_id(); let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty)?; self.cached[result] = id; - self.writer.lookup_statement_result.insert(result, id); let pointer_id = match self.write_expression_pointer(pointer, &mut block)? { ExpressionPointer::Ready { pointer_id } => pointer_id, @@ -1735,19 +1728,59 @@ impl<'w> BlockContext<'w> { crate::TypeInner::Pointer { base: _, class } => class, _ => unimplemented!(), }; - let semantics = - spirv::MemorySemantics::ACQUIRE_RELEASE | class.to_spirv_semantics(); - let scope_constant_id = self.get_scope_constant(ATOMIC_SCOPE as u32)?; + let (semantics, scope) = class.to_spirv_semantics_and_scope(); + let scope_constant_id = self.get_scope_constant(scope as u32)?; let semantics_id = self.get_index_constant(semantics.bits())?; + let value_id = self.cached[value]; + let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types); - let instruction = match fun { - crate::AtomicFunction::Binary { op, value } => { - let value_id = self.cached[value]; - let spirv_op = match op { - crate::BinaryOperator::Add => spirv::Op::AtomicIAdd, - crate::BinaryOperator::And => spirv::Op::AtomicAnd, - crate::BinaryOperator::InclusiveOr => spirv::Op::AtomicOr, - crate::BinaryOperator::ExclusiveOr => spirv::Op::AtomicXor, + let instruction = match *fun { + crate::AtomicFunction::Add => Instruction::atomic_binary( + spirv::Op::AtomicIAdd, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::And => Instruction::atomic_binary( + spirv::Op::AtomicAnd, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::InclusiveOr => Instruction::atomic_binary( + spirv::Op::AtomicOr, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::ExclusiveOr => Instruction::atomic_binary( + spirv::Op::AtomicXor, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::Min => { + let spirv_op = match *value_inner { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + } => spirv::Op::AtomicSMin, + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + width: _, + } => spirv::Op::AtomicUMin, _ => unimplemented!(), }; Instruction::atomic_binary( @@ -1760,20 +1793,18 @@ impl<'w> BlockContext<'w> { value_id, ) } - crate::AtomicFunction::Min(value) => { - let value_id = self.cached[value]; - let spirv_op = - match *self.fun_info[value].ty.inner_with(&self.ir_module.types) { - crate::TypeInner::Scalar { - kind: crate::ScalarKind::Sint, - width: _, - } => spirv::Op::AtomicSMin, - crate::TypeInner::Scalar { - kind: crate::ScalarKind::Uint, - width: _, - } => spirv::Op::AtomicUMin, - _ => unimplemented!(), - }; + crate::AtomicFunction::Max => { + let spirv_op = match *value_inner { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + } => spirv::Op::AtomicSMax, + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + width: _, + } => spirv::Op::AtomicUMax, + _ => unimplemented!(), + }; Instruction::atomic_binary( spirv_op, result_type_id, @@ -1784,32 +1815,7 @@ impl<'w> BlockContext<'w> { value_id, ) } - crate::AtomicFunction::Max(value) => { - let value_id = self.cached[value]; - let spirv_op = - match *self.fun_info[value].ty.inner_with(&self.ir_module.types) { - crate::TypeInner::Scalar { - kind: crate::ScalarKind::Sint, - width: _, - } => spirv::Op::AtomicSMax, - crate::TypeInner::Scalar { - kind: crate::ScalarKind::Uint, - width: _, - } => spirv::Op::AtomicUMax, - _ => unimplemented!(), - }; - Instruction::atomic_binary( - spirv_op, - result_type_id, - id, - pointer_id, - scope_constant_id, - semantics_id, - value_id, - ) - } - crate::AtomicFunction::Exchange(value) => { - let value_id = self.cached[value]; + crate::AtomicFunction::Exchange { compare: None } => { Instruction::atomic_binary( spirv::Op::AtomicExchange, result_type_id, @@ -1820,7 +1826,7 @@ impl<'w> BlockContext<'w> { value_id, ) } - crate::AtomicFunction::CompareExchange { .. } => { + crate::AtomicFunction::Exchange { compare: Some(_) } => { return Err(Error::FeatureNotImplemented("atomic CompareExchange")); } }; diff --git a/src/back/spv/helpers.rs b/src/back/spv/helpers.rs index 9db7041263..fe8fbaf757 100644 --- a/src/back/spv/helpers.rs +++ b/src/back/spv/helpers.rs @@ -50,11 +50,14 @@ pub(super) fn contains_builtin( } impl crate::StorageClass { - pub(super) fn to_spirv_semantics(self) -> spirv::MemorySemantics { + pub(super) fn to_spirv_semantics_and_scope(self) -> (spirv::MemorySemantics, spirv::Scope) { match self { - Self::Storage { .. } => spirv::MemorySemantics::UNIFORM_MEMORY, - Self::WorkGroup => spirv::MemorySemantics::WORKGROUP_MEMORY, - _ => spirv::MemorySemantics::empty(), + Self::Storage { .. } => (spirv::MemorySemantics::UNIFORM_MEMORY, spirv::Scope::Device), + Self::WorkGroup => ( + spirv::MemorySemantics::WORKGROUP_MEMORY, + spirv::Scope::Workgroup, + ), + _ => (spirv::MemorySemantics::empty(), spirv::Scope::Invocation), } } } diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index dd39471753..6530160e5f 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -771,7 +771,7 @@ impl super::Instruction { pointer: Word, scope_id: Word, semantics_id: Word, - operand: Word, + value: Word, ) -> Self { let mut instruction = Self::new(op); instruction.set_type(result_type_id); @@ -779,7 +779,7 @@ impl super::Instruction { instruction.add_operand(pointer); instruction.add_operand(scope_id); instruction.add_operand(semantics_id); - instruction.add_operand(operand); + instruction.add_operand(value); instruction } diff --git a/src/back/spv/mod.rs b/src/back/spv/mod.rs index 39b015728e..55c1d6ba4b 100644 --- a/src/back/spv/mod.rs +++ b/src/back/spv/mod.rs @@ -396,7 +396,6 @@ pub struct Writer { lookup_type: crate::FastHashMap, lookup_function: crate::FastHashMap, Word>, lookup_function_type: crate::FastHashMap, - lookup_statement_result: crate::FastHashMap, Word>, constant_ids: Vec, cached_constants: crate::FastHashMap<(crate::ScalarValue, crate::Bytes), Word>, global_variables: Vec, diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index fec4d5c33e..31a2f98926 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -78,7 +78,6 @@ impl Writer { lookup_type: crate::FastHashMap::default(), lookup_function: crate::FastHashMap::default(), lookup_function_type: crate::FastHashMap::default(), - lookup_statement_result: crate::FastHashMap::default(), constant_ids: Vec::new(), cached_constants: crate::FastHashMap::default(), global_variables: Vec::new(), @@ -127,7 +126,6 @@ impl Writer { lookup_type: take(&mut self.lookup_type).recycle(), lookup_function: take(&mut self.lookup_function).recycle(), lookup_function_type: take(&mut self.lookup_function_type).recycle(), - lookup_statement_result: take(&mut self.lookup_statement_result).recycle(), constant_ids: take(&mut self.constant_ids).recycle(), cached_constants: take(&mut self.cached_constants).recycle(), global_variables: take(&mut self.global_variables).recycle(), diff --git a/src/back/wgsl/mod.rs b/src/back/wgsl/mod.rs index e8fb1e4938..33d01e2458 100644 --- a/src/back/wgsl/mod.rs +++ b/src/back/wgsl/mod.rs @@ -27,14 +27,17 @@ pub fn write_string( Ok(output) } -impl crate::BinaryOperator { - fn to_wgsl_atomic_suffix(self) -> &'static str { +impl crate::AtomicFunction { + fn to_wgsl(self) -> &'static str { match self { Self::Add => "Add", Self::And => "And", Self::InclusiveOr => "Or", Self::ExclusiveOr => "Xor", - _ => unreachable!(), + Self::Min => "Min", + Self::Max => "Max", + Self::Exchange { compare: None } => "Exchange", + Self::Exchange { .. } => "CompareExchangeWeak", } } } diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 68d1563b61..9748c4284d 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -703,7 +703,8 @@ impl Writer { } Statement::Atomic { pointer, - fun, + ref fun, + value, result, } => { write!(self.out, "{}", INDENT.repeat(indent))?; @@ -712,45 +713,15 @@ impl Writer { self.write_expr(module, result, func_ctx)?; self.named_expressions.insert(result, res_name); - match fun { - crate::AtomicFunction::Binary { op, value } => { - write!( - self.out, - "atomic{}({}", - op.to_wgsl_atomic_suffix(), - ATOMIC_REFERENCE - )?; - self.write_expr(module, pointer, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, value, func_ctx)?; - } - crate::AtomicFunction::Min(value) => { - write!(self.out, "atomicMin({}", ATOMIC_REFERENCE)?; - self.write_expr(module, pointer, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, value, func_ctx)?; - } - crate::AtomicFunction::Max(value) => { - write!(self.out, "atomicMax({}", ATOMIC_REFERENCE)?; - self.write_expr(module, pointer, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, value, func_ctx)?; - } - crate::AtomicFunction::Exchange(value) => { - write!(self.out, "atomicExchange({}", ATOMIC_REFERENCE)?; - self.write_expr(module, pointer, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, value, func_ctx)?; - } - crate::AtomicFunction::CompareExchange { cmp, value } => { - write!(self.out, "atomicCompareExchangeWeak({}", ATOMIC_REFERENCE)?; - self.write_expr(module, pointer, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, cmp, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, value, func_ctx)?; - } + let fun_str = fun.to_wgsl(); + write!(self.out, "atomic{}({}", fun_str, ATOMIC_REFERENCE)?; + self.write_expr(module, pointer, func_ctx)?; + if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { + write!(self.out, ", ")?; + self.write_expr(module, cmp, func_ctx)?; } + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; writeln!(self.out, ");")? } Statement::ImageStore { diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 5dcdb95871..96a3d88b99 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -1101,10 +1101,10 @@ impl Parser { Ok(Some((fun_handle, arguments))) } - fn parse_atomic_helper<'a, F: FnOnce(Handle) -> crate::AtomicFunction>( + fn parse_atomic_helper<'a>( &mut self, lexer: &mut Lexer<'a>, - function_factory: F, + fun: crate::AtomicFunction, mut ctx: ExpressionContext<'a, '_, '_>, ) -> Result, Error<'a>> { lexer.open_arguments()?; @@ -1127,25 +1127,13 @@ impl Parser { let result = ctx.interrupt_emitter(expression); ctx.block.push(crate::Statement::Atomic { pointer, - fun: function_factory(value), + fun, + value, result, }); Ok(result) } - fn parse_atomic_binary_op<'a>( - &mut self, - lexer: &mut Lexer<'a>, - op: crate::BinaryOperator, - ctx: ExpressionContext<'a, '_, '_>, - ) -> Result, Error<'a>> { - self.parse_atomic_helper( - lexer, - |value| crate::AtomicFunction::Binary { op, value }, - ctx, - ) - } - fn parse_function_call_inner<'a>( &mut self, lexer: &mut Lexer<'a>, @@ -1215,33 +1203,33 @@ impl Parser { crate::Expression::Load { pointer } } "atomicAdd" => { - let handle = self.parse_atomic_binary_op( + let handle = self.parse_atomic_helper( lexer, - crate::BinaryOperator::Add, + crate::AtomicFunction::Add, ctx.reborrow(), )?; return Ok(Some(handle)); } "atomicAnd" => { - let handle = self.parse_atomic_binary_op( + let handle = self.parse_atomic_helper( lexer, - crate::BinaryOperator::And, + crate::AtomicFunction::And, ctx.reborrow(), )?; return Ok(Some(handle)); } "atomicOr" => { - let handle = self.parse_atomic_binary_op( + let handle = self.parse_atomic_helper( lexer, - crate::BinaryOperator::InclusiveOr, + crate::AtomicFunction::InclusiveOr, ctx.reborrow(), )?; return Ok(Some(handle)); } "atomicXor" => { - let handle = self.parse_atomic_binary_op( + let handle = self.parse_atomic_helper( lexer, - crate::BinaryOperator::ExclusiveOr, + crate::AtomicFunction::ExclusiveOr, ctx.reborrow(), )?; return Ok(Some(handle)); @@ -1257,8 +1245,11 @@ impl Parser { return Ok(Some(handle)); } "atomicExchange" => { - let handle = - self.parse_atomic_helper(lexer, crate::AtomicFunction::Exchange, ctx)?; + let handle = self.parse_atomic_helper( + lexer, + crate::AtomicFunction::Exchange { compare: None }, + ctx, + )?; return Ok(Some(handle)); } "atomicCompareExchangeWeak" => { @@ -1286,7 +1277,8 @@ impl Parser { let result = ctx.interrupt_emitter(expression); ctx.block.push(crate::Statement::Atomic { pointer, - fun: crate::AtomicFunction::CompareExchange { cmp, value }, + fun: crate::AtomicFunction::Exchange { compare: Some(cmp) }, + value, result, }); return Ok(Some(result)); diff --git a/src/lib.rs b/src/lib.rs index 12a5a0283a..bc934cdde6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -745,20 +745,13 @@ pub enum BinaryOperator { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum AtomicFunction { - /// Binary operation of an atomic with a value. - /// - /// Note: only supports a subset of operations, defined by the validator. - Binary { - op: BinaryOperator, - value: Handle, - }, - Min(Handle), - Max(Handle), - Exchange(Handle), - CompareExchange { - cmp: Handle, - value: Handle, - }, + Add, + And, + ExclusiveOr, + InclusiveOr, + Min, + Max, + Exchange { compare: Option> }, } /// Axis on which to compute a derivative. @@ -1320,6 +1313,8 @@ pub enum Statement { pointer: Handle, /// Function to run on the atomic. fun: AtomicFunction, + /// Value to use in the function. + value: Handle, /// Emitted expression as a result. result: Handle, }, diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 04f766c045..543e25034e 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -780,19 +780,15 @@ impl FunctionInfo { } S::Atomic { pointer, - fun, + ref fun, + value, result: _, } => { let _ = self.add_ref_impl(pointer, GlobalUse::WRITE); - let _ = match fun { - crate::AtomicFunction::Binary { op: _, value } - | crate::AtomicFunction::Min(value) - | crate::AtomicFunction::Max(value) - | crate::AtomicFunction::Exchange(value) => self.add_ref(value), - crate::AtomicFunction::CompareExchange { cmp, value } => { - self.add_ref(value).or(self.add_ref(cmp)) - } - }; + let _ = self.add_ref(value); + if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { + let _ = self.add_ref(cmp); + } FunctionUniformity::new() } }; diff --git a/src/valid/function.rs b/src/valid/function.rs index 612e979c64..7238c29775 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -271,12 +271,11 @@ impl super::Validator { fn validate_atomic( &mut self, pointer: Handle, - fun: crate::AtomicFunction, + fun: &crate::AtomicFunction, + value: Handle, result: Handle, context: &BlockContext, ) -> Result<(), FunctionError> { - use crate::BinaryOperator as Bo; - let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?; let (ptr_kind, ptr_width) = match *pointer_inner { crate::TypeInner::Pointer { base, .. } => match context.types[base].inner { @@ -292,29 +291,8 @@ impl super::Validator { } }; - let value = match fun { - crate::AtomicFunction::Binary { op, value } => { - match op { - Bo::Add | Bo::And | Bo::InclusiveOr | Bo::ExclusiveOr => {} - _ => return Err(AtomicError::InvalidBinaryOp(op, ptr_kind).into()), - } - value - } - crate::AtomicFunction::Min(value) - | crate::AtomicFunction::Max(value) - | crate::AtomicFunction::Exchange(value) => value, - crate::AtomicFunction::CompareExchange { cmp, value } => { - if context.resolve_type(cmp, &self.valid_expression_set)? - != context.resolve_type(value, &self.valid_expression_set)? - { - log::error!("Atomic exchange comparison has a different type from the value"); - return Err(AtomicError::InvalidOperand(cmp).into()); - } - value - } - }; - - match *context.resolve_type(value, &self.valid_expression_set)? { + let value_inner = context.resolve_type(value, &self.valid_expression_set)?; + match *value_inner { crate::TypeInner::Scalar { width, kind } if kind == ptr_kind && width == ptr_width => {} ref other => { log::error!("Atomic operand type {:?}", other); @@ -322,6 +300,13 @@ impl super::Validator { } } + if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { + if context.resolve_type(cmp, &self.valid_expression_set)? != value_inner { + log::error!("Atomic exchange comparison has a different type from the value"); + return Err(AtomicError::InvalidOperand(cmp).into()); + } + } + if self.valid_expression_set.insert(result.index()) { self.valid_expression_list.push(result); } else { @@ -602,10 +587,11 @@ impl super::Validator { }, S::Atomic { pointer, - fun, + ref fun, + value, result, } => { - self.validate_atomic(pointer, fun, result, context)?; + self.validate_atomic(pointer, fun, value, result, context)?; } } } diff --git a/tests/out/spv/access.spvasm b/tests/out/spv/access.spvasm index 33920bc062..df869e495b 100644 --- a/tests/out/spv/access.spvasm +++ b/tests/out/spv/access.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 107 +; Bound: 105 OpCapability Shader OpExtension "SPV_KHR_storage_buffer_storage_class" %1 = OpExtInstImport "GLSL.std.450" @@ -79,9 +79,7 @@ OpDecorate %37 BuiltIn Position %74 = OpTypePointer Function %4 %78 = OpTypeVector %4 4 %86 = OpTypePointer StorageBuffer %4 -%89 = OpConstant %9 66 -%92 = OpConstant %9 72 -%106 = OpConstant %9 68 +%89 = OpConstant %9 64 %39 = OpFunction %2 None %40 %33 = OpLabel %29 = OpVariable %30 Function %5 @@ -138,27 +136,27 @@ OpBranch %85 %87 = OpAccessChain %86 %27 %15 %88 = OpAtomicLoad %4 %87 %11 %89 %91 = OpAccessChain %86 %27 %15 -%90 = OpAtomicIAdd %4 %91 %11 %92 %16 +%90 = OpAtomicIAdd %4 %91 %11 %89 %16 OpStore %82 %90 -%94 = OpAccessChain %86 %27 %15 -%93 = OpAtomicAnd %4 %94 %11 %92 %16 -OpStore %82 %93 -%96 = OpAccessChain %86 %27 %15 -%95 = OpAtomicOr %4 %96 %11 %92 %16 -OpStore %82 %95 -%98 = OpAccessChain %86 %27 %15 -%97 = OpAtomicXor %4 %98 %11 %92 %16 -OpStore %82 %97 -%100 = OpAccessChain %86 %27 %15 -%99 = OpAtomicSMin %4 %100 %11 %92 %16 -OpStore %82 %99 -%102 = OpAccessChain %86 %27 %15 -%101 = OpAtomicSMax %4 %102 %11 %92 %16 -OpStore %82 %101 +%93 = OpAccessChain %86 %27 %15 +%92 = OpAtomicAnd %4 %93 %11 %89 %16 +OpStore %82 %92 +%95 = OpAccessChain %86 %27 %15 +%94 = OpAtomicOr %4 %95 %11 %89 %16 +OpStore %82 %94 +%97 = OpAccessChain %86 %27 %15 +%96 = OpAtomicXor %4 %97 %11 %89 %16 +OpStore %82 %96 +%99 = OpAccessChain %86 %27 %15 +%98 = OpAtomicSMin %4 %99 %11 %89 %16 +OpStore %82 %98 +%101 = OpAccessChain %86 %27 %15 +%100 = OpAtomicSMax %4 %101 %11 %89 %16 +OpStore %82 %100 +%103 = OpAccessChain %86 %27 %15 +%102 = OpAtomicExchange %4 %103 %11 %89 %16 +OpStore %82 %102 %104 = OpAccessChain %86 %27 %15 -%103 = OpAtomicExchange %4 %104 %11 %92 %16 -OpStore %82 %103 -%105 = OpAccessChain %86 %27 %15 -OpAtomicStore %105 %11 %106 %88 +OpAtomicStore %104 %11 %89 %88 OpReturn OpFunctionEnd \ No newline at end of file