Add support for retrieving array length in msl backend (#806)

* Add support for arrayLength to the wgsl frontend

* Fix clippy warning

* Add draft support for array length to the msl backend

* Finalize support for array length in msl

* Convert buffer size to array length in msl backend

* Fix clippy warning

* Fix misleading documentation

* Changes based on review of PR

* Use a fake binding for sizes buffer in msl backend if necessary

* Only generate the msl buffer size structure if globals are present that have unsized arrays

* Make sure to generate separators

* Replace uses_unsized_buffers with !self.runtime_sized_buffers.is_empty()

* Clear self.runtime_sized_buffers

* Run snapshot of shadow.msl

* Add Expression::ArrayLength support to the spirv backend

* Remove wgsl target from the access snapshot test

* Update access.msl snapshot

* Update stack size for msl backend

* Update stack size again for msl backend
This commit is contained in:
Lachlan Sneff
2021-04-30 22:40:32 -04:00
committed by GitHub
parent 992bdd83ce
commit 232fd65ec7
12 changed files with 307 additions and 79 deletions

View File

@@ -157,6 +157,10 @@ pub struct Options {
pub spirv_cross_compatibility: bool,
/// Don't panic on missing bindings, instead generate invalid MSL.
pub fake_missing_bindings: bool,
/// The slot of a buffer that contains an array of `u32`,
/// one for the size of each bound buffer that contains a runtime array,
/// in order of [`GlobalVariable`] declarations.
pub sizes_buffer_binding: Option<Slot>,
}
impl Default for Options {
@@ -168,6 +172,7 @@ impl Default for Options {
inline_samplers: Vec::new(),
spirv_cross_compatibility: false,
fake_missing_bindings: true,
sizes_buffer_binding: None,
}
}
}

View File

@@ -1,6 +1,6 @@
use super::{
keywords::RESERVED, sampler as sm, Error, LocationMode, Options, PipelineOptions,
TranslationInfo,
keywords::RESERVED, sampler as sm, BindTarget, Error, LocationMode, Options, PipelineOptions,
ResolvedBinding, TranslationInfo,
};
use crate::{
arena::{Arena, Handle},
@@ -290,6 +290,7 @@ pub struct Writer<W> {
names: FastHashMap<NameKey, String>,
named_expressions: BitSet,
namer: Namer,
runtime_sized_buffers: FastHashMap<Handle<crate::GlobalVariable>, usize>,
#[cfg(test)]
put_expression_stack_pointers: crate::FastHashSet<*const ()>,
#[cfg(test)]
@@ -435,6 +436,7 @@ impl<W: Write> Writer<W> {
names: FastHashMap::default(),
named_expressions: BitSet::new(),
namer: Namer::default(),
runtime_sized_buffers: FastHashMap::default(),
#[cfg(test)]
put_expression_stack_pointers: Default::default(),
#[cfg(test)]
@@ -1060,26 +1062,53 @@ impl<W: Write> Writer<W> {
}
// has to be a named expression
crate::Expression::Call(_) => unreachable!(),
crate::Expression::ArrayLength(expr) => match *context.resolve_type(expr) {
crate::TypeInner::Array {
size: crate::ArraySize::Constant(const_handle),
..
} => {
let coco = ConstantContext {
handle: const_handle,
arena: &context.module.constants,
names: &self.names,
first_time: false,
};
write!(self.out, "{}", coco)?;
crate::Expression::ArrayLength(expr) => {
let handle = match context.function.expressions[expr] {
crate::Expression::AccessIndex { base, .. } => {
match context.function.expressions[base] {
crate::Expression::GlobalVariable(handle) => handle,
_ => return Err(Error::Validation),
}
}
_ => return Err(Error::Validation),
};
let global = &context.module.global_variables[handle];
if let crate::TypeInner::Struct { ref members, .. } =
context.module.types[global.ty].inner
{
if let Some(&crate::StructMember {
offset,
ty: array_ty,
..
}) = members.last()
{
let (span, stride) = match context.module.types[array_ty].inner {
crate::TypeInner::Array { base, stride, .. } => (
context.module.types[base]
.inner
.span(&context.module.constants),
stride,
),
_ => return Err(Error::Validation),
};
let buffer_idx = self.runtime_sized_buffers[&handle];
write!(
self.out,
"(1 + (_buffer_sizes.size{idx} - {offset} - {span}) / {stride})",
idx = buffer_idx,
offset = offset,
span = span,
stride = stride,
)?;
} else {
return Err(Error::Validation);
}
} else {
return Err(Error::Validation);
}
crate::TypeInner::Array { .. } => {
return Err(Error::FeatureNotImplemented(
"dynamic array size".to_string(),
))
}
_ => return Err(Error::Validation),
},
}
}
Ok(())
}
@@ -1405,6 +1434,14 @@ impl<W: Write> Writer<W> {
write!(self.out, "{}", name)?;
}
}
if !self.runtime_sized_buffers.is_empty() {
if separate {
write!(self.out, ", ")?;
}
write!(self.out, "_buffer_sizes")?;
}
// done
writeln!(self.out, ");")?;
}
@@ -1432,11 +1469,42 @@ impl<W: Write> Writer<W> {
) -> Result<TranslationInfo, Error> {
self.names.clear();
self.namer.reset(module, RESERVED, &mut self.names);
self.runtime_sized_buffers.clear();
writeln!(self.out, "#include <metal_stdlib>")?;
writeln!(self.out, "#include <simd/simd.h>")?;
writeln!(self.out)?;
{
let mut indices = vec![];
for (handle, gv) in module.global_variables.iter() {
if let crate::TypeInner::Struct { ref members, .. } = module.types[gv.ty].inner {
if let Some(member) = members.last() {
if let crate::TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} = module.types[member.ty].inner
{
let idx = handle.index();
self.runtime_sized_buffers.insert(handle, idx);
indices.push(idx);
}
}
}
}
if !indices.is_empty() {
writeln!(self.out, "struct _mslBufferSizes {{")?;
for idx in indices {
writeln!(self.out, "{}{}::uint size{};", INDENT, NAMESPACE, idx)?;
}
writeln!(self.out, "}};")?;
writeln!(self.out)?;
}
};
self.write_scalar_constants(module)?;
self.write_type_defs(module)?;
self.write_composite_constants(module)?;
@@ -1746,8 +1814,11 @@ impl<W: Write> Writer<W> {
access: crate::StorageAccess::empty(),
first_time: false,
};
let separator =
separate(!pass_through_globals.is_empty() || index + 1 != fun.arguments.len());
let separator = separate(
!pass_through_globals.is_empty()
|| index + 1 != fun.arguments.len()
|| !self.runtime_sized_buffers.is_empty(),
);
writeln!(
self.out,
"{}{} {}{}",
@@ -1762,11 +1833,23 @@ impl<W: Write> Writer<W> {
usage: fun_info[handle],
reference: true,
};
let separator = separate(index + 1 != pass_through_globals.len());
let separator = separate(
index + 1 != pass_through_globals.len()
|| !self.runtime_sized_buffers.is_empty(),
);
write!(self.out, "{}", INDENT)?;
tyvar.try_fmt(&mut self.out)?;
writeln!(self.out, "{}", separator)?;
}
if !self.runtime_sized_buffers.is_empty() {
writeln!(
self.out,
"{}constant _mslBufferSizes& _buffer_sizes",
INDENT
)?;
}
writeln!(self.out, ") {{")?;
for (local_handle, local) in fun.local_variables.iter() {
@@ -2049,6 +2132,34 @@ impl<W: Write> Writer<W> {
writeln!(self.out)?;
}
if !self.runtime_sized_buffers.is_empty() {
let resolved = if let Some(slot) = options.sizes_buffer_binding {
ResolvedBinding::Resource(BindTarget {
buffer: Some(slot),
mutable: false,
..Default::default()
})
} else {
ResolvedBinding::User {
prefix: "fake",
index: 0,
interpolation: None,
}
};
let separator = if module.global_variables.is_empty() {
' '
} else {
','
};
write!(
self.out,
"{} constant _mslBufferSizes& _buffer_sizes",
separator,
)?;
resolved.try_fmt_decorated(&mut self.out, "\n")?;
}
// end of the entry point argument list
writeln!(self.out, ") {{")?;
@@ -2231,8 +2342,8 @@ fn test_stack_size() {
}
let stack_size = addresses.end - addresses.start;
// check the size (in debug only)
// last observed macOS value: 21760
if stack_size < 21000 || stack_size > 23000 {
// last observed macOS value: 23040
if stack_size < 21000 || stack_size > 25000 {
panic!("`put_expression` stack size {} has changed!", stack_size);
}
}
@@ -2246,8 +2357,8 @@ fn test_stack_size() {
}
let stack_size = addresses.end - addresses.start;
// check the size (in debug only)
// last observed macOS value: 12736
if stack_size < 12000 || stack_size > 13500 {
// last observed macOS value: 13600
if stack_size < 12000 || stack_size > 14500 {
panic!("`put_block` stack size {} has changed!", stack_size);
}
}

View File

@@ -473,6 +473,20 @@ impl super::Instruction {
instruction
}
pub(super) fn array_length(
result_type_id: Word,
id: Word,
structure_id: Word,
array_member: Word,
) -> Self {
let mut instruction = Self::new(Op::ArrayLength);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(structure_id);
instruction.add_operand(array_member);
instruction
}
//
// Function Instructions
//

View File

@@ -2349,9 +2349,37 @@ impl Writer {
.push(Instruction::relational(op, result_type_id, id, arg_id));
id
}
crate::Expression::ArrayLength(_) => {
log::error!("unimplemented {:?}", ir_function.expressions[expr_handle]);
return Err(Error::FeatureNotImplemented("expression"));
crate::Expression::ArrayLength(expr) => {
let (structure_id, member_idx) = match ir_function.expressions[expr] {
crate::Expression::AccessIndex { base, .. } => {
match ir_function.expressions[base] {
crate::Expression::GlobalVariable(handle) => {
let global = &ir_module.global_variables[handle];
let last_idx = match ir_module.types[global.ty].inner {
crate::TypeInner::Struct { ref members, .. } => {
members.len() as u32 - 1
}
_ => return Err(Error::Validation("array length expression")),
};
(self.global_variables[handle.index()].id, last_idx)
}
_ => return Err(Error::Validation("array length expression")),
}
}
_ => return Err(Error::Validation("array length expression")),
};
// let structure_id = self.get_expression_global(ir_function, global);
let id = self.id_gen.next();
block.body.push(Instruction::array_length(
result_type_id,
id,
structure_id,
member_idx,
));
id
}
};

View File

@@ -3,5 +3,17 @@
spv_capabilities: [ Shader, Image1D, Sampled1D ],
spv_debug: true,
spv_adjust_coordinate_space: false,
msl_custom: false,
msl_custom: true,
msl: (
lang_version: (2, 0),
binding_map: {
(stage: Vertex, group: 0, binding: 0): (buffer: Some(0), mutable: true),
},
push_constants_map: (
),
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
sizes_buffer_binding: Some(24),
),
)

View File

@@ -1,8 +1,18 @@
// This snapshot tests accessing various containers, dereferencing pointers.
[[block]]
struct Bar {
data: [[stride(4)]] array<i32>;
};
[[group(0), binding(0)]]
var<storage> bar: [[access(read_write)]] Bar;
[[stage(vertex)]]
fn foo([[builtin(vertex_index)]] vi: u32) -> [[builtin(position)]] vec4<f32> {
let array = array<i32, 5>(1, 2, 3, 4, 5);
let a = bar.data[arrayLength(&bar.data) - 1u];
let array = array<i32, 5>(a, 2, 3, 4, 5);
let value = array[vi];
return vec4<f32>(vec4<i32>(value));
}

View File

@@ -1,7 +1,15 @@
#include <metal_stdlib>
#include <simd/simd.h>
struct type3 {
struct _mslBufferSizes {
metal::uint size0;
};
typedef int type1[1];
struct Bar {
type1 data;
};
struct type4 {
int inner[5];
};
@@ -12,6 +20,8 @@ struct fooOutput {
};
vertex fooOutput foo(
metal::uint vi [[vertex_id]]
, device Bar& bar [[buffer(0)]]
, constant _mslBufferSizes& _buffer_sizes [[buffer(24)]]
) {
return fooOutput { static_cast<float4>(int4(type3 {1, 2, 3, 4, 5}.inner[vi])) };
return fooOutput { static_cast<float4>(int4(type4 {bar.data[(1 + (_buffer_sizes.size0 - 0 - 4) / 4) - 1u], 2, 3, 4, 5}.inner[vi])) };
}

View File

@@ -1,50 +1,70 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 31
; Bound: 42
OpCapability Image1D
OpCapability Shader
OpCapability Sampled1D
OpExtension "SPV_KHR_storage_buffer_storage_class"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Vertex %19 "foo" %14 %17
OpEntryPoint Vertex %23 "foo" %18 %21
OpSource GLSL 450
OpName %14 "vi"
OpName %19 "foo"
OpDecorate %12 ArrayStride 4
OpDecorate %14 BuiltIn VertexIndex
OpDecorate %17 BuiltIn Position
OpName %11 "Bar"
OpMemberName %11 0 "data"
OpName %15 "bar"
OpName %18 "vi"
OpName %23 "foo"
OpDecorate %10 ArrayStride 4
OpDecorate %11 Block
OpMemberDecorate %11 0 Offset 0
OpDecorate %14 ArrayStride 4
OpDecorate %15 DescriptorSet 0
OpDecorate %15 Binding 0
OpDecorate %18 BuiltIn VertexIndex
OpDecorate %21 BuiltIn Position
%2 = OpTypeVoid
%4 = OpTypeInt 32 1
%3 = OpConstant %4 5
%5 = OpConstant %4 1
%6 = OpConstant %4 2
%7 = OpConstant %4 3
%8 = OpConstant %4 4
%9 = OpTypeInt 32 0
%11 = OpTypeFloat 32
%10 = OpTypeVector %11 4
%12 = OpTypeArray %4 %3
%15 = OpTypePointer Input %9
%14 = OpVariable %15 Input
%18 = OpTypePointer Output %10
%17 = OpVariable %18 Output
%20 = OpTypeFunction %2
%23 = OpTypePointer Function %12
%26 = OpTypePointer Function %4
%28 = OpTypeVector %4 4
%19 = OpFunction %2 None %20
%13 = OpLabel
%24 = OpVariable %23 Function
%16 = OpLoad %9 %14
OpBranch %21
%21 = OpLabel
%22 = OpCompositeConstruct %12 %5 %6 %7 %8 %3
OpStore %24 %22
%25 = OpAccessChain %26 %24 %16
%27 = OpLoad %4 %25
%29 = OpCompositeConstruct %28 %27 %27 %27 %27
%30 = OpConvertSToF %10 %29
OpStore %17 %30
%4 = OpTypeInt 32 0
%3 = OpConstant %4 1
%6 = OpTypeInt 32 1
%5 = OpConstant %6 5
%7 = OpConstant %6 2
%8 = OpConstant %6 3
%9 = OpConstant %6 4
%10 = OpTypeRuntimeArray %6
%11 = OpTypeStruct %10
%13 = OpTypeFloat 32
%12 = OpTypeVector %13 4
%14 = OpTypeArray %6 %5
%16 = OpTypePointer StorageBuffer %11
%15 = OpVariable %16 StorageBuffer
%19 = OpTypePointer Input %4
%18 = OpVariable %19 Input
%22 = OpTypePointer Output %12
%21 = OpVariable %22 Output
%24 = OpTypeFunction %2
%26 = OpTypePointer StorageBuffer %10
%29 = OpTypePointer StorageBuffer %6
%30 = OpConstant %6 0
%34 = OpTypePointer Function %14
%37 = OpTypePointer Function %6
%39 = OpTypeVector %6 4
%23 = OpFunction %2 None %24
%17 = OpLabel
%35 = OpVariable %34 Function
%20 = OpLoad %4 %18
OpBranch %25
%25 = OpLabel
%27 = OpArrayLength %4 %15 0
%28 = OpISub %4 %27 %3
%31 = OpAccessChain %29 %15 %30 %28
%32 = OpLoad %6 %31
%33 = OpCompositeConstruct %14 %32 %7 %8 %9 %5
OpStore %35 %33
%36 = OpAccessChain %37 %35 %20
%38 = OpLoad %6 %36
%40 = OpCompositeConstruct %39 %38 %38 %38 %38
%41 = OpConvertSToF %12 %40
OpStore %21 %41
OpReturn
OpFunctionEnd

View File

@@ -1,6 +1,11 @@
#include <metal_stdlib>
#include <simd/simd.h>
struct _mslBufferSizes {
metal::uint size1;
metal::uint size2;
};
constexpr constant unsigned NUM_PARTICLES = 1500u;
struct Particle {
metal::float2 pos;
@@ -27,6 +32,7 @@ kernel void main1(
, constant SimParams& params [[buffer(0)]]
, constant Particles& particlesSrc [[buffer(1)]]
, device Particles& particlesDst [[buffer(2)]]
, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]]
) {
metal::float2 vPos;
metal::float2 vVel;

View File

@@ -1,13 +1,18 @@
#include <metal_stdlib>
#include <simd/simd.h>
struct _mslBufferSizes {
metal::uint size0;
};
typedef metal::uint type1[1];
struct PrimeIndices {
type1 data;
};
metal::uint collatz_iterations(
metal::uint n_base
metal::uint n_base,
constant _mslBufferSizes& _buffer_sizes
) {
metal::uint n;
metal::uint i = 0u;
@@ -31,8 +36,9 @@ struct main1Input {
kernel void main1(
metal::uint3 global_id [[thread_position_in_grid]]
, device PrimeIndices& v_indices [[user(fake0)]]
, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]]
) {
metal::uint _e9 = collatz_iterations(v_indices.data[global_id.x]);
metal::uint _e9 = collatz_iterations(v_indices.data[global_id.x], _buffer_sizes);
v_indices.data[global_id.x] = _e9;
return;
}

View File

@@ -1,6 +1,10 @@
#include <metal_stdlib>
#include <simd/simd.h>
struct _mslBufferSizes {
metal::uint size1;
};
constexpr constant unsigned c_max_lights = 10u;
struct Globals {
metal::uint4 num_lights;
@@ -20,7 +24,8 @@ float fetch_shadow(
metal::uint light_id,
metal::float4 homogeneous_coords,
metal::depth2d_array<float, metal::access::sample> t_shadow,
metal::sampler sampler_shadow
metal::sampler sampler_shadow,
constant _mslBufferSizes& _buffer_sizes
) {
if (homogeneous_coords.w <= 0.0) {
return 1.0;
@@ -42,6 +47,7 @@ fragment fs_mainOutput fs_main(
, constant Lights& s_lights [[user(fake0)]]
, metal::depth2d_array<float, metal::access::sample> t_shadow [[user(fake0)]]
, metal::sampler sampler_shadow [[user(fake0)]]
, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]]
) {
const auto raw_normal = varyings.raw_normal;
const auto position = varyings.position;
@@ -57,7 +63,7 @@ fragment fs_mainOutput fs_main(
break;
}
Light _e21 = s_lights.data[i];
float _e25 = fetch_shadow(i, _e21.proj * position, t_shadow, sampler_shadow);
float _e25 = fetch_shadow(i, _e21.proj * position, t_shadow, sampler_shadow, _buffer_sizes);
color1 = color1 + ((_e25 * metal::max(0.0, metal::dot(metal::normalize(raw_normal), metal::normalize(_e21.pos.xyz - position.xyz)))) * _e21.color.xyz);
}
return fs_mainOutput { metal::float4(color1, 1.0) };

View File

@@ -154,7 +154,7 @@ fn check_output_msl(
let options = &params.msl;
#[cfg(not(feature = "deserialize"))]
let options = if params.msl_custom {
println!("Skipping {}", destination);
println!("Skipping {}", destination.display());
return;
} else {
&default_options
@@ -248,7 +248,7 @@ fn convert_wgsl() {
"interpolate",
Targets::SPIRV | Targets::METAL | Targets::GLSL,
),
("access", Targets::SPIRV | Targets::METAL | Targets::WGSL),
("access", Targets::SPIRV | Targets::METAL),
];
for &(name, targets) in inputs.iter() {