[msl-out] Replace per_stage_map with per_entry_point_map (#2237)

The existing `per_stage_map` field of MSL backend options specifies
resource binding maps that apply to all entry points of each stage type.
It is useful to have the ability to provide a separate binding index map
for each entry point, especially when the same shader module defines
multiple entry points of the same stage kind.

This patch replaces `per_stage_map` with a new `per_entry_point_map`
option where resources are keyed by the entry-point function name.
This commit is contained in:
Arman Uguray
2023-02-22 09:15:39 -08:00
committed by GitHub
parent 9742f1616c
commit 00be08e9f8
16 changed files with 227 additions and 82 deletions

View File

@@ -27,10 +27,7 @@ holding the result.
*/
use crate::{arena::Handle, proc::index, valid::ModuleInfo};
use std::{
fmt::{Error as FmtError, Write},
ops,
};
use std::fmt::{Error as FmtError, Write};
mod keywords;
pub mod sampler;
@@ -69,7 +66,7 @@ pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, BindTar
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
pub struct PerStageResources {
pub struct EntryPointResources {
pub resources: BindingMap,
pub push_constant_buffer: Option<Slot>,
@@ -80,26 +77,7 @@ pub struct PerStageResources {
pub sizes_buffer: Option<Slot>,
}
#[derive(Clone, Debug, Default, Hash, Eq, PartialEq)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(default))]
pub struct PerStageMap {
pub vs: PerStageResources,
pub fs: PerStageResources,
pub cs: PerStageResources,
}
impl ops::Index<crate::ShaderStage> for PerStageMap {
type Output = PerStageResources;
fn index(&self, stage: crate::ShaderStage) -> &PerStageResources {
match stage {
crate::ShaderStage::Vertex => &self.vs,
crate::ShaderStage::Fragment => &self.fs,
crate::ShaderStage::Compute => &self.cs,
}
}
}
pub type EntryPointResourceMap = std::collections::BTreeMap<String, EntryPointResources>;
enum ResolvedBinding {
BuiltIn(crate::BuiltIn),
@@ -198,8 +176,8 @@ enum LocationMode {
pub struct Options {
/// (Major, Minor) target version of the Metal Shading Language.
pub lang_version: (u8, u8),
/// Map of per-stage resources to slots.
pub per_stage_map: PerStageMap,
/// Map of entry-point resources, indexed by entry point function name, to slots.
pub per_entry_point_map: EntryPointResourceMap,
/// Samplers to be inlined into the code.
pub inline_samplers: Vec<sampler::InlineSampler>,
/// Make it possible to link different stages via SPIRV-Cross.
@@ -217,7 +195,7 @@ impl Default for Options {
fn default() -> Self {
Options {
lang_version: (2, 0),
per_stage_map: PerStageMap::default(),
per_entry_point_map: EntryPointResourceMap::default(),
inline_samplers: Vec::new(),
spirv_cross_compatibility: false,
fake_missing_bindings: true,
@@ -296,12 +274,26 @@ impl Options {
}
}
fn get_entry_point_resources(&self, ep: &crate::EntryPoint) -> Option<&EntryPointResources> {
self.per_entry_point_map.get(&ep.name)
}
fn get_resource_binding_target(
&self,
ep: &crate::EntryPoint,
res_binding: &crate::ResourceBinding,
) -> Option<&BindTarget> {
self.get_entry_point_resources(ep)
.and_then(|res| res.resources.get(res_binding))
}
fn resolve_resource_binding(
&self,
stage: crate::ShaderStage,
ep: &crate::EntryPoint,
res_binding: &crate::ResourceBinding,
) -> Result<ResolvedBinding, EntryPointError> {
match self.per_stage_map[stage].resources.get(res_binding) {
let target = self.get_resource_binding_target(ep, res_binding);
match target {
Some(target) => Ok(ResolvedBinding::Resource(target.clone())),
None if self.fake_missing_bindings => Ok(ResolvedBinding::User {
prefix: "fake",
@@ -312,15 +304,13 @@ impl Options {
}
}
const fn resolve_push_constants(
fn resolve_push_constants(
&self,
stage: crate::ShaderStage,
ep: &crate::EntryPoint,
) -> Result<ResolvedBinding, EntryPointError> {
let slot = match stage {
crate::ShaderStage::Vertex => self.per_stage_map.vs.push_constant_buffer,
crate::ShaderStage::Fragment => self.per_stage_map.fs.push_constant_buffer,
crate::ShaderStage::Compute => self.per_stage_map.cs.push_constant_buffer,
};
let slot = self
.get_entry_point_resources(ep)
.and_then(|res| res.push_constant_buffer);
match slot {
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
buffer: Some(slot),
@@ -340,9 +330,11 @@ impl Options {
fn resolve_sizes_buffer(
&self,
stage: crate::ShaderStage,
ep: &crate::EntryPoint,
) -> Result<ResolvedBinding, EntryPointError> {
let slot = self.per_stage_map[stage].sizes_buffer;
let slot = self
.get_entry_point_resources(ep)
.and_then(|res| res.sizes_buffer);
match slot {
Some(slot) => Ok(ResolvedBinding::Resource(BindTarget {
buffer: Some(slot),

View File

@@ -3406,7 +3406,8 @@ impl<W: Write> Writer<W> {
break;
}
};
let good = match options.per_stage_map[ep.stage].resources.get(br) {
let target = options.get_resource_binding_target(ep, br);
let good = match target {
Some(target) => {
let binding_ty = match module.types[var.ty].inner {
crate::TypeInner::BindingArray { base, .. } => {
@@ -3431,7 +3432,7 @@ impl<W: Write> Writer<W> {
}
}
crate::AddressSpace::PushConstant => {
if let Err(e) = options.resolve_push_constants(ep.stage) {
if let Err(e) = options.resolve_push_constants(ep) {
ep_error = Some(e);
break;
}
@@ -3442,7 +3443,7 @@ impl<W: Write> Writer<W> {
}
}
if supports_array_length {
if let Err(err) = options.resolve_sizes_buffer(ep.stage) {
if let Err(err) = options.resolve_sizes_buffer(ep) {
ep_error = Some(err);
}
}
@@ -3711,15 +3712,13 @@ impl<W: Write> Writer<W> {
}
// the resolves have already been checked for `!fake_missing_bindings` case
let resolved = match var.space {
crate::AddressSpace::PushConstant => {
options.resolve_push_constants(ep.stage).ok()
}
crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(),
crate::AddressSpace::WorkGroup => None,
crate::AddressSpace::Storage { .. } if options.lang_version < (2, 0) => {
return Err(Error::UnsupportedAddressSpace(var.space))
}
_ => options
.resolve_resource_binding(ep.stage, var.binding.as_ref().unwrap())
.resolve_resource_binding(ep, var.binding.as_ref().unwrap())
.ok(),
};
if let Some(ref resolved) = resolved {
@@ -3764,7 +3763,7 @@ impl<W: Write> Writer<W> {
// passed as a final struct-typed argument.
if supports_array_length {
// this is checked earlier
let resolved = options.resolve_sizes_buffer(ep.stage).unwrap();
let resolved = options.resolve_sizes_buffer(ep).unwrap();
let separator = if module.global_variables.is_empty() {
' '
} else {
@@ -3824,7 +3823,7 @@ impl<W: Write> Writer<W> {
};
} else if let Some(ref binding) = var.binding {
// write an inline sampler
let resolved = options.resolve_resource_binding(ep.stage, binding).unwrap();
let resolved = options.resolve_resource_binding(ep, binding).unwrap();
if let Some(sampler) = resolved.as_inline_sampler(options) {
let name = &self.names[&NameKey::GlobalVariable(handle)];
writeln!(

View File

@@ -6,8 +6,8 @@
),
msl: (
lang_version: (2, 0),
per_stage_map: (
vs: (
per_entry_point_map: {
"foo_vert": (
resources: {
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
(group: 0, binding: 1): (buffer: Some(1), mutable: false),
@@ -16,20 +16,20 @@
},
sizes_buffer: Some(24),
),
fs: (
"foo_frag": (
resources: {
(group: 0, binding: 0): (buffer: Some(0), mutable: true),
(group: 0, binding: 2): (buffer: Some(2), mutable: true),
},
sizes_buffer: Some(24),
),
cs: (
"atomics": (
resources: {
(group: 0, binding: 0): (buffer: Some(0), mutable: true),
},
sizes_buffer: Some(24),
),
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,

View File

@@ -19,14 +19,14 @@
),
msl: (
lang_version: (2, 0),
per_stage_map: (
fs: (
per_entry_point_map: {
"main": (
resources: {
(group: 0, binding: 0): (texture: Some(0), binding_array_size: Some(10), mutable: false),
},
sizes_buffer: None,
)
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: true,

View File

@@ -1,13 +1,13 @@
(
msl: (
lang_version: (1, 2),
per_stage_map: (
cs: (
per_entry_point_map: {
"main": (
resources: {
},
sizes_buffer: Some(0),
)
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,

View File

@@ -1,13 +1,13 @@
(
msl: (
lang_version: (1, 2),
per_stage_map: (
cs: (
per_entry_point_map: {
"main": (
resources: {
},
sizes_buffer: Some(0),
)
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,

View File

@@ -6,8 +6,8 @@
),
msl: (
lang_version: (2, 0),
per_stage_map: (
cs: (
per_entry_point_map: {
"main": (
resources: {
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
(group: 0, binding: 1): (buffer: Some(1), mutable: true),
@@ -15,7 +15,7 @@
},
sizes_buffer: Some(3),
)
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,

View File

@@ -5,11 +5,11 @@
),
msl: (
lang_version: (2, 2),
per_stage_map: (
fs: (
per_entry_point_map: {
"main": (
push_constant_buffer: Some(1),
),
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,

View File

@@ -19,7 +19,7 @@
),
msl: (
lang_version: (2, 1),
per_stage_map: (),
per_entry_point_map: {},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,

View File

@@ -6,15 +6,15 @@
),
msl: (
lang_version: (2, 0),
per_stage_map: (
vs: (
per_entry_point_map: {
"vertex": (
resources: {
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
(group: 0, binding: 1): (buffer: Some(1), mutable: false),
(group: 0, binding: 2): (buffer: Some(2), mutable: false),
},
),
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,

View File

@@ -0,0 +1,53 @@
(
god_mode: true,
msl: (
lang_version: (2, 0),
per_entry_point_map: {
"entry_point_one": (
resources: {
(group: 0, binding: 0): (texture: Some(0)),
(group: 0, binding: 2): (sampler: Some(Inline(0))),
(group: 0, binding: 4): (buffer: Some(0)),
}
),
"entry_point_two": (
resources: {
(group: 0, binding: 0): (texture: Some(0)),
(group: 0, binding: 2): (sampler: Some(Resource(0))),
(group: 0, binding: 4): (buffer: Some(0)),
}
),
"entry_point_three": (
resources: {
(group: 0, binding: 0): (texture: Some(0)),
(group: 0, binding: 1): (texture: Some(1)),
(group: 0, binding: 2): (sampler: Some(Inline(0))),
(group: 0, binding: 3): (sampler: Some(Resource(1))),
(group: 0, binding: 4): (buffer: Some(0)),
(group: 1, binding: 0): (buffer: Some(1)),
}
)
},
inline_samplers: [
(
coord: Normalized,
address: (ClampToEdge, ClampToEdge, ClampToEdge),
mag_filter: Linear,
min_filter: Linear,
mip_filter: None,
border_color: TransparentBlack,
compare_func: Never,
lod_clamp: Some((start: 0.5, end: 10.0)),
max_anisotropy: Some(8),
),
],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
zero_initialize_workgroup_memory: true,
),
bounds_check_policies: (
index: ReadZeroSkipWrite,
buffer: ReadZeroSkipWrite,
image: ReadZeroSkipWrite,
)
)

View File

@@ -0,0 +1,23 @@
@group(0) @binding(0) var t1: texture_2d<f32>;
@group(0) @binding(1) var t2: texture_2d<f32>;
@group(0) @binding(2) var s1: sampler;
@group(0) @binding(3) var s2: sampler;
@group(0) @binding(4) var<uniform> uniformOne: vec2<f32>;
@group(1) @binding(0) var<uniform> uniformTwo: vec2<f32>;
@fragment
fn entry_point_one(@builtin(position) pos: vec4<f32>) -> @location(0) vec4<f32> {
return textureSample(t1, s1, pos.xy);
}
@fragment
fn entry_point_two() -> @location(0) vec4<f32> {
return textureSample(t1, s1, uniformOne);
}
@fragment
fn entry_point_three() -> @location(0) vec4<f32> {
return textureSample(t1, s1, uniformTwo + uniformOne) +
textureSample(t2, s2, uniformOne);
}

View File

@@ -7,19 +7,19 @@
),
msl: (
lang_version: (2, 1),
per_stage_map: (
vs: (
per_entry_point_map: {
"vs_main": (
resources: {
(group: 0, binding: 0): (buffer: Some(0)),
},
),
fs: (
"fs_main": (
resources: {
(group: 0, binding: 1): (texture: Some(0)),
(group: 0, binding: 2): (sampler: Some(Inline(0))),
},
),
),
},
inline_samplers: [
(
coord: Normalized,

View File

@@ -6,17 +6,17 @@
),
msl: (
lang_version: (2, 0),
per_stage_map: (
cs: (
per_entry_point_map: {
"main": (
resources: {
(group: 0, binding: 0): (buffer: Some(0), mutable: true),
},
sizes_buffer: None,
),
),
},
inline_samplers: [],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
zero_initialize_workgroup_memory: true,
),
)
)

View File

@@ -0,0 +1,74 @@
// language: metal2.0
#include <metal_stdlib>
#include <simd/simd.h>
using metal::uint;
struct DefaultConstructible {
template<typename T>
operator T() && {
return T {};
}
};
struct entry_point_oneInput {
};
struct entry_point_oneOutput {
metal::float4 member [[color(0)]];
};
fragment entry_point_oneOutput entry_point_one(
metal::float4 pos [[position]]
, metal::texture2d<float, metal::access::sample> t1_ [[texture(0)]]
) {
constexpr metal::sampler s1_(
metal::s_address::clamp_to_edge,
metal::t_address::clamp_to_edge,
metal::r_address::clamp_to_edge,
metal::mag_filter::linear,
metal::min_filter::linear,
metal::coord::normalized
);
metal::float4 _e4 = t1_.sample(s1_, pos.xy);
return entry_point_oneOutput { _e4 };
}
struct entry_point_twoOutput {
metal::float4 member_1 [[color(0)]];
};
fragment entry_point_twoOutput entry_point_two(
metal::texture2d<float, metal::access::sample> t1_ [[texture(0)]]
, metal::sampler s1_ [[sampler(0)]]
, constant metal::float2& uniformOne [[buffer(0)]]
) {
metal::float2 _e3 = uniformOne;
metal::float4 _e4 = t1_.sample(s1_, _e3);
return entry_point_twoOutput { _e4 };
}
struct entry_point_threeOutput {
metal::float4 member_2 [[color(0)]];
};
fragment entry_point_threeOutput entry_point_three(
metal::texture2d<float, metal::access::sample> t1_ [[texture(0)]]
, metal::texture2d<float, metal::access::sample> t2_ [[texture(1)]]
, metal::sampler s2_ [[sampler(1)]]
, constant metal::float2& uniformOne [[buffer(0)]]
, constant metal::float2& uniformTwo [[buffer(1)]]
) {
constexpr metal::sampler s1_(
metal::s_address::clamp_to_edge,
metal::t_address::clamp_to_edge,
metal::r_address::clamp_to_edge,
metal::mag_filter::linear,
metal::min_filter::linear,
metal::coord::normalized
);
metal::float2 _e3 = uniformTwo;
metal::float2 _e5 = uniformOne;
metal::float4 _e7 = t1_.sample(s1_, _e3 + _e5);
metal::float2 _e11 = uniformOne;
metal::float4 _e12 = t2_.sample(s2_, _e11);
return entry_point_threeOutput { _e7 + _e12 };
}

View File

@@ -90,8 +90,11 @@ struct Parameters {
#[allow(unused_variables)]
fn check_targets(module: &naga::Module, name: &str, targets: Targets) {
let root = env!("CARGO_MANIFEST_DIR");
let params = match fs::read_to_string(format!("{root}/{BASE_DIR_IN}/{name}.param.ron")) {
Ok(string) => ron::de::from_str(&string).expect("Couldn't parse param file"),
let filepath = format!("{root}/{BASE_DIR_IN}/{name}.param.ron");
let params = match fs::read_to_string(&filepath) {
Ok(string) => {
ron::de::from_str(&string).expect(&format!("Couldn't parse param file: {}", filepath))
}
Err(_) => Parameters::default(),
};
@@ -543,6 +546,7 @@ fn convert_wgsl() {
"binding-arrays",
Targets::WGSL | Targets::HLSL | Targets::METAL | Targets::SPIRV,
),
("resource-binding-map", Targets::METAL),
("multiview", Targets::SPIRV | Targets::GLSL | Targets::WGSL),
("multiview_webgl", Targets::GLSL),
(