mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
[spv-in] work around row-major matrices
This commit is contained in:
committed by
Dzmitry Malyshau
parent
0cc22c8c65
commit
67e3e0a697
@@ -54,6 +54,5 @@ pub enum Error {
|
||||
InvalidEdgeClassification,
|
||||
FunctionCallCycle(spirv::Word),
|
||||
// incomplete implementation error
|
||||
UnsupportedRowMajorMatrix,
|
||||
UnsupportedMatrixStride(spirv::Word),
|
||||
}
|
||||
|
||||
@@ -63,6 +63,9 @@ impl<I: Iterator<Item = u32>> super::Parser<I> {
|
||||
}
|
||||
|
||||
pub(super) fn parse_function(&mut self, module: &mut crate::Module) -> Result<(), Error> {
|
||||
self.lookup_expression.clear();
|
||||
self.lookup_load_override.clear();
|
||||
|
||||
let result_type_id = self.next()?;
|
||||
let fun_id = self.next()?;
|
||||
let _fun_control = self.next()?;
|
||||
|
||||
@@ -16,6 +16,13 @@ populated at the start of an entry point. The outputs are saved at the end.
|
||||
The function associated with an entry point is wrapped in another function,
|
||||
such that we can handle any `Return` statements without problems.
|
||||
|
||||
## Row-major matrices
|
||||
|
||||
We don't handle them natively, since the IR only expects column majority.
|
||||
Instead, we detect when such matrix is accessed in the `OpAccessChain`,
|
||||
and we generate a parallel expression that loads the value, but transposed.
|
||||
This value then gets used instead of `OpLoad` result later on.
|
||||
|
||||
!*/
|
||||
#![allow(dead_code)]
|
||||
|
||||
@@ -192,7 +199,7 @@ bitflags::bitflags! {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, PartialEq)]
|
||||
enum Majority {
|
||||
Column,
|
||||
Row,
|
||||
@@ -300,6 +307,21 @@ struct LookupExpression {
|
||||
type_id: spirv::Word,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct LookupMember {
|
||||
type_id: spirv::Word,
|
||||
// This is true for either matrices, or arrays of matrices (yikes).
|
||||
row_major: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum LookupLoadOverride {
|
||||
/// For arrays of matrices, we track them but not loading yet.
|
||||
Pending,
|
||||
/// For matrices, vectors, and scalars, we pre-load the data.
|
||||
Loaded(Handle<crate::Expression>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct Assignment {
|
||||
to: Handle<crate::Expression>,
|
||||
@@ -337,7 +359,7 @@ pub struct Parser<I> {
|
||||
ext_glsl_id: Option<spirv::Word>,
|
||||
future_decor: FastHashMap<spirv::Word, Decoration>,
|
||||
future_member_decor: FastHashMap<(spirv::Word, MemberIndex), Decoration>,
|
||||
lookup_member_type_id: FastHashMap<(Handle<crate::Type>, MemberIndex), spirv::Word>,
|
||||
lookup_member: FastHashMap<(Handle<crate::Type>, MemberIndex), LookupMember>,
|
||||
handle_sampling: FastHashMap<Handle<crate::GlobalVariable>, image::SamplingFlags>,
|
||||
lookup_type: FastHashMap<spirv::Word, LookupType>,
|
||||
lookup_void_type: Option<spirv::Word>,
|
||||
@@ -346,6 +368,8 @@ pub struct Parser<I> {
|
||||
lookup_constant: FastHashMap<spirv::Word, LookupConstant>,
|
||||
lookup_variable: FastHashMap<spirv::Word, LookupVariable>,
|
||||
lookup_expression: FastHashMap<spirv::Word, LookupExpression>,
|
||||
// Load overrides are used to work around row-major matrices
|
||||
lookup_load_override: FastHashMap<spirv::Word, LookupLoadOverride>,
|
||||
lookup_sampled_image: FastHashMap<spirv::Word, image::LookupSampledImage>,
|
||||
lookup_function_type: FastHashMap<spirv::Word, LookupFunctionType>,
|
||||
lookup_function: FastHashMap<spirv::Word, Handle<crate::Function>>,
|
||||
@@ -373,13 +397,14 @@ impl<I: Iterator<Item = u32>> Parser<I> {
|
||||
future_decor: FastHashMap::default(),
|
||||
future_member_decor: FastHashMap::default(),
|
||||
handle_sampling: FastHashMap::default(),
|
||||
lookup_member_type_id: FastHashMap::default(),
|
||||
lookup_member: FastHashMap::default(),
|
||||
lookup_type: FastHashMap::default(),
|
||||
lookup_void_type: None,
|
||||
lookup_storage_buffer_types: FastHashSet::default(),
|
||||
lookup_constant: FastHashMap::default(),
|
||||
lookup_variable: FastHashMap::default(),
|
||||
lookup_expression: FastHashMap::default(),
|
||||
lookup_load_override: FastHashMap::default(),
|
||||
lookup_sampled_image: FastHashMap::default(),
|
||||
lookup_function_type: FastHashMap::default(),
|
||||
lookup_function: FastHashMap::default(),
|
||||
@@ -636,11 +661,11 @@ impl<I: Iterator<Item = u32>> Parser<I> {
|
||||
let root_lookup = self.lookup_type.lookup(root_type_id)?;
|
||||
let (count, child_type_id) = match type_arena[root_lookup.handle].inner {
|
||||
crate::TypeInner::Struct { ref members, .. } => {
|
||||
let child_type_id = *self
|
||||
.lookup_member_type_id
|
||||
let child_member = self
|
||||
.lookup_member
|
||||
.get(&(root_lookup.handle, selection))
|
||||
.ok_or(Error::InvalidAccessType(root_type_id))?;
|
||||
(members.len(), child_type_id)
|
||||
(members.len(), child_member.type_id)
|
||||
}
|
||||
// crate::TypeInner::Array //TODO?
|
||||
crate::TypeInner::Vector { size, .. }
|
||||
@@ -797,6 +822,7 @@ impl<I: Iterator<Item = u32>> Parser<I> {
|
||||
struct AccessExpression {
|
||||
base_handle: Handle<crate::Expression>,
|
||||
type_id: spirv::Word,
|
||||
load_override: Option<LookupLoadOverride>,
|
||||
}
|
||||
|
||||
inst.expect_at_least(4)?;
|
||||
@@ -813,50 +839,96 @@ impl<I: Iterator<Item = u32>> Parser<I> {
|
||||
AccessExpression {
|
||||
base_handle: lexp.handle,
|
||||
type_id: lty.base_id.ok_or(Error::InvalidAccessType(lexp.type_id))?,
|
||||
load_override: self.lookup_load_override.get(&base_id).cloned(),
|
||||
}
|
||||
};
|
||||
for _ in 4..inst.wc {
|
||||
let access_id = self.next()?;
|
||||
log::trace!("\t\t\tlooking up index expr {:?}", access_id);
|
||||
let index_expr = self.lookup_expression.lookup(access_id)?.clone();
|
||||
let index_type_handle = self.lookup_type.lookup(index_expr.type_id)?.handle;
|
||||
match type_arena[index_type_handle].inner {
|
||||
crate::TypeInner::Scalar {
|
||||
kind: crate::ScalarKind::Uint,
|
||||
..
|
||||
let index_expr_data = &expressions[index_expr.handle];
|
||||
let index_maybe = match *index_expr_data {
|
||||
crate::Expression::Constant(const_handle) => {
|
||||
match const_arena[const_handle].inner {
|
||||
crate::ConstantInner::Scalar {
|
||||
width: _,
|
||||
value: crate::ScalarValue::Uint(v),
|
||||
} => Some(v as u32),
|
||||
crate::ConstantInner::Scalar {
|
||||
width: _,
|
||||
value: crate::ScalarValue::Sint(v),
|
||||
} => Some(v as u32),
|
||||
_ => {
|
||||
return Err(Error::InvalidAccess(
|
||||
crate::Expression::Constant(const_handle),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
| crate::TypeInner::Scalar {
|
||||
kind: crate::ScalarKind::Sint,
|
||||
..
|
||||
} => (),
|
||||
ref other => {
|
||||
log::warn!("access index type {:?}", other);
|
||||
return Err(Error::UnsupportedType(index_type_handle));
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
log::trace!("\t\t\tlooking up type {:?}", acex.type_id);
|
||||
let type_lookup = self.lookup_type.lookup(acex.type_id)?;
|
||||
acex = match type_arena[type_lookup.handle].inner {
|
||||
// can only index a struct with a constant
|
||||
crate::TypeInner::Struct { .. } => {
|
||||
let index = match expressions[index_expr.handle] {
|
||||
crate::Expression::Constant(const_handle) => {
|
||||
match const_arena[const_handle].inner {
|
||||
crate::ConstantInner::Scalar {
|
||||
width: 4,
|
||||
value: crate::ScalarValue::Uint(v),
|
||||
} => v as u32,
|
||||
crate::ConstantInner::Scalar {
|
||||
width: 4,
|
||||
value: crate::ScalarValue::Sint(v),
|
||||
} => v as u32,
|
||||
_ => {
|
||||
return Err(Error::InvalidAccess(
|
||||
crate::Expression::Constant(const_handle),
|
||||
))
|
||||
let index = index_maybe
|
||||
.ok_or_else(|| Error::InvalidAccess(index_expr_data.clone()))?;
|
||||
let lookup_member = self
|
||||
.lookup_member
|
||||
.get(&(type_lookup.handle, index))
|
||||
.ok_or(Error::InvalidAccessType(acex.type_id))?;
|
||||
let base_handle =
|
||||
expressions.append(crate::Expression::AccessIndex {
|
||||
base: acex.base_handle,
|
||||
index,
|
||||
});
|
||||
AccessExpression {
|
||||
base_handle,
|
||||
type_id: lookup_member.type_id,
|
||||
load_override: if lookup_member.row_major {
|
||||
debug_assert!(acex.load_override.is_none());
|
||||
let sub_type_lookup =
|
||||
self.lookup_type.lookup(lookup_member.type_id)?;
|
||||
Some(match type_arena[sub_type_lookup.handle].inner {
|
||||
// load it transposed, to match column major expectations
|
||||
crate::TypeInner::Matrix { .. } => {
|
||||
let loaded =
|
||||
expressions.append(crate::Expression::Load {
|
||||
pointer: base_handle,
|
||||
});
|
||||
let transposed =
|
||||
expressions.append(crate::Expression::Math {
|
||||
fun: crate::MathFunction::Transpose,
|
||||
arg: loaded,
|
||||
arg1: None,
|
||||
arg2: None,
|
||||
});
|
||||
LookupLoadOverride::Loaded(transposed)
|
||||
}
|
||||
}
|
||||
_ => LookupLoadOverride::Pending,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
},
|
||||
}
|
||||
}
|
||||
// we can't dynamically index matrices, so expecting constant index here
|
||||
crate::TypeInner::Matrix { .. } => {
|
||||
let index = index_maybe
|
||||
.ok_or_else(|| Error::InvalidAccess(index_expr_data.clone()))?;
|
||||
let load_override = match acex.load_override {
|
||||
// We are indexing inside a row-major matrix
|
||||
Some(LookupLoadOverride::Loaded(load_expr)) => {
|
||||
let sub_expr =
|
||||
expressions.append(crate::Expression::AccessIndex {
|
||||
base: load_expr,
|
||||
index,
|
||||
});
|
||||
Some(LookupLoadOverride::Loaded(sub_expr))
|
||||
}
|
||||
ref other => return Err(Error::InvalidAccess(other.clone())),
|
||||
_ => None,
|
||||
};
|
||||
AccessExpression {
|
||||
base_handle: expressions.append(
|
||||
@@ -865,24 +937,62 @@ impl<I: Iterator<Item = u32>> Parser<I> {
|
||||
index,
|
||||
},
|
||||
),
|
||||
type_id: *self
|
||||
.lookup_member_type_id
|
||||
.get(&(type_lookup.handle, index))
|
||||
type_id: type_lookup
|
||||
.base_id
|
||||
.ok_or(Error::InvalidAccessType(acex.type_id))?,
|
||||
load_override,
|
||||
}
|
||||
}
|
||||
_ => AccessExpression {
|
||||
base_handle: expressions.append(crate::Expression::Access {
|
||||
// This must be a vector or an array.
|
||||
_ => {
|
||||
let base_handle = expressions.append(crate::Expression::Access {
|
||||
base: acex.base_handle,
|
||||
index: index_expr.handle,
|
||||
}),
|
||||
type_id: type_lookup
|
||||
.base_id
|
||||
.ok_or(Error::InvalidAccessType(acex.type_id))?,
|
||||
},
|
||||
});
|
||||
let load_override = match acex.load_override {
|
||||
// If there is a load override in place, then we always end up
|
||||
// with a side-loaded value here.
|
||||
Some(lookup_load_override) => {
|
||||
let sub_expr = match lookup_load_override {
|
||||
// We must be indexing into the array of row-major matrices.
|
||||
// Let's load the result of indexing and transpose it.
|
||||
LookupLoadOverride::Pending => {
|
||||
let loaded =
|
||||
expressions.append(crate::Expression::Load {
|
||||
pointer: base_handle,
|
||||
});
|
||||
expressions.append(crate::Expression::Math {
|
||||
fun: crate::MathFunction::Transpose,
|
||||
arg: loaded,
|
||||
arg1: None,
|
||||
arg2: None,
|
||||
})
|
||||
}
|
||||
// We are indexing inside a row-major matrix.
|
||||
LookupLoadOverride::Loaded(load_expr) => expressions
|
||||
.append(crate::Expression::Access {
|
||||
base: load_expr,
|
||||
index: index_expr.handle,
|
||||
}),
|
||||
};
|
||||
Some(LookupLoadOverride::Loaded(sub_expr))
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
AccessExpression {
|
||||
base_handle,
|
||||
type_id: type_lookup
|
||||
.base_id
|
||||
.ok_or(Error::InvalidAccessType(acex.type_id))?,
|
||||
load_override,
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(load_expr) = acex.load_override {
|
||||
self.lookup_load_override.insert(result_id, load_expr);
|
||||
}
|
||||
let lookup_expression = LookupExpression {
|
||||
handle: acex.base_handle,
|
||||
type_id: result_type_id,
|
||||
@@ -997,10 +1107,12 @@ impl<I: Iterator<Item = u32>> Parser<I> {
|
||||
log::trace!("\t\t\tlooking up type {:?}", lexp.type_id);
|
||||
let type_lookup = self.lookup_type.lookup(lexp.type_id)?;
|
||||
let type_id = match type_arena[type_lookup.handle].inner {
|
||||
crate::TypeInner::Struct { .. } => *self
|
||||
.lookup_member_type_id
|
||||
.get(&(type_lookup.handle, index))
|
||||
.ok_or(Error::InvalidAccessType(lexp.type_id))?,
|
||||
crate::TypeInner::Struct { .. } => {
|
||||
self.lookup_member
|
||||
.get(&(type_lookup.handle, index))
|
||||
.ok_or(Error::InvalidAccessType(lexp.type_id))?
|
||||
.type_id
|
||||
}
|
||||
crate::TypeInner::Array { .. }
|
||||
| crate::TypeInner::Vector { .. }
|
||||
| crate::TypeInner::Matrix { .. } => type_lookup
|
||||
@@ -1100,9 +1212,13 @@ impl<I: Iterator<Item = u32>> Parser<I> {
|
||||
crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => {
|
||||
base_lexp.handle
|
||||
}
|
||||
_ => expressions.append(crate::Expression::Load {
|
||||
pointer: base_lexp.handle,
|
||||
}),
|
||||
_ => match self.lookup_load_override.get(&pointer_id) {
|
||||
Some(&LookupLoadOverride::Loaded(handle)) => handle,
|
||||
//Note: we aren't handling `LookupLoadOverride::Pending` properly here
|
||||
_ => expressions.append(crate::Expression::Load {
|
||||
pointer: base_lexp.handle,
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
self.lookup_expression.insert(
|
||||
@@ -2465,16 +2581,20 @@ impl<I: Iterator<Item = u32>> Parser<I> {
|
||||
let block_decor = parent_decor.as_ref().and_then(|decor| decor.block.clone());
|
||||
|
||||
let mut members = Vec::<crate::StructMember>::with_capacity(inst.wc as usize - 2);
|
||||
let mut member_type_ids = Vec::with_capacity(members.capacity());
|
||||
let mut member_lookups = Vec::with_capacity(members.capacity());
|
||||
for i in 0..u32::from(inst.wc) - 2 {
|
||||
let type_id = self.next()?;
|
||||
member_type_ids.push(type_id);
|
||||
let ty = self.lookup_type.lookup(type_id)?.handle;
|
||||
let decor = self
|
||||
.future_member_decor
|
||||
.remove(&(id, i))
|
||||
.unwrap_or_default();
|
||||
|
||||
member_lookups.push(LookupMember {
|
||||
type_id,
|
||||
row_major: decor.matrix_major == Some(Majority::Row),
|
||||
});
|
||||
|
||||
let binding = decor.io_binding().ok();
|
||||
let offset = match decor.offset {
|
||||
Some(offset) => offset,
|
||||
@@ -2499,10 +2619,6 @@ impl<I: Iterator<Item = u32>> Parser<I> {
|
||||
return Err(Error::UnsupportedMatrixStride(stride.get()));
|
||||
}
|
||||
}
|
||||
match decor.matrix_major {
|
||||
None | Some(Majority::Column) => (),
|
||||
Some(Majority::Row) => return Err(Error::UnsupportedRowMajorMatrix),
|
||||
}
|
||||
}
|
||||
|
||||
members.push(crate::StructMember {
|
||||
@@ -2540,9 +2656,9 @@ impl<I: Iterator<Item = u32>> Parser<I> {
|
||||
if block_decor == Some(Block { buffer: true }) {
|
||||
self.lookup_storage_buffer_types.insert(ty_handle);
|
||||
}
|
||||
for (i, type_id) in member_type_ids.into_iter().enumerate() {
|
||||
self.lookup_member_type_id
|
||||
.insert((ty_handle, i as u32), type_id);
|
||||
for (i, member_lookup) in member_lookups.into_iter().enumerate() {
|
||||
self.lookup_member
|
||||
.insert((ty_handle, i as u32), member_lookup);
|
||||
}
|
||||
self.lookup_type.insert(
|
||||
id,
|
||||
|
||||
Reference in New Issue
Block a user