[spv] enable quad conversion test, refactor binary operations

This commit is contained in:
Dzmitry Malyshau
2020-12-08 00:16:47 -05:00
committed by Dzmitry Malyshau
parent 548bafaf40
commit 148fac0601
6 changed files with 86 additions and 57 deletions

View File

@@ -14,7 +14,7 @@ bitflags = "1"
fxhash = "0.2"
log = "0.4"
num-traits = "0.2"
spirv = { package = "spirv_headers", version = "1.4.2", optional = true }
spirv = { package = "spirv_headers", version = "1.5", optional = true }
pomelo = { version = "0.1.4", optional = true }
thiserror = "1.0.21"
serde = { version = "1.0", features = ["derive"], optional = true }
@@ -37,3 +37,4 @@ difference = "2.0"
env_logger = "0.6"
ron = "0.6"
serde = { version = "1.0", features = ["derive"] }
spirv = { package = "spirv_headers", version = "1.5", features = ["deserialize"] }

View File

@@ -1,21 +1,20 @@
use serde::{Deserialize, Serialize};
use std::{env, fs, path::Path};
#[derive(Hash, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Hash, PartialEq, Eq, serde::Deserialize)]
enum Stage {
Vertex,
Fragment,
Compute,
}
#[derive(Hash, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Hash, PartialEq, Eq, serde::Deserialize)]
struct BindSource {
stage: Stage,
group: u32,
binding: u32,
}
#[derive(Serialize, Deserialize)]
#[derive(serde::Deserialize)]
struct BindTarget {
#[serde(default)]
buffer: Option<u8>,
@@ -27,11 +26,15 @@ struct BindTarget {
mutable: bool,
}
#[derive(Default, Serialize, Deserialize)]
#[derive(Default, serde::Deserialize)]
struct Parameters {
#[serde(default)]
#[cfg_attr(not(feature = "spv-out"), allow(dead_code))]
spv_flow_dump_prefix: String,
metal_bindings: naga::FastHashMap<BindSource, BindTarget>,
#[cfg_attr(not(feature = "spv-out"), allow(dead_code))]
spv_capabilities: naga::FastHashSet<spirv::Capability>,
#[cfg_attr(not(feature = "msl-out"), allow(dead_code))]
mtl_bindings: naga::FastHashMap<BindSource, BindTarget>,
}
fn main() {
@@ -138,7 +141,7 @@ fn main() {
"metal" => {
use naga::back::msl;
let mut binding_map = msl::BindingMap::default();
for (key, value) in params.metal_bindings {
for (key, value) in params.mtl_bindings {
binding_map.insert(
msl::BindSource {
stage: match key.stage {
@@ -177,8 +180,7 @@ fn main() {
}
});
let capabilities = Default::default(); //TODO
let spv = spv::write_vec(&module, debug_flag, capabilities).unwrap();
let spv = spv::write_vec(&module, debug_flag, params.spv_capabilities).unwrap();
let bytes = spv
.iter()

View File

@@ -112,6 +112,7 @@ impl<'a, T> ops::Deref for MaybeOwned<'a, T> {
}
}
#[derive(Debug)]
enum Dimension {
Scalar,
Vector,
@@ -1101,14 +1102,10 @@ impl Writer {
function,
)?;
let left_lookup_ty = left_lookup_ty;
let right_lookup_ty = right_lookup_ty;
let left_ty_inner = self.get_type_inner(&ir_module.types, left_lookup_ty);
let right_ty_inner = self.get_type_inner(&ir_module.types, right_lookup_ty);
let left_result_type_id = self.get_type_id(&ir_module.types, left_lookup_ty)?;
let right_result_type_id = self.get_type_id(&ir_module.types, right_lookup_ty)?;
let left_id = match *left_expression {
@@ -1143,56 +1140,63 @@ impl Writer {
let left_dimension = get_dimension(&left_ty_inner);
let right_dimension = get_dimension(&right_ty_inner);
let (spirv_op, lookup_ty) = match op {
let mut result_side_left = true;
let mut preserve_order = true;
let spirv_op = match op {
crate::BinaryOperator::Add => match *left_ty_inner {
crate::TypeInner::Scalar { kind, .. }
| crate::TypeInner::Vector { kind, .. } => match kind {
crate::ScalarKind::Float => (spirv::Op::FAdd, left_lookup_ty),
_ => (spirv::Op::IAdd, left_lookup_ty),
crate::ScalarKind::Float => spirv::Op::FAdd,
_ => spirv::Op::IAdd,
},
_ => unreachable!(),
},
crate::BinaryOperator::Subtract => match *left_ty_inner {
crate::TypeInner::Scalar { kind, .. }
| crate::TypeInner::Vector { kind, .. } => match kind {
crate::ScalarKind::Float => (spirv::Op::FSub, left_lookup_ty),
_ => (spirv::Op::ISub, left_lookup_ty),
crate::ScalarKind::Float => spirv::Op::FSub,
_ => spirv::Op::ISub,
},
_ => unreachable!(),
},
crate::BinaryOperator::Multiply => match (left_dimension, right_dimension) {
(Dimension::Vector, Dimension::Scalar { .. }) => {
(spirv::Op::VectorTimesScalar, left_lookup_ty)
crate::BinaryOperator::Multiply => {
// whenever there is a vector on the right,
// the result type is a vector.
if let Dimension::Vector = right_dimension {
result_side_left = false;
}
(Dimension::Vector, Dimension::Matrix) => {
(spirv::Op::VectorTimesMatrix, left_lookup_ty)
match (left_dimension, right_dimension) {
(Dimension::Scalar, Dimension::Vector { .. }) => {
preserve_order = false;
spirv::Op::VectorTimesScalar
}
(Dimension::Vector, Dimension::Scalar { .. }) => {
spirv::Op::VectorTimesScalar
}
(Dimension::Vector, Dimension::Matrix) => spirv::Op::VectorTimesMatrix,
(Dimension::Matrix, Dimension::Scalar { .. }) => {
spirv::Op::MatrixTimesScalar
}
(Dimension::Matrix, Dimension::Vector) => spirv::Op::MatrixTimesVector,
(Dimension::Matrix, Dimension::Matrix) => spirv::Op::MatrixTimesMatrix,
(Dimension::Vector, Dimension::Vector)
| (Dimension::Scalar, Dimension::Scalar)
if left_ty_inner.scalar_kind()
== Some(crate::ScalarKind::Float) =>
{
spirv::Op::FMul
}
(Dimension::Vector, Dimension::Vector)
| (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
other => unreachable!("Mul {:?}", other),
}
(Dimension::Matrix, Dimension::Scalar { .. }) => {
(spirv::Op::MatrixTimesScalar, left_lookup_ty)
}
(Dimension::Matrix, Dimension::Vector) => {
(spirv::Op::MatrixTimesVector, right_lookup_ty)
}
(Dimension::Matrix, Dimension::Matrix) => {
(spirv::Op::MatrixTimesMatrix, left_lookup_ty)
}
(Dimension::Vector, Dimension::Vector)
| (Dimension::Scalar, Dimension::Scalar)
if left_ty_inner.scalar_kind() == Some(crate::ScalarKind::Float) =>
{
(spirv::Op::FMul, left_lookup_ty)
}
(Dimension::Vector, Dimension::Vector)
| (Dimension::Scalar, Dimension::Scalar) => {
(spirv::Op::IMul, left_lookup_ty)
}
_ => unreachable!(),
},
}
crate::BinaryOperator::Divide => match *left_ty_inner {
crate::TypeInner::Scalar { kind, .. }
| crate::TypeInner::Vector { kind, .. } => match kind {
crate::ScalarKind::Sint => (spirv::Op::SDiv, left_lookup_ty),
crate::ScalarKind::Uint => (spirv::Op::UDiv, left_lookup_ty),
crate::ScalarKind::Sint => spirv::Op::SDiv,
crate::ScalarKind::Uint => spirv::Op::UDiv,
_ => unreachable!(),
},
_ => unreachable!(),
@@ -1200,28 +1204,33 @@ impl Writer {
crate::BinaryOperator::Modulo => match *left_ty_inner {
crate::TypeInner::Scalar { kind, .. }
| crate::TypeInner::Vector { kind, .. } => match kind {
crate::ScalarKind::Sint => (spirv::Op::SMod, left_lookup_ty),
crate::ScalarKind::Uint => (spirv::Op::UMod, left_lookup_ty),
crate::ScalarKind::Float => (spirv::Op::FMod, left_lookup_ty),
crate::ScalarKind::Sint => spirv::Op::SMod,
crate::ScalarKind::Uint => spirv::Op::UMod,
crate::ScalarKind::Float => spirv::Op::FMod,
_ => unreachable!(),
},
_ => unreachable!(),
},
crate::BinaryOperator::And => (spirv::Op::BitwiseAnd, left_lookup_ty),
crate::BinaryOperator::And => spirv::Op::BitwiseAnd,
_ => {
log::error!("unimplemented {:?}", op);
return Err(Error::FeatureNotImplemented("binary operator"));
}
};
let (result_type_id, result_lookup_ty) = if result_side_left {
(left_result_type_id, left_lookup_ty)
} else {
(right_result_type_id, right_lookup_ty)
};
block.body.push(super::instructions::instruction_binary(
spirv_op,
left_result_type_id,
result_type_id,
id,
left_id,
right_id,
if preserve_order { left_id } else { right_id },
if preserve_order { right_id } else { left_id },
));
Ok((id, lookup_ty))
Ok((id, result_lookup_ty))
}
crate::Expression::LocalVariable(variable) => {
let var = &ir_function.local_variables[variable];

View File

@@ -1,6 +1,7 @@
(
spv_flow_dump_prefix: "",
metal_bindings: {
spv_capabilities: [ Shader ],
mtl_bindings: {
(stage: Compute, group: 0, binding: 0): (buffer: Some(0), mutable: false),
(stage: Compute, group: 0, binding: 1): (buffer: Some(1), mutable: true),
(stage: Compute, group: 0, binding: 2): (buffer: Some(2), mutable: true),

View File

@@ -1,5 +1,6 @@
(
metal_bindings: {
spv_capabilities: [ Shader ],
mtl_bindings: {
(stage: Fragment, group: 0, binding: 0): (texture: Some(0)),
(stage: Fragment, group: 0, binding: 1): (sampler: Some(0)),
}

View File

@@ -65,6 +65,12 @@ fn convert_quad() {
};
msl::write_string(&module, &options).unwrap();
}
#[cfg(feature = "spv-out")]
{
use naga::back::spv;
let capabilities = Some(spirv::Capability::Shader).into_iter().collect();
spv::write_vec(&module, spv::WriterFlags::empty(), capabilities).unwrap();
}
}
#[cfg(feature = "wgsl-in")]
@@ -122,6 +128,15 @@ fn convert_boids() {
};
msl::write_string(&module, &options).unwrap();
}
#[cfg(feature = "spv-out")]
{
use naga::back::spv;
let capabilities = Some(spirv::Capability::Shader).into_iter().collect();
if let Err(e) = spv::write_vec(&module, spv::WriterFlags::empty(), capabilities) {
//TODO: panic here when `spv-out` supports it
println!("Quad SPIR-V error {:?}", e);
}
}
}
#[cfg(feature = "spv-in")]