From 78f225a37a3ca405299b2d2991ac07d2742e9abb Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Fri, 6 Aug 2021 01:18:51 -0400 Subject: [PATCH] Implemented atomics in HLSL-out --- src/back/hlsl/conv.rs | 13 ++++ src/back/hlsl/storage.rs | 2 +- src/back/hlsl/writer.rs | 125 +++++++++++++++++++++--------- tests/out/hlsl/access.hlsl | 32 +++++++- tests/out/hlsl/access.hlsl.config | 2 +- tests/snapshots.rs | 3 +- 6 files changed, 134 insertions(+), 43 deletions(-) diff --git a/src/back/hlsl/conv.rs b/src/back/hlsl/conv.rs index 2291589731..958858c412 100644 --- a/src/back/hlsl/conv.rs +++ b/src/back/hlsl/conv.rs @@ -122,3 +122,16 @@ impl crate::Sampling { } } } + +impl crate::BinaryOperator { + /// Return the HLSL suffix for the `InterlockedXxx` method. + pub(super) fn to_hlsl_atomic_suffix(self) -> &'static str { + match self { + Self::Add => "Add", + Self::And => "And", + Self::InclusiveOr => "Or", + Self::ExclusiveOr => "Xor", + _ => "", + } + } +} diff --git a/src/back/hlsl/storage.rs b/src/back/hlsl/storage.rs index d16c17c477..3ba13f545b 100644 --- a/src/back/hlsl/storage.rs +++ b/src/back/hlsl/storage.rs @@ -40,7 +40,7 @@ pub(super) enum StoreValue { } impl super::Writer<'_, W> { - fn write_storage_address( + pub(super) fn write_storage_address( &mut self, module: &crate::Module, chain: &[SubAccess], diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 06868349e7..5d651fe41d 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -8,7 +8,7 @@ use crate::{ proc::{self, NameKey}, valid, Handle, Module, ShaderStage, TypeInner, }; -use std::fmt; +use std::{fmt, mem}; const LOCATION_SEMANTIC: &str = "LOC"; const SPECIAL_CBUF_TYPE: &str = "NagaConstants"; @@ -992,38 +992,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, ";")? } } - Statement::Call { - function, - ref arguments, - result, - } => { - write!(self.out, "{}", INDENT.repeat(indent))?; - if let Some(expr) = result { - write!(self.out, "const ")?; - let name = format!("{}{}", back::BAKE_PREFIX, expr.index()); - let expr_ty = &func_ctx.info[expr].ty; - match *expr_ty { - proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, - proc::TypeResolution::Value(ref value) => { - self.write_value_type(module, value)? - } - }; - write!(self.out, " {} = ", name)?; - self.write_expr(module, expr, func_ctx)?; - self.named_expressions.insert(expr, name); - } - let func_name = &self.names[&NameKey::Function(function)]; - write!(self.out, "{}(", func_name)?; - for (index, argument) in arguments.iter().enumerate() { - self.write_expr(module, *argument, func_ctx)?; - // Only write a comma if isn't the last element - if index != arguments.len().saturating_sub(1) { - // The leading space is for readability only - write!(self.out, ", ")?; - } - } - writeln!(self.out, ");")? - } Statement::Loop { ref body, ref continuing, @@ -1107,7 +1075,94 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_expr(module, value, func_ctx)?; writeln!(self.out, ";")?; } - _ => return Err(Error::Unimplemented(format!("write_stmt {:?}", stmt))), + Statement::Call { + function, + ref arguments, + result, + } => { + write!(self.out, "{}", INDENT.repeat(indent))?; + if let Some(expr) = result { + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, expr.index()); + let expr_ty = &func_ctx.info[expr].ty; + match *expr_ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {} = ", name)?; + self.named_expressions.insert(expr, name); + } + let func_name = &self.names[&NameKey::Function(function)]; + write!(self.out, "{}(", func_name)?; + for (index, argument) in arguments.iter().enumerate() { + self.write_expr(module, *argument, func_ctx)?; + // Only write a comma if isn't the last element + if index != arguments.len().saturating_sub(1) { + // The leading space is for readability only + write!(self.out, ", ")?; + } + } + writeln!(self.out, ");")? + } + Statement::Atomic { + pointer, + fun, + result, + } => { + write!(self.out, "{}", INDENT.repeat(indent))?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + + let var_handle = self.fill_access_chain(module, pointer, func_ctx)?; + // working around the borrow checker in `self.write_expr` + let chain = mem::take(&mut self.temp_access_chain); + let var_name = &self.names[&NameKey::GlobalVariable(var_handle)]; + + write!(self.out, " {}; {}.Interlocked", res_name, var_name)?; + match fun { + crate::AtomicFunction::Binary { op, value } => { + let suffix = op.to_hlsl_atomic_suffix(); + write!(self.out, "{}(", suffix)?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + } + crate::AtomicFunction::Min(value) => { + write!(self.out, "Min(")?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + } + crate::AtomicFunction::Max(value) => { + write!(self.out, "Max(")?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + } + crate::AtomicFunction::Exchange(value) => { + write!(self.out, "Exchange(")?; + self.write_storage_address(module, &chain, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, value, func_ctx)?; + } + crate::AtomicFunction::CompareExchange { .. } => { + return Err(Error::Unimplemented("atomic CompareExchange".to_string())); + } + } + writeln!(self.out, ", {});", res_name)?; + self.temp_access_chain = chain; + self.named_expressions.insert(result, res_name); + } + Statement::Switch { .. } => { + return Err(Error::Unimplemented(format!("write_stmt {:?}", stmt))) + } } Ok(()) @@ -1667,7 +1722,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, ")")? } // Nothing to do here, since call expression already cached - Expression::CallResult(_) => {} + Expression::CallResult(_) | Expression::AtomicResult { .. } => {} _ => return Err(Error::Unimplemented(format!("write_expr {:?}", expression))), } diff --git a/tests/out/hlsl/access.hlsl b/tests/out/hlsl/access.hlsl index 4e1494a9c2..d724bbcaf5 100644 --- a/tests/out/hlsl/access.hlsl +++ b/tests/out/hlsl/access.hlsl @@ -20,10 +20,10 @@ float4 foo(VertexInput_foo vertexinput_foo) : SV_Position float baz = foo1; foo1 = 1.0; float4x4 matrix1 = transpose(float4x4(asfloat(bar.Load4(0+0)), asfloat(bar.Load4(0+16)), asfloat(bar.Load4(0+32)), asfloat(bar.Load4(0+48)))); - uint2 arr[2] = {asuint(bar.Load2(64+0)), asuint(bar.Load2(64+8))}; + uint2 arr[2] = {asuint(bar.Load2(72+0)), asuint(bar.Load2(72+8))}; float4 _expr13 = asfloat(bar.Load4(48+0)); float b = _expr13.x; - int a = asint(bar.Load((((NagaBufferLengthRW(bar) - 80) / 4) - 2u)*4+80)); + int a = asint(bar.Load((((NagaBufferLengthRW(bar) - 88) / 4) - 2u)*4+88)); bar.Store(8+16+0, asuint(1.0)); { float4x4 _value2 = transpose(float4x4(float4(0.0.xxxx), float4(1.0.xxxx), float4(2.0.xxxx), float4(3.0.xxxx))); @@ -34,8 +34,8 @@ float4 foo(VertexInput_foo vertexinput_foo) : SV_Position } { uint2 _value2[2] = { uint2(0u.xx), uint2(1u.xx) }; - bar.Store2(64+0, asuint(_value2[0])); - bar.Store2(64+8, asuint(_value2[1])); + bar.Store2(72+0, asuint(_value2[0])); + bar.Store2(72+8, asuint(_value2[1])); } { int _result[5]={ a, int(b), 3, 4, 5 }; @@ -45,3 +45,27 @@ float4 foo(VertexInput_foo vertexinput_foo) : SV_Position int value = c[vertexinput_foo.vi1]; return mul(matrix1, float4(int4(value.xxxx))); } + +[numthreads(1, 1, 1)] +void atomics() +{ + int tmp = (int)0; + + int value = asint(bar.Load(64)); + int _e6; bar.InterlockedAdd(64, 5, _e6); + tmp = _e6; + int _e9; bar.InterlockedAnd(64, 5, _e9); + tmp = _e9; + int _e12; bar.InterlockedOr(64, 5, _e12); + tmp = _e12; + int _e15; bar.InterlockedXor(64, 5, _e15); + tmp = _e15; + int _e18; bar.InterlockedMin(64, 5, _e18); + tmp = _e18; + int _e21; bar.InterlockedMax(64, 5, _e21); + tmp = _e21; + int _e24; bar.InterlockedExchange(64, 5, _e24); + tmp = _e24; + bar.Store(64, asuint(value)); + return; +} diff --git a/tests/out/hlsl/access.hlsl.config b/tests/out/hlsl/access.hlsl.config index 0f887fb4cf..49dcb6821d 100644 --- a/tests/out/hlsl/access.hlsl.config +++ b/tests/out/hlsl/access.hlsl.config @@ -1,3 +1,3 @@ vertex=(foo:vs_5_1 ) fragment=() -compute=() +compute=(atomics:cs_5_1 ) diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 12574b814d..191ef8b622 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -421,8 +421,7 @@ fn convert_wgsl() { ), ( "access", - //TODO: atomics on HLSL - Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL, + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), ( "control-flow",