mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
Implemented atomics in HLSL-out
This commit is contained in:
committed by
Dzmitry Malyshau
parent
bc4576c0a2
commit
78f225a37a
@@ -122,3 +122,16 @@ impl crate::Sampling {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::BinaryOperator {
|
||||
/// Return the HLSL suffix for the `InterlockedXxx` method.
|
||||
pub(super) fn to_hlsl_atomic_suffix(self) -> &'static str {
|
||||
match self {
|
||||
Self::Add => "Add",
|
||||
Self::And => "And",
|
||||
Self::InclusiveOr => "Or",
|
||||
Self::ExclusiveOr => "Xor",
|
||||
_ => "",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ pub(super) enum StoreValue {
|
||||
}
|
||||
|
||||
impl<W: fmt::Write> super::Writer<'_, W> {
|
||||
fn write_storage_address(
|
||||
pub(super) fn write_storage_address(
|
||||
&mut self,
|
||||
module: &crate::Module,
|
||||
chain: &[SubAccess],
|
||||
|
||||
@@ -8,7 +8,7 @@ use crate::{
|
||||
proc::{self, NameKey},
|
||||
valid, Handle, Module, ShaderStage, TypeInner,
|
||||
};
|
||||
use std::fmt;
|
||||
use std::{fmt, mem};
|
||||
|
||||
const LOCATION_SEMANTIC: &str = "LOC";
|
||||
const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
|
||||
@@ -992,38 +992,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
writeln!(self.out, ";")?
|
||||
}
|
||||
}
|
||||
Statement::Call {
|
||||
function,
|
||||
ref arguments,
|
||||
result,
|
||||
} => {
|
||||
write!(self.out, "{}", INDENT.repeat(indent))?;
|
||||
if let Some(expr) = result {
|
||||
write!(self.out, "const ")?;
|
||||
let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
|
||||
let expr_ty = &func_ctx.info[expr].ty;
|
||||
match *expr_ty {
|
||||
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
|
||||
proc::TypeResolution::Value(ref value) => {
|
||||
self.write_value_type(module, value)?
|
||||
}
|
||||
};
|
||||
write!(self.out, " {} = ", name)?;
|
||||
self.write_expr(module, expr, func_ctx)?;
|
||||
self.named_expressions.insert(expr, name);
|
||||
}
|
||||
let func_name = &self.names[&NameKey::Function(function)];
|
||||
write!(self.out, "{}(", func_name)?;
|
||||
for (index, argument) in arguments.iter().enumerate() {
|
||||
self.write_expr(module, *argument, func_ctx)?;
|
||||
// Only write a comma if isn't the last element
|
||||
if index != arguments.len().saturating_sub(1) {
|
||||
// The leading space is for readability only
|
||||
write!(self.out, ", ")?;
|
||||
}
|
||||
}
|
||||
writeln!(self.out, ");")?
|
||||
}
|
||||
Statement::Loop {
|
||||
ref body,
|
||||
ref continuing,
|
||||
@@ -1107,7 +1075,94 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
self.write_expr(module, value, func_ctx)?;
|
||||
writeln!(self.out, ";")?;
|
||||
}
|
||||
_ => return Err(Error::Unimplemented(format!("write_stmt {:?}", stmt))),
|
||||
Statement::Call {
|
||||
function,
|
||||
ref arguments,
|
||||
result,
|
||||
} => {
|
||||
write!(self.out, "{}", INDENT.repeat(indent))?;
|
||||
if let Some(expr) = result {
|
||||
write!(self.out, "const ")?;
|
||||
let name = format!("{}{}", back::BAKE_PREFIX, expr.index());
|
||||
let expr_ty = &func_ctx.info[expr].ty;
|
||||
match *expr_ty {
|
||||
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
|
||||
proc::TypeResolution::Value(ref value) => {
|
||||
self.write_value_type(module, value)?
|
||||
}
|
||||
};
|
||||
write!(self.out, " {} = ", name)?;
|
||||
self.named_expressions.insert(expr, name);
|
||||
}
|
||||
let func_name = &self.names[&NameKey::Function(function)];
|
||||
write!(self.out, "{}(", func_name)?;
|
||||
for (index, argument) in arguments.iter().enumerate() {
|
||||
self.write_expr(module, *argument, func_ctx)?;
|
||||
// Only write a comma if isn't the last element
|
||||
if index != arguments.len().saturating_sub(1) {
|
||||
// The leading space is for readability only
|
||||
write!(self.out, ", ")?;
|
||||
}
|
||||
}
|
||||
writeln!(self.out, ");")?
|
||||
}
|
||||
Statement::Atomic {
|
||||
pointer,
|
||||
fun,
|
||||
result,
|
||||
} => {
|
||||
write!(self.out, "{}", INDENT.repeat(indent))?;
|
||||
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
|
||||
match func_ctx.info[result].ty {
|
||||
proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
|
||||
proc::TypeResolution::Value(ref value) => {
|
||||
self.write_value_type(module, value)?
|
||||
}
|
||||
};
|
||||
|
||||
let var_handle = self.fill_access_chain(module, pointer, func_ctx)?;
|
||||
// working around the borrow checker in `self.write_expr`
|
||||
let chain = mem::take(&mut self.temp_access_chain);
|
||||
let var_name = &self.names[&NameKey::GlobalVariable(var_handle)];
|
||||
|
||||
write!(self.out, " {}; {}.Interlocked", res_name, var_name)?;
|
||||
match fun {
|
||||
crate::AtomicFunction::Binary { op, value } => {
|
||||
let suffix = op.to_hlsl_atomic_suffix();
|
||||
write!(self.out, "{}(", suffix)?;
|
||||
self.write_storage_address(module, &chain, func_ctx)?;
|
||||
write!(self.out, ", ")?;
|
||||
self.write_expr(module, value, func_ctx)?;
|
||||
}
|
||||
crate::AtomicFunction::Min(value) => {
|
||||
write!(self.out, "Min(")?;
|
||||
self.write_storage_address(module, &chain, func_ctx)?;
|
||||
write!(self.out, ", ")?;
|
||||
self.write_expr(module, value, func_ctx)?;
|
||||
}
|
||||
crate::AtomicFunction::Max(value) => {
|
||||
write!(self.out, "Max(")?;
|
||||
self.write_storage_address(module, &chain, func_ctx)?;
|
||||
write!(self.out, ", ")?;
|
||||
self.write_expr(module, value, func_ctx)?;
|
||||
}
|
||||
crate::AtomicFunction::Exchange(value) => {
|
||||
write!(self.out, "Exchange(")?;
|
||||
self.write_storage_address(module, &chain, func_ctx)?;
|
||||
write!(self.out, ", ")?;
|
||||
self.write_expr(module, value, func_ctx)?;
|
||||
}
|
||||
crate::AtomicFunction::CompareExchange { .. } => {
|
||||
return Err(Error::Unimplemented("atomic CompareExchange".to_string()));
|
||||
}
|
||||
}
|
||||
writeln!(self.out, ", {});", res_name)?;
|
||||
self.temp_access_chain = chain;
|
||||
self.named_expressions.insert(result, res_name);
|
||||
}
|
||||
Statement::Switch { .. } => {
|
||||
return Err(Error::Unimplemented(format!("write_stmt {:?}", stmt)))
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -1667,7 +1722,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
write!(self.out, ")")?
|
||||
}
|
||||
// Nothing to do here, since call expression already cached
|
||||
Expression::CallResult(_) => {}
|
||||
Expression::CallResult(_) | Expression::AtomicResult { .. } => {}
|
||||
_ => return Err(Error::Unimplemented(format!("write_expr {:?}", expression))),
|
||||
}
|
||||
|
||||
|
||||
@@ -20,10 +20,10 @@ float4 foo(VertexInput_foo vertexinput_foo) : SV_Position
|
||||
float baz = foo1;
|
||||
foo1 = 1.0;
|
||||
float4x4 matrix1 = transpose(float4x4(asfloat(bar.Load4(0+0)), asfloat(bar.Load4(0+16)), asfloat(bar.Load4(0+32)), asfloat(bar.Load4(0+48))));
|
||||
uint2 arr[2] = {asuint(bar.Load2(64+0)), asuint(bar.Load2(64+8))};
|
||||
uint2 arr[2] = {asuint(bar.Load2(72+0)), asuint(bar.Load2(72+8))};
|
||||
float4 _expr13 = asfloat(bar.Load4(48+0));
|
||||
float b = _expr13.x;
|
||||
int a = asint(bar.Load((((NagaBufferLengthRW(bar) - 80) / 4) - 2u)*4+80));
|
||||
int a = asint(bar.Load((((NagaBufferLengthRW(bar) - 88) / 4) - 2u)*4+88));
|
||||
bar.Store(8+16+0, asuint(1.0));
|
||||
{
|
||||
float4x4 _value2 = transpose(float4x4(float4(0.0.xxxx), float4(1.0.xxxx), float4(2.0.xxxx), float4(3.0.xxxx)));
|
||||
@@ -34,8 +34,8 @@ float4 foo(VertexInput_foo vertexinput_foo) : SV_Position
|
||||
}
|
||||
{
|
||||
uint2 _value2[2] = { uint2(0u.xx), uint2(1u.xx) };
|
||||
bar.Store2(64+0, asuint(_value2[0]));
|
||||
bar.Store2(64+8, asuint(_value2[1]));
|
||||
bar.Store2(72+0, asuint(_value2[0]));
|
||||
bar.Store2(72+8, asuint(_value2[1]));
|
||||
}
|
||||
{
|
||||
int _result[5]={ a, int(b), 3, 4, 5 };
|
||||
@@ -45,3 +45,27 @@ float4 foo(VertexInput_foo vertexinput_foo) : SV_Position
|
||||
int value = c[vertexinput_foo.vi1];
|
||||
return mul(matrix1, float4(int4(value.xxxx)));
|
||||
}
|
||||
|
||||
[numthreads(1, 1, 1)]
|
||||
void atomics()
|
||||
{
|
||||
int tmp = (int)0;
|
||||
|
||||
int value = asint(bar.Load(64));
|
||||
int _e6; bar.InterlockedAdd(64, 5, _e6);
|
||||
tmp = _e6;
|
||||
int _e9; bar.InterlockedAnd(64, 5, _e9);
|
||||
tmp = _e9;
|
||||
int _e12; bar.InterlockedOr(64, 5, _e12);
|
||||
tmp = _e12;
|
||||
int _e15; bar.InterlockedXor(64, 5, _e15);
|
||||
tmp = _e15;
|
||||
int _e18; bar.InterlockedMin(64, 5, _e18);
|
||||
tmp = _e18;
|
||||
int _e21; bar.InterlockedMax(64, 5, _e21);
|
||||
tmp = _e21;
|
||||
int _e24; bar.InterlockedExchange(64, 5, _e24);
|
||||
tmp = _e24;
|
||||
bar.Store(64, asuint(value));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
vertex=(foo:vs_5_1 )
|
||||
fragment=()
|
||||
compute=()
|
||||
compute=(atomics:cs_5_1 )
|
||||
|
||||
@@ -421,8 +421,7 @@ fn convert_wgsl() {
|
||||
),
|
||||
(
|
||||
"access",
|
||||
//TODO: atomics on HLSL
|
||||
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL,
|
||||
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
|
||||
),
|
||||
(
|
||||
"control-flow",
|
||||
|
||||
Reference in New Issue
Block a user