diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 0c4979a891..e06d727912 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -229,6 +229,42 @@ impl<'a> TypedGlobalVariable<'a> { } } +struct ConstantContext<'a> { + handle: Handle, + arena: &'a Arena, + names: &'a FastHashMap, + first_time: bool, +} + +impl<'a> Display for ConstantContext<'a> { + fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> { + let con = &self.arena[self.handle]; + if con.needs_alias() && !self.first_time { + let name = &self.names[&NameKey::Constant(self.handle)]; + return write!(out, "{}", name); + } + + match con.inner { + crate::ConstantInner::Scalar { value, width: _ } => match value { + crate::ScalarValue::Sint(value) => { + write!(out, "{}", value) + } + crate::ScalarValue::Uint(value) => { + write!(out, "{}u", value) + } + crate::ScalarValue::Float(value) => { + let suffix = if value.fract() == 0.0 { ".0" } else { "" }; + write!(out, "{}{}", value, suffix) + } + crate::ScalarValue::Bool(value) => { + write!(out, "{}", value) + } + }, + crate::ConstantInner::Composite { .. } => unreachable!("should be aliased"), + } + } +} + pub struct Writer { out: W, names: FastHashMap, @@ -299,7 +335,7 @@ impl crate::StorageClass { } impl crate::Type { - // Returns `true` if we need to emit a type alias for this type. + // Returns `true` if we need to emit an alias for this type. fn needs_alias(&self) -> bool { use crate::TypeInner as Ti; match self.inner { @@ -317,6 +353,16 @@ impl crate::Type { } } +impl crate::Constant { + // Returns `true` if we need to emit an alias for this constant. + fn needs_alias(&self) -> bool { + match self.inner { + crate::ConstantInner::Scalar { .. } => self.name.is_some(), + crate::ConstantInner::Composite { .. } => true, + } + } +} + enum FunctionOrigin { Handle(Handle), EntryPoint(EntryPointIndex), @@ -545,8 +591,13 @@ impl Writer { } } crate::Expression::Constant(handle) => { - let handle_name = &self.names[&NameKey::Constant(handle)]; - write!(self.out, "{}", handle_name)?; + let coco = ConstantContext { + handle, + arena: &context.module.constants, + names: &self.names, + first_time: false, + }; + write!(self.out, "{}", coco)?; } crate::Expression::Compose { ty, ref components } => { let inner = &context.module.types[ty].inner; @@ -670,8 +721,13 @@ impl Writer { } } if let Some(constant) = offset { - let offset_str = &self.names[&NameKey::Constant(constant)]; - write!(self.out, ", {}", offset_str)?; + let coco = ConstantContext { + handle: constant, + arena: &context.module.constants, + names: &self.names, + first_time: false, + }; + write!(self.out, ", {}", coco)?; } write!(self.out, ")")?; } @@ -914,8 +970,13 @@ impl Writer { size: crate::ArraySize::Constant(const_handle), .. } => { - let size_str = &self.names[&NameKey::Constant(const_handle)]; - write!(self.out, "{}", size_str)?; + let coco = ConstantContext { + handle: const_handle, + arena: &context.module.constants, + names: &self.names, + first_time: false, + }; + write!(self.out, "{}", coco)?; } crate::TypeInner::Array { .. } => { return Err(Error::FeatureNotImplemented( @@ -1296,13 +1357,21 @@ impl Writer { access: crate::StorageAccess::empty(), first_time: false, }; - let size_str = match size { + write!(self.out, "typedef {} {}", base_name, name)?; + match size { crate::ArraySize::Constant(const_handle) => { - &self.names[&NameKey::Constant(const_handle)] + let coco = ConstantContext { + handle: const_handle, + arena: &module.constants, + names: &self.names, + first_time: false, + }; + writeln!(self.out, "[{}];", coco)?; } - crate::ArraySize::Dynamic => "1", - }; - writeln!(self.out, "typedef {} {}[{}];", base_name, name, size_str)?; + crate::ArraySize::Dynamic => { + writeln!(self.out, "[1];")?; + } + } } crate::TypeInner::Struct { block: _, @@ -1345,29 +1414,33 @@ impl Writer { crate::ConstantInner::Scalar { width: _, ref value, - } => { - let name = &self.names[&NameKey::Constant(handle)]; + } if constant.name.is_some() => { + debug_assert!(constant.needs_alias()); write!(self.out, "constexpr constant ")?; match *value { - crate::ScalarValue::Sint(value) => { - write!(self.out, "int {} = {}", name, value)?; + crate::ScalarValue::Sint(_) => { + write!(self.out, "int")?; } - crate::ScalarValue::Uint(value) => { - write!(self.out, "unsigned {} = {}u", name, value)?; + crate::ScalarValue::Uint(_) => { + write!(self.out, "unsigned")?; } - crate::ScalarValue::Float(value) => { - write!(self.out, "float {} = {}", name, value)?; - if value.fract() == 0.0 { - write!(self.out, ".0")?; - } + crate::ScalarValue::Float(_) => { + write!(self.out, "float")?; } - crate::ScalarValue::Bool(value) => { - write!(self.out, "bool {} = {}", name, value)?; + crate::ScalarValue::Bool(_) => { + write!(self.out, "bool")?; } } - writeln!(self.out, ";")?; + let name = &self.names[&NameKey::Constant(handle)]; + let coco = ConstantContext { + handle, + arena: &module.constants, + names: &self.names, + first_time: true, + }; + writeln!(self.out, " {} = {};", name, coco)?; } - crate::ConstantInner::Composite { .. } => {} + _ => {} } } Ok(()) @@ -1378,6 +1451,7 @@ impl Writer { match constant.inner { crate::ConstantInner::Scalar { .. } => {} crate::ConstantInner::Composite { ty, ref components } => { + debug_assert!(constant.needs_alias()); let name = &self.names[&NameKey::Constant(handle)]; let ty_name = TypeContext { handle: ty, @@ -1390,8 +1464,13 @@ impl Writer { write!(self.out, "constexpr constant {} {} = {{", ty_name, name,)?; for (i, &sub_handle) in components.iter().enumerate() { let separator = if i != 0 { ", " } else { "" }; - let sub_name = &self.names[&NameKey::Constant(sub_handle)]; - write!(self.out, "{}{}", separator, sub_name)?; + let coco = ConstantContext { + handle: sub_handle, + arena: &module.constants, + names: &self.names, + first_time: false, + }; + write!(self.out, "{}{}", separator, coco)?; } writeln!(self.out, "}};")?; } @@ -1561,8 +1640,13 @@ impl Writer { let local_name = &self.names[&NameKey::FunctionLocal(fun_handle, local_handle)]; write!(self.out, "{}{} {}", INDENT, ty_name, local_name)?; if let Some(value) = local.init { - let value_str = &self.names[&NameKey::Constant(value)]; - write!(self.out, " = {}", value_str)?; + let coco = ConstantContext { + handle: value, + arena: &module.constants, + names: &self.names, + first_time: false, + }; + write!(self.out, " = {}", coco)?; } writeln!(self.out, ";")?; } @@ -1801,8 +1885,13 @@ impl Writer { resolved.try_fmt_decorated(&mut self.out, "")?; } if let Some(value) = var.init { - let value_str = &self.names[&NameKey::Constant(value)]; - write!(self.out, " = {}", value_str)?; + let coco = ConstantContext { + handle: value, + arena: &module.constants, + names: &self.names, + first_time: false, + }; + write!(self.out, " = {}", coco)?; } writeln!(self.out)?; } @@ -1827,11 +1916,20 @@ impl Writer { }; write!(self.out, "{}", INDENT)?; tyvar.try_fmt(&mut self.out)?; - let value_str = match var.init { - Some(value) => &self.names[&NameKey::Constant(value)], - None => "{}", + match var.init { + Some(value) => { + let coco = ConstantContext { + handle: value, + arena: &module.constants, + names: &self.names, + first_time: false, + }; + writeln!(self.out, " = {};", coco)?; + } + None => { + writeln!(self.out, " = {{}};")?; + } }; - writeln!(self.out, " = {};", value_str)?; } else if let Some(ref binding) = var.binding { // write an inline sampler let resolved = options.resolve_global_binding(ep.stage, binding).unwrap(); @@ -1902,8 +2000,13 @@ impl Writer { }; write!(self.out, "{}{} {}", INDENT, ty_name, name)?; if let Some(value) = local.init { - let value_str = &self.names[&NameKey::Constant(value)]; - write!(self.out, " = {}", value_str)?; + let coco = ConstantContext { + handle: value, + arena: &module.constants, + names: &self.names, + first_time: false, + }; + write!(self.out, " = {}", coco)?; } writeln!(self.out, ";")?; } diff --git a/tests/out/boids.msl.snap b/tests/out/boids.msl.snap index 9f302992a6..d59f82aea3 100644 --- a/tests/out/boids.msl.snap +++ b/tests/out/boids.msl.snap @@ -6,14 +6,6 @@ expression: msl #include constexpr constant unsigned NUM_PARTICLES = 1500u; -constexpr constant float const_0f = 0.0; -constexpr constant int const_0i = 0; -constexpr constant unsigned const_0u = 0u; -constexpr constant int const_1i = 1; -constexpr constant unsigned const_1u = 1u; -constexpr constant float const_1f = 1.0; -constexpr constant float const_0_10f = 0.1; -constexpr constant float const_n1f = -1.0; struct Particle { metal::float2 pos; metal::float2 vel; @@ -45,23 +37,23 @@ kernel void main1( metal::float2 cMass; metal::float2 cVel; metal::float2 colVel; - int cMassCount = const_0i; - int cVelCount = const_0i; + int cMassCount = 0; + int cVelCount = 0; metal::float2 pos1; metal::float2 vel1; - metal::uint i = const_0u; + metal::uint i = 0u; if (global_invocation_id.x >= NUM_PARTICLES) { return; } vPos = particlesSrc.particles[global_invocation_id.x].pos; vVel = particlesSrc.particles[global_invocation_id.x].vel; - cMass = metal::float2(const_0f, const_0f); - cVel = metal::float2(const_0f, const_0f); - colVel = metal::float2(const_0f, const_0f); + cMass = metal::float2(0.0, 0.0); + cVel = metal::float2(0.0, 0.0); + colVel = metal::float2(0.0, 0.0); bool loop_init = true; while(true) { if (!loop_init) { - i = i + const_1u; + i = i + 1u; } loop_init = false; if (i >= NUM_PARTICLES) { @@ -74,36 +66,36 @@ kernel void main1( vel1 = particlesSrc.particles[i].vel; if (metal::distance(pos1, vPos) < params.rule1Distance) { cMass = cMass + pos1; - cMassCount = cMassCount + const_1i; + cMassCount = cMassCount + 1; } if (metal::distance(pos1, vPos) < params.rule2Distance) { colVel = colVel - (pos1 - vPos); } if (metal::distance(pos1, vPos) < params.rule3Distance) { cVel = cVel + vel1; - cVelCount = cVelCount + const_1i; + cVelCount = cVelCount + 1; } } - if (cMassCount > const_0i) { - cMass = (cMass * (const_1f / static_cast(cMassCount))) - vPos; + if (cMassCount > 0) { + cMass = (cMass * (1.0 / static_cast(cMassCount))) - vPos; } - if (cVelCount > const_0i) { - cVel = cVel * (const_1f / static_cast(cVelCount)); + if (cVelCount > 0) { + cVel = cVel * (1.0 / static_cast(cVelCount)); } vVel = ((vVel + (cMass * params.rule1Scale)) + (colVel * params.rule2Scale)) + (cVel * params.rule3Scale); - vVel = metal::normalize(vVel) * metal::clamp(metal::length(vVel), const_0f, const_0_10f); + vVel = metal::normalize(vVel) * metal::clamp(metal::length(vVel), 0.0, 0.1); vPos = vPos + (vVel * params.deltaT); - if (vPos.x < const_n1f) { - vPos.x = const_1f; + if (vPos.x < -1.0) { + vPos.x = 1.0; } - if (vPos.x > const_1f) { - vPos.x = const_n1f; + if (vPos.x > 1.0) { + vPos.x = -1.0; } - if (vPos.y < const_n1f) { - vPos.y = const_1f; + if (vPos.y < -1.0) { + vPos.y = 1.0; } - if (vPos.y > const_1f) { - vPos.y = const_n1f; + if (vPos.y > 1.0) { + vPos.y = -1.0; } particlesDst.particles[global_invocation_id.x].pos = vPos; particlesDst.particles[global_invocation_id.x].vel = vVel; diff --git a/tests/out/collatz.msl.snap b/tests/out/collatz.msl.snap index 2aa4fe3446..f86b138a7c 100644 --- a/tests/out/collatz.msl.snap +++ b/tests/out/collatz.msl.snap @@ -5,10 +5,6 @@ expression: msl #include #include -constexpr constant unsigned const_0u = 0u; -constexpr constant unsigned const_1u = 1u; -constexpr constant unsigned const_2u = 2u; -constexpr constant unsigned const_3u = 3u; typedef metal::uint type1[1]; struct PrimeIndices { type1 data; @@ -18,18 +14,18 @@ metal::uint collatz_iterations( metal::uint n_base ) { metal::uint n; - metal::uint i = const_0u; + metal::uint i = 0u; n = n_base; while(true) { - if (n <= const_1u) { + if (n <= 1u) { break; } - if ((n % const_2u) == const_0u) { - n = n / const_2u; + if ((n % 2u) == 0u) { + n = n / 2u; } else { - n = (const_3u * n) + const_1u; + n = (3u * n) + 1u; } - i = i + const_1u; + i = i + 1u; } return i; } diff --git a/tests/out/image-copy.msl.snap b/tests/out/image-copy.msl.snap index 3b9ca6f741..da5c6b1550 100644 --- a/tests/out/image-copy.msl.snap +++ b/tests/out/image-copy.msl.snap @@ -5,8 +5,6 @@ expression: msl #include #include -constexpr constant int const_10i = 10; -constexpr constant int const_20i = 20; struct main1Input { }; @@ -15,7 +13,7 @@ kernel void main1( , metal::texture2d image_src [[user(fake0)]] , metal::texture1d image_dst [[user(fake0)]] ) { - metal::int2 _expr12 = (int2(image_src.get_width(), image_src.get_height()) * static_cast(metal::uint2(local_id.x, local_id.y))) % metal::int2(const_10i, const_20i); + metal::int2 _expr12 = (int2(image_src.get_width(), image_src.get_height()) * static_cast(metal::uint2(local_id.x, local_id.y))) % metal::int2(10, 20); metal::uint4 _expr13 = image_src.read(metal::uint2(_expr12)); image_dst.write(_expr13, metal::uint(_expr12.x)); return; diff --git a/tests/out/quad-vert.msl.snap b/tests/out/quad-vert.msl.snap index a871820759..245ffea218 100644 --- a/tests/out/quad-vert.msl.snap +++ b/tests/out/quad-vert.msl.snap @@ -5,15 +5,7 @@ expression: msl #include #include -constexpr constant int const_0i = 0; -constexpr constant int const_1i = 1; -constexpr constant int const_2i = 2; -constexpr constant int const_3i = 3; -constexpr constant unsigned const_1u = 1u; -constexpr constant int const_0i1 = 0; -constexpr constant float const_0f = 0.0; -constexpr constant float const_1f = 1.0; -typedef float type6[const_1u]; +typedef float type6[1u]; struct gl_PerVertex { metal::float4 gl_Position; float gl_PointSize; @@ -35,7 +27,7 @@ void main1( ) { v_uv = a_uv; metal::float2 _expr13 = a_pos; - _.gl_Position = metal::float4(_expr13.x, _expr13.y, const_0f, const_1f); + _.gl_Position = metal::float4(_expr13.x, _expr13.y, 0.0, 1.0); return; } diff --git a/tests/out/quad.msl.snap b/tests/out/quad.msl.snap index 434a62aa64..b45e13d338 100644 --- a/tests/out/quad.msl.snap +++ b/tests/out/quad.msl.snap @@ -6,8 +6,6 @@ expression: msl #include constexpr constant float c_scale = 1.2; -constexpr constant float const_0f = 0.0; -constexpr constant float const_1f = 1.0; struct VertexOutput { metal::float2 uv; metal::float4 position; @@ -28,7 +26,7 @@ vertex main1Output main1( const auto uv1 = varyings.uv1; VertexOutput out; out.uv = uv1; - out.position = metal::float4(c_scale * pos, const_0f, const_1f); + out.position = metal::float4(c_scale * pos, 0.0, 1.0); const auto _tmp = out; return main1Output { _tmp.uv, _tmp.position }; } @@ -47,7 +45,7 @@ fragment main2Output main2( ) { const auto uv2 = varyings1.uv2; metal::float4 _expr4 = u_texture.sample(u_sampler, uv2); - if (_expr4.w == const_0f) { + if (_expr4.w == 0.0) { metal::discard_fragment(); } return main2Output { _expr4.w * _expr4 }; diff --git a/tests/out/shadow.msl.snap b/tests/out/shadow.msl.snap index 057daff8e8..9ac8176a13 100644 --- a/tests/out/shadow.msl.snap +++ b/tests/out/shadow.msl.snap @@ -5,14 +5,7 @@ expression: msl #include #include -constexpr constant float const_0f = 0.0; -constexpr constant float const_1f = 1.0; -constexpr constant float const_0_50f = 0.5; -constexpr constant float const_n0_50f = -0.5; -constexpr constant float const_0_05f = 0.05; constexpr constant unsigned c_max_lights = 10u; -constexpr constant unsigned const_0u = 0u; -constexpr constant unsigned const_1u = 1u; struct Globals { metal::uint4 num_lights; }; @@ -25,7 +18,7 @@ typedef Light type3[1]; struct Lights { type3 data; }; -constexpr constant metal::float3 c_ambient = {const_0_05f, const_0_05f, const_0_05f}; +constexpr constant metal::float3 c_ambient = {0.05, 0.05, 0.05}; float fetch_shadow( metal::uint light_id, @@ -33,11 +26,11 @@ float fetch_shadow( metal::depth2d_array t_shadow, metal::sampler sampler_shadow ) { - if (homogeneous_coords.w <= const_0f) { - return const_1f; + if (homogeneous_coords.w <= 0.0) { + return 1.0; } - float _expr15 = const_1f / homogeneous_coords.w; - float _expr28 = t_shadow.sample_compare(sampler_shadow, ((metal::float2(homogeneous_coords.x, homogeneous_coords.y) * metal::float2(const_0_50f, const_n0_50f)) * _expr15) + metal::float2(const_0_50f, const_0_50f), static_cast(light_id), homogeneous_coords.z * _expr15); + float _expr15 = 1.0 / homogeneous_coords.w; + float _expr28 = t_shadow.sample_compare(sampler_shadow, ((metal::float2(homogeneous_coords.x, homogeneous_coords.y) * metal::float2(0.5, -0.5)) * _expr15) + metal::float2(0.5, 0.5), static_cast(light_id), homogeneous_coords.z * _expr15); return _expr28; } @@ -58,11 +51,11 @@ fragment fs_mainOutput fs_main( const auto raw_normal = varyings.raw_normal; const auto position = varyings.position; metal::float3 color1 = c_ambient; - metal::uint i = const_0u; + metal::uint i = 0u; bool loop_init = true; while(true) { if (!loop_init) { - i = i + const_1u; + i = i + 1u; } loop_init = false; if (i >= metal::min(u_globals.num_lights.x, c_max_lights)) { @@ -70,8 +63,8 @@ fragment fs_mainOutput fs_main( } Light _expr21 = s_lights.data[i]; float _expr25 = fetch_shadow(i, _expr21.proj * position, t_shadow, sampler_shadow); - color1 = color1 + ((_expr25 * metal::max(const_0f, metal::dot(metal::normalize(raw_normal), metal::normalize(metal::float3(_expr21.pos.x, _expr21.pos.y, _expr21.pos.z) - metal::float3(position.x, position.y, position.z))))) * metal::float3(_expr21.color.x, _expr21.color.y, _expr21.color.z)); + color1 = color1 + ((_expr25 * metal::max(0.0, metal::dot(metal::normalize(raw_normal), metal::normalize(metal::float3(_expr21.pos.x, _expr21.pos.y, _expr21.pos.z) - metal::float3(position.x, position.y, position.z))))) * metal::float3(_expr21.color.x, _expr21.color.y, _expr21.color.z)); } - return fs_mainOutput { metal::float4(color1, const_1f) }; + return fs_mainOutput { metal::float4(color1, 1.0) }; } diff --git a/tests/out/skybox.msl.snap b/tests/out/skybox.msl.snap index 96bbabe890..67988bf82b 100644 --- a/tests/out/skybox.msl.snap +++ b/tests/out/skybox.msl.snap @@ -5,11 +5,6 @@ expression: msl #include #include -constexpr constant int const_2i = 2; -constexpr constant int const_1i = 1; -constexpr constant float const_4f = 4.0; -constexpr constant float const_1f = 1.0; -constexpr constant float const_0f = 0.0; struct VertexOutput { metal::float4 position; metal::float3 uv; @@ -32,9 +27,9 @@ vertex vs_mainOutput vs_main( int tmp1_; int tmp2_; VertexOutput out; - tmp1_ = static_cast(vertex_index) / const_2i; - tmp2_ = static_cast(vertex_index) & const_1i; - metal::float4 _expr24 = metal::float4((static_cast(tmp1_) * const_4f) - const_1f, (static_cast(tmp2_) * const_4f) - const_1f, const_0f, const_1f); + tmp1_ = static_cast(vertex_index) / 2; + tmp2_ = static_cast(vertex_index) & 1; + metal::float4 _expr24 = metal::float4((static_cast(tmp1_) * 4.0) - 1.0, (static_cast(tmp2_) * 4.0) - 1.0, 0.0, 1.0); metal::float4 _expr50 = r_data.proj_inv * _expr24; out.uv = metal::transpose(metal::float3x3(metal::float3(r_data.view[0].x, r_data.view[0].y, r_data.view[0].z), metal::float3(r_data.view[1].x, r_data.view[1].y, r_data.view[1].z), metal::float3(r_data.view[2].x, r_data.view[2].y, r_data.view[2].z))) * metal::float3(_expr50.x, _expr50.y, _expr50.z); out.position = _expr24;