diff --git a/src/back/msl/mod.rs b/src/back/msl/mod.rs index e38beed88d..0579dcfec2 100644 --- a/src/back/msl/mod.rs +++ b/src/back/msl/mod.rs @@ -157,6 +157,10 @@ pub struct Options { pub spirv_cross_compatibility: bool, /// Don't panic on missing bindings, instead generate invalid MSL. pub fake_missing_bindings: bool, + /// The slot of a buffer that contains an array of `u32`, + /// one for the size of each bound buffer that contains a runtime array, + /// in order of [`GlobalVariable`] declarations. + pub sizes_buffer_binding: Option, } impl Default for Options { @@ -168,6 +172,7 @@ impl Default for Options { inline_samplers: Vec::new(), spirv_cross_compatibility: false, fake_missing_bindings: true, + sizes_buffer_binding: None, } } } diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index bfccf1dcd5..66951b44b6 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1,6 +1,6 @@ use super::{ - keywords::RESERVED, sampler as sm, Error, LocationMode, Options, PipelineOptions, - TranslationInfo, + keywords::RESERVED, sampler as sm, BindTarget, Error, LocationMode, Options, PipelineOptions, + ResolvedBinding, TranslationInfo, }; use crate::{ arena::{Arena, Handle}, @@ -290,6 +290,7 @@ pub struct Writer { names: FastHashMap, named_expressions: BitSet, namer: Namer, + runtime_sized_buffers: FastHashMap, usize>, #[cfg(test)] put_expression_stack_pointers: crate::FastHashSet<*const ()>, #[cfg(test)] @@ -435,6 +436,7 @@ impl Writer { names: FastHashMap::default(), named_expressions: BitSet::new(), namer: Namer::default(), + runtime_sized_buffers: FastHashMap::default(), #[cfg(test)] put_expression_stack_pointers: Default::default(), #[cfg(test)] @@ -1060,26 +1062,53 @@ impl Writer { } // has to be a named expression crate::Expression::Call(_) => unreachable!(), - crate::Expression::ArrayLength(expr) => match *context.resolve_type(expr) { - crate::TypeInner::Array { - size: crate::ArraySize::Constant(const_handle), - .. - } => { - let coco = ConstantContext { - handle: const_handle, - arena: &context.module.constants, - names: &self.names, - first_time: false, - }; - write!(self.out, "{}", coco)?; + crate::Expression::ArrayLength(expr) => { + let handle = match context.function.expressions[expr] { + crate::Expression::AccessIndex { base, .. } => { + match context.function.expressions[base] { + crate::Expression::GlobalVariable(handle) => handle, + _ => return Err(Error::Validation), + } + } + _ => return Err(Error::Validation), + }; + + let global = &context.module.global_variables[handle]; + if let crate::TypeInner::Struct { ref members, .. } = + context.module.types[global.ty].inner + { + if let Some(&crate::StructMember { + offset, + ty: array_ty, + .. + }) = members.last() + { + let (span, stride) = match context.module.types[array_ty].inner { + crate::TypeInner::Array { base, stride, .. } => ( + context.module.types[base] + .inner + .span(&context.module.constants), + stride, + ), + _ => return Err(Error::Validation), + }; + + let buffer_idx = self.runtime_sized_buffers[&handle]; + write!( + self.out, + "(1 + (_buffer_sizes.size{idx} - {offset} - {span}) / {stride})", + idx = buffer_idx, + offset = offset, + span = span, + stride = stride, + )?; + } else { + return Err(Error::Validation); + } + } else { + return Err(Error::Validation); } - crate::TypeInner::Array { .. } => { - return Err(Error::FeatureNotImplemented( - "dynamic array size".to_string(), - )) - } - _ => return Err(Error::Validation), - }, + } } Ok(()) } @@ -1405,6 +1434,14 @@ impl Writer { write!(self.out, "{}", name)?; } } + if !self.runtime_sized_buffers.is_empty() { + if separate { + write!(self.out, ", ")?; + } + + write!(self.out, "_buffer_sizes")?; + } + // done writeln!(self.out, ");")?; } @@ -1432,11 +1469,42 @@ impl Writer { ) -> Result { self.names.clear(); self.namer.reset(module, RESERVED, &mut self.names); + self.runtime_sized_buffers.clear(); writeln!(self.out, "#include ")?; writeln!(self.out, "#include ")?; writeln!(self.out)?; + { + let mut indices = vec![]; + for (handle, gv) in module.global_variables.iter() { + if let crate::TypeInner::Struct { ref members, .. } = module.types[gv.ty].inner { + if let Some(member) = members.last() { + if let crate::TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } = module.types[member.ty].inner + { + let idx = handle.index(); + self.runtime_sized_buffers.insert(handle, idx); + indices.push(idx); + } + } + } + } + + if !indices.is_empty() { + writeln!(self.out, "struct _mslBufferSizes {{")?; + + for idx in indices { + writeln!(self.out, "{}{}::uint size{};", INDENT, NAMESPACE, idx)?; + } + + writeln!(self.out, "}};")?; + writeln!(self.out)?; + } + }; + self.write_scalar_constants(module)?; self.write_type_defs(module)?; self.write_composite_constants(module)?; @@ -1746,8 +1814,11 @@ impl Writer { access: crate::StorageAccess::empty(), first_time: false, }; - let separator = - separate(!pass_through_globals.is_empty() || index + 1 != fun.arguments.len()); + let separator = separate( + !pass_through_globals.is_empty() + || index + 1 != fun.arguments.len() + || !self.runtime_sized_buffers.is_empty(), + ); writeln!( self.out, "{}{} {}{}", @@ -1762,11 +1833,23 @@ impl Writer { usage: fun_info[handle], reference: true, }; - let separator = separate(index + 1 != pass_through_globals.len()); + let separator = separate( + index + 1 != pass_through_globals.len() + || !self.runtime_sized_buffers.is_empty(), + ); write!(self.out, "{}", INDENT)?; tyvar.try_fmt(&mut self.out)?; writeln!(self.out, "{}", separator)?; } + + if !self.runtime_sized_buffers.is_empty() { + writeln!( + self.out, + "{}constant _mslBufferSizes& _buffer_sizes", + INDENT + )?; + } + writeln!(self.out, ") {{")?; for (local_handle, local) in fun.local_variables.iter() { @@ -2049,6 +2132,34 @@ impl Writer { writeln!(self.out)?; } + if !self.runtime_sized_buffers.is_empty() { + let resolved = if let Some(slot) = options.sizes_buffer_binding { + ResolvedBinding::Resource(BindTarget { + buffer: Some(slot), + mutable: false, + ..Default::default() + }) + } else { + ResolvedBinding::User { + prefix: "fake", + index: 0, + interpolation: None, + } + }; + + let separator = if module.global_variables.is_empty() { + ' ' + } else { + ',' + }; + write!( + self.out, + "{} constant _mslBufferSizes& _buffer_sizes", + separator, + )?; + resolved.try_fmt_decorated(&mut self.out, "\n")?; + } + // end of the entry point argument list writeln!(self.out, ") {{")?; @@ -2231,8 +2342,8 @@ fn test_stack_size() { } let stack_size = addresses.end - addresses.start; // check the size (in debug only) - // last observed macOS value: 21760 - if stack_size < 21000 || stack_size > 23000 { + // last observed macOS value: 23040 + if stack_size < 21000 || stack_size > 25000 { panic!("`put_expression` stack size {} has changed!", stack_size); } } @@ -2246,8 +2357,8 @@ fn test_stack_size() { } let stack_size = addresses.end - addresses.start; // check the size (in debug only) - // last observed macOS value: 12736 - if stack_size < 12000 || stack_size > 13500 { + // last observed macOS value: 13600 + if stack_size < 12000 || stack_size > 14500 { panic!("`put_block` stack size {} has changed!", stack_size); } } diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index 308d9b6d63..0b3b42b42f 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -473,6 +473,20 @@ impl super::Instruction { instruction } + pub(super) fn array_length( + result_type_id: Word, + id: Word, + structure_id: Word, + array_member: Word, + ) -> Self { + let mut instruction = Self::new(Op::ArrayLength); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(structure_id); + instruction.add_operand(array_member); + instruction + } + // // Function Instructions // diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index c370b0aa69..d6c1a6cd16 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -2349,9 +2349,37 @@ impl Writer { .push(Instruction::relational(op, result_type_id, id, arg_id)); id } - crate::Expression::ArrayLength(_) => { - log::error!("unimplemented {:?}", ir_function.expressions[expr_handle]); - return Err(Error::FeatureNotImplemented("expression")); + crate::Expression::ArrayLength(expr) => { + let (structure_id, member_idx) = match ir_function.expressions[expr] { + crate::Expression::AccessIndex { base, .. } => { + match ir_function.expressions[base] { + crate::Expression::GlobalVariable(handle) => { + let global = &ir_module.global_variables[handle]; + let last_idx = match ir_module.types[global.ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + members.len() as u32 - 1 + } + _ => return Err(Error::Validation("array length expression")), + }; + + (self.global_variables[handle.index()].id, last_idx) + } + _ => return Err(Error::Validation("array length expression")), + } + } + _ => return Err(Error::Validation("array length expression")), + }; + + // let structure_id = self.get_expression_global(ir_function, global); + let id = self.id_gen.next(); + + block.body.push(Instruction::array_length( + result_type_id, + id, + structure_id, + member_idx, + )); + id } }; diff --git a/tests/in/access.param.ron b/tests/in/access.param.ron index 52f49f360f..0464ddf323 100644 --- a/tests/in/access.param.ron +++ b/tests/in/access.param.ron @@ -3,5 +3,17 @@ spv_capabilities: [ Shader, Image1D, Sampled1D ], spv_debug: true, spv_adjust_coordinate_space: false, - msl_custom: false, + msl_custom: true, + msl: ( + lang_version: (2, 0), + binding_map: { + (stage: Vertex, group: 0, binding: 0): (buffer: Some(0), mutable: true), + }, + push_constants_map: ( + ), + inline_samplers: [], + spirv_cross_compatibility: false, + fake_missing_bindings: false, + sizes_buffer_binding: Some(24), + ), ) diff --git a/tests/in/access.wgsl b/tests/in/access.wgsl index c5f6c27e74..1860ff1c43 100644 --- a/tests/in/access.wgsl +++ b/tests/in/access.wgsl @@ -1,8 +1,18 @@ // This snapshot tests accessing various containers, dereferencing pointers. +[[block]] +struct Bar { + data: [[stride(4)]] array; +}; + +[[group(0), binding(0)]] +var bar: [[access(read_write)]] Bar; + [[stage(vertex)]] fn foo([[builtin(vertex_index)]] vi: u32) -> [[builtin(position)]] vec4 { - let array = array(1, 2, 3, 4, 5); + let a = bar.data[arrayLength(&bar.data) - 1u]; + + let array = array(a, 2, 3, 4, 5); let value = array[vi]; return vec4(vec4(value)); } diff --git a/tests/out/access.msl b/tests/out/access.msl index 36ed564822..a87a402ae0 100644 --- a/tests/out/access.msl +++ b/tests/out/access.msl @@ -1,7 +1,15 @@ #include #include -struct type3 { +struct _mslBufferSizes { + metal::uint size0; +}; + +typedef int type1[1]; +struct Bar { + type1 data; +}; +struct type4 { int inner[5]; }; @@ -12,6 +20,8 @@ struct fooOutput { }; vertex fooOutput foo( metal::uint vi [[vertex_id]] +, device Bar& bar [[buffer(0)]] +, constant _mslBufferSizes& _buffer_sizes [[buffer(24)]] ) { - return fooOutput { static_cast(int4(type3 {1, 2, 3, 4, 5}.inner[vi])) }; + return fooOutput { static_cast(int4(type4 {bar.data[(1 + (_buffer_sizes.size0 - 0 - 4) / 4) - 1u], 2, 3, 4, 5}.inner[vi])) }; } diff --git a/tests/out/access.spvasm b/tests/out/access.spvasm index 985b9a8408..307305a9e6 100644 --- a/tests/out/access.spvasm +++ b/tests/out/access.spvasm @@ -1,50 +1,70 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 31 +; Bound: 42 OpCapability Image1D OpCapability Shader OpCapability Sampled1D +OpExtension "SPV_KHR_storage_buffer_storage_class" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint Vertex %19 "foo" %14 %17 +OpEntryPoint Vertex %23 "foo" %18 %21 OpSource GLSL 450 -OpName %14 "vi" -OpName %19 "foo" -OpDecorate %12 ArrayStride 4 -OpDecorate %14 BuiltIn VertexIndex -OpDecorate %17 BuiltIn Position +OpName %11 "Bar" +OpMemberName %11 0 "data" +OpName %15 "bar" +OpName %18 "vi" +OpName %23 "foo" +OpDecorate %10 ArrayStride 4 +OpDecorate %11 Block +OpMemberDecorate %11 0 Offset 0 +OpDecorate %14 ArrayStride 4 +OpDecorate %15 DescriptorSet 0 +OpDecorate %15 Binding 0 +OpDecorate %18 BuiltIn VertexIndex +OpDecorate %21 BuiltIn Position %2 = OpTypeVoid -%4 = OpTypeInt 32 1 -%3 = OpConstant %4 5 -%5 = OpConstant %4 1 -%6 = OpConstant %4 2 -%7 = OpConstant %4 3 -%8 = OpConstant %4 4 -%9 = OpTypeInt 32 0 -%11 = OpTypeFloat 32 -%10 = OpTypeVector %11 4 -%12 = OpTypeArray %4 %3 -%15 = OpTypePointer Input %9 -%14 = OpVariable %15 Input -%18 = OpTypePointer Output %10 -%17 = OpVariable %18 Output -%20 = OpTypeFunction %2 -%23 = OpTypePointer Function %12 -%26 = OpTypePointer Function %4 -%28 = OpTypeVector %4 4 -%19 = OpFunction %2 None %20 -%13 = OpLabel -%24 = OpVariable %23 Function -%16 = OpLoad %9 %14 -OpBranch %21 -%21 = OpLabel -%22 = OpCompositeConstruct %12 %5 %6 %7 %8 %3 -OpStore %24 %22 -%25 = OpAccessChain %26 %24 %16 -%27 = OpLoad %4 %25 -%29 = OpCompositeConstruct %28 %27 %27 %27 %27 -%30 = OpConvertSToF %10 %29 -OpStore %17 %30 +%4 = OpTypeInt 32 0 +%3 = OpConstant %4 1 +%6 = OpTypeInt 32 1 +%5 = OpConstant %6 5 +%7 = OpConstant %6 2 +%8 = OpConstant %6 3 +%9 = OpConstant %6 4 +%10 = OpTypeRuntimeArray %6 +%11 = OpTypeStruct %10 +%13 = OpTypeFloat 32 +%12 = OpTypeVector %13 4 +%14 = OpTypeArray %6 %5 +%16 = OpTypePointer StorageBuffer %11 +%15 = OpVariable %16 StorageBuffer +%19 = OpTypePointer Input %4 +%18 = OpVariable %19 Input +%22 = OpTypePointer Output %12 +%21 = OpVariable %22 Output +%24 = OpTypeFunction %2 +%26 = OpTypePointer StorageBuffer %10 +%29 = OpTypePointer StorageBuffer %6 +%30 = OpConstant %6 0 +%34 = OpTypePointer Function %14 +%37 = OpTypePointer Function %6 +%39 = OpTypeVector %6 4 +%23 = OpFunction %2 None %24 +%17 = OpLabel +%35 = OpVariable %34 Function +%20 = OpLoad %4 %18 +OpBranch %25 +%25 = OpLabel +%27 = OpArrayLength %4 %15 0 +%28 = OpISub %4 %27 %3 +%31 = OpAccessChain %29 %15 %30 %28 +%32 = OpLoad %6 %31 +%33 = OpCompositeConstruct %14 %32 %7 %8 %9 %5 +OpStore %35 %33 +%36 = OpAccessChain %37 %35 %20 +%38 = OpLoad %6 %36 +%40 = OpCompositeConstruct %39 %38 %38 %38 %38 +%41 = OpConvertSToF %12 %40 +OpStore %21 %41 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/boids.msl b/tests/out/boids.msl index 3cabb8e664..1deab9f731 100644 --- a/tests/out/boids.msl +++ b/tests/out/boids.msl @@ -1,6 +1,11 @@ #include #include +struct _mslBufferSizes { + metal::uint size1; + metal::uint size2; +}; + constexpr constant unsigned NUM_PARTICLES = 1500u; struct Particle { metal::float2 pos; @@ -27,6 +32,7 @@ kernel void main1( , constant SimParams& params [[buffer(0)]] , constant Particles& particlesSrc [[buffer(1)]] , device Particles& particlesDst [[buffer(2)]] +, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]] ) { metal::float2 vPos; metal::float2 vVel; diff --git a/tests/out/collatz.msl b/tests/out/collatz.msl index 403789659f..7342aa8ecd 100644 --- a/tests/out/collatz.msl +++ b/tests/out/collatz.msl @@ -1,13 +1,18 @@ #include #include +struct _mslBufferSizes { + metal::uint size0; +}; + typedef metal::uint type1[1]; struct PrimeIndices { type1 data; }; metal::uint collatz_iterations( - metal::uint n_base + metal::uint n_base, + constant _mslBufferSizes& _buffer_sizes ) { metal::uint n; metal::uint i = 0u; @@ -31,8 +36,9 @@ struct main1Input { kernel void main1( metal::uint3 global_id [[thread_position_in_grid]] , device PrimeIndices& v_indices [[user(fake0)]] +, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]] ) { - metal::uint _e9 = collatz_iterations(v_indices.data[global_id.x]); + metal::uint _e9 = collatz_iterations(v_indices.data[global_id.x], _buffer_sizes); v_indices.data[global_id.x] = _e9; return; } diff --git a/tests/out/shadow.msl b/tests/out/shadow.msl index 0ff4e18f40..420dfda576 100644 --- a/tests/out/shadow.msl +++ b/tests/out/shadow.msl @@ -1,6 +1,10 @@ #include #include +struct _mslBufferSizes { + metal::uint size1; +}; + constexpr constant unsigned c_max_lights = 10u; struct Globals { metal::uint4 num_lights; @@ -20,7 +24,8 @@ float fetch_shadow( metal::uint light_id, metal::float4 homogeneous_coords, metal::depth2d_array t_shadow, - metal::sampler sampler_shadow + metal::sampler sampler_shadow, + constant _mslBufferSizes& _buffer_sizes ) { if (homogeneous_coords.w <= 0.0) { return 1.0; @@ -42,6 +47,7 @@ fragment fs_mainOutput fs_main( , constant Lights& s_lights [[user(fake0)]] , metal::depth2d_array t_shadow [[user(fake0)]] , metal::sampler sampler_shadow [[user(fake0)]] +, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]] ) { const auto raw_normal = varyings.raw_normal; const auto position = varyings.position; @@ -57,7 +63,7 @@ fragment fs_mainOutput fs_main( break; } Light _e21 = s_lights.data[i]; - float _e25 = fetch_shadow(i, _e21.proj * position, t_shadow, sampler_shadow); + float _e25 = fetch_shadow(i, _e21.proj * position, t_shadow, sampler_shadow, _buffer_sizes); color1 = color1 + ((_e25 * metal::max(0.0, metal::dot(metal::normalize(raw_normal), metal::normalize(_e21.pos.xyz - position.xyz)))) * _e21.color.xyz); } return fs_mainOutput { metal::float4(color1, 1.0) }; diff --git a/tests/snapshots.rs b/tests/snapshots.rs index cf2ce6a48b..dfce79d745 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -154,7 +154,7 @@ fn check_output_msl( let options = ¶ms.msl; #[cfg(not(feature = "deserialize"))] let options = if params.msl_custom { - println!("Skipping {}", destination); + println!("Skipping {}", destination.display()); return; } else { &default_options @@ -248,7 +248,7 @@ fn convert_wgsl() { "interpolate", Targets::SPIRV | Targets::METAL | Targets::GLSL, ), - ("access", Targets::SPIRV | Targets::METAL | Targets::WGSL), + ("access", Targets::SPIRV | Targets::METAL), ]; for &(name, targets) in inputs.iter() {