Implemented atomics in HLSL-out

This commit is contained in:
Dzmitry Malyshau
2021-08-06 01:18:51 -04:00
committed by Dzmitry Malyshau
parent bc4576c0a2
commit 78f225a37a
6 changed files with 134 additions and 43 deletions

View File

@@ -122,3 +122,16 @@ impl crate::Sampling {
}
}
}
impl crate::BinaryOperator {
/// Return the HLSL suffix for the `InterlockedXxx` method.
pub(super) fn to_hlsl_atomic_suffix(self) -> &'static str {
match self {
Self::Add => "Add",
Self::And => "And",
Self::InclusiveOr => "Or",
Self::ExclusiveOr => "Xor",
_ => "",
}
}
}

View File

@@ -40,7 +40,7 @@ pub(super) enum StoreValue {
}
impl<W: fmt::Write> super::Writer<'_, W> {
fn write_storage_address(
pub(super) fn write_storage_address(
&mut self,
module: &crate::Module,
chain: &[SubAccess],

View File

@@ -8,7 +8,7 @@ use crate::{
proc::{self, NameKey},
valid, Handle, Module, ShaderStage, TypeInner,
};
use std::fmt;
use std::{fmt, mem};
const LOCATION_SEMANTIC: &str = "LOC";
const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
@@ -992,38 +992,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out, ";")?
}
}
Statement::Call {
function,
ref arguments,
result,
} => {
write!(self.out, "{}", INDENT.repeat(indent))?;
if let Some(expr) = result {
write!(self.out, "const ")?;
let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
let expr_ty = &func_ctx.info[expr].ty;
match *expr_ty {
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
proc::TypeResolution::Value(ref value) => {
self.write_value_type(module, value)?
}
};
write!(self.out, " {} = ", name)?;
self.write_expr(module, expr, func_ctx)?;
self.named_expressions.insert(expr, name);
}
let func_name = &self.names[&NameKey::Function(function)];
write!(self.out, "{}(", func_name)?;
for (index, argument) in arguments.iter().enumerate() {
self.write_expr(module, *argument, func_ctx)?;
// Only write a comma if isn't the last element
if index != arguments.len().saturating_sub(1) {
// The leading space is for readability only
write!(self.out, ", ")?;
}
}
writeln!(self.out, ");")?
}
Statement::Loop {
ref body,
ref continuing,
@@ -1107,7 +1075,94 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_expr(module, value, func_ctx)?;
writeln!(self.out, ";")?;
}
_ => return Err(Error::Unimplemented(format!("write_stmt {:?}", stmt))),
Statement::Call {
function,
ref arguments,
result,
} => {
write!(self.out, "{}", INDENT.repeat(indent))?;
if let Some(expr) = result {
write!(self.out, "const ")?;
let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
let expr_ty = &func_ctx.info[expr].ty;
match *expr_ty {
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
proc::TypeResolution::Value(ref value) => {
self.write_value_type(module, value)?
}
};
write!(self.out, " {} = ", name)?;
self.named_expressions.insert(expr, name);
}
let func_name = &self.names[&NameKey::Function(function)];
write!(self.out, "{}(", func_name)?;
for (index, argument) in arguments.iter().enumerate() {
self.write_expr(module, *argument, func_ctx)?;
// Only write a comma if isn't the last element
if index != arguments.len().saturating_sub(1) {
// The leading space is for readability only
write!(self.out, ", ")?;
}
}
writeln!(self.out, ");")?
}
Statement::Atomic {
pointer,
fun,
result,
} => {
write!(self.out, "{}", INDENT.repeat(indent))?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
match func_ctx.info[result].ty {
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
proc::TypeResolution::Value(ref value) => {
self.write_value_type(module, value)?
}
};
let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
// working around the borrow checker in `self.write_expr`
let chain = mem::take(&mut self.temp_access_chain);
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
write!(self.out, " {}; {}.Interlocked", res_name, var_name)?;
match fun {
crate::AtomicFunction::Binary { op, value } => {
let suffix = op.to_hlsl_atomic_suffix();
write!(self.out, "{}(", suffix)?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, value, func_ctx)?;
}
crate::AtomicFunction::Min(value) => {
write!(self.out, "Min(")?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, value, func_ctx)?;
}
crate::AtomicFunction::Max(value) => {
write!(self.out, "Max(")?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, value, func_ctx)?;
}
crate::AtomicFunction::Exchange(value) => {
write!(self.out, "Exchange(")?;
self.write_storage_address(module, &chain, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, value, func_ctx)?;
}
crate::AtomicFunction::CompareExchange { .. } => {
return Err(Error::Unimplemented("atomic CompareExchange".to_string()));
}
}
writeln!(self.out, ", {});", res_name)?;
self.temp_access_chain = chain;
self.named_expressions.insert(result, res_name);
}
Statement::Switch { .. } => {
return Err(Error::Unimplemented(format!("write_stmt {:?}", stmt)))
}
}
Ok(())
@@ -1667,7 +1722,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, ")")?
}
// Nothing to do here, since call expression already cached
Expression::CallResult(_) => {}
Expression::CallResult(_) | Expression::AtomicResult { .. } => {}
_ => return Err(Error::Unimplemented(format!("write_expr {:?}", expression))),
}

View File

@@ -20,10 +20,10 @@ float4 foo(VertexInput_foo vertexinput_foo) : SV_Position
float baz = foo1;
foo1 = 1.0;
float4x4 matrix1 = transpose(float4x4(asfloat(bar.Load4(0+0)), asfloat(bar.Load4(0+16)), asfloat(bar.Load4(0+32)), asfloat(bar.Load4(0+48))));
uint2 arr[2] = {asuint(bar.Load2(64+0)), asuint(bar.Load2(64+8))};
uint2 arr[2] = {asuint(bar.Load2(72+0)), asuint(bar.Load2(72+8))};
float4 _expr13 = asfloat(bar.Load4(48+0));
float b = _expr13.x;
int a = asint(bar.Load((((NagaBufferLengthRW(bar) - 80) / 4) - 2u)*4+80));
int a = asint(bar.Load((((NagaBufferLengthRW(bar) - 88) / 4) - 2u)*4+88));
bar.Store(8+16+0, asuint(1.0));
{
float4x4 _value2 = transpose(float4x4(float4(0.0.xxxx), float4(1.0.xxxx), float4(2.0.xxxx), float4(3.0.xxxx)));
@@ -34,8 +34,8 @@ float4 foo(VertexInput_foo vertexinput_foo) : SV_Position
}
{
uint2 _value2[2] = { uint2(0u.xx), uint2(1u.xx) };
bar.Store2(64+0, asuint(_value2[0]));
bar.Store2(64+8, asuint(_value2[1]));
bar.Store2(72+0, asuint(_value2[0]));
bar.Store2(72+8, asuint(_value2[1]));
}
{
int _result[5]={ a, int(b), 3, 4, 5 };
@@ -45,3 +45,27 @@ float4 foo(VertexInput_foo vertexinput_foo) : SV_Position
int value = c[vertexinput_foo.vi1];
return mul(matrix1, float4(int4(value.xxxx)));
}
[numthreads(1, 1, 1)]
void atomics()
{
int tmp = (int)0;
int value = asint(bar.Load(64));
int _e6; bar.InterlockedAdd(64, 5, _e6);
tmp = _e6;
int _e9; bar.InterlockedAnd(64, 5, _e9);
tmp = _e9;
int _e12; bar.InterlockedOr(64, 5, _e12);
tmp = _e12;
int _e15; bar.InterlockedXor(64, 5, _e15);
tmp = _e15;
int _e18; bar.InterlockedMin(64, 5, _e18);
tmp = _e18;
int _e21; bar.InterlockedMax(64, 5, _e21);
tmp = _e21;
int _e24; bar.InterlockedExchange(64, 5, _e24);
tmp = _e24;
bar.Store(64, asuint(value));
return;
}

View File

@@ -1,3 +1,3 @@
vertex=(foo:vs_5_1 )
fragment=()
compute=()
compute=(atomics:cs_5_1 )

View File

@@ -421,8 +421,7 @@ fn convert_wgsl() {
),
(
"access",
//TODO: atomics on HLSL
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL,
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
(
"control-flow",