diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 61bc1d4338..ba0f1403b9 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -501,7 +501,7 @@ impl Writer { .scalar_kind() .ok_or(Error::UnsupportedBinaryOp(op))?; if op == crate::BinaryOperator::Modulo && kind == crate::ScalarKind::Float { - write!(self.out, "fmod(")?; + write!(self.out, "{}::fmod(", NAMESPACE)?; self.put_expression(left, context, true)?; write!(self.out, ", ")?; self.put_expression(right, context, true)?; @@ -907,8 +907,9 @@ impl Writer { writeln!(self.out, "#include ")?; writeln!(self.out)?; + self.write_scalar_constants(module)?; self.write_type_defs(module)?; - self.write_constants(module)?; + self.write_composite_constants(module)?; self.write_functions(module, info, options) } @@ -1076,43 +1077,60 @@ impl Writer { Ok(()) } - fn write_constants(&mut self, module: &crate::Module) -> Result<(), Error> { + fn write_scalar_constants(&mut self, module: &crate::Module) -> Result<(), Error> { for (handle, constant) in module.constants.iter() { - write!(self.out, "constexpr constant ")?; - let name = &self.names[&NameKey::Constant(handle)]; match constant.inner { crate::ConstantInner::Scalar { width: _, ref value, - } => match *value { - crate::ScalarValue::Sint(value) => { - write!(self.out, "int {} = {}", name, value)?; - } - crate::ScalarValue::Uint(value) => { - write!(self.out, "unsigned {} = {}u", name, value)?; - } - crate::ScalarValue::Float(value) => { - write!(self.out, "float {} = {}", name, value)?; - if value.fract() == 0.0 { - write!(self.out, ".0")?; + } => { + let name = &self.names[&NameKey::Constant(handle)]; + write!(self.out, "constexpr constant ")?; + match *value { + crate::ScalarValue::Sint(value) => { + write!(self.out, "int {} = {}", name, value)?; + } + crate::ScalarValue::Uint(value) => { + write!(self.out, "unsigned {} = {}u", name, value)?; + } + crate::ScalarValue::Float(value) => { + write!(self.out, "float {} = {}", name, value)?; + if value.fract() == 0.0 { + write!(self.out, ".0")?; + } + } + crate::ScalarValue::Bool(value) => { + write!(self.out, "bool {} = {}", name, value)?; } } - crate::ScalarValue::Bool(value) => { - write!(self.out, "bool {} = {}", name, value)?; - } - }, + writeln!(self.out, ";")?; + } + crate::ConstantInner::Composite { .. } => {} + } + } + Ok(()) + } + + fn write_composite_constants(&mut self, module: &crate::Module) -> Result<(), Error> { + for (handle, constant) in module.constants.iter() { + match constant.inner { + crate::ConstantInner::Scalar { .. } => {} crate::ConstantInner::Composite { ty, ref components } => { + let name = &self.names[&NameKey::Constant(handle)]; let ty_name = &self.names[&NameKey::Type(ty)]; - write!(self.out, "{} {} = {}(", ty_name, name, ty_name)?; + write!( + self.out, + "constexpr constant {} {} = {}(", + ty_name, name, ty_name + )?; for (i, &sub_handle) in components.iter().enumerate() { let separator = if i != 0 { ", " } else { "" }; let sub_name = &self.names[&NameKey::Constant(sub_handle)]; write!(self.out, "{}{}", separator, sub_name)?; } - write!(self.out, ")")?; + writeln!(self.out, ");")?; } } - writeln!(self.out, ";")?; } Ok(()) } diff --git a/tests/out/boids.msl.snap b/tests/out/boids.msl.snap index a7d8102413..82bcf12b24 100644 --- a/tests/out/boids.msl.snap +++ b/tests/out/boids.msl.snap @@ -5,6 +5,15 @@ expression: msl #include #include +constexpr constant unsigned NUM_PARTICLES = 1500u; +constexpr constant float const_0f = 0.0; +constexpr constant int const_0i = 0; +constexpr constant unsigned const_0u = 0u; +constexpr constant int const_1i = 1; +constexpr constant unsigned const_1u = 1u; +constexpr constant float const_1f = 1.0; +constexpr constant float const_0_10f = 0.1; +constexpr constant float const_n1f = -1.0; typedef uint type; typedef metal::float2 type1; struct Particle { @@ -27,15 +36,6 @@ struct Particles { }; typedef metal::uint3 type4; typedef int type5; -constexpr constant unsigned NUM_PARTICLES = 1500u; -constexpr constant float const_0f = 0.0; -constexpr constant int const_0i = 0; -constexpr constant unsigned const_0u = 0u; -constexpr constant int const_1i = 1; -constexpr constant unsigned const_1u = 1u; -constexpr constant float const_1f = 1.0; -constexpr constant float const_0_10f = 0.1; -constexpr constant float const_n1f = -1.0; struct main1Input { }; kernel void main1( diff --git a/tests/out/collatz.msl.snap b/tests/out/collatz.msl.snap index b7d56f09c2..5614404713 100644 --- a/tests/out/collatz.msl.snap +++ b/tests/out/collatz.msl.snap @@ -5,16 +5,16 @@ expression: msl #include #include +constexpr constant unsigned const_0u = 0u; +constexpr constant unsigned const_1u = 1u; +constexpr constant unsigned const_2u = 2u; +constexpr constant unsigned const_3u = 3u; typedef uint type; typedef type type1[1]; struct PrimeIndices { type1 data; }; typedef metal::uint3 type2; -constexpr constant unsigned const_0u = 0u; -constexpr constant unsigned const_1u = 1u; -constexpr constant unsigned const_2u = 2u; -constexpr constant unsigned const_3u = 3u; type collatz_iterations( type n_base ) { diff --git a/tests/out/quad.msl.snap b/tests/out/quad.msl.snap index 0a01f1484a..084d7fb87b 100644 --- a/tests/out/quad.msl.snap +++ b/tests/out/quad.msl.snap @@ -5,6 +5,9 @@ expression: msl #include #include +constexpr constant float c_scale = 1.2; +constexpr constant float const_0f = 0.0; +constexpr constant float const_1f = 1.0; typedef float type; typedef metal::float2 type1; typedef metal::float4 type2; @@ -14,9 +17,6 @@ struct VertexOutput { }; typedef metal::texture2d type3; typedef metal::sampler type4; -constexpr constant float c_scale = 1.2; -constexpr constant float const_0f = 0.0; -constexpr constant float const_1f = 1.0; struct main1Input { type1 pos [[attribute(0)]]; type1 uv1 [[attribute(1)]]; diff --git a/tests/out/shadow.msl.snap b/tests/out/shadow.msl.snap index b2eae0b4c7..906e53bfd0 100644 --- a/tests/out/shadow.msl.snap +++ b/tests/out/shadow.msl.snap @@ -5,6 +5,14 @@ expression: msl #include #include +constexpr constant float const_0f = 0.0; +constexpr constant float const_1f = 1.0; +constexpr constant float const_0_50f = 0.5; +constexpr constant float const_n0_50f = -0.5; +constexpr constant float const_0_05f = 0.05; +constexpr constant unsigned c_max_lights = 10u; +constexpr constant unsigned const_0u = 0u; +constexpr constant unsigned const_1u = 1u; typedef metal::uint4 type; struct Globals { type num_lights; @@ -26,15 +34,7 @@ typedef uint type6; typedef float type7; typedef metal::float2 type8; typedef metal::float3 type9; -constexpr constant float const_0f = 0.0; -constexpr constant float const_1f = 1.0; -constexpr constant float const_0_50f = 0.5; -constexpr constant float const_n0_50f = -0.5; -constexpr constant float const_0_05f = 0.05; constexpr constant type9 c_ambient = type9(const_0_05f, const_0_05f, const_0_05f); -constexpr constant unsigned c_max_lights = 10u; -constexpr constant unsigned const_0u = 0u; -constexpr constant unsigned const_1u = 1u; type7 fetch_shadow( type6 light_id, type2 homogeneous_coords, diff --git a/tests/out/skybox.msl.snap b/tests/out/skybox.msl.snap index eb27b7cfeb..7b22fd95f6 100644 --- a/tests/out/skybox.msl.snap +++ b/tests/out/skybox.msl.snap @@ -5,6 +5,11 @@ expression: msl #include #include +constexpr constant int const_2i = 2; +constexpr constant int const_1i = 1; +constexpr constant float const_4f = 4.0; +constexpr constant float const_1f = 1.0; +constexpr constant float const_0f = 0.0; typedef metal::float4 type; typedef metal::float3 type1; struct VertexOutput { @@ -21,11 +26,6 @@ typedef int type4; typedef metal::float3x3 type5; typedef metal::texturecube type6; typedef metal::sampler type7; -constexpr constant int const_2i = 2; -constexpr constant int const_1i = 1; -constexpr constant float const_4f = 4.0; -constexpr constant float const_1f = 1.0; -constexpr constant float const_0f = 0.0; struct vs_mainInput { }; struct vs_mainOutput {