diff --git a/src/front/spv/error.rs b/src/front/spv/error.rs index 0893f9cd36..ea5bbb3774 100644 --- a/src/front/spv/error.rs +++ b/src/front/spv/error.rs @@ -54,6 +54,5 @@ pub enum Error { InvalidEdgeClassification, FunctionCallCycle(spirv::Word), // incomplete implementation error - UnsupportedRowMajorMatrix, UnsupportedMatrixStride(spirv::Word), } diff --git a/src/front/spv/function.rs b/src/front/spv/function.rs index 2a6f83bad7..5a57256320 100644 --- a/src/front/spv/function.rs +++ b/src/front/spv/function.rs @@ -63,6 +63,9 @@ impl> super::Parser { } pub(super) fn parse_function(&mut self, module: &mut crate::Module) -> Result<(), Error> { + self.lookup_expression.clear(); + self.lookup_load_override.clear(); + let result_type_id = self.next()?; let fun_id = self.next()?; let _fun_control = self.next()?; diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index c3df4a8af8..6589d1bb73 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -16,6 +16,13 @@ populated at the start of an entry point. The outputs are saved at the end. The function associated with an entry point is wrapped in another function, such that we can handle any `Return` statements without problems. +## Row-major matrices + +We don't handle them natively, since the IR only expects column majority. +Instead, we detect when such matrix is accessed in the `OpAccessChain`, +and we generate a parallel expression that loads the value, but transposed. +This value then gets used instead of `OpLoad` result later on. + !*/ #![allow(dead_code)] @@ -192,7 +199,7 @@ bitflags::bitflags! { } } -#[derive(Debug)] +#[derive(Debug, PartialEq)] enum Majority { Column, Row, @@ -300,6 +307,21 @@ struct LookupExpression { type_id: spirv::Word, } +#[derive(Debug)] +struct LookupMember { + type_id: spirv::Word, + // This is true for either matrices, or arrays of matrices (yikes). + row_major: bool, +} + +#[derive(Clone, Debug)] +enum LookupLoadOverride { + /// For arrays of matrices, we track them but not loading yet. + Pending, + /// For matrices, vectors, and scalars, we pre-load the data. + Loaded(Handle), +} + #[derive(Clone, Debug)] struct Assignment { to: Handle, @@ -337,7 +359,7 @@ pub struct Parser { ext_glsl_id: Option, future_decor: FastHashMap, future_member_decor: FastHashMap<(spirv::Word, MemberIndex), Decoration>, - lookup_member_type_id: FastHashMap<(Handle, MemberIndex), spirv::Word>, + lookup_member: FastHashMap<(Handle, MemberIndex), LookupMember>, handle_sampling: FastHashMap, image::SamplingFlags>, lookup_type: FastHashMap, lookup_void_type: Option, @@ -346,6 +368,8 @@ pub struct Parser { lookup_constant: FastHashMap, lookup_variable: FastHashMap, lookup_expression: FastHashMap, + // Load overrides are used to work around row-major matrices + lookup_load_override: FastHashMap, lookup_sampled_image: FastHashMap, lookup_function_type: FastHashMap, lookup_function: FastHashMap>, @@ -373,13 +397,14 @@ impl> Parser { future_decor: FastHashMap::default(), future_member_decor: FastHashMap::default(), handle_sampling: FastHashMap::default(), - lookup_member_type_id: FastHashMap::default(), + lookup_member: FastHashMap::default(), lookup_type: FastHashMap::default(), lookup_void_type: None, lookup_storage_buffer_types: FastHashSet::default(), lookup_constant: FastHashMap::default(), lookup_variable: FastHashMap::default(), lookup_expression: FastHashMap::default(), + lookup_load_override: FastHashMap::default(), lookup_sampled_image: FastHashMap::default(), lookup_function_type: FastHashMap::default(), lookup_function: FastHashMap::default(), @@ -636,11 +661,11 @@ impl> Parser { let root_lookup = self.lookup_type.lookup(root_type_id)?; let (count, child_type_id) = match type_arena[root_lookup.handle].inner { crate::TypeInner::Struct { ref members, .. } => { - let child_type_id = *self - .lookup_member_type_id + let child_member = self + .lookup_member .get(&(root_lookup.handle, selection)) .ok_or(Error::InvalidAccessType(root_type_id))?; - (members.len(), child_type_id) + (members.len(), child_member.type_id) } // crate::TypeInner::Array //TODO? crate::TypeInner::Vector { size, .. } @@ -797,6 +822,7 @@ impl> Parser { struct AccessExpression { base_handle: Handle, type_id: spirv::Word, + load_override: Option, } inst.expect_at_least(4)?; @@ -813,50 +839,96 @@ impl> Parser { AccessExpression { base_handle: lexp.handle, type_id: lty.base_id.ok_or(Error::InvalidAccessType(lexp.type_id))?, + load_override: self.lookup_load_override.get(&base_id).cloned(), } }; for _ in 4..inst.wc { let access_id = self.next()?; log::trace!("\t\t\tlooking up index expr {:?}", access_id); let index_expr = self.lookup_expression.lookup(access_id)?.clone(); - let index_type_handle = self.lookup_type.lookup(index_expr.type_id)?.handle; - match type_arena[index_type_handle].inner { - crate::TypeInner::Scalar { - kind: crate::ScalarKind::Uint, - .. + let index_expr_data = &expressions[index_expr.handle]; + let index_maybe = match *index_expr_data { + crate::Expression::Constant(const_handle) => { + match const_arena[const_handle].inner { + crate::ConstantInner::Scalar { + width: _, + value: crate::ScalarValue::Uint(v), + } => Some(v as u32), + crate::ConstantInner::Scalar { + width: _, + value: crate::ScalarValue::Sint(v), + } => Some(v as u32), + _ => { + return Err(Error::InvalidAccess( + crate::Expression::Constant(const_handle), + )) + } + } } - | crate::TypeInner::Scalar { - kind: crate::ScalarKind::Sint, - .. - } => (), - ref other => { - log::warn!("access index type {:?}", other); - return Err(Error::UnsupportedType(index_type_handle)); - } - } + _ => None, + }; + log::trace!("\t\t\tlooking up type {:?}", acex.type_id); let type_lookup = self.lookup_type.lookup(acex.type_id)?; acex = match type_arena[type_lookup.handle].inner { + // can only index a struct with a constant crate::TypeInner::Struct { .. } => { - let index = match expressions[index_expr.handle] { - crate::Expression::Constant(const_handle) => { - match const_arena[const_handle].inner { - crate::ConstantInner::Scalar { - width: 4, - value: crate::ScalarValue::Uint(v), - } => v as u32, - crate::ConstantInner::Scalar { - width: 4, - value: crate::ScalarValue::Sint(v), - } => v as u32, - _ => { - return Err(Error::InvalidAccess( - crate::Expression::Constant(const_handle), - )) + let index = index_maybe + .ok_or_else(|| Error::InvalidAccess(index_expr_data.clone()))?; + let lookup_member = self + .lookup_member + .get(&(type_lookup.handle, index)) + .ok_or(Error::InvalidAccessType(acex.type_id))?; + let base_handle = + expressions.append(crate::Expression::AccessIndex { + base: acex.base_handle, + index, + }); + AccessExpression { + base_handle, + type_id: lookup_member.type_id, + load_override: if lookup_member.row_major { + debug_assert!(acex.load_override.is_none()); + let sub_type_lookup = + self.lookup_type.lookup(lookup_member.type_id)?; + Some(match type_arena[sub_type_lookup.handle].inner { + // load it transposed, to match column major expectations + crate::TypeInner::Matrix { .. } => { + let loaded = + expressions.append(crate::Expression::Load { + pointer: base_handle, + }); + let transposed = + expressions.append(crate::Expression::Math { + fun: crate::MathFunction::Transpose, + arg: loaded, + arg1: None, + arg2: None, + }); + LookupLoadOverride::Loaded(transposed) } - } + _ => LookupLoadOverride::Pending, + }) + } else { + None + }, + } + } + // we can't dynamically index matrices, so expecting constant index here + crate::TypeInner::Matrix { .. } => { + let index = index_maybe + .ok_or_else(|| Error::InvalidAccess(index_expr_data.clone()))?; + let load_override = match acex.load_override { + // We are indexing inside a row-major matrix + Some(LookupLoadOverride::Loaded(load_expr)) => { + let sub_expr = + expressions.append(crate::Expression::AccessIndex { + base: load_expr, + index, + }); + Some(LookupLoadOverride::Loaded(sub_expr)) } - ref other => return Err(Error::InvalidAccess(other.clone())), + _ => None, }; AccessExpression { base_handle: expressions.append( @@ -865,24 +937,62 @@ impl> Parser { index, }, ), - type_id: *self - .lookup_member_type_id - .get(&(type_lookup.handle, index)) + type_id: type_lookup + .base_id .ok_or(Error::InvalidAccessType(acex.type_id))?, + load_override, } } - _ => AccessExpression { - base_handle: expressions.append(crate::Expression::Access { + // This must be a vector or an array. + _ => { + let base_handle = expressions.append(crate::Expression::Access { base: acex.base_handle, index: index_expr.handle, - }), - type_id: type_lookup - .base_id - .ok_or(Error::InvalidAccessType(acex.type_id))?, - }, + }); + let load_override = match acex.load_override { + // If there is a load override in place, then we always end up + // with a side-loaded value here. + Some(lookup_load_override) => { + let sub_expr = match lookup_load_override { + // We must be indexing into the array of row-major matrices. + // Let's load the result of indexing and transpose it. + LookupLoadOverride::Pending => { + let loaded = + expressions.append(crate::Expression::Load { + pointer: base_handle, + }); + expressions.append(crate::Expression::Math { + fun: crate::MathFunction::Transpose, + arg: loaded, + arg1: None, + arg2: None, + }) + } + // We are indexing inside a row-major matrix. + LookupLoadOverride::Loaded(load_expr) => expressions + .append(crate::Expression::Access { + base: load_expr, + index: index_expr.handle, + }), + }; + Some(LookupLoadOverride::Loaded(sub_expr)) + } + None => None, + }; + AccessExpression { + base_handle, + type_id: type_lookup + .base_id + .ok_or(Error::InvalidAccessType(acex.type_id))?, + load_override, + } + } }; } + if let Some(load_expr) = acex.load_override { + self.lookup_load_override.insert(result_id, load_expr); + } let lookup_expression = LookupExpression { handle: acex.base_handle, type_id: result_type_id, @@ -997,10 +1107,12 @@ impl> Parser { log::trace!("\t\t\tlooking up type {:?}", lexp.type_id); let type_lookup = self.lookup_type.lookup(lexp.type_id)?; let type_id = match type_arena[type_lookup.handle].inner { - crate::TypeInner::Struct { .. } => *self - .lookup_member_type_id - .get(&(type_lookup.handle, index)) - .ok_or(Error::InvalidAccessType(lexp.type_id))?, + crate::TypeInner::Struct { .. } => { + self.lookup_member + .get(&(type_lookup.handle, index)) + .ok_or(Error::InvalidAccessType(lexp.type_id))? + .type_id + } crate::TypeInner::Array { .. } | crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => type_lookup @@ -1100,9 +1212,13 @@ impl> Parser { crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => { base_lexp.handle } - _ => expressions.append(crate::Expression::Load { - pointer: base_lexp.handle, - }), + _ => match self.lookup_load_override.get(&pointer_id) { + Some(&LookupLoadOverride::Loaded(handle)) => handle, + //Note: we aren't handling `LookupLoadOverride::Pending` properly here + _ => expressions.append(crate::Expression::Load { + pointer: base_lexp.handle, + }), + }, }; self.lookup_expression.insert( @@ -2465,16 +2581,20 @@ impl> Parser { let block_decor = parent_decor.as_ref().and_then(|decor| decor.block.clone()); let mut members = Vec::::with_capacity(inst.wc as usize - 2); - let mut member_type_ids = Vec::with_capacity(members.capacity()); + let mut member_lookups = Vec::with_capacity(members.capacity()); for i in 0..u32::from(inst.wc) - 2 { let type_id = self.next()?; - member_type_ids.push(type_id); let ty = self.lookup_type.lookup(type_id)?.handle; let decor = self .future_member_decor .remove(&(id, i)) .unwrap_or_default(); + member_lookups.push(LookupMember { + type_id, + row_major: decor.matrix_major == Some(Majority::Row), + }); + let binding = decor.io_binding().ok(); let offset = match decor.offset { Some(offset) => offset, @@ -2499,10 +2619,6 @@ impl> Parser { return Err(Error::UnsupportedMatrixStride(stride.get())); } } - match decor.matrix_major { - None | Some(Majority::Column) => (), - Some(Majority::Row) => return Err(Error::UnsupportedRowMajorMatrix), - } } members.push(crate::StructMember { @@ -2540,9 +2656,9 @@ impl> Parser { if block_decor == Some(Block { buffer: true }) { self.lookup_storage_buffer_types.insert(ty_handle); } - for (i, type_id) in member_type_ids.into_iter().enumerate() { - self.lookup_member_type_id - .insert((ty_handle, i as u32), type_id); + for (i, member_lookup) in member_lookups.into_iter().enumerate() { + self.lookup_member + .insert((ty_handle, i as u32), member_lookup); } self.lookup_type.insert( id,