mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
[msl] refactor the options, add override stages
This commit is contained in:
@@ -1,8 +1,16 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{env, fs, path::Path};
|
||||
|
||||
#[derive(Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
enum Stage {
|
||||
Vertex,
|
||||
Fragment,
|
||||
Compute,
|
||||
}
|
||||
|
||||
#[derive(Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
struct BindSource {
|
||||
stage: Stage,
|
||||
group: u32,
|
||||
binding: u32,
|
||||
}
|
||||
@@ -120,6 +128,11 @@ fn main() {
|
||||
for (key, value) in params.metal_bindings {
|
||||
binding_map.insert(
|
||||
msl::BindSource {
|
||||
stage: match key.stage {
|
||||
Stage::Vertex => naga::ShaderStage::Vertex,
|
||||
Stage::Fragment => naga::ShaderStage::Fragment,
|
||||
Stage::Compute => naga::ShaderStage::Compute,
|
||||
},
|
||||
group: key.group,
|
||||
binding: key.binding,
|
||||
},
|
||||
@@ -132,9 +145,11 @@ fn main() {
|
||||
);
|
||||
}
|
||||
let options = msl::Options {
|
||||
binding_map: &binding_map,
|
||||
lang_version: (1, 0),
|
||||
spirv_cross_compatibility: false,
|
||||
binding_map,
|
||||
};
|
||||
let msl = msl::write_string(&module, options).unwrap();
|
||||
let msl = msl::write_string(&module, &options).unwrap();
|
||||
fs::write(&args[2], msl).unwrap();
|
||||
}
|
||||
#[cfg(feature = "spv-out")]
|
||||
|
||||
@@ -31,6 +31,7 @@ pub struct BindTarget {
|
||||
|
||||
#[derive(Clone, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
pub struct BindSource {
|
||||
pub stage: crate::ShaderStage,
|
||||
pub group: u32,
|
||||
pub binding: u32,
|
||||
}
|
||||
@@ -102,14 +103,20 @@ enum LocationMode {
|
||||
Uniform,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Options<'a> {
|
||||
pub binding_map: &'a BindingMap,
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct Options {
|
||||
/// (Major, Minor) target version of the Metal Shading Language.
|
||||
pub lang_version: (u8, u8),
|
||||
/// Make it possible to link different stages via SPIRV-Cross.
|
||||
pub spirv_cross_compatibility: bool,
|
||||
/// Binding model mapping to Metal.
|
||||
pub binding_map: BindingMap,
|
||||
}
|
||||
|
||||
impl Options<'_> {
|
||||
impl Options {
|
||||
fn resolve_binding(
|
||||
self,
|
||||
&self,
|
||||
stage: crate::ShaderStage,
|
||||
binding: &crate::Binding,
|
||||
mode: LocationMode,
|
||||
) -> Result<ResolvedBinding, Error> {
|
||||
@@ -119,13 +126,21 @@ impl Options<'_> {
|
||||
LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(index)),
|
||||
LocationMode::FragmentOutput => Ok(ResolvedBinding::Color(index)),
|
||||
LocationMode::Intermediate => Ok(ResolvedBinding::User {
|
||||
prefix: "loc",
|
||||
prefix: if self.spirv_cross_compatibility {
|
||||
"locn"
|
||||
} else {
|
||||
"loc"
|
||||
},
|
||||
index,
|
||||
}),
|
||||
LocationMode::Uniform => Err(Error::UnexpectedLocation),
|
||||
},
|
||||
crate::Binding::Resource { group, binding } => {
|
||||
let source = BindSource { group, binding };
|
||||
let source = BindSource {
|
||||
stage,
|
||||
group,
|
||||
binding,
|
||||
};
|
||||
self.binding_map
|
||||
.get(&source)
|
||||
.cloned()
|
||||
@@ -790,7 +805,7 @@ impl<W: Write> Writer<W> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write(&mut self, module: &crate::Module, options: Options) -> Result<(), Error> {
|
||||
pub fn write(&mut self, module: &crate::Module, options: &Options) -> Result<(), Error> {
|
||||
writeln!(self.out, "#include <metal_stdlib>")?;
|
||||
writeln!(self.out, "#include <simd/simd.h>")?;
|
||||
writeln!(self.out, "using namespace metal;")?;
|
||||
@@ -937,7 +952,7 @@ impl<W: Write> Writer<W> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_functions(&mut self, module: &crate::Module, options: Options) -> Result<(), Error> {
|
||||
fn write_functions(&mut self, module: &crate::Module, options: &Options) -> Result<(), Error> {
|
||||
for (fun_handle, fun) in module.functions.iter() {
|
||||
self.typifier.resolve_all(
|
||||
&fun.expressions,
|
||||
@@ -1081,7 +1096,7 @@ impl<W: Write> Writer<W> {
|
||||
handle,
|
||||
usage: crate::GlobalUse::empty(),
|
||||
};
|
||||
let resolved = options.resolve_binding(binding, in_mode)?;
|
||||
let resolved = options.resolve_binding(stage, binding, in_mode)?;
|
||||
|
||||
write!(self.out, "\t")?;
|
||||
tyvar.try_fmt(&mut self.out)?;
|
||||
@@ -1128,7 +1143,7 @@ impl<W: Write> Writer<W> {
|
||||
write!(self.out, "\t")?;
|
||||
tyvar.try_fmt(&mut self.out)?;
|
||||
if let Some(ref binding) = var.binding {
|
||||
let resolved = options.resolve_binding(binding, out_mode)?;
|
||||
let resolved = options.resolve_binding(stage, binding, out_mode)?;
|
||||
resolved.try_fmt_decorated(&mut self.out, "")?;
|
||||
}
|
||||
writeln!(self.out, ";")?;
|
||||
@@ -1172,7 +1187,8 @@ impl<W: Write> Writer<W> {
|
||||
}
|
||||
_ => LocationMode::Uniform,
|
||||
};
|
||||
let resolved = options.resolve_binding(var.binding.as_ref().unwrap(), loc_mode)?;
|
||||
let resolved =
|
||||
options.resolve_binding(stage, var.binding.as_ref().unwrap(), loc_mode)?;
|
||||
let tyvar = TypedGlobalVariable {
|
||||
module,
|
||||
handle,
|
||||
@@ -1214,7 +1230,7 @@ impl<W: Write> Writer<W> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_string(module: &crate::Module, options: Options) -> Result<String, Error> {
|
||||
pub fn write_string(module: &crate::Module, options: &Options) -> Result<String, Error> {
|
||||
let mut w = Writer {
|
||||
out: String::new(),
|
||||
typifier: Typifier::new(),
|
||||
|
||||
34
src/lib.rs
34
src/lib.rs
@@ -61,7 +61,7 @@ pub struct Header {
|
||||
/// For more, see:
|
||||
/// - https://www.khronos.org/opengl/wiki/Early_Fragment_Test#Explicit_specification
|
||||
/// - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-attributes-earlydepthstencil
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub struct EarlyDepthTest {
|
||||
@@ -77,7 +77,7 @@ pub struct EarlyDepthTest {
|
||||
/// For more, see:
|
||||
/// - https://www.khronos.org/registry/OpenGL/extensions/ARB/ARB_conservative_depth.txt
|
||||
/// - https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-semantics#system-value-semantics
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum ConservativeDepth {
|
||||
@@ -92,7 +92,7 @@ pub enum ConservativeDepth {
|
||||
}
|
||||
|
||||
/// Stage of the programmable pipeline.
|
||||
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
#[allow(missing_docs)] // The names are self evident
|
||||
@@ -103,7 +103,7 @@ pub enum ShaderStage {
|
||||
}
|
||||
|
||||
/// Class of storage for variables.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
#[allow(missing_docs)] // The names are self evident
|
||||
@@ -129,7 +129,7 @@ pub enum StorageClass {
|
||||
}
|
||||
|
||||
/// Built-in inputs and outputs.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum BuiltIn {
|
||||
@@ -158,7 +158,7 @@ pub type Bytes = u8;
|
||||
|
||||
/// Number of components in a vector.
|
||||
#[repr(u8)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum VectorSize {
|
||||
@@ -172,7 +172,7 @@ pub enum VectorSize {
|
||||
|
||||
/// Primitive type for a scalar.
|
||||
#[repr(u8)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum ScalarKind {
|
||||
@@ -188,7 +188,7 @@ pub enum ScalarKind {
|
||||
|
||||
/// Size of an array.
|
||||
#[repr(u8)]
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum ArraySize {
|
||||
@@ -199,7 +199,7 @@ pub enum ArraySize {
|
||||
}
|
||||
|
||||
/// Describes where a struct member is placed.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum MemberOrigin {
|
||||
@@ -212,7 +212,7 @@ pub enum MemberOrigin {
|
||||
}
|
||||
|
||||
/// The interpolation qualifier of a binding or struct field.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum Interpolation {
|
||||
@@ -247,7 +247,7 @@ pub struct StructMember {
|
||||
}
|
||||
|
||||
/// The number of dimensions an image has.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum ImageDimension {
|
||||
@@ -274,7 +274,7 @@ bitflags::bitflags! {
|
||||
}
|
||||
|
||||
// Storage image format.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum StorageFormat {
|
||||
@@ -324,7 +324,7 @@ pub enum StorageFormat {
|
||||
}
|
||||
|
||||
/// Sub-class of the image type.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum ImageClass {
|
||||
@@ -479,7 +479,7 @@ pub struct LocalVariable {
|
||||
}
|
||||
|
||||
/// Operation that can be applied on a single value.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum UnaryOperator {
|
||||
@@ -488,7 +488,7 @@ pub enum UnaryOperator {
|
||||
}
|
||||
|
||||
/// Operation that can be applied on two values.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum BinaryOperator {
|
||||
@@ -514,7 +514,7 @@ pub enum BinaryOperator {
|
||||
}
|
||||
|
||||
/// Built-in shader function.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum IntrinsicFunction {
|
||||
@@ -527,7 +527,7 @@ pub enum IntrinsicFunction {
|
||||
}
|
||||
|
||||
/// Axis on which to compute a derivative.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum DerivativeAxis {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
(
|
||||
metal_bindings: {
|
||||
(group: 0, binding: 0): (buffer: Some(0), mutable: false),
|
||||
(group: 0, binding: 1): (buffer: Some(1), mutable: true),
|
||||
(group: 0, binding: 2): (buffer: Some(2), mutable: true),
|
||||
(stage: Compute, group: 0, binding: 0): (buffer: Some(0), mutable: false),
|
||||
(stage: Compute, group: 0, binding: 1): (buffer: Some(1), mutable: true),
|
||||
(stage: Compute, group: 0, binding: 2): (buffer: Some(2), mutable: true),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
(
|
||||
metal_bindings: {
|
||||
(group: 0, binding: 0): (texture: Some(0)),
|
||||
(group: 0, binding: 1): (sampler: Some(0)),
|
||||
(stage: Fragment, group: 0, binding: 0): (texture: Some(0)),
|
||||
(stage: Fragment, group: 0, binding: 1): (sampler: Some(0)),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -34,6 +34,7 @@ fn convert_quad() {
|
||||
let mut binding_map = msl::BindingMap::default();
|
||||
binding_map.insert(
|
||||
msl::BindSource {
|
||||
stage: naga::ShaderStage::Fragment,
|
||||
group: 0,
|
||||
binding: 0,
|
||||
},
|
||||
@@ -46,6 +47,7 @@ fn convert_quad() {
|
||||
);
|
||||
binding_map.insert(
|
||||
msl::BindSource {
|
||||
stage: naga::ShaderStage::Fragment,
|
||||
group: 0,
|
||||
binding: 1,
|
||||
},
|
||||
@@ -57,9 +59,11 @@ fn convert_quad() {
|
||||
},
|
||||
);
|
||||
let options = msl::Options {
|
||||
binding_map: &binding_map,
|
||||
lang_version: (1, 0),
|
||||
spirv_cross_compatibility: false,
|
||||
binding_map,
|
||||
};
|
||||
msl::write_string(&module, options).unwrap();
|
||||
msl::write_string(&module, &options).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,6 +78,7 @@ fn convert_boids() {
|
||||
let mut binding_map = msl::BindingMap::default();
|
||||
binding_map.insert(
|
||||
msl::BindSource {
|
||||
stage: naga::ShaderStage::Compute,
|
||||
group: 0,
|
||||
binding: 0,
|
||||
},
|
||||
@@ -86,6 +91,7 @@ fn convert_boids() {
|
||||
);
|
||||
binding_map.insert(
|
||||
msl::BindSource {
|
||||
stage: naga::ShaderStage::Compute,
|
||||
group: 0,
|
||||
binding: 1,
|
||||
},
|
||||
@@ -98,6 +104,7 @@ fn convert_boids() {
|
||||
);
|
||||
binding_map.insert(
|
||||
msl::BindSource {
|
||||
stage: naga::ShaderStage::Compute,
|
||||
group: 0,
|
||||
binding: 2,
|
||||
},
|
||||
@@ -109,9 +116,11 @@ fn convert_boids() {
|
||||
},
|
||||
);
|
||||
let options = msl::Options {
|
||||
binding_map: &binding_map,
|
||||
lang_version: (1, 0),
|
||||
spirv_cross_compatibility: false,
|
||||
binding_map,
|
||||
};
|
||||
msl::write_string(&module, options).unwrap();
|
||||
msl::write_string(&module, &options).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,6 +138,7 @@ fn convert_cube() {
|
||||
let mut binding_map = msl::BindingMap::default();
|
||||
binding_map.insert(
|
||||
msl::BindSource {
|
||||
stage: naga::ShaderStage::Vertex,
|
||||
group: 0,
|
||||
binding: 0,
|
||||
},
|
||||
@@ -141,6 +151,7 @@ fn convert_cube() {
|
||||
);
|
||||
binding_map.insert(
|
||||
msl::BindSource {
|
||||
stage: naga::ShaderStage::Fragment,
|
||||
group: 0,
|
||||
binding: 1,
|
||||
},
|
||||
@@ -153,6 +164,7 @@ fn convert_cube() {
|
||||
);
|
||||
binding_map.insert(
|
||||
msl::BindSource {
|
||||
stage: naga::ShaderStage::Fragment,
|
||||
group: 0,
|
||||
binding: 2,
|
||||
},
|
||||
@@ -164,10 +176,12 @@ fn convert_cube() {
|
||||
},
|
||||
);
|
||||
let options = msl::Options {
|
||||
binding_map: &binding_map,
|
||||
lang_version: (1, 0),
|
||||
spirv_cross_compatibility: false,
|
||||
binding_map,
|
||||
};
|
||||
msl::write_string(&vs, options).unwrap();
|
||||
msl::write_string(&fs, options).unwrap();
|
||||
msl::write_string(&vs, &options).unwrap();
|
||||
msl::write_string(&fs, &options).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user