diff --git a/src/back/msl/keywords.rs b/src/back/msl/keywords.rs index af4be8827c..a62113c097 100644 --- a/src/back/msl/keywords.rs +++ b/src/back/msl/keywords.rs @@ -210,4 +210,6 @@ pub const RESERVED: &[&str] = &[ "M_2_SQRTPI", "M_SQRT2", "M_SQRT1_2", + // Naga utilities + "DefaultConstructible", ]; diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 7b01779642..3a917a3562 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -822,7 +822,7 @@ impl Writer { { write!(self.out, " ? ")?; self.put_access_chain(expr_handle, policy, context)?; - write!(self.out, " : 0")?; + write!(self.out, " : DefaultConstructible()")?; if !is_scoped { write!(self.out, ")")?; @@ -1529,7 +1529,7 @@ impl Writer { { write!(self.out, " ? ")?; self.put_unchecked_load(pointer, policy, context)?; - write!(self.out, " : 0")?; + write!(self.out, " : DefaultConstructible()")?; if !is_scoped { write!(self.out, ")")?; @@ -2154,6 +2154,13 @@ impl Writer { writeln!(self.out, "#include ")?; writeln!(self.out)?; + if options + .bounds_check_policies + .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite) + { + self.put_default_constructible()?; + } + { let mut indices = vec![]; for (handle, var) in module.global_variables.iter() { @@ -2181,6 +2188,28 @@ impl Writer { self.write_functions(module, info, options, pipeline_options) } + /// Write the definition for the `DefaultConstructible` class. + /// + /// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to + /// produce 'zero' values for any type, including structs, arrays, and so + /// on. We could do this by emitting default constructor applications, but + /// that would entail printing the name of the type, which is more trouble + /// than you'd think. Instead, we just construct this magic C++14 class that + /// can be converted to any type that can be default constructed, using + /// template parameter inference to detect which type is needed, so we don't + /// have to figure out the name. + /// + /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite + fn put_default_constructible(&mut self) -> BackendResult { + writeln!(self.out, "struct DefaultConstructible {{")?; + writeln!(self.out, " template")?; + writeln!(self.out, " operator T() && {{")?; + writeln!(self.out, " return T {{}};")?; + writeln!(self.out, " }}")?; + writeln!(self.out, "}};")?; + Ok(()) + } + fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult { for (handle, ty) in module.types.iter() { if !ty.needs_alias() { diff --git a/src/proc/index.rs b/src/proc/index.rs index 15381c9d23..ae20cc5d6e 100644 --- a/src/proc/index.rs +++ b/src/proc/index.rs @@ -148,6 +148,11 @@ impl BoundsCheckPolicies { _ => self.index, } } + + /// Return `true` if any of `self`'s policies are `policy`. + pub fn contains(&self, policy: BoundsCheckPolicy) -> bool { + self.index == policy || self.buffer == policy || self.image == policy + } } /// An index that may be statically known, or may need to be computed at runtime. @@ -204,9 +209,7 @@ pub fn find_checked_indexes( let mut guarded_indices = BitSet::new(); // Don't bother scanning if we never need `ReadZeroSkipWrite`. - if policies.index == BoundsCheckPolicy::ReadZeroSkipWrite - || policies.buffer == BoundsCheckPolicy::ReadZeroSkipWrite - { + if policies.contains(BoundsCheckPolicy::ReadZeroSkipWrite) { for (_handle, expr) in function.expressions.iter() { // There's no need to handle `AccessIndex` expressions, as their // indices never need to be cached. diff --git a/tests/out/msl/bounds-check-zero.msl b/tests/out/msl/bounds-check-zero.msl index 6fa6613204..0662b33409 100644 --- a/tests/out/msl/bounds-check-zero.msl +++ b/tests/out/msl/bounds-check-zero.msl @@ -2,6 +2,12 @@ #include #include +struct DefaultConstructible { + template + operator T() && { + return T {}; + } +}; struct _mslBufferSizes { metal::uint size0; }; @@ -23,7 +29,7 @@ float index_array( device Globals& globals, constant _mslBufferSizes& _buffer_sizes ) { - float _e4 = metal::uint(i) < 10 ? globals.a.inner[i] : 0; + float _e4 = metal::uint(i) < 10 ? globals.a.inner[i] : DefaultConstructible(); return _e4; } @@ -32,7 +38,7 @@ float index_dynamic_array( device Globals& globals, constant _mslBufferSizes& _buffer_sizes ) { - float _e4 = metal::uint(i_1) < 1 + (_buffer_sizes.size0 - 112 - 4) / 4 ? globals.d[i_1] : 0; + float _e4 = metal::uint(i_1) < 1 + (_buffer_sizes.size0 - 112 - 4) / 4 ? globals.d[i_1] : DefaultConstructible(); return _e4; } @@ -41,7 +47,7 @@ float index_vector( device Globals& globals, constant _mslBufferSizes& _buffer_sizes ) { - float _e4 = metal::uint(i_2) < 4 ? globals.v[i_2] : 0; + float _e4 = metal::uint(i_2) < 4 ? globals.v[i_2] : DefaultConstructible(); return _e4; } @@ -49,7 +55,7 @@ float index_vector_by_value( metal::float4 v, int i_3 ) { - return metal::uint(i_3) < 4 ? v[i_3] : 0; + return metal::uint(i_3) < 4 ? v[i_3] : DefaultConstructible(); } metal::float4 index_matrix( @@ -57,7 +63,7 @@ metal::float4 index_matrix( device Globals& globals, constant _mslBufferSizes& _buffer_sizes ) { - metal::float4 _e4 = metal::uint(i_4) < 3 ? globals.m[i_4] : 0; + metal::float4 _e4 = metal::uint(i_4) < 3 ? globals.m[i_4] : DefaultConstructible(); return _e4; } @@ -67,7 +73,7 @@ float index_twice( device Globals& globals, constant _mslBufferSizes& _buffer_sizes ) { - float _e6 = metal::uint(j) < 4 && metal::uint(i_5) < 3 ? globals.m[i_5][j] : 0; + float _e6 = metal::uint(j) < 4 && metal::uint(i_5) < 3 ? globals.m[i_5][j] : DefaultConstructible(); return _e6; } @@ -77,7 +83,7 @@ float index_expensive( constant _mslBufferSizes& _buffer_sizes ) { int _e9 = static_cast(metal::sin(static_cast(i_6) / 100.0) * 100.0); - float _e11 = metal::uint(_e9) < 10 ? globals.a.inner[_e9] : 0; + float _e11 = metal::uint(_e9) < 10 ? globals.a.inner[_e9] : DefaultConstructible(); return _e11; } diff --git a/tests/out/msl/policy-mix.msl b/tests/out/msl/policy-mix.msl index c3b5bb1dda..06ed990be3 100644 --- a/tests/out/msl/policy-mix.msl +++ b/tests/out/msl/policy-mix.msl @@ -2,6 +2,12 @@ #include #include +struct DefaultConstructible { + template + operator T() && { + return T {}; + } +}; struct type_1 { metal::float4 inner[10]; };