diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index 064bfa0014..7df29ec7fb 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -375,13 +375,13 @@ pub(super) fn instruction_variable( pub(super) fn instruction_load( result_type_id: Word, id: Word, - pointer_type_id: Word, + pointer_id: Word, memory_access: Option, ) -> Instruction { let mut instruction = Instruction::new(Op::Load); instruction.set_type(result_type_id); instruction.set_result(id); - instruction.add_operand(pointer_type_id); + instruction.add_operand(pointer_id); if let Some(memory_access) = memory_access { instruction.add_operand(memory_access.bits()); diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index b8d8c46741..e14b773c75 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -26,6 +26,11 @@ struct LocalVariable { instruction: Instruction, } +enum RawExpression { + Value(Word), + Pointer(Word), +} + #[derive(Default)] struct Function { signature: Option, @@ -87,6 +92,12 @@ enum LookupType { Local(LocalType), } +impl From for LookupType { + fn from(local: LocalType) -> Self { + Self::Local(local) + } +} + fn map_dim(dim: crate::ImageDimension) -> spirv::Dim { match dim { crate::ImageDimension::D1 => spirv::Dim::Dim1D, @@ -272,40 +283,48 @@ impl Writer { handle: crate::Handle, class: crate::StorageClass, ) -> Result { - let ty = &arena[handle]; let ty_id = self.get_type_id(arena, LookupType::Handle(handle))?; - Ok(match ty.inner { - crate::TypeInner::Pointer { .. } => ty_id, - _ => { - match self - .lookup_type - .entry(LookupType::Local(LocalType::Pointer { - base: handle, - class, - })) { - Entry::Occupied(e) => *e.get(), - _ => { - let id = - self.create_pointer(ty_id, self.parse_to_spirv_storage_class(class)); - self.lookup_type.insert( - LookupType::Local(LocalType::Pointer { - base: handle, - class, - }), - id, - ); - id - } + if let crate::TypeInner::Pointer { .. } = arena[handle].inner { + return Ok(ty_id); + } + Ok( + match self + .lookup_type + .entry(LookupType::Local(LocalType::Pointer { + base: handle, + class, + })) { + Entry::Occupied(e) => *e.get(), + _ => { + let storage_class = self.parse_to_spirv_storage_class(class); + let id = self.generate_id(); + let instruction = + super::instructions::instruction_type_pointer(id, storage_class, ty_id); + instruction.to_words(&mut self.logical_layout.declarations); + self.lookup_type.insert( + LookupType::Local(LocalType::Pointer { + base: handle, + class, + }), + id, + ); + id } - } - }) + }, + ) } - fn create_pointer(&mut self, ty_id: Word, class: spirv::StorageClass) -> Word { + fn create_pointer_type( + &mut self, + lookup_type: LookupType, + class: spirv::StorageClass, + type_arena: &crate::Arena, + ) -> Result<(Word, LookupType), Error> { + let type_id = self.get_type_id(type_arena, lookup_type)?; let id = self.generate_id(); - let instruction = super::instructions::instruction_type_pointer(id, class, ty_id); + let instruction = super::instructions::instruction_type_pointer(id, class, type_id); instruction.to_words(&mut self.logical_layout.declarations); - id + Ok((id, lookup_type)) } fn create_constant(&mut self, type_id: Word, value: &[Word]) -> Word { @@ -329,10 +348,10 @@ impl Writer { .init .map(|constant| self.get_constant_id(constant, ir_module)) .transpose()?; - let pointer_id = + let pointer_type_id = self.get_pointer_id(&ir_module.types, variable.ty, crate::StorageClass::Function)?; let instruction = super::instructions::instruction_variable( - pointer_id, + pointer_type_id, id, spirv::StorageClass::Function, init_word, @@ -350,17 +369,18 @@ impl Writer { for argument in ir_function.arguments.iter() { let id = self.generate_id(); - let pointer_id = + let pointer_type_id = self.get_pointer_id(&ir_module.types, argument.ty, crate::StorageClass::Function)?; - function_parameter_pointer_ids.push(pointer_id); + function_parameter_pointer_ids.push(pointer_type_id); let parameter_type_id = self.get_type_id(&ir_module.types, LookupType::Handle(argument.ty))?; parameter_type_ids.push(parameter_type_id); function .parameters .push(super::instructions::instruction_function_parameter( - pointer_id, id, + pointer_type_id, + id, )); } @@ -768,10 +788,10 @@ impl Writer { .init .map(|constant| self.get_constant_id(constant, ir_module)) .transpose()?; - let pointer_id = + let pointer_type_id = self.get_pointer_id(&ir_module.types, global_variable.ty, global_variable.class)?; let instruction = - super::instructions::instruction_variable(pointer_id, id, class, init_word); + super::instructions::instruction_variable(pointer_type_id, id, class, init_word); if self.flags.contains(WriterFlags::DEBUG) { if let Some(ref name) = global_variable.name { @@ -799,7 +819,7 @@ impl Writer { } } - match *global_variable.binding.as_ref().unwrap() { + match global_variable.binding.clone().unwrap() { crate::Binding::Location(location) => { self.annotations .push(super::instructions::instruction_decorate( @@ -929,47 +949,75 @@ impl Writer { } } + /// Write an expression and return a value ID. fn write_expression<'a>( &mut self, ir_module: &'a crate::Module, ir_function: &crate::Function, - expression: &crate::Expression, + handle: crate::Handle, block: &mut Block, function: &mut Function, ) -> Result { - match *expression { + let (raw_expression, lookup_ty) = + self.write_expression_raw(ir_module, ir_function, handle, block, function)?; + Ok(match raw_expression { + RawExpression::Value(id) => (id, lookup_ty), + RawExpression::Pointer(id) => { + let load_id = self.generate_id(); + let type_id = self.get_type_id(&ir_module.types, lookup_ty)?; + block.body.push(super::instructions::instruction_load( + type_id, load_id, id, None, + )); + (load_id, lookup_ty) + } + }) + } + + /// Write an expression and return a pointer ID to the result. + fn write_expression_pointer<'a>( + &mut self, + ir_module: &'a crate::Module, + ir_function: &crate::Function, + handle: crate::Handle, + block: &mut Block, + function: &mut Function, + ) -> Result { + let (raw_expression, lookup_ty) = + self.write_expression_raw(ir_module, ir_function, handle, block, function)?; + Ok(match raw_expression { + RawExpression::Value(_id) => { + //TODO: create a local variable? + return Err(Error::FeatureNotImplemented("getting pointer of a value")); + } + RawExpression::Pointer(id) => (id, lookup_ty), + }) + } + + /// Write an expression, and the result may be either a pointer, or a value. + fn write_expression_raw<'a>( + &mut self, + ir_module: &'a crate::Module, + ir_function: &crate::Function, + expr_handle: crate::Handle, + block: &mut Block, + function: &mut Function, + ) -> Result<(RawExpression, LookupType), Error> { + match ir_function.expressions[expr_handle] { crate::Expression::Access { base, index } => { let id = self.generate_id(); - let (base_id, base_lookup_ty) = self.write_expression( - ir_module, - ir_function, - &ir_function.expressions[base], - block, - function, - )?; - let (index_id, _) = self.write_expression( - ir_module, - ir_function, - &ir_function.expressions[index], - block, - function, - )?; + let (base_id, base_lookup_ty) = + self.write_expression_pointer(ir_module, ir_function, base, block, function)?; + let (index_id, _) = + self.write_expression(ir_module, ir_function, index, block, function)?; let base_ty_inner = self.get_type_inner(&ir_module.types, base_lookup_ty); - - let (pointer_id, type_id, lookup_ty) = match *base_ty_inner { - crate::TypeInner::Vector { kind, width, .. } => { - let scalar_id = self.get_type_id( - &ir_module.types, - LookupType::Local(LocalType::Scalar { kind, width }), - )?; - ( - self.create_pointer(scalar_id, spirv::StorageClass::Function), - scalar_id, - LookupType::Local(LocalType::Scalar { kind, width }), - ) - } + let (pointer_id, lookup_ty) = match *base_ty_inner { + crate::TypeInner::Vector { kind, width, .. } => self.create_pointer_type( + LocalType::Scalar { kind, width }.into(), + spirv::StorageClass::Function, + &ir_module.types, + )?, _ => return Err(Error::FeatureNotImplemented("accessing of non-vector")), }; @@ -982,68 +1030,50 @@ impl Writer { &[index_id], )); - let load_id = self.generate_id(); - block.body.push(super::instructions::instruction_load( - type_id, load_id, id, None, - )); - - Ok((load_id, lookup_ty)) + Ok((RawExpression::Pointer(id), lookup_ty)) } crate::Expression::AccessIndex { base, index } => { let id = self.generate_id(); - let (base_id, base_lookup_ty) = self.write_expression( - ir_module, - ir_function, - &ir_function.expressions[base], - block, - function, - )?; + let (base_id, base_lookup_ty) = + self.write_expression_pointer(ir_module, ir_function, base, block, function)?; let base_ty_inner = self.get_type_inner(&ir_module.types, base_lookup_ty); - - let (pointer_id, type_id, lookup_ty) = match *base_ty_inner { + let (pointer_id, lookup_ty) = match *base_ty_inner { crate::TypeInner::Vector { size: _, kind, width, - } => { - let lookup_type = LookupType::Local(LocalType::Scalar { kind, width }); - let scalar_id = self.get_type_id(&ir_module.types, lookup_type)?; - ( - self.create_pointer(scalar_id, spirv::StorageClass::Function), - scalar_id, - lookup_type, - ) - } + } => self.create_pointer_type( + LocalType::Scalar { kind, width }.into(), + spirv::StorageClass::Function, + &ir_module.types, + )?, crate::TypeInner::Matrix { columns: _, rows, width, } => { - let lookup_type = LookupType::Local(LocalType::Vector { + let local_type = LocalType::Vector { size: rows, kind: crate::ScalarKind::Float, width, - }); - let vector_id = self.get_type_id(&ir_module.types, lookup_type)?; - ( - self.create_pointer(vector_id, spirv::StorageClass::Function), - vector_id, - lookup_type, - ) + }; + self.create_pointer_type( + local_type.into(), + spirv::StorageClass::Function, + &ir_module.types, + )? } crate::TypeInner::Struct { block: _, ref members, } => { let member = &members[index as usize]; - let type_id = - self.get_type_id(&ir_module.types, LookupType::Handle(member.ty))?; - ( - self.create_pointer(type_id, spirv::StorageClass::Uniform), - type_id, + self.create_pointer_type( LookupType::Handle(member.ty), - ) + spirv::StorageClass::Uniform, + &ir_module.types, + )? } _ => { return Err(Error::FeatureNotImplemented( @@ -1070,53 +1100,31 @@ impl Writer { &[const_id], )); - let load_id = self.generate_id(); - block.body.push(super::instructions::instruction_load( - type_id, load_id, id, None, - )); - - Ok((load_id, lookup_ty)) + Ok((RawExpression::Pointer(id), lookup_ty)) } crate::Expression::GlobalVariable(handle) => { let var = &ir_module.global_variables[handle]; let id = self.get_global_variable_id(&ir_module, handle)?; - Ok((id, LookupType::Handle(var.ty))) + Ok((RawExpression::Pointer(id), LookupType::Handle(var.ty))) } crate::Expression::Constant(handle) => { let var = &ir_module.constants[handle]; let id = self.get_constant_id(handle, ir_module)?; - Ok((id, LookupType::Handle(var.ty))) + Ok((RawExpression::Value(id), LookupType::Handle(var.ty))) } crate::Expression::Compose { ty, ref components } => { let base_type_id = self.get_type_id(&ir_module.types, LookupType::Handle(ty))?; let mut constituent_ids = Vec::with_capacity(components.len()); for component in components { - let expression = &ir_function.expressions[*component]; - let (component_id, component_local_ty) = self.write_expression( + let (component_id, _) = self.write_expression( ir_module, &ir_function, - expression, + *component, block, function, )?; - - let component_id = match expression { - crate::Expression::LocalVariable(_) - | crate::Expression::GlobalVariable(_) => { - let load_id = self.generate_id(); - block.body.push(super::instructions::instruction_load( - self.get_type_id(&ir_module.types, component_local_ty)?, - load_id, - component_id, - None, - )); - load_id - } - _ => component_id, - }; - constituent_ids.push(component_id); } let constituent_ids_slice = constituent_ids.as_slice(); @@ -1161,26 +1169,14 @@ impl Writer { _ => unreachable!(), }; - Ok((id, LookupType::Handle(ty))) + Ok((RawExpression::Value(id), LookupType::Handle(ty))) } crate::Expression::Binary { op, left, right } => { let id = self.generate_id(); - let left_expression = &ir_function.expressions[left]; - let right_expression = &ir_function.expressions[right]; - let (left_id, left_lookup_ty) = self.write_expression( - ir_module, - ir_function, - left_expression, - block, - function, - )?; - let (right_id, right_lookup_ty) = self.write_expression( - ir_module, - ir_function, - right_expression, - block, - function, - )?; + let (left_id, left_lookup_ty) = + self.write_expression(ir_module, ir_function, left, block, function)?; + let (right_id, right_lookup_ty) = + self.write_expression(ir_module, ir_function, right, block, function)?; let left_ty_inner = self.get_type_inner(&ir_module.types, left_lookup_ty); let right_ty_inner = self.get_type_inner(&ir_module.types, right_lookup_ty); @@ -1188,35 +1184,6 @@ impl Writer { let left_result_type_id = self.get_type_id(&ir_module.types, left_lookup_ty)?; let right_result_type_id = self.get_type_id(&ir_module.types, right_lookup_ty)?; - let left_id = match *left_expression { - crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => { - let load_id = self.generate_id(); - block.body.push(super::instructions::instruction_load( - left_result_type_id, - load_id, - left_id, - None, - )); - load_id - } - _ => left_id, - }; - - let right_id = match *right_expression { - crate::Expression::LocalVariable(..) - | crate::Expression::GlobalVariable(..) => { - let load_id = self.generate_id(); - block.body.push(super::instructions::instruction_load( - right_result_type_id, - load_id, - right_id, - None, - )); - load_id - } - _ => right_id, - }; - let left_dimension = get_dimension(&left_ty_inner); let right_dimension = get_dimension(&right_ty_inner); @@ -1310,33 +1277,13 @@ impl Writer { if preserve_order { left_id } else { right_id }, if preserve_order { right_id } else { left_id }, )); - Ok((id, result_lookup_ty)) + Ok((RawExpression::Value(id), result_lookup_ty)) } crate::Expression::Math { fun, arg, .. } => { use crate::MathFunction as Mf; - let arg0_expression = &ir_function.expressions[arg]; - let (arg0_id, arg0_lookup_ty) = self.write_expression( - ir_module, - ir_function, - arg0_expression, - block, - function, - )?; - let arg0_id = match *arg0_expression { - crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => { - let load_id = self.generate_id(); - let arg_result_id = self.get_type_id(&ir_module.types, arg0_lookup_ty)?; - block.body.push(super::instructions::instruction_load( - arg_result_id, - load_id, - arg0_id, - None, - )); - load_id - } - _ => arg0_id, - }; + let (arg0_id, arg0_lookup_ty) = + self.write_expression(ir_module, ir_function, arg, block, function)?; let id = self.generate_id(); match fun { @@ -1363,7 +1310,7 @@ impl Writer { id, arg0_id, )); - Ok((id, result_lookup_ty)) + Ok((RawExpression::Value(id), result_lookup_ty)) } _ => { log::error!("unimplemented math function {:?}", fun); @@ -1374,20 +1321,15 @@ impl Writer { crate::Expression::LocalVariable(variable) => { let var = &ir_function.local_variables[variable]; let local_var = &function.variables[&variable]; - Ok((local_var.id, LookupType::Handle(var.ty))) + Ok(( + RawExpression::Pointer(local_var.id), + LookupType::Handle(var.ty), + )) } crate::Expression::FunctionArgument(index) => { let handle = ir_function.arguments[index as usize].ty; - let type_id = self.get_type_id(&ir_module.types, LookupType::Handle(handle))?; - let load_id = self.generate_id(); - - block.body.push(super::instructions::instruction_load( - type_id, - load_id, - function.parameters[index as usize].result_id.unwrap(), - None, - )); - Ok((load_id, LookupType::Handle(handle))) + let id = function.parameters[index as usize].result_id.unwrap(); + Ok((RawExpression::Pointer(id), LookupType::Handle(handle))) } crate::Expression::Call { function: local_function, @@ -1398,9 +1340,8 @@ impl Writer { let mut argument_ids = vec![]; for argument in arguments { - let expression = &ir_function.expressions[*argument]; let (arg_id, _) = - self.write_expression(ir_module, ir_function, expression, block, function)?; + self.write_expression(ir_module, ir_function, *argument, block, function)?; argument_ids.push(arg_id); } @@ -1420,7 +1361,7 @@ impl Writer { Some(ty_handle) => LookupType::Handle(ty_handle), None => LookupType::Local(LocalType::Void), }; - Ok((id, result_type)) + Ok((RawExpression::Value(id), result_type)) } crate::Expression::As { expr, @@ -1431,13 +1372,8 @@ impl Writer { return Err(Error::FeatureNotImplemented("bitcast")); } - let (expr_id, expr_type) = self.write_expression( - ir_module, - ir_function, - &ir_function.expressions[expr], - block, - function, - )?; + let (expr_id, expr_type) = + self.write_expression(ir_module, ir_function, expr, block, function)?; let expr_type_inner = self.get_type_inner(&ir_module.types, expr_type); @@ -1463,24 +1399,7 @@ impl Writer { (crate::ScalarKind::Uint, crate::ScalarKind::Float) => spirv::Op::ConvertUToF, // We assume it's either an identity cast, or int-uint. // In both cases no SPIR-V instructions need to be generated. - _ => { - let id = match ir_function.expressions[expr] { - crate::Expression::LocalVariable(_) - | crate::Expression::GlobalVariable(_) => { - let load_id = self.generate_id(); - let kind_type_id = self.get_type_id(&ir_module.types, expr_type)?; - block.body.push(super::instructions::instruction_load( - kind_type_id, - load_id, - expr_id, - None, - )); - load_id - } - _ => expr_id, - }; - return Ok((id, lookup_type)); - } + _ => return Ok((RawExpression::Value(expr_id), lookup_type)), }; let id = self.generate_id(); @@ -1489,7 +1408,7 @@ impl Writer { super::instructions::instruction_unary(op, kind_type_id, id, expr_id); block.body.push(instruction); - Ok((id, lookup_type)) + Ok((RawExpression::Value(id), lookup_type)) } crate::Expression::ImageSample { image, @@ -1501,29 +1420,8 @@ impl Writer { depth_ref: _, } => { // image - let image_expression = &ir_function.expressions[image]; - let (image_id, image_lookup_ty) = self.write_expression( - ir_module, - ir_function, - image_expression, - block, - function, - )?; - - let image_result_type_id = self.get_type_id(&ir_module.types, image_lookup_ty)?; - let image_id = match *image_expression { - crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => { - let load_id = self.generate_id(); - block.body.push(super::instructions::instruction_load( - image_result_type_id, - load_id, - image_id, - None, - )); - load_id - } - _ => image_id, - }; + let (image_id, image_lookup_ty) = + self.write_expression(ir_module, ir_function, image, block, function)?; let image_ty = match image_lookup_ty { LookupType::Handle(handle) => handle, @@ -1539,56 +1437,12 @@ impl Writer { )?; // sampler - let sampler_expression = &ir_function.expressions[sampler]; - let (sampler_id, sampler_lookup_ty) = self.write_expression( - ir_module, - ir_function, - sampler_expression, - block, - function, - )?; - - let sampler_result_type_id = - self.get_type_id(&ir_module.types, sampler_lookup_ty)?; - let sampler_id = match *sampler_expression { - crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => { - let load_id = self.generate_id(); - block.body.push(super::instructions::instruction_load( - sampler_result_type_id, - load_id, - sampler_id, - None, - )); - load_id - } - _ => sampler_id, - }; + let (sampler_id, _) = + self.write_expression(ir_module, ir_function, sampler, block, function)?; // coordinate - let coordinate_expression = &ir_function.expressions[coordinate]; - let (coordinate_id, coordinate_lookup_ty) = self.write_expression( - ir_module, - ir_function, - coordinate_expression, - block, - function, - )?; - - let coordinate_result_type_id = - self.get_type_id(&ir_module.types, coordinate_lookup_ty)?; - let coordinate_id = match *coordinate_expression { - crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) => { - let load_id = self.generate_id(); - block.body.push(super::instructions::instruction_load( - coordinate_result_type_id, - load_id, - coordinate_id, - None, - )); - load_id - } - _ => coordinate_id, - }; + let (coordinate_id, _) = + self.write_expression(ir_module, ir_function, coordinate, block, function)?; // component kind let image_type = &ir_module.types[image_ty]; @@ -1639,10 +1493,10 @@ impl Writer { }; block.body.push(main_instruction); - Ok((id, image_sample_result_type)) + Ok((RawExpression::Value(id), image_sample_result_type)) } - _ => { - log::error!("unimplemented {:?}", expression); + ref other => { + log::error!("unimplemented {:?}", other); Err(Error::FeatureNotImplemented("expression")) } } @@ -1670,71 +1524,28 @@ impl Writer { crate::Statement::Return { value } => { block.termination = Some(match ir_function.return_type { Some(_) => { - let expression = &ir_function.expressions[value.unwrap()]; - let (id, lookup_ty) = self.write_expression( + let (id, _) = self.write_expression( ir_module, ir_function, - expression, + value.unwrap(), &mut block, function, )?; - - let id = match *expression { - crate::Expression::LocalVariable(_) - | crate::Expression::GlobalVariable(_) => { - let load_id = self.generate_id(); - let value_ty_id = - self.get_type_id(&ir_module.types, lookup_ty)?; - block.body.push(super::instructions::instruction_load( - value_ty_id, - load_id, - id, - None, - )); - load_id - } - - _ => id, - }; super::instructions::instruction_return_value(id) } None => super::instructions::instruction_return(), }); } crate::Statement::Store { pointer, value } => { - let pointer_expression = &ir_function.expressions[pointer]; - let value_expression = &ir_function.expressions[value]; - let (pointer_id, _) = self.write_expression( + let (pointer_id, _) = self.write_expression_pointer( ir_module, ir_function, - pointer_expression, + pointer, &mut block, function, )?; - let (value_id, value_lookup_ty) = self.write_expression( - ir_module, - ir_function, - value_expression, - &mut block, - function, - )?; - - let value_id = match value_expression { - crate::Expression::LocalVariable(_) - | crate::Expression::GlobalVariable(_) => { - let load_id = self.generate_id(); - let value_ty_id = - self.get_type_id(&ir_module.types, value_lookup_ty)?; - block.body.push(super::instructions::instruction_load( - value_ty_id, - load_id, - value_id, - None, - )); - load_id - } - _ => value_id, - }; + let (value_id, _) = + self.write_expression(ir_module, ir_function, value, &mut block, function)?; block.body.push(super::instructions::instruction_store( pointer_id, value_id, None,