From a6240b498874beb31f06d80a22712aa6f3bed95f Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Fri, 11 Dec 2020 18:10:48 -0500 Subject: [PATCH] [spv] support transpose and access index on matrices --- src/back/spv/writer.rs | 137 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 126 insertions(+), 11 deletions(-) diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index f688c44811..b8d8c46741 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -67,6 +67,11 @@ enum LocalType { kind: crate::ScalarKind, width: crate::Bytes, }, + Matrix { + columns: crate::VectorSize, + rows: crate::VectorSize, + width: crate::Bytes, + }, Pointer { base: crate::Handle, class: crate::StorageClass, @@ -201,6 +206,10 @@ impl Writer { LookupType::Handle(handle) => match arena[handle].inner { crate::TypeInner::Scalar { kind, width } => self .get_type_id(arena, LookupType::Local(LocalType::Scalar { kind, width })), + crate::TypeInner::Vector { size, kind, width } => self.get_type_id( + arena, + LookupType::Local(LocalType::Vector { size, kind, width }), + ), _ => self.write_type_declaration_arena(arena, handle), }, LookupType::Local(local_ty) => self.write_type_declaration_local(arena, local_ty), @@ -491,6 +500,21 @@ impl Writer { self.get_type_id(arena, LookupType::Local(LocalType::Scalar { kind, width }))?; super::instructions::instruction_type_vector(id, scalar_id, size) } + LocalType::Matrix { + columns, + rows, + width, + } => { + let vector_id = self.get_type_id( + arena, + LookupType::Local(LocalType::Vector { + size: rows, + kind: crate::ScalarKind::Float, + width, + }), + )?; + super::instructions::instruction_type_matrix(id, vector_id, columns) + } LocalType::Pointer { .. } => { return Err(Error::FeatureNotImplemented("pointer declaration")) } @@ -530,7 +554,7 @@ impl Writer { } crate::TypeInner::Matrix { columns, - rows: _, + rows, width, } => { let vector_id = self.get_type_id( @@ -541,6 +565,14 @@ impl Writer { width, }), )?; + self.lookup_type.insert( + LookupType::Local(LocalType::Matrix { + columns, + rows, + width, + }), + id, + ); super::instructions::instruction_type_matrix(id, vector_id, columns) } crate::TypeInner::Image { @@ -880,10 +912,19 @@ impl Writer { LocalType::Vector { size, kind, width } => { MaybeOwned::Owned(crate::TypeInner::Vector { size, kind, width }) } + LocalType::Matrix { + columns, + rows, + width, + } => MaybeOwned::Owned(crate::TypeInner::Matrix { + columns, + rows, + width, + }), LocalType::Pointer { base, class } => { MaybeOwned::Owned(crate::TypeInner::Pointer { base, class }) } - _ => unreachable!(), + LocalType::Void | LocalType::SampledImage { .. } => unreachable!(), }, } } @@ -961,15 +1002,34 @@ impl Writer { let base_ty_inner = self.get_type_inner(&ir_module.types, base_lookup_ty); let (pointer_id, type_id, lookup_ty) = match *base_ty_inner { - crate::TypeInner::Vector { kind, width, .. } => { - let scalar_id = self.get_type_id( - &ir_module.types, - LookupType::Local(LocalType::Scalar { kind, width }), - )?; + crate::TypeInner::Vector { + size: _, + kind, + width, + } => { + let lookup_type = LookupType::Local(LocalType::Scalar { kind, width }); + let scalar_id = self.get_type_id(&ir_module.types, lookup_type)?; ( self.create_pointer(scalar_id, spirv::StorageClass::Function), scalar_id, - LookupType::Local(LocalType::Scalar { kind, width }), + lookup_type, + ) + } + crate::TypeInner::Matrix { + columns: _, + rows, + width, + } => { + let lookup_type = LookupType::Local(LocalType::Vector { + size: rows, + kind: crate::ScalarKind::Float, + width, + }); + let vector_id = self.get_type_id(&ir_module.types, lookup_type)?; + ( + self.create_pointer(vector_id, spirv::StorageClass::Function), + vector_id, + lookup_type, ) } crate::TypeInner::Struct { @@ -1252,9 +1312,64 @@ impl Writer { )); Ok((id, result_lookup_ty)) } - crate::Expression::Math { fun, .. } => { - log::error!("unimplemented math function {:?}", fun); - Err(Error::FeatureNotImplemented("math function")) + crate::Expression::Math { fun, arg, .. } => { + use crate::MathFunction as Mf; + + let arg0_expression = &ir_function.expressions[arg]; + let (arg0_id, arg0_lookup_ty) = self.write_expression( + ir_module, + ir_function, + arg0_expression, + block, + function, + )?; + let arg0_id = match *arg0_expression { + crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => { + let load_id = self.generate_id(); + let arg_result_id = self.get_type_id(&ir_module.types, arg0_lookup_ty)?; + block.body.push(super::instructions::instruction_load( + arg_result_id, + load_id, + arg0_id, + None, + )); + load_id + } + _ => arg0_id, + }; + + let id = self.generate_id(); + match fun { + Mf::Transpose => { + let result_lookup_ty = + match *self.get_type_inner(&ir_module.types, arg0_lookup_ty) { + crate::TypeInner::Matrix { + columns, + rows, + width, + } => LookupType::Local(LocalType::Matrix { + columns: rows, + rows: columns, + width, + }), + _ => unreachable!(), + }; + let result_type_id = + self.get_type_id(&ir_module.types, result_lookup_ty)?; + + block.body.push(super::instructions::instruction_unary( + spirv::Op::Transpose, + result_type_id, + id, + arg0_id, + )); + Ok((id, result_lookup_ty)) + } + _ => { + log::error!("unimplemented math function {:?}", fun); + Err(Error::FeatureNotImplemented("math function")) + } + } } crate::Expression::LocalVariable(variable) => { let var = &ir_function.local_variables[variable];