diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index ca4523fd31..819a4f15fb 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -342,6 +342,21 @@ fn should_pack_struct_member( } } +fn needs_array_length(ty: Handle, arena: &Arena) -> bool { + if let crate::TypeInner::Struct { ref members, .. } = arena[ty].inner { + if let Some(member) = members.last() { + if let crate::TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } = arena[member.ty].inner + { + return true; + } + } + } + false +} + impl crate::StorageClass { /// Returns true for storage classes, for which the global /// variables are passed in function arguments. @@ -1488,8 +1503,12 @@ impl Writer { // follow-up with any global resources used let mut separate = !arguments.is_empty(); let fun_info = &context.mod_info[function]; + let mut supports_array_length = false; for (handle, var) in context.expression.module.global_variables.iter() { - if !fun_info[handle].is_empty() && var.class.needs_pass_through() { + if fun_info[handle].is_empty() { + continue; + } + if var.class.needs_pass_through() { let name = &self.names[&NameKey::GlobalVariable(handle)]; if separate { write!(self.out, ", ")?; @@ -1498,12 +1517,13 @@ impl Writer { } write!(self.out, "{}", name)?; } + supports_array_length |= + needs_array_length(var.ty, &context.expression.module.types); } - if !self.runtime_sized_buffers.is_empty() { + if supports_array_length { if separate { write!(self.out, ", ")?; } - write!(self.out, "_buffer_sizes")?; } @@ -1542,19 +1562,11 @@ impl Writer { { let mut indices = vec![]; - for (handle, gv) in module.global_variables.iter() { - if let crate::TypeInner::Struct { ref members, .. } = module.types[gv.ty].inner { - if let Some(member) = members.last() { - if let crate::TypeInner::Array { - size: crate::ArraySize::Dynamic, - .. - } = module.types[member.ty].inner - { - let idx = handle.index(); - self.runtime_sized_buffers.insert(handle, idx); - indices.push(idx); - } - } + for (handle, var) in module.global_variables.iter() { + if needs_array_length(var.ty, &module.types) { + let idx = handle.index(); + self.runtime_sized_buffers.insert(handle, idx); + indices.push(idx); } } @@ -1845,9 +1857,13 @@ impl Writer { for (fun_handle, fun) in module.functions.iter() { let fun_info = &mod_info[fun_handle]; pass_through_globals.clear(); + let mut supports_array_length = false; for (handle, var) in module.global_variables.iter() { - if !fun_info[handle].is_empty() && var.class.needs_pass_through() { - pass_through_globals.push(handle); + if !fun_info[handle].is_empty() { + if var.class.needs_pass_through() { + pass_through_globals.push(handle); + } + supports_array_length |= needs_array_length(var.ty, &module.types); } } @@ -1882,7 +1898,7 @@ impl Writer { let separator = separate( !pass_through_globals.is_empty() || index + 1 != fun.arguments.len() - || !self.runtime_sized_buffers.is_empty(), + || supports_array_length, ); writeln!( self.out, @@ -1898,16 +1914,14 @@ impl Writer { usage: fun_info[handle], reference: true, }; - let separator = separate( - index + 1 != pass_through_globals.len() - || !self.runtime_sized_buffers.is_empty(), - ); + let separator = + separate(index + 1 != pass_through_globals.len() || supports_array_length); write!(self.out, "{}", INDENT)?; tyvar.try_fmt(&mut self.out)?; writeln!(self.out, "{}", separator)?; } - if !self.runtime_sized_buffers.is_empty() { + if supports_array_length { writeln!( self.out, "{}constant _mslBufferSizes& _buffer_sizes", @@ -1961,42 +1975,45 @@ impl Writer { for (ep_index, ep) in module.entry_points.iter().enumerate() { let fun = &ep.function; let fun_info = mod_info.get_entry_point(ep_index); + let mut ep_error = None; + let mut supports_array_length = false; + // skip this entry point if any global bindings are missing if !options.fake_missing_bindings { - if let Some(err) = module - .global_variables - .iter() - .find_map(|(var_handle, var)| { - if !fun_info[var_handle].is_empty() { - if let Some(ref br) = var.binding { - if let Err(e) = options.resolve_resource_binding(ep.stage, br) { - return Some(e); - } - } - if var.class == crate::StorageClass::PushConstant { - if let Err(e) = options.resolve_push_constants(ep.stage) { - return Some(e); - } - } - } - None - }) - { - info.entry_point_names.push(Err(err)); - continue; - } - if !self.runtime_sized_buffers.is_empty() { - if let Err(err) = options.resolve_sizes_buffer(ep.stage) { - info.entry_point_names.push(Err(err)); + for (var_handle, var) in module.global_variables.iter() { + if fun_info[var_handle].is_empty() { continue; } + if let Some(ref br) = var.binding { + if let Err(e) = options.resolve_resource_binding(ep.stage, br) { + ep_error = Some(e); + break; + } + } + if var.class == crate::StorageClass::PushConstant { + if let Err(e) = options.resolve_push_constants(ep.stage) { + ep_error = Some(e); + break; + } + } + supports_array_length |= needs_array_length(var.ty, &module.types); + } + if supports_array_length { + if let Err(err) = options.resolve_sizes_buffer(ep.stage) { + ep_error = Some(err); + } } } - writeln!(self.out)?; + if let Some(err) = ep_error { + info.entry_point_names.push(Err(err)); + continue; + } let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)]; info.entry_point_names.push(Ok(fun_name.clone())); + writeln!(self.out)?; + let stage_out_name = format!("{}Output", fun_name); let stage_in_name = format!("{}Input", fun_name); @@ -2212,7 +2229,7 @@ impl Writer { writeln!(self.out)?; } - if !self.runtime_sized_buffers.is_empty() { + if supports_array_length { // this is checked earlier let resolved = options.resolve_sizes_buffer(ep.stage).unwrap(); let separator = if module.global_variables.is_empty() { diff --git a/tests/out/collatz.msl b/tests/out/collatz.msl index 7342aa8ecd..b33506022b 100644 --- a/tests/out/collatz.msl +++ b/tests/out/collatz.msl @@ -11,8 +11,7 @@ struct PrimeIndices { }; metal::uint collatz_iterations( - metal::uint n_base, - constant _mslBufferSizes& _buffer_sizes + metal::uint n_base ) { metal::uint n; metal::uint i = 0u; @@ -36,9 +35,8 @@ struct main1Input { kernel void main1( metal::uint3 global_id [[thread_position_in_grid]] , device PrimeIndices& v_indices [[user(fake0)]] -, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]] ) { - metal::uint _e9 = collatz_iterations(v_indices.data[global_id.x], _buffer_sizes); + metal::uint _e9 = collatz_iterations(v_indices.data[global_id.x]); v_indices.data[global_id.x] = _e9; return; } diff --git a/tests/out/shadow.msl b/tests/out/shadow.msl index 17f1c82eaf..ed9b5bc55a 100644 --- a/tests/out/shadow.msl +++ b/tests/out/shadow.msl @@ -24,8 +24,7 @@ float fetch_shadow( metal::uint light_id, metal::float4 homogeneous_coords, metal::depth2d_array t_shadow, - metal::sampler sampler_shadow, - constant _mslBufferSizes& _buffer_sizes + metal::sampler sampler_shadow ) { if (homogeneous_coords.w <= 0.0) { return 1.0; @@ -47,7 +46,6 @@ fragment fs_mainOutput fs_main( , constant Lights& s_lights [[user(fake0)]] , metal::depth2d_array t_shadow [[user(fake0)]] , metal::sampler sampler_shadow [[user(fake0)]] -, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]] ) { const auto raw_normal = varyings.raw_normal; const auto position = varyings.position; @@ -63,7 +61,7 @@ fragment fs_mainOutput fs_main( break; } Light _e21 = s_lights.data[i]; - float _e25 = fetch_shadow(i, _e21.proj * position, t_shadow, sampler_shadow, _buffer_sizes); + float _e25 = fetch_shadow(i, _e21.proj * position, t_shadow, sampler_shadow); color1 = color1 + ((_e25 * metal::max(0.0, metal::dot(metal::normalize(raw_normal), metal::normalize(_e21.pos.xyz - position.xyz)))) * _e21.color.xyz); } return fs_mainOutput { metal::float4(color1, 1.0) }; diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 5fe0c2b450..8a85ee32cf 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -303,7 +303,11 @@ fn convert_spv(name: &str, adjust_coordinate_space: bool, targets: Targets) { #[cfg(feature = "spv-in")] #[test] fn convert_spv_quad_vert() { - convert_spv("quad-vert", false, Targets::METAL | Targets::GLSL | Targets::WGSL); + convert_spv( + "quad-vert", + false, + Targets::METAL | Targets::GLSL | Targets::WGSL, + ); } #[cfg(feature = "spv-in")]