[msl] refactor the options, add override stages

This commit is contained in:
Dzmitry Malyshau
2020-10-28 00:30:22 -04:00
parent ce49afa391
commit 587dc01a2c
6 changed files with 89 additions and 44 deletions

View File

@@ -1,8 +1,16 @@
use serde::{Deserialize, Serialize};
use std::{env, fs, path::Path};
#[derive(Hash, PartialEq, Eq, Serialize, Deserialize)]
enum Stage {
Vertex,
Fragment,
Compute,
}
#[derive(Hash, PartialEq, Eq, Serialize, Deserialize)]
struct BindSource {
stage: Stage,
group: u32,
binding: u32,
}
@@ -120,6 +128,11 @@ fn main() {
for (key, value) in params.metal_bindings {
binding_map.insert(
msl::BindSource {
stage: match key.stage {
Stage::Vertex => naga::ShaderStage::Vertex,
Stage::Fragment => naga::ShaderStage::Fragment,
Stage::Compute => naga::ShaderStage::Compute,
},
group: key.group,
binding: key.binding,
},
@@ -132,9 +145,11 @@ fn main() {
);
}
let options = msl::Options {
binding_map: &binding_map,
lang_version: (1, 0),
spirv_cross_compatibility: false,
binding_map,
};
let msl = msl::write_string(&module, options).unwrap();
let msl = msl::write_string(&module, &options).unwrap();
fs::write(&args[2], msl).unwrap();
}
#[cfg(feature = "spv-out")]

View File

@@ -31,6 +31,7 @@ pub struct BindTarget {
#[derive(Clone, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
pub struct BindSource {
pub stage: crate::ShaderStage,
pub group: u32,
pub binding: u32,
}
@@ -102,14 +103,20 @@ enum LocationMode {
Uniform,
}
#[derive(Debug, Clone, Copy)]
pub struct Options<'a> {
pub binding_map: &'a BindingMap,
#[derive(Debug, Default, Clone)]
pub struct Options {
/// (Major, Minor) target version of the Metal Shading Language.
pub lang_version: (u8, u8),
/// Make it possible to link different stages via SPIRV-Cross.
pub spirv_cross_compatibility: bool,
/// Binding model mapping to Metal.
pub binding_map: BindingMap,
}
impl Options<'_> {
impl Options {
fn resolve_binding(
self,
&self,
stage: crate::ShaderStage,
binding: &crate::Binding,
mode: LocationMode,
) -> Result<ResolvedBinding, Error> {
@@ -119,13 +126,21 @@ impl Options<'_> {
LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(index)),
LocationMode::FragmentOutput => Ok(ResolvedBinding::Color(index)),
LocationMode::Intermediate => Ok(ResolvedBinding::User {
prefix: "loc",
prefix: if self.spirv_cross_compatibility {
"locn"
} else {
"loc"
},
index,
}),
LocationMode::Uniform => Err(Error::UnexpectedLocation),
},
crate::Binding::Resource { group, binding } => {
let source = BindSource { group, binding };
let source = BindSource {
stage,
group,
binding,
};
self.binding_map
.get(&source)
.cloned()
@@ -790,7 +805,7 @@ impl<W: Write> Writer<W> {
Ok(())
}
pub fn write(&mut self, module: &crate::Module, options: Options) -> Result<(), Error> {
pub fn write(&mut self, module: &crate::Module, options: &Options) -> Result<(), Error> {
writeln!(self.out, "#include <metal_stdlib>")?;
writeln!(self.out, "#include <simd/simd.h>")?;
writeln!(self.out, "using namespace metal;")?;
@@ -937,7 +952,7 @@ impl<W: Write> Writer<W> {
Ok(())
}
fn write_functions(&mut self, module: &crate::Module, options: Options) -> Result<(), Error> {
fn write_functions(&mut self, module: &crate::Module, options: &Options) -> Result<(), Error> {
for (fun_handle, fun) in module.functions.iter() {
self.typifier.resolve_all(
&fun.expressions,
@@ -1081,7 +1096,7 @@ impl<W: Write> Writer<W> {
handle,
usage: crate::GlobalUse::empty(),
};
let resolved = options.resolve_binding(binding, in_mode)?;
let resolved = options.resolve_binding(stage, binding, in_mode)?;
write!(self.out, "\t")?;
tyvar.try_fmt(&mut self.out)?;
@@ -1128,7 +1143,7 @@ impl<W: Write> Writer<W> {
write!(self.out, "\t")?;
tyvar.try_fmt(&mut self.out)?;
if let Some(ref binding) = var.binding {
let resolved = options.resolve_binding(binding, out_mode)?;
let resolved = options.resolve_binding(stage, binding, out_mode)?;
resolved.try_fmt_decorated(&mut self.out, "")?;
}
writeln!(self.out, ";")?;
@@ -1172,7 +1187,8 @@ impl<W: Write> Writer<W> {
}
_ => LocationMode::Uniform,
};
let resolved = options.resolve_binding(var.binding.as_ref().unwrap(), loc_mode)?;
let resolved =
options.resolve_binding(stage, var.binding.as_ref().unwrap(), loc_mode)?;
let tyvar = TypedGlobalVariable {
module,
handle,
@@ -1214,7 +1230,7 @@ impl<W: Write> Writer<W> {
}
}
pub fn write_string(module: &crate::Module, options: Options) -> Result<String, Error> {
pub fn write_string(module: &crate::Module, options: &Options) -> Result<String, Error> {
let mut w = Writer {
out: String::new(),
typifier: Typifier::new(),

View File

@@ -61,7 +61,7 @@ pub struct Header {
/// For more, see:
/// - https://www.khronos.org/opengl/wiki/Early_Fragment_Test#Explicit_specification
/// - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-attributes-earlydepthstencil
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub struct EarlyDepthTest {
@@ -77,7 +77,7 @@ pub struct EarlyDepthTest {
/// For more, see:
/// - https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt
/// - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-semantics#system-value-semantics
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ConservativeDepth {
@@ -92,7 +92,7 @@ pub enum ConservativeDepth {
}
/// Stage of the programmable pipeline.
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[allow(missing_docs)] // The names are self evident
@@ -103,7 +103,7 @@ pub enum ShaderStage {
}
/// Class of storage for variables.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[allow(missing_docs)] // The names are self evident
@@ -129,7 +129,7 @@ pub enum StorageClass {
}
/// Built-in inputs and outputs.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum BuiltIn {
@@ -158,7 +158,7 @@ pub type Bytes = u8;
/// Number of components in a vector.
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum VectorSize {
@@ -172,7 +172,7 @@ pub enum VectorSize {
/// Primitive type for a scalar.
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ScalarKind {
@@ -188,7 +188,7 @@ pub enum ScalarKind {
/// Size of an array.
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ArraySize {
@@ -199,7 +199,7 @@ pub enum ArraySize {
}
/// Describes where a struct member is placed.
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum MemberOrigin {
@@ -212,7 +212,7 @@ pub enum MemberOrigin {
}
/// The interpolation qualifier of a binding or struct field.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum Interpolation {
@@ -247,7 +247,7 @@ pub struct StructMember {
}
/// The number of dimensions an image has.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ImageDimension {
@@ -274,7 +274,7 @@ bitflags::bitflags! {
}
// Storage image format.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum StorageFormat {
@@ -324,7 +324,7 @@ pub enum StorageFormat {
}
/// Sub-class of the image type.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ImageClass {
@@ -479,7 +479,7 @@ pub struct LocalVariable {
}
/// Operation that can be applied on a single value.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum UnaryOperator {
@@ -488,7 +488,7 @@ pub enum UnaryOperator {
}
/// Operation that can be applied on two values.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum BinaryOperator {
@@ -514,7 +514,7 @@ pub enum BinaryOperator {
}
/// Built-in shader function.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum IntrinsicFunction {
@@ -527,7 +527,7 @@ pub enum IntrinsicFunction {
}
/// Axis on which to compute a derivative.
#[derive(Clone, Copy, Debug, PartialEq)]
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum DerivativeAxis {

View File

@@ -1,7 +1,7 @@
(
metal_bindings: {
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
(group: 0, binding: 1): (buffer: Some(1), mutable: true),
(group: 0, binding: 2): (buffer: Some(2), mutable: true),
(stage: Compute, group: 0, binding: 0): (buffer: Some(0), mutable: false),
(stage: Compute, group: 0, binding: 1): (buffer: Some(1), mutable: true),
(stage: Compute, group: 0, binding: 2): (buffer: Some(2), mutable: true),
}
)

View File

@@ -1,6 +1,6 @@
(
metal_bindings: {
(group: 0, binding: 0): (texture: Some(0)),
(group: 0, binding: 1): (sampler: Some(0)),
(stage: Fragment, group: 0, binding: 0): (texture: Some(0)),
(stage: Fragment, group: 0, binding: 1): (sampler: Some(0)),
}
)

View File

@@ -34,6 +34,7 @@ fn convert_quad() {
let mut binding_map = msl::BindingMap::default();
binding_map.insert(
msl::BindSource {
stage: naga::ShaderStage::Fragment,
group: 0,
binding: 0,
},
@@ -46,6 +47,7 @@ fn convert_quad() {
);
binding_map.insert(
msl::BindSource {
stage: naga::ShaderStage::Fragment,
group: 0,
binding: 1,
},
@@ -57,9 +59,11 @@ fn convert_quad() {
},
);
let options = msl::Options {
binding_map: &binding_map,
lang_version: (1, 0),
spirv_cross_compatibility: false,
binding_map,
};
msl::write_string(&module, options).unwrap();
msl::write_string(&module, &options).unwrap();
}
}
@@ -74,6 +78,7 @@ fn convert_boids() {
let mut binding_map = msl::BindingMap::default();
binding_map.insert(
msl::BindSource {
stage: naga::ShaderStage::Compute,
group: 0,
binding: 0,
},
@@ -86,6 +91,7 @@ fn convert_boids() {
);
binding_map.insert(
msl::BindSource {
stage: naga::ShaderStage::Compute,
group: 0,
binding: 1,
},
@@ -98,6 +104,7 @@ fn convert_boids() {
);
binding_map.insert(
msl::BindSource {
stage: naga::ShaderStage::Compute,
group: 0,
binding: 2,
},
@@ -109,9 +116,11 @@ fn convert_boids() {
},
);
let options = msl::Options {
binding_map: &binding_map,
lang_version: (1, 0),
spirv_cross_compatibility: false,
binding_map,
};
msl::write_string(&module, options).unwrap();
msl::write_string(&module, &options).unwrap();
}
}
@@ -129,6 +138,7 @@ fn convert_cube() {
let mut binding_map = msl::BindingMap::default();
binding_map.insert(
msl::BindSource {
stage: naga::ShaderStage::Vertex,
group: 0,
binding: 0,
},
@@ -141,6 +151,7 @@ fn convert_cube() {
);
binding_map.insert(
msl::BindSource {
stage: naga::ShaderStage::Fragment,
group: 0,
binding: 1,
},
@@ -153,6 +164,7 @@ fn convert_cube() {
);
binding_map.insert(
msl::BindSource {
stage: naga::ShaderStage::Fragment,
group: 0,
binding: 2,
},
@@ -164,10 +176,12 @@ fn convert_cube() {
},
);
let options = msl::Options {
binding_map: &binding_map,
lang_version: (1, 0),
spirv_cross_compatibility: false,
binding_map,
};
msl::write_string(&vs, options).unwrap();
msl::write_string(&fs, options).unwrap();
msl::write_string(&vs, &options).unwrap();
msl::write_string(&fs, &options).unwrap();
}
}