[spv-in] work around row-major matrices

This commit is contained in:
Dzmitry Malyshau
2021-04-08 23:02:51 -04:00
committed by Dzmitry Malyshau
parent 0cc22c8c65
commit 67e3e0a697
3 changed files with 182 additions and 64 deletions

View File

@@ -54,6 +54,5 @@ pub enum Error {
InvalidEdgeClassification,
FunctionCallCycle(spirv::Word),
// incomplete implementation error
UnsupportedRowMajorMatrix,
UnsupportedMatrixStride(spirv::Word),
}

View File

@@ -63,6 +63,9 @@ impl<I: Iterator<Item = u32>> super::Parser<I> {
}
pub(super) fn parse_function(&mut self, module: &mut crate::Module) -> Result<(), Error> {
self.lookup_expression.clear();
self.lookup_load_override.clear();
let result_type_id = self.next()?;
let fun_id = self.next()?;
let _fun_control = self.next()?;

View File

@@ -16,6 +16,13 @@ populated at the start of an entry point. The outputs are saved at the end.
The function associated with an entry point is wrapped in another function,
such that we can handle any `Return` statements without problems.
## Row-major matrices
We don't handle them natively, since the IR only expects column majority.
Instead, we detect when such matrix is accessed in the `OpAccessChain`,
and we generate a parallel expression that loads the value, but transposed.
This value then gets used instead of `OpLoad` result later on.
!*/
#![allow(dead_code)]
@@ -192,7 +199,7 @@ bitflags::bitflags! {
}
}
#[derive(Debug)]
#[derive(Debug, PartialEq)]
enum Majority {
Column,
Row,
@@ -300,6 +307,21 @@ struct LookupExpression {
type_id: spirv::Word,
}
#[derive(Debug)]
struct LookupMember {
type_id: spirv::Word,
// This is true for either matrices, or arrays of matrices (yikes).
row_major: bool,
}
#[derive(Clone, Debug)]
enum LookupLoadOverride {
/// For arrays of matrices, we track them but not loading yet.
Pending,
/// For matrices, vectors, and scalars, we pre-load the data.
Loaded(Handle<crate::Expression>),
}
#[derive(Clone, Debug)]
struct Assignment {
to: Handle<crate::Expression>,
@@ -337,7 +359,7 @@ pub struct Parser<I> {
ext_glsl_id: Option<spirv::Word>,
future_decor: FastHashMap<spirv::Word, Decoration>,
future_member_decor: FastHashMap<(spirv::Word, MemberIndex), Decoration>,
lookup_member_type_id: FastHashMap<(Handle<crate::Type>, MemberIndex), spirv::Word>,
lookup_member: FastHashMap<(Handle<crate::Type>, MemberIndex), LookupMember>,
handle_sampling: FastHashMap<Handle<crate::GlobalVariable>, image::SamplingFlags>,
lookup_type: FastHashMap<spirv::Word, LookupType>,
lookup_void_type: Option<spirv::Word>,
@@ -346,6 +368,8 @@ pub struct Parser<I> {
lookup_constant: FastHashMap<spirv::Word, LookupConstant>,
lookup_variable: FastHashMap<spirv::Word, LookupVariable>,
lookup_expression: FastHashMap<spirv::Word, LookupExpression>,
// Load overrides are used to work around row-major matrices
lookup_load_override: FastHashMap<spirv::Word, LookupLoadOverride>,
lookup_sampled_image: FastHashMap<spirv::Word, image::LookupSampledImage>,
lookup_function_type: FastHashMap<spirv::Word, LookupFunctionType>,
lookup_function: FastHashMap<spirv::Word, Handle<crate::Function>>,
@@ -373,13 +397,14 @@ impl<I: Iterator<Item = u32>> Parser<I> {
future_decor: FastHashMap::default(),
future_member_decor: FastHashMap::default(),
handle_sampling: FastHashMap::default(),
lookup_member_type_id: FastHashMap::default(),
lookup_member: FastHashMap::default(),
lookup_type: FastHashMap::default(),
lookup_void_type: None,
lookup_storage_buffer_types: FastHashSet::default(),
lookup_constant: FastHashMap::default(),
lookup_variable: FastHashMap::default(),
lookup_expression: FastHashMap::default(),
lookup_load_override: FastHashMap::default(),
lookup_sampled_image: FastHashMap::default(),
lookup_function_type: FastHashMap::default(),
lookup_function: FastHashMap::default(),
@@ -636,11 +661,11 @@ impl<I: Iterator<Item = u32>> Parser<I> {
let root_lookup = self.lookup_type.lookup(root_type_id)?;
let (count, child_type_id) = match type_arena[root_lookup.handle].inner {
crate::TypeInner::Struct { ref members, .. } => {
let child_type_id = *self
.lookup_member_type_id
let child_member = self
.lookup_member
.get(&(root_lookup.handle, selection))
.ok_or(Error::InvalidAccessType(root_type_id))?;
(members.len(), child_type_id)
(members.len(), child_member.type_id)
}
// crate::TypeInner::Array //TODO?
crate::TypeInner::Vector { size, .. }
@@ -797,6 +822,7 @@ impl<I: Iterator<Item = u32>> Parser<I> {
struct AccessExpression {
base_handle: Handle<crate::Expression>,
type_id: spirv::Word,
load_override: Option<LookupLoadOverride>,
}
inst.expect_at_least(4)?;
@@ -813,50 +839,96 @@ impl<I: Iterator<Item = u32>> Parser<I> {
AccessExpression {
base_handle: lexp.handle,
type_id: lty.base_id.ok_or(Error::InvalidAccessType(lexp.type_id))?,
load_override: self.lookup_load_override.get(&base_id).cloned(),
}
};
for _ in 4..inst.wc {
let access_id = self.next()?;
log::trace!("\t\t\tlooking up index expr {:?}", access_id);
let index_expr = self.lookup_expression.lookup(access_id)?.clone();
let index_type_handle = self.lookup_type.lookup(index_expr.type_id)?.handle;
match type_arena[index_type_handle].inner {
crate::TypeInner::Scalar {
kind: crate::ScalarKind::Uint,
..
let index_expr_data = &expressions[index_expr.handle];
let index_maybe = match *index_expr_data {
crate::Expression::Constant(const_handle) => {
match const_arena[const_handle].inner {
crate::ConstantInner::Scalar {
width: _,
value: crate::ScalarValue::Uint(v),
} => Some(v as u32),
crate::ConstantInner::Scalar {
width: _,
value: crate::ScalarValue::Sint(v),
} => Some(v as u32),
_ => {
return Err(Error::InvalidAccess(
crate::Expression::Constant(const_handle),
))
}
}
}
| crate::TypeInner::Scalar {
kind: crate::ScalarKind::Sint,
..
} => (),
ref other => {
log::warn!("access index type {:?}", other);
return Err(Error::UnsupportedType(index_type_handle));
}
}
_ => None,
};
log::trace!("\t\t\tlooking up type {:?}", acex.type_id);
let type_lookup = self.lookup_type.lookup(acex.type_id)?;
acex = match type_arena[type_lookup.handle].inner {
// can only index a struct with a constant
crate::TypeInner::Struct { .. } => {
let index = match expressions[index_expr.handle] {
crate::Expression::Constant(const_handle) => {
match const_arena[const_handle].inner {
crate::ConstantInner::Scalar {
width: 4,
value: crate::ScalarValue::Uint(v),
} => v as u32,
crate::ConstantInner::Scalar {
width: 4,
value: crate::ScalarValue::Sint(v),
} => v as u32,
_ => {
return Err(Error::InvalidAccess(
crate::Expression::Constant(const_handle),
))
let index = index_maybe
.ok_or_else(|| Error::InvalidAccess(index_expr_data.clone()))?;
let lookup_member = self
.lookup_member
.get(&(type_lookup.handle, index))
.ok_or(Error::InvalidAccessType(acex.type_id))?;
let base_handle =
expressions.append(crate::Expression::AccessIndex {
base: acex.base_handle,
index,
});
AccessExpression {
base_handle,
type_id: lookup_member.type_id,
load_override: if lookup_member.row_major {
debug_assert!(acex.load_override.is_none());
let sub_type_lookup =
self.lookup_type.lookup(lookup_member.type_id)?;
Some(match type_arena[sub_type_lookup.handle].inner {
// load it transposed, to match column major expectations
crate::TypeInner::Matrix { .. } => {
let loaded =
expressions.append(crate::Expression::Load {
pointer: base_handle,
});
let transposed =
expressions.append(crate::Expression::Math {
fun: crate::MathFunction::Transpose,
arg: loaded,
arg1: None,
arg2: None,
});
LookupLoadOverride::Loaded(transposed)
}
}
_ => LookupLoadOverride::Pending,
})
} else {
None
},
}
}
// we can't dynamically index matrices, so expecting constant index here
crate::TypeInner::Matrix { .. } => {
let index = index_maybe
.ok_or_else(|| Error::InvalidAccess(index_expr_data.clone()))?;
let load_override = match acex.load_override {
// We are indexing inside a row-major matrix
Some(LookupLoadOverride::Loaded(load_expr)) => {
let sub_expr =
expressions.append(crate::Expression::AccessIndex {
base: load_expr,
index,
});
Some(LookupLoadOverride::Loaded(sub_expr))
}
ref other => return Err(Error::InvalidAccess(other.clone())),
_ => None,
};
AccessExpression {
base_handle: expressions.append(
@@ -865,24 +937,62 @@ impl<I: Iterator<Item = u32>> Parser<I> {
index,
},
),
type_id: *self
.lookup_member_type_id
.get(&(type_lookup.handle, index))
type_id: type_lookup
.base_id
.ok_or(Error::InvalidAccessType(acex.type_id))?,
load_override,
}
}
_ => AccessExpression {
base_handle: expressions.append(crate::Expression::Access {
// This must be a vector or an array.
_ => {
let base_handle = expressions.append(crate::Expression::Access {
base: acex.base_handle,
index: index_expr.handle,
}),
type_id: type_lookup
.base_id
.ok_or(Error::InvalidAccessType(acex.type_id))?,
},
});
let load_override = match acex.load_override {
// If there is a load override in place, then we always end up
// with a side-loaded value here.
Some(lookup_load_override) => {
let sub_expr = match lookup_load_override {
// We must be indexing into the array of row-major matrices.
// Let's load the result of indexing and transpose it.
LookupLoadOverride::Pending => {
let loaded =
expressions.append(crate::Expression::Load {
pointer: base_handle,
});
expressions.append(crate::Expression::Math {
fun: crate::MathFunction::Transpose,
arg: loaded,
arg1: None,
arg2: None,
})
}
// We are indexing inside a row-major matrix.
LookupLoadOverride::Loaded(load_expr) => expressions
.append(crate::Expression::Access {
base: load_expr,
index: index_expr.handle,
}),
};
Some(LookupLoadOverride::Loaded(sub_expr))
}
None => None,
};
AccessExpression {
base_handle,
type_id: type_lookup
.base_id
.ok_or(Error::InvalidAccessType(acex.type_id))?,
load_override,
}
}
};
}
if let Some(load_expr) = acex.load_override {
self.lookup_load_override.insert(result_id, load_expr);
}
let lookup_expression = LookupExpression {
handle: acex.base_handle,
type_id: result_type_id,
@@ -997,10 +1107,12 @@ impl<I: Iterator<Item = u32>> Parser<I> {
log::trace!("\t\t\tlooking up type {:?}", lexp.type_id);
let type_lookup = self.lookup_type.lookup(lexp.type_id)?;
let type_id = match type_arena[type_lookup.handle].inner {
crate::TypeInner::Struct { .. } => *self
.lookup_member_type_id
.get(&(type_lookup.handle, index))
.ok_or(Error::InvalidAccessType(lexp.type_id))?,
crate::TypeInner::Struct { .. } => {
self.lookup_member
.get(&(type_lookup.handle, index))
.ok_or(Error::InvalidAccessType(lexp.type_id))?
.type_id
}
crate::TypeInner::Array { .. }
| crate::TypeInner::Vector { .. }
| crate::TypeInner::Matrix { .. } => type_lookup
@@ -1100,9 +1212,13 @@ impl<I: Iterator<Item = u32>> Parser<I> {
crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => {
base_lexp.handle
}
_ => expressions.append(crate::Expression::Load {
pointer: base_lexp.handle,
}),
_ => match self.lookup_load_override.get(&pointer_id) {
Some(&LookupLoadOverride::Loaded(handle)) => handle,
//Note: we aren't handling `LookupLoadOverride::Pending` properly here
_ => expressions.append(crate::Expression::Load {
pointer: base_lexp.handle,
}),
},
};
self.lookup_expression.insert(
@@ -2465,16 +2581,20 @@ impl<I: Iterator<Item = u32>> Parser<I> {
let block_decor = parent_decor.as_ref().and_then(|decor| decor.block.clone());
let mut members = Vec::<crate::StructMember>::with_capacity(inst.wc as usize - 2);
let mut member_type_ids = Vec::with_capacity(members.capacity());
let mut member_lookups = Vec::with_capacity(members.capacity());
for i in 0..u32::from(inst.wc) - 2 {
let type_id = self.next()?;
member_type_ids.push(type_id);
let ty = self.lookup_type.lookup(type_id)?.handle;
let decor = self
.future_member_decor
.remove(&(id, i))
.unwrap_or_default();
member_lookups.push(LookupMember {
type_id,
row_major: decor.matrix_major == Some(Majority::Row),
});
let binding = decor.io_binding().ok();
let offset = match decor.offset {
Some(offset) => offset,
@@ -2499,10 +2619,6 @@ impl<I: Iterator<Item = u32>> Parser<I> {
return Err(Error::UnsupportedMatrixStride(stride.get()));
}
}
match decor.matrix_major {
None | Some(Majority::Column) => (),
Some(Majority::Row) => return Err(Error::UnsupportedRowMajorMatrix),
}
}
members.push(crate::StructMember {
@@ -2540,9 +2656,9 @@ impl<I: Iterator<Item = u32>> Parser<I> {
if block_decor == Some(Block { buffer: true }) {
self.lookup_storage_buffer_types.insert(ty_handle);
}
for (i, type_id) in member_type_ids.into_iter().enumerate() {
self.lookup_member_type_id
.insert((ty_handle, i as u32), type_id);
for (i, member_lookup) in member_lookups.into_iter().enumerate() {
self.lookup_member
.insert((ty_handle, i as u32), member_lookup);
}
self.lookup_type.insert(
id,