diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 970e9210fa..8619b9ccd6 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -301,7 +301,10 @@ impl Writer { } } } - crate::Expression::Constant(handle) => self.put_constant(handle, context.module)?, + crate::Expression::Constant(handle) => { + let handle_name = &self.names[&NameKey::Constant(handle)]; + write!(self.out, "{}", handle_name)?; + } crate::Expression::Compose { ty, ref components } => { let inner = &context.module.types[ty].inner; match *inner { @@ -428,8 +431,8 @@ impl Writer { } } if let Some(constant) = offset { - write!(self.out, ", ")?; - self.put_constant(constant, context.module)?; + let offset_str = &self.names[&NameKey::Constant(constant)]; + write!(self.out, ", {}", offset_str)?; } write!(self.out, ")")?; } @@ -687,7 +690,8 @@ impl Writer { size: crate::ArraySize::Constant(const_handle), .. } => { - self.put_constant(const_handle, context.module)?; + let size_str = &self.names[&NameKey::Constant(const_handle)]; + write!(self.out, "{}", size_str)?; } crate::TypeInner::Array { .. } => return Err(Error::UnsupportedDynamicArrayLength), _ => return Err(Error::Validation), @@ -696,48 +700,6 @@ impl Writer { Ok(()) } - fn put_constant( - &mut self, - handle: Handle, - module: &crate::Module, - ) -> Result<(), Error> { - let constant = &module.constants[handle]; - match constant.inner { - crate::ConstantInner::Scalar { - width: _, - ref value, - } => match *value { - crate::ScalarValue::Sint(value) => { - write!(self.out, "{}", value)?; - } - crate::ScalarValue::Uint(value) => { - write!(self.out, "{}u", value)?; - } - crate::ScalarValue::Float(value) => { - write!(self.out, "{}", value)?; - if value.fract() == 0.0 { - write!(self.out, ".0")?; - } - } - crate::ScalarValue::Bool(value) => { - write!(self.out, "{}", value)?; - } - }, - crate::ConstantInner::Composite { ty, ref components } => { - let ty_name = &self.names[&NameKey::Type(ty)]; - write!(self.out, "{}(", ty_name)?; - for (i, &handle) in components.iter().enumerate() { - if i != 0 { - write!(self.out, ", ")?; - } - self.put_constant(handle, module)?; - } - write!(self.out, ")")?; - } - } - Ok(()) - } - // Write down any required intermediate results fn prepare_expression( &mut self, @@ -956,6 +918,7 @@ impl Writer { writeln!(self.out)?; self.write_type_defs(module)?; + self.write_constants(module)?; self.write_functions(module, analysis, options) } @@ -1008,14 +971,13 @@ impl Writer { stride: _, } => { let base_name = &self.names[&NameKey::Type(base)]; - write!(self.out, "typedef {} {}[", base_name, name)?; - match size { + let size_str = match size { crate::ArraySize::Constant(const_handle) => { - self.put_constant(const_handle, module)?; - write!(self.out, "]")?; + &self.names[&NameKey::Constant(const_handle)] } - crate::ArraySize::Dynamic => write!(self.out, "1]")?, - } + crate::ArraySize::Dynamic => "1", + }; + write!(self.out, "typedef {} {}[{}]", base_name, name, size_str)?; } crate::TypeInner::Struct { block: _, @@ -1094,6 +1056,47 @@ impl Writer { Ok(()) } + fn write_constants(&mut self, module: &crate::Module) -> Result<(), Error> { + for (handle, constant) in module.constants.iter() { + write!(self.out, "constexpr constant ")?; + let name = &self.names[&NameKey::Constant(handle)]; + match constant.inner { + crate::ConstantInner::Scalar { + width: _, + ref value, + } => match *value { + crate::ScalarValue::Sint(value) => { + write!(self.out, "int {} = {}", name, value)?; + } + crate::ScalarValue::Uint(value) => { + write!(self.out, "unsigned {} = {}u", name, value)?; + } + crate::ScalarValue::Float(value) => { + write!(self.out, "float {} = {}", name, value)?; + if value.fract() == 0.0 { + write!(self.out, ".0")?; + } + } + crate::ScalarValue::Bool(value) => { + write!(self.out, "bool {} = {}", name, value)?; + } + }, + crate::ConstantInner::Composite { ty, ref components } => { + let ty_name = &self.names[&NameKey::Type(ty)]; + write!(self.out, "{} {} = {}(", ty_name, name, ty_name)?; + for (i, &sub_handle) in components.iter().enumerate() { + let separator = if i != 0 { ", " } else { "" }; + let sub_name = &self.names[&NameKey::Constant(sub_handle)]; + write!(self.out, "{}{}", separator, sub_name)?; + } + write!(self.out, ")")?; + } + } + writeln!(self.out, ";")?; + } + Ok(()) + } + // Returns the array of mapped entry point names. fn write_functions( &mut self, @@ -1159,8 +1162,8 @@ impl Writer { let local_name = &self.names[&NameKey::FunctionLocal(fun_handle, local_handle)]; write!(self.out, "{}{} {}", INDENT, ty_name, local_name)?; if let Some(value) = local.init { - write!(self.out, " = ")?; - self.put_constant(value, module)?; + let value_str = &self.names[&NameKey::Constant(value)]; + write!(self.out, " = {}", value_str)?; } writeln!(self.out, ";")?; } @@ -1342,8 +1345,8 @@ impl Writer { tyvar.try_fmt(&mut self.out)?; resolved.try_fmt_decorated(&mut self.out, separator)?; if let Some(value) = var.init { - write!(self.out, " = ")?; - self.put_constant(value, module)?; + let value_str = &self.names[&NameKey::Constant(value)]; + write!(self.out, " = {}", value_str)?; } writeln!(self.out)?; } @@ -1364,8 +1367,8 @@ impl Writer { let ty_name = &self.names[&NameKey::Type(local.ty)]; write!(self.out, "{}{} {}", INDENT, ty_name, name)?; if let Some(value) = local.init { - write!(self.out, " = ")?; - self.put_constant(value, module)?; + let value_str = &self.names[&NameKey::Constant(value)]; + write!(self.out, " = {}", value_str)?; } writeln!(self.out, ";")?; } diff --git a/src/proc/namer.rs b/src/proc/namer.rs index bc083fbd5e..c76ab89146 100644 --- a/src/proc/namer.rs +++ b/src/proc/namer.rs @@ -5,6 +5,7 @@ pub type EntryPointIndex = u16; #[derive(Debug, Eq, Hash, PartialEq)] pub enum NameKey { + Constant(Handle), GlobalVariable(Handle), Type(Handle), StructMember(Handle, u32), @@ -69,11 +70,7 @@ impl Namer { self.unique.clear(); self.unique .extend(reserved.iter().map(|string| (string.to_string(), 0))); - - for (handle, var) in module.global_variables.iter() { - let name = self.call_or(&var.name, "global"); - output.insert(NameKey::GlobalVariable(handle), name); - } + let mut temp = String::new(); for (ty_handle, ty) in module.types.iter() { let ty_name = self.call_or(&ty.name, "type"); @@ -91,6 +88,62 @@ impl Namer { } } + for (handle, var) in module.global_variables.iter() { + let name = self.call_or(&var.name, "global"); + output.insert(NameKey::GlobalVariable(handle), name); + } + + for (handle, constant) in module.constants.iter() { + let label = match constant.name { + Some(ref name) => name, + None => { + use std::fmt::Write; + // Try to be more descriptive about the constant values + temp.clear(); + match constant.inner { + crate::ConstantInner::Scalar { + width: _, + value: crate::ScalarValue::Sint(v), + } => write!(temp, "const_{}i", v), + crate::ConstantInner::Scalar { + width: _, + value: crate::ScalarValue::Uint(v), + } => write!(temp, "const_{}u", v), + crate::ConstantInner::Scalar { + width: _, + value: crate::ScalarValue::Float(v), + } => { + let abs = v.abs(); + write!( + temp, + "const_{}{}", + if v < 0.0 { "n" } else { "" }, + abs.trunc(), + ) + .unwrap(); + let fract = abs.fract(); + if fract == 0.0 { + write!(temp, "f") + } else { + write!(temp, "_{:02}f", (fract * 100.0) as i8) + } + } + crate::ConstantInner::Scalar { + width: _, + value: crate::ScalarValue::Bool(v), + } => write!(temp, "const_{}", v), + crate::ConstantInner::Composite { ty, components: _ } => { + write!(temp, "const_{}", output[&NameKey::Type(ty)]) + } + } + .unwrap(); + &temp + } + }; + let name = self.call(label); + output.insert(NameKey::Constant(handle), name); + } + for (fun_handle, fun) in module.functions.iter() { let fun_name = self.call_or(&fun.name, "function"); output.insert(NameKey::Function(fun_handle), fun_name); diff --git a/tests/out/boids.msl.snap b/tests/out/boids.msl.snap index 47cdc704cf..13c88f7564 100644 --- a/tests/out/boids.msl.snap +++ b/tests/out/boids.msl.snap @@ -38,6 +38,15 @@ typedef int type5; typedef bool type6; +constexpr constant int NUM_PARTICLES = 1500; +constexpr constant float const_0f = 0.0; +constexpr constant int const_0i = 0; +constexpr constant unsigned const_0u = 0u; +constexpr constant int const_1i = 1; +constexpr constant unsigned const_1u = 1u; +constexpr constant float const_1f = 1.0; +constexpr constant float const_0_10f = 0.1; +constexpr constant float const_n1f = -1.0; kernel void main1( constant SimParams& params [[buffer(0)]], constant Particles& particlesSrc [[buffer(1)]], @@ -49,26 +58,26 @@ kernel void main1( type1 cMass; type1 cVel; type1 colVel; - type5 cMassCount = 0; - type5 cVelCount = 0; + type5 cMassCount = const_0i; + type5 cVelCount = const_0i; type1 pos1; type1 vel1; - type i = 0u; - if ((gl_GlobalInvocationID.x >= 1500)) { + type i = const_0u; + if ((gl_GlobalInvocationID.x >= NUM_PARTICLES)) { return ; } vPos = particlesSrc.particles[gl_GlobalInvocationID.x].pos; vVel = particlesSrc.particles[gl_GlobalInvocationID.x].vel; - cMass = metal::float2(0.0, 0.0); - cVel = metal::float2(0.0, 0.0); - colVel = metal::float2(0.0, 0.0); + cMass = metal::float2(const_0f, const_0f); + cVel = metal::float2(const_0f, const_0f); + colVel = metal::float2(const_0f, const_0f); bool loop_init = true; while(true) { if (!loop_init) { - i = (i + 1u); + i = (i + const_1u); } loop_init = false; - if ((i >= 1500)) { + if ((i >= NUM_PARTICLES)) { break; } if ((i == gl_GlobalInvocationID.x)) { @@ -78,36 +87,36 @@ kernel void main1( vel1 = particlesSrc.particles[i].vel; if ((metal::distance(pos1, vPos) < params.rule1Distance)) { cMass = (cMass + pos1); - cMassCount = (cMassCount + 1); + cMassCount = (cMassCount + const_1i); } if ((metal::distance(pos1, vPos) < params.rule2Distance)) { colVel = (colVel - (pos1 - vPos)); } if ((metal::distance(pos1, vPos) < params.rule3Distance)) { cVel = (cVel + vel1); - cVelCount = (cVelCount + 1); + cVelCount = (cVelCount + const_1i); } } - if ((cMassCount > 0)) { - cMass = ((cMass * (1.0 / static_cast(cMassCount))) - vPos); + if ((cMassCount > const_0i)) { + cMass = ((cMass * (const_1f / static_cast(cMassCount))) - vPos); } - if ((cVelCount > 0)) { - cVel = (cVel * (1.0 / static_cast(cVelCount))); + if ((cVelCount > const_0i)) { + cVel = (cVel * (const_1f / static_cast(cVelCount))); } vVel = (((vVel + (cMass * params.rule1Scale)) + (colVel * params.rule2Scale)) + (cVel * params.rule3Scale)); - vVel = (metal::normalize(vVel) * metal::clamp(metal::length(vVel), 0.0, 0.1)); + vVel = (metal::normalize(vVel) * metal::clamp(metal::length(vVel), const_0f, const_0_10f)); vPos = (vPos + (vVel * params.deltaT)); - if ((vPos.x < -1.0)) { - vPos.x = 1.0; + if ((vPos.x < const_n1f)) { + vPos.x = const_1f; } - if ((vPos.x > 1.0)) { - vPos.x = -1.0; + if ((vPos.x > const_1f)) { + vPos.x = const_n1f; } - if ((vPos.y < -1.0)) { - vPos.y = 1.0; + if ((vPos.y < const_n1f)) { + vPos.y = const_1f; } - if ((vPos.y > 1.0)) { - vPos.y = -1.0; + if ((vPos.y > const_1f)) { + vPos.y = const_n1f; } particlesDst.particles[gl_GlobalInvocationID.x].pos = vPos; particlesDst.particles[gl_GlobalInvocationID.x].vel = vVel; diff --git a/tests/out/collatz.msl.snap b/tests/out/collatz.msl.snap index 10912b0ef4..2ab590d737 100644 --- a/tests/out/collatz.msl.snap +++ b/tests/out/collatz.msl.snap @@ -15,22 +15,26 @@ struct PrimeIndices { type2 data; }; +constexpr constant unsigned const_0u = 0u; +constexpr constant unsigned const_1u = 1u; +constexpr constant unsigned const_2u = 2u; +constexpr constant unsigned const_3u = 3u; type1 collatz_iterations( type1 n_base ) { type1 n; - type1 i = 0u; + type1 i = const_0u; n = n_base; while(true) { - if ((n <= 1u)) { + if ((n <= const_1u)) { break; } - if (((n % 2u) == 0u)) { - n = (n / 2u); + if (((n % const_2u) == const_0u)) { + n = (n / const_2u); } else { - n = ((3u * n) + 1u); + n = ((const_3u * n) + const_1u); } - i = (i + 1u); + i = (i + const_1u); } return i; } diff --git a/tests/out/quad.msl.snap b/tests/out/quad.msl.snap index 488e66966e..be352e9027 100644 --- a/tests/out/quad.msl.snap +++ b/tests/out/quad.msl.snap @@ -15,6 +15,9 @@ typedef metal::texture2d type3; typedef metal::sampler type4; +constexpr constant float c_scale = 1.2; +constexpr constant float const_0f = 0.0; +constexpr constant float const_1f = 1.0; struct main1Input { type1 a_pos [[attribute(0)]]; type1 a_uv [[attribute(1)]]; @@ -30,7 +33,7 @@ vertex main1Output main1( ) { main1Output output; output.v_uv = input.a_uv; - output.o_position = metal::float4((1.2 * input.a_pos), 0.0, 1.0); + output.o_position = metal::float4((c_scale * input.a_pos), const_0f, const_1f); return output; } diff --git a/tests/out/shadow.msl.snap b/tests/out/shadow.msl.snap index 400381564b..53984b1246 100644 --- a/tests/out/shadow.msl.snap +++ b/tests/out/shadow.msl.snap @@ -41,17 +41,26 @@ typedef metal::float3 type9; typedef bool type10; +constexpr constant float const_0f = 0.0; +constexpr constant float const_1f = 1.0; +constexpr constant float const_0_50f = 0.5; +constexpr constant float const_n0_50f = -0.5; +constexpr constant float const_0_05f = 0.05; +constexpr constant type9 c_ambient = type9(const_0_05f, const_0_05f, const_0_05f); +constexpr constant unsigned c_max_lights = 10u; +constexpr constant unsigned const_0u = 0u; +constexpr constant unsigned const_1u = 1u; type7 fetch_shadow( type6 light_id, type2 homogeneous_coords, type4 t_shadow, type5 sampler_shadow ) { - if ((homogeneous_coords.w <= 0.0)) { - return 1.0; + if ((homogeneous_coords.w <= const_0f)) { + return const_1f; } - float expr15 = (1.0 / homogeneous_coords.w); - return t_shadow.sample_compare(sampler_shadow, (((metal::float2(homogeneous_coords.x, homogeneous_coords.y) * metal::float2(0.5, -0.5)) * expr15) + metal::float2(0.5, 0.5)), static_cast(light_id), (homogeneous_coords.z * expr15)); + float expr15 = (const_1f / homogeneous_coords.w); + return t_shadow.sample_compare(sampler_shadow, (((metal::float2(homogeneous_coords.x, homogeneous_coords.y) * metal::float2(const_0_50f, const_n0_50f)) * expr15) + metal::float2(const_0_50f, const_0_50f)), static_cast(light_id), (homogeneous_coords.z * expr15)); } struct fs_mainInput { @@ -71,22 +80,22 @@ fragment fs_mainOutput fs_main( type5 sampler_shadow [[sampler(0)]] ) { fs_mainOutput output; - type9 color1 = type9(0.05, 0.05, 0.05); - type6 i = 0u; + type9 color1 = c_ambient; + type6 i = const_0u; bool loop_init = true; while(true) { if (!loop_init) { - i = (i + 1u); + i = (i + const_1u); } loop_init = false; - if ((i >= metal::min(u_globals.num_lights.x, 10u))) { + if ((i >= metal::min(u_globals.num_lights.x, c_max_lights))) { break; } Light expr18 = s_lights.data[i]; type7 expr21 = fetch_shadow(i, (expr18.proj * input.in_position_fs), t_shadow, sampler_shadow); - color1 = (color1 + ((expr21 * metal::max(0.0, metal::dot(metal::normalize(input.in_normal_fs), metal::normalize((metal::float3(expr18.pos.x, expr18.pos.y, expr18.pos.z) - metal::float3(input.in_position_fs.x, input.in_position_fs.y, input.in_position_fs.z)))))) * metal::float3(expr18.color.x, expr18.color.y, expr18.color.z))); + color1 = (color1 + ((expr21 * metal::max(const_0f, metal::dot(metal::normalize(input.in_normal_fs), metal::normalize((metal::float3(expr18.pos.x, expr18.pos.y, expr18.pos.z) - metal::float3(input.in_position_fs.x, input.in_position_fs.y, input.in_position_fs.z)))))) * metal::float3(expr18.color.x, expr18.color.y, expr18.color.z))); } - output.out_color_fs = metal::float4(color1, 1.0); + output.out_color_fs = metal::float4(color1, const_1f); return output; } diff --git a/tests/out/skybox.msl.snap b/tests/out/skybox.msl.snap index bb036f6ac2..277e9e6818 100644 --- a/tests/out/skybox.msl.snap +++ b/tests/out/skybox.msl.snap @@ -28,6 +28,11 @@ typedef metal::texturecube type7; typedef metal::sampler type8; +constexpr constant int const_2i = 2; +constexpr constant int const_1i = 1; +constexpr constant float const_4f = 4.0; +constexpr constant float const_1f = 1.0; +constexpr constant float const_0f = 0.0; struct vs_mainInput { }; @@ -45,9 +50,9 @@ vertex vs_mainOutput vs_main( type4 tmp1_; type4 tmp2_; type unprojected; - tmp1_ = (static_cast(in_vertex_index) / 2); - tmp2_ = (static_cast(in_vertex_index) & 1); - type expr24 = metal::float4(((static_cast(tmp1_) * 4.0) - 1.0), ((static_cast(tmp2_) * 4.0) - 1.0), 0.0, 1.0); + tmp1_ = (static_cast(in_vertex_index) / const_2i); + tmp2_ = (static_cast(in_vertex_index) & const_1i); + type expr24 = metal::float4(((static_cast(tmp1_) * const_4f) - const_1f), ((static_cast(tmp2_) * const_4f) - const_1f), const_0f, const_1f); unprojected = (r_data.proj_inv * expr24); output.out_uv = (metal::transpose(metal::float3x3(metal::float3(r_data.view[0].x, r_data.view[0].y, r_data.view[0].z), metal::float3(r_data.view[1].x, r_data.view[1].y, r_data.view[1].z), metal::float3(r_data.view[2].x, r_data.view[2].y, r_data.view[2].z))) * metal::float3(unprojected.x, unprojected.y, unprojected.z)); output.out_position = expr24;