From 587dc01a2cc43fdf4637f1bf6d2b75b703e9f681 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Wed, 28 Oct 2020 00:30:22 -0400 Subject: [PATCH] [msl] refactor the options, add override stages --- examples/convert.rs | 19 ++++++++++++++++-- src/back/msl.rs | 42 +++++++++++++++++++++++++++------------ src/lib.rs | 34 +++++++++++++++---------------- test-data/boids.param.ron | 6 +++--- test-data/quad.param.ron | 4 ++-- tests/convert.rs | 28 +++++++++++++++++++------- 6 files changed, 89 insertions(+), 44 deletions(-) diff --git a/examples/convert.rs b/examples/convert.rs index a8160c14cc..7f72548a5f 100644 --- a/examples/convert.rs +++ b/examples/convert.rs @@ -1,8 +1,16 @@ use serde::{Deserialize, Serialize}; use std::{env, fs, path::Path}; +#[derive(Hash, PartialEq, Eq, Serialize, Deserialize)] +enum Stage { + Vertex, + Fragment, + Compute, +} + #[derive(Hash, PartialEq, Eq, Serialize, Deserialize)] struct BindSource { + stage: Stage, group: u32, binding: u32, } @@ -120,6 +128,11 @@ fn main() { for (key, value) in params.metal_bindings { binding_map.insert( msl::BindSource { + stage: match key.stage { + Stage::Vertex => naga::ShaderStage::Vertex, + Stage::Fragment => naga::ShaderStage::Fragment, + Stage::Compute => naga::ShaderStage::Compute, + }, group: key.group, binding: key.binding, }, @@ -132,9 +145,11 @@ fn main() { ); } let options = msl::Options { - binding_map: &binding_map, + lang_version: (1, 0), + spirv_cross_compatibility: false, + binding_map, }; - let msl = msl::write_string(&module, options).unwrap(); + let msl = msl::write_string(&module, &options).unwrap(); fs::write(&args[2], msl).unwrap(); } #[cfg(feature = "spv-out")] diff --git a/src/back/msl.rs b/src/back/msl.rs index 007fa85aea..c14be2f8f9 100644 --- a/src/back/msl.rs +++ b/src/back/msl.rs @@ -31,6 +31,7 @@ pub struct BindTarget { #[derive(Clone, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] pub struct BindSource { + pub stage: crate::ShaderStage, pub group: u32, pub binding: u32, } @@ -102,14 +103,20 @@ enum LocationMode { Uniform, } -#[derive(Debug, Clone, Copy)] -pub struct Options<'a> { - pub binding_map: &'a BindingMap, +#[derive(Debug, Default, Clone)] +pub struct Options { + /// (Major, Minor) target version of the Metal Shading Language. + pub lang_version: (u8, u8), + /// Make it possible to link different stages via SPIRV-Cross. + pub spirv_cross_compatibility: bool, + /// Binding model mapping to Metal. + pub binding_map: BindingMap, } -impl Options<'_> { +impl Options { fn resolve_binding( - self, + &self, + stage: crate::ShaderStage, binding: &crate::Binding, mode: LocationMode, ) -> Result { @@ -119,13 +126,21 @@ impl Options<'_> { LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(index)), LocationMode::FragmentOutput => Ok(ResolvedBinding::Color(index)), LocationMode::Intermediate => Ok(ResolvedBinding::User { - prefix: "loc", + prefix: if self.spirv_cross_compatibility { + "locn" + } else { + "loc" + }, index, }), LocationMode::Uniform => Err(Error::UnexpectedLocation), }, crate::Binding::Resource { group, binding } => { - let source = BindSource { group, binding }; + let source = BindSource { + stage, + group, + binding, + }; self.binding_map .get(&source) .cloned() @@ -790,7 +805,7 @@ impl Writer { Ok(()) } - pub fn write(&mut self, module: &crate::Module, options: Options) -> Result<(), Error> { + pub fn write(&mut self, module: &crate::Module, options: &Options) -> Result<(), Error> { writeln!(self.out, "#include ")?; writeln!(self.out, "#include ")?; writeln!(self.out, "using namespace metal;")?; @@ -937,7 +952,7 @@ impl Writer { Ok(()) } - fn write_functions(&mut self, module: &crate::Module, options: Options) -> Result<(), Error> { + fn write_functions(&mut self, module: &crate::Module, options: &Options) -> Result<(), Error> { for (fun_handle, fun) in module.functions.iter() { self.typifier.resolve_all( &fun.expressions, @@ -1081,7 +1096,7 @@ impl Writer { handle, usage: crate::GlobalUse::empty(), }; - let resolved = options.resolve_binding(binding, in_mode)?; + let resolved = options.resolve_binding(stage, binding, in_mode)?; write!(self.out, "\t")?; tyvar.try_fmt(&mut self.out)?; @@ -1128,7 +1143,7 @@ impl Writer { write!(self.out, "\t")?; tyvar.try_fmt(&mut self.out)?; if let Some(ref binding) = var.binding { - let resolved = options.resolve_binding(binding, out_mode)?; + let resolved = options.resolve_binding(stage, binding, out_mode)?; resolved.try_fmt_decorated(&mut self.out, "")?; } writeln!(self.out, ";")?; @@ -1172,7 +1187,8 @@ impl Writer { } _ => LocationMode::Uniform, }; - let resolved = options.resolve_binding(var.binding.as_ref().unwrap(), loc_mode)?; + let resolved = + options.resolve_binding(stage, var.binding.as_ref().unwrap(), loc_mode)?; let tyvar = TypedGlobalVariable { module, handle, @@ -1214,7 +1230,7 @@ impl Writer { } } -pub fn write_string(module: &crate::Module, options: Options) -> Result { +pub fn write_string(module: &crate::Module, options: &Options) -> Result { let mut w = Writer { out: String::new(), typifier: Typifier::new(), diff --git a/src/lib.rs b/src/lib.rs index 63ca33f9a2..c8c1b0107b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -61,7 +61,7 @@ pub struct Header { /// For more, see: /// - https://www.khronos.org/opengl/wiki/Early_Fragment_Test#Explicit_specification /// - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-attributes-earlydepthstencil -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub struct EarlyDepthTest { @@ -77,7 +77,7 @@ pub struct EarlyDepthTest { /// For more, see: /// - https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt /// - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-semantics#system-value-semantics -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum ConservativeDepth { @@ -92,7 +92,7 @@ pub enum ConservativeDepth { } /// Stage of the programmable pipeline. -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[allow(missing_docs)] // The names are self evident @@ -103,7 +103,7 @@ pub enum ShaderStage { } /// Class of storage for variables. -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[allow(missing_docs)] // The names are self evident @@ -129,7 +129,7 @@ pub enum StorageClass { } /// Built-in inputs and outputs. -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum BuiltIn { @@ -158,7 +158,7 @@ pub type Bytes = u8; /// Number of components in a vector. #[repr(u8)] -#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum VectorSize { @@ -172,7 +172,7 @@ pub enum VectorSize { /// Primitive type for a scalar. #[repr(u8)] -#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum ScalarKind { @@ -188,7 +188,7 @@ pub enum ScalarKind { /// Size of an array. #[repr(u8)] -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum ArraySize { @@ -199,7 +199,7 @@ pub enum ArraySize { } /// Describes where a struct member is placed. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum MemberOrigin { @@ -212,7 +212,7 @@ pub enum MemberOrigin { } /// The interpolation qualifier of a binding or struct field. -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum Interpolation { @@ -247,7 +247,7 @@ pub struct StructMember { } /// The number of dimensions an image has. -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum ImageDimension { @@ -274,7 +274,7 @@ bitflags::bitflags! { } // Storage image format. -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum StorageFormat { @@ -324,7 +324,7 @@ pub enum StorageFormat { } /// Sub-class of the image type. -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum ImageClass { @@ -479,7 +479,7 @@ pub struct LocalVariable { } /// Operation that can be applied on a single value. -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum UnaryOperator { @@ -488,7 +488,7 @@ pub enum UnaryOperator { } /// Operation that can be applied on two values. -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum BinaryOperator { @@ -514,7 +514,7 @@ pub enum BinaryOperator { } /// Built-in shader function. -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum IntrinsicFunction { @@ -527,7 +527,7 @@ pub enum IntrinsicFunction { } /// Axis on which to compute a derivative. -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum DerivativeAxis { diff --git a/test-data/boids.param.ron b/test-data/boids.param.ron index 4f5503c1cc..32fe7e8e32 100644 --- a/test-data/boids.param.ron +++ b/test-data/boids.param.ron @@ -1,7 +1,7 @@ ( metal_bindings: { - (group: 0, binding: 0): (buffer: Some(0), mutable: false), - (group: 0, binding: 1): (buffer: Some(1), mutable: true), - (group: 0, binding: 2): (buffer: Some(2), mutable: true), + (stage: Compute, group: 0, binding: 0): (buffer: Some(0), mutable: false), + (stage: Compute, group: 0, binding: 1): (buffer: Some(1), mutable: true), + (stage: Compute, group: 0, binding: 2): (buffer: Some(2), mutable: true), } ) diff --git a/test-data/quad.param.ron b/test-data/quad.param.ron index cd26cc9e8f..c4245bd185 100644 --- a/test-data/quad.param.ron +++ b/test-data/quad.param.ron @@ -1,6 +1,6 @@ ( metal_bindings: { - (group: 0, binding: 0): (texture: Some(0)), - (group: 0, binding: 1): (sampler: Some(0)), + (stage: Fragment, group: 0, binding: 0): (texture: Some(0)), + (stage: Fragment, group: 0, binding: 1): (sampler: Some(0)), } ) diff --git a/tests/convert.rs b/tests/convert.rs index 99c0e8b0dc..50ed7288db 100644 --- a/tests/convert.rs +++ b/tests/convert.rs @@ -34,6 +34,7 @@ fn convert_quad() { let mut binding_map = msl::BindingMap::default(); binding_map.insert( msl::BindSource { + stage: naga::ShaderStage::Fragment, group: 0, binding: 0, }, @@ -46,6 +47,7 @@ fn convert_quad() { ); binding_map.insert( msl::BindSource { + stage: naga::ShaderStage::Fragment, group: 0, binding: 1, }, @@ -57,9 +59,11 @@ fn convert_quad() { }, ); let options = msl::Options { - binding_map: &binding_map, + lang_version: (1, 0), + spirv_cross_compatibility: false, + binding_map, }; - msl::write_string(&module, options).unwrap(); + msl::write_string(&module, &options).unwrap(); } } @@ -74,6 +78,7 @@ fn convert_boids() { let mut binding_map = msl::BindingMap::default(); binding_map.insert( msl::BindSource { + stage: naga::ShaderStage::Compute, group: 0, binding: 0, }, @@ -86,6 +91,7 @@ fn convert_boids() { ); binding_map.insert( msl::BindSource { + stage: naga::ShaderStage::Compute, group: 0, binding: 1, }, @@ -98,6 +104,7 @@ fn convert_boids() { ); binding_map.insert( msl::BindSource { + stage: naga::ShaderStage::Compute, group: 0, binding: 2, }, @@ -109,9 +116,11 @@ fn convert_boids() { }, ); let options = msl::Options { - binding_map: &binding_map, + lang_version: (1, 0), + spirv_cross_compatibility: false, + binding_map, }; - msl::write_string(&module, options).unwrap(); + msl::write_string(&module, &options).unwrap(); } } @@ -129,6 +138,7 @@ fn convert_cube() { let mut binding_map = msl::BindingMap::default(); binding_map.insert( msl::BindSource { + stage: naga::ShaderStage::Vertex, group: 0, binding: 0, }, @@ -141,6 +151,7 @@ fn convert_cube() { ); binding_map.insert( msl::BindSource { + stage: naga::ShaderStage::Fragment, group: 0, binding: 1, }, @@ -153,6 +164,7 @@ fn convert_cube() { ); binding_map.insert( msl::BindSource { + stage: naga::ShaderStage::Fragment, group: 0, binding: 2, }, @@ -164,10 +176,12 @@ fn convert_cube() { }, ); let options = msl::Options { - binding_map: &binding_map, + lang_version: (1, 0), + spirv_cross_compatibility: false, + binding_map, }; - msl::write_string(&vs, options).unwrap(); - msl::write_string(&fs, options).unwrap(); + msl::write_string(&vs, &options).unwrap(); + msl::write_string(&fs, &options).unwrap(); } }