diff --git a/Cargo.toml b/Cargo.toml index 55d65a7534..b6ac07d55d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ bitflags = "1" fxhash = "0.2" log = "0.4" num-traits = "0.2" -spirv = { package = "spirv_headers", version = "1.4.2", optional = true } +spirv = { package = "spirv_headers", version = "1.5", optional = true } pomelo = { version = "0.1.4", optional = true } thiserror = "1.0.21" serde = { version = "1.0", features = ["derive"], optional = true } @@ -37,3 +37,4 @@ difference = "2.0" env_logger = "0.6" ron = "0.6" serde = { version = "1.0", features = ["derive"] } +spirv = { package = "spirv_headers", version = "1.5", features = ["deserialize"] } diff --git a/examples/convert.rs b/examples/convert.rs index 2589168ab8..d4801fc93b 100644 --- a/examples/convert.rs +++ b/examples/convert.rs @@ -1,21 +1,20 @@ -use serde::{Deserialize, Serialize}; use std::{env, fs, path::Path}; -#[derive(Hash, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Hash, PartialEq, Eq, serde::Deserialize)] enum Stage { Vertex, Fragment, Compute, } -#[derive(Hash, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Hash, PartialEq, Eq, serde::Deserialize)] struct BindSource { stage: Stage, group: u32, binding: u32, } -#[derive(Serialize, Deserialize)] +#[derive(serde::Deserialize)] struct BindTarget { #[serde(default)] buffer: Option, @@ -27,11 +26,15 @@ struct BindTarget { mutable: bool, } -#[derive(Default, Serialize, Deserialize)] +#[derive(Default, serde::Deserialize)] struct Parameters { #[serde(default)] + #[cfg_attr(not(feature = "spv-out"), allow(dead_code))] spv_flow_dump_prefix: String, - metal_bindings: naga::FastHashMap, + #[cfg_attr(not(feature = "spv-out"), allow(dead_code))] + spv_capabilities: naga::FastHashSet, + #[cfg_attr(not(feature = "msl-out"), allow(dead_code))] + mtl_bindings: naga::FastHashMap, } fn main() { @@ -138,7 +141,7 @@ fn main() { "metal" => { use naga::back::msl; let mut binding_map = msl::BindingMap::default(); - for (key, value) in params.metal_bindings { + for (key, value) in params.mtl_bindings { binding_map.insert( msl::BindSource { stage: match key.stage { @@ -177,8 +180,7 @@ fn main() { } }); - let capabilities = Default::default(); //TODO - let spv = spv::write_vec(&module, debug_flag, capabilities).unwrap(); + let spv = spv::write_vec(&module, debug_flag, params.spv_capabilities).unwrap(); let bytes = spv .iter() diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 3c8186e87b..50d354dc9e 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -112,6 +112,7 @@ impl<'a, T> ops::Deref for MaybeOwned<'a, T> { } } +#[derive(Debug)] enum Dimension { Scalar, Vector, @@ -1101,14 +1102,10 @@ impl Writer { function, )?; - let left_lookup_ty = left_lookup_ty; - let right_lookup_ty = right_lookup_ty; - let left_ty_inner = self.get_type_inner(&ir_module.types, left_lookup_ty); let right_ty_inner = self.get_type_inner(&ir_module.types, right_lookup_ty); let left_result_type_id = self.get_type_id(&ir_module.types, left_lookup_ty)?; - let right_result_type_id = self.get_type_id(&ir_module.types, right_lookup_ty)?; let left_id = match *left_expression { @@ -1143,56 +1140,63 @@ impl Writer { let left_dimension = get_dimension(&left_ty_inner); let right_dimension = get_dimension(&right_ty_inner); - let (spirv_op, lookup_ty) = match op { + let mut result_side_left = true; + let mut preserve_order = true; + + let spirv_op = match op { crate::BinaryOperator::Add => match *left_ty_inner { crate::TypeInner::Scalar { kind, .. } | crate::TypeInner::Vector { kind, .. } => match kind { - crate::ScalarKind::Float => (spirv::Op::FAdd, left_lookup_ty), - _ => (spirv::Op::IAdd, left_lookup_ty), + crate::ScalarKind::Float => spirv::Op::FAdd, + _ => spirv::Op::IAdd, }, _ => unreachable!(), }, crate::BinaryOperator::Subtract => match *left_ty_inner { crate::TypeInner::Scalar { kind, .. } | crate::TypeInner::Vector { kind, .. } => match kind { - crate::ScalarKind::Float => (spirv::Op::FSub, left_lookup_ty), - _ => (spirv::Op::ISub, left_lookup_ty), + crate::ScalarKind::Float => spirv::Op::FSub, + _ => spirv::Op::ISub, }, _ => unreachable!(), }, - crate::BinaryOperator::Multiply => match (left_dimension, right_dimension) { - (Dimension::Vector, Dimension::Scalar { .. }) => { - (spirv::Op::VectorTimesScalar, left_lookup_ty) + crate::BinaryOperator::Multiply => { + // whenever there is a vector on the right, + // the result type is a vector. + if let Dimension::Vector = right_dimension { + result_side_left = false; } - (Dimension::Vector, Dimension::Matrix) => { - (spirv::Op::VectorTimesMatrix, left_lookup_ty) + match (left_dimension, right_dimension) { + (Dimension::Scalar, Dimension::Vector { .. }) => { + preserve_order = false; + spirv::Op::VectorTimesScalar + } + (Dimension::Vector, Dimension::Scalar { .. }) => { + spirv::Op::VectorTimesScalar + } + (Dimension::Vector, Dimension::Matrix) => spirv::Op::VectorTimesMatrix, + (Dimension::Matrix, Dimension::Scalar { .. }) => { + spirv::Op::MatrixTimesScalar + } + (Dimension::Matrix, Dimension::Vector) => spirv::Op::MatrixTimesVector, + (Dimension::Matrix, Dimension::Matrix) => spirv::Op::MatrixTimesMatrix, + (Dimension::Vector, Dimension::Vector) + | (Dimension::Scalar, Dimension::Scalar) + if left_ty_inner.scalar_kind() + == Some(crate::ScalarKind::Float) => + { + spirv::Op::FMul + } + (Dimension::Vector, Dimension::Vector) + | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul, + other => unreachable!("Mul {:?}", other), } - (Dimension::Matrix, Dimension::Scalar { .. }) => { - (spirv::Op::MatrixTimesScalar, left_lookup_ty) - } - (Dimension::Matrix, Dimension::Vector) => { - (spirv::Op::MatrixTimesVector, right_lookup_ty) - } - (Dimension::Matrix, Dimension::Matrix) => { - (spirv::Op::MatrixTimesMatrix, left_lookup_ty) - } - (Dimension::Vector, Dimension::Vector) - | (Dimension::Scalar, Dimension::Scalar) - if left_ty_inner.scalar_kind() == Some(crate::ScalarKind::Float) => - { - (spirv::Op::FMul, left_lookup_ty) - } - (Dimension::Vector, Dimension::Vector) - | (Dimension::Scalar, Dimension::Scalar) => { - (spirv::Op::IMul, left_lookup_ty) - } - _ => unreachable!(), - }, + } crate::BinaryOperator::Divide => match *left_ty_inner { crate::TypeInner::Scalar { kind, .. } | crate::TypeInner::Vector { kind, .. } => match kind { - crate::ScalarKind::Sint => (spirv::Op::SDiv, left_lookup_ty), - crate::ScalarKind::Uint => (spirv::Op::UDiv, left_lookup_ty), + crate::ScalarKind::Sint => spirv::Op::SDiv, + crate::ScalarKind::Uint => spirv::Op::UDiv, _ => unreachable!(), }, _ => unreachable!(), @@ -1200,28 +1204,33 @@ impl Writer { crate::BinaryOperator::Modulo => match *left_ty_inner { crate::TypeInner::Scalar { kind, .. } | crate::TypeInner::Vector { kind, .. } => match kind { - crate::ScalarKind::Sint => (spirv::Op::SMod, left_lookup_ty), - crate::ScalarKind::Uint => (spirv::Op::UMod, left_lookup_ty), - crate::ScalarKind::Float => (spirv::Op::FMod, left_lookup_ty), + crate::ScalarKind::Sint => spirv::Op::SMod, + crate::ScalarKind::Uint => spirv::Op::UMod, + crate::ScalarKind::Float => spirv::Op::FMod, _ => unreachable!(), }, _ => unreachable!(), }, - crate::BinaryOperator::And => (spirv::Op::BitwiseAnd, left_lookup_ty), + crate::BinaryOperator::And => spirv::Op::BitwiseAnd, _ => { log::error!("unimplemented {:?}", op); return Err(Error::FeatureNotImplemented("binary operator")); } }; + let (result_type_id, result_lookup_ty) = if result_side_left { + (left_result_type_id, left_lookup_ty) + } else { + (right_result_type_id, right_lookup_ty) + }; block.body.push(super::instructions::instruction_binary( spirv_op, - left_result_type_id, + result_type_id, id, - left_id, - right_id, + if preserve_order { left_id } else { right_id }, + if preserve_order { right_id } else { left_id }, )); - Ok((id, lookup_ty)) + Ok((id, result_lookup_ty)) } crate::Expression::LocalVariable(variable) => { let var = &ir_function.local_variables[variable]; diff --git a/test-data/boids.param.ron b/test-data/boids.param.ron index 4de90960bf..da2365392a 100644 --- a/test-data/boids.param.ron +++ b/test-data/boids.param.ron @@ -1,6 +1,7 @@ ( spv_flow_dump_prefix: "", - metal_bindings: { + spv_capabilities: [ Shader ], + mtl_bindings: { (stage: Compute, group: 0, binding: 0): (buffer: Some(0), mutable: false), (stage: Compute, group: 0, binding: 1): (buffer: Some(1), mutable: true), (stage: Compute, group: 0, binding: 2): (buffer: Some(2), mutable: true), diff --git a/test-data/quad.param.ron b/test-data/quad.param.ron index c4245bd185..25747c1694 100644 --- a/test-data/quad.param.ron +++ b/test-data/quad.param.ron @@ -1,5 +1,6 @@ ( - metal_bindings: { + spv_capabilities: [ Shader ], + mtl_bindings: { (stage: Fragment, group: 0, binding: 0): (texture: Some(0)), (stage: Fragment, group: 0, binding: 1): (sampler: Some(0)), } diff --git a/tests/convert.rs b/tests/convert.rs index a05be6eea2..4b8f55fd6c 100644 --- a/tests/convert.rs +++ b/tests/convert.rs @@ -65,6 +65,12 @@ fn convert_quad() { }; msl::write_string(&module, &options).unwrap(); } + #[cfg(feature = "spv-out")] + { + use naga::back::spv; + let capabilities = Some(spirv::Capability::Shader).into_iter().collect(); + spv::write_vec(&module, spv::WriterFlags::empty(), capabilities).unwrap(); + } } #[cfg(feature = "wgsl-in")] @@ -122,6 +128,15 @@ fn convert_boids() { }; msl::write_string(&module, &options).unwrap(); } + #[cfg(feature = "spv-out")] + { + use naga::back::spv; + let capabilities = Some(spirv::Capability::Shader).into_iter().collect(); + if let Err(e) = spv::write_vec(&module, spv::WriterFlags::empty(), capabilities) { + //TODO: panic here when `spv-out` supports it + println!("Quad SPIR-V error {:?}", e); + } + } } #[cfg(feature = "spv-in")]