diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index f0041251a4..9f704dba4c 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -5,6 +5,65 @@ use std::collections::hash_map::Entry; const BITS_PER_BYTE: crate::Bytes = 8; +struct Block { + label: Option, + body: Vec, + termination: Option, +} + +impl Block { + pub fn new() -> Self { + Block { + label: None, + body: vec![], + termination: None, + } + } +} + +struct LocalVariable { + id: Word, + name: Option, + instruction: Instruction, +} + +struct Function { + signature: Option, + parameters: Vec, + variables: Vec, + blocks: Vec, +} + +impl Function { + pub fn new() -> Self { + Function { + signature: None, + parameters: vec![], + variables: vec![], + blocks: vec![], + } + } + + fn to_words(&self, sink: &mut impl Extend) { + self.signature.as_ref().unwrap().to_words(sink); + for instruction in self.parameters.iter() { + instruction.to_words(sink); + } + for (index, block) in self.blocks.iter().enumerate() { + block.label.as_ref().unwrap().to_words(sink); + if index == 0 { + for local_var in self.variables.iter() { + local_var.instruction.to_words(sink); + } + } + for instruction in block.body.iter() { + instruction.to_words(sink); + } + block.termination.as_ref().unwrap().to_words(sink); + } + } +} + enum Signedness { Unsigned = 0, Signed = 1, @@ -583,14 +642,14 @@ impl Writer { fn instruction_variable( &self, - pointer_type_id: Word, - id: Word, + result_type_id: Word, + result_id: Word, storage_class: spirv::StorageClass, initializer_id: Option, ) -> Instruction { let mut instruction = Instruction::new(Op::Variable); - instruction.set_type(pointer_type_id); - instruction.set_result(id); + instruction.set_type(result_type_id); + instruction.set_result(result_id); instruction.add_operand(storage_class as u32); if let Some(initializer_id) = initializer_id { @@ -658,6 +717,21 @@ impl Writer { Instruction::new(Op::FunctionEnd) } + fn instruction_function_call( + &self, + result_type_id: Word, + id: Word, + function_id: Word, + argument_ids: Vec, + ) -> Instruction { + let mut instruction = Instruction::new(Op::FunctionCall); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(function_id); + instruction.add_operands(argument_ids); + instruction + } + /// /// Image Instructions /// @@ -1089,121 +1163,104 @@ impl Writer { fn write_expression<'a>( &mut self, ir_module: &'a crate::Module, - function: &crate::Function, + ir_function: &crate::Function, expression: &crate::Expression, - output: &mut Vec, - parameter_type_ids: &Vec, - ) -> Option<(Word, &'a crate::TypeInner)> { + block: &mut Block, + function: &mut Function, + ) -> Option<(Word, Option>)> { match expression { crate::Expression::GlobalVariable(handle) => { let var = &ir_module.global_variables[*handle]; - let inner = &ir_module.types[var.ty].inner; let id = self.get_global_variable_id( &ir_module.types, &ir_module.global_variables, *handle, ); - Some((id, inner)) + Some((id, Some(var.ty))) } crate::Expression::Constant(handle) => { let var = &ir_module.constants[*handle]; - let inner = &ir_module.types[var.ty].inner; let id = self.get_constant_id(*handle, ir_module); - Some((id, inner)) + Some((id, Some(var.ty))) } crate::Expression::Compose { ty, components } => { - let var = &ir_module.types[*ty]; - let inner = &var.inner; let id = self.generate_id(); let type_id = self.get_type_id(&ir_module.types, LookupType::Handle(*ty)); let mut constituent_ids = Vec::with_capacity(components.len()); for component in components { - let expression = &function.expressions[*component]; + let expression = &ir_function.expressions[*component]; let (component_id, _) = self - .write_expression( - ir_module, - &function, - expression, - output, - parameter_type_ids, - ) + .write_expression(ir_module, &ir_function, expression, block, function) .unwrap(); constituent_ids.push(component_id); } let instruction = self.instruction_composite_construct(type_id, id, constituent_ids); - output.push(instruction); - - Some((id, inner)) + block.body.push(instruction); + Some((id, Some(*ty))) } crate::Expression::Binary { op, left, right } => { match op { crate::BinaryOperator::Multiply => { // TODO OpVectorTimesScalar is only supported let id = self.generate_id(); - let left_expression = &function.expressions[*left]; - let right_expression = &function.expressions[*right]; - let (left_id, left_inner) = self + let left_expression = &ir_function.expressions[*left]; + let right_expression = &ir_function.expressions[*right]; + let (left_id, left_ty) = self .write_expression( ir_module, - function, + ir_function, left_expression, - output, - parameter_type_ids, + block, + function, ) .unwrap(); - let (right_id, right_inner) = self + let (right_id, right_ty) = self .write_expression( ir_module, - function, + ir_function, right_expression, - output, - parameter_type_ids, + block, + function, ) .unwrap(); - let (result_type_id, vector_id, scalar_id) = match (left_inner, right_inner) - { - ( - crate::TypeInner::Vector { size, kind, width }, - crate::TypeInner::Scalar { .. }, - ) => { - let result_type_id = *self - .lookup_type - .get(&LookupType::Local(LocalType::Vector { - size: *size, - kind: *kind, - width: *width, - })) - .unwrap(); + let left_ty_inner = &ir_module.types[left_ty.unwrap()].inner; + let right_ty_inner = &ir_module.types[right_ty.unwrap()].inner; - (result_type_id, left_id, right_id) - } - ( - crate::TypeInner::Scalar { .. }, - crate::TypeInner::Vector { size, kind, width }, - ) => { - let result_type_id = *self - .lookup_type - .get(&LookupType::Local(LocalType::Vector { - size: *size, - kind: *kind, - width: *width, - })) - .unwrap(); - (result_type_id, right_id, left_id) - } - _ => unreachable!("Expression requires both a scalar and vector"), - }; + let (result_type_id, vector_id, scalar_id) = + match (left_ty_inner, right_ty_inner) { + ( + crate::TypeInner::Vector { .. }, + crate::TypeInner::Scalar { .. }, + ) => ( + self.get_type_id( + &ir_module.types, + LookupType::Handle(left_ty.unwrap()), + ), + left_id, + right_id, + ), + ( + crate::TypeInner::Scalar { .. }, + crate::TypeInner::Vector { .. }, + ) => ( + self.get_type_id( + &ir_module.types, + LookupType::Handle(right_ty.unwrap()), + ), + right_id, + left_id, + ), + _ => unreachable!("Expression requires both a scalar and vector"), + }; - // TODO Quick fix let load_id = self.generate_id(); - let load_instruction = self.instruction_load(result_type_id, load_id, vector_id, None); - output.push(load_instruction); + block.body.push(load_instruction); let instruction = self.instruction_vector_times_scalar( result_type_id, @@ -1211,102 +1268,161 @@ impl Writer { load_id, scalar_id, ); - output.push(instruction); - Some((id, &crate::TypeInner::Scalar { - kind: crate::ScalarKind::Float, - width: 10, - })) + block.body.push(instruction); + Some((id, None)) } _ => unimplemented!("{:?}", op), } } crate::Expression::LocalVariable(variable) => { - let id = self.generate_id(); - let var = &function.local_variables[*variable]; - let ty = &ir_module.types[var.ty]; + let var = &ir_function.local_variables[*variable]; + let id = if let Some(local_var) = function + .variables + .iter() + .find(|&v| v.name.as_ref().unwrap() == var.name.as_ref().unwrap()) + { + local_var.id + } else { + panic!("Could not find: {:?}", var) + }; - let pointer_id = - self.get_pointer_id(&ir_module.types, var.ty, spirv::StorageClass::Function); - - let instruction = - self.instruction_variable(pointer_id, id, spirv::StorageClass::Function, None); - output.push(instruction); - Some((id, &ty.inner)) + Some((id, Some(var.ty))) } crate::Expression::FunctionParameter(index) => { - let handle = function.parameter_types.get(*index as usize).unwrap(); + let handle = ir_function.parameter_types.get(*index as usize).unwrap(); let type_id = self.get_type_id(&ir_module.types, LookupType::Handle(*handle)); let load_id = self.generate_id(); - output.push(self.instruction_load( + block.body.push(self.instruction_load( type_id, load_id, - parameter_type_ids[*index as usize], + function.parameters[*index as usize].result_id.unwrap(), None, )); - Some(( - load_id, - &crate::TypeInner::Scalar { - kind: crate::ScalarKind::Float, - width: 10, - }, - )) + Some((load_id, Some(*handle))) } + crate::Expression::Call { origin, arguments } => match origin { + crate::FunctionOrigin::Local(local_function) => { + let origin_function = &ir_module.functions[*local_function]; + let id = self.generate_id(); + let mut argument_ids = vec![]; + + for argument in arguments { + let expression = &ir_function.expressions[*argument]; + let (id, ty) = self + .write_expression(ir_module, ir_function, expression, block, function) + .unwrap(); + + // Create variable - OpVariable + // Store value to variable - OpStore + // Use id of variable + + let pointer_id = self.get_pointer_id( + &ir_module.types, + ty.unwrap(), + spirv::StorageClass::Function, + ); + + let variable_id = self.generate_id(); + function.variables.push(LocalVariable { + id: variable_id, + name: None, + instruction: self.instruction_variable( + pointer_id, + variable_id, + spirv::StorageClass::Function, + None, + ), + }); + block.body.push(self.instruction_store(variable_id, id)); + argument_ids.push(variable_id); + } + + let return_type_id = self + .get_function_return_type(origin_function.return_type, &ir_module.types); + + block.body.push(self.instruction_function_call( + return_type_id, + id, + *self.lookup_function.get(local_function).unwrap(), + argument_ids, + )); + Some((id, None)) + } + _ => unimplemented!("{:?}", origin), + }, _ => unimplemented!("{:?}", expression), } } - fn write_function_block( + fn write_function_statement( &mut self, ir_module: &crate::Module, - function: &crate::Function, + ir_function: &crate::Function, statement: &crate::Statement, - output: &mut Vec, - parameter_type_ids: &Vec, + block: &mut Block, + function: &mut Function, ) { match statement { - crate::Statement::Return { value } => match function.return_type { - Some(ty) => { - let expression = &function.expressions[value.unwrap()]; - let (id, _) = self - .write_expression( - ir_module, - function, - expression, - output, - parameter_type_ids, - ) + crate::Statement::Return { value } => match ir_function.return_type { + Some(_) => { + let expression = &ir_function.expressions[value.unwrap()]; + let (id, ty) = self + .write_expression(ir_module, ir_function, expression, block, function) .unwrap(); - output.push(self.instruction_return_value(id)) + + let id = match expression { + crate::Expression::LocalVariable(_) => { + let load_id = self.generate_id(); + let value_ty_id = self + .get_type_id(&ir_module.types, LookupType::Handle(ty.unwrap())); + block.body.push(self.instruction_load( + value_ty_id, + load_id, + id, + None, + )); + load_id + } + _ => id + }; + block.termination = Some(self.instruction_return_value(id)); } - None => output.push(self.instruction_return()), + None => block.termination = Some(self.instruction_return()), }, crate::Statement::Store { pointer, value } => { - let pointer_expression = &function.expressions[*pointer]; - let value_expression = &function.expressions[*value]; + let pointer_expression = &ir_function.expressions[*pointer]; + let value_expression = &ir_function.expressions[*value]; let (pointer_id, _) = self - .write_expression( - ir_module, - function, - pointer_expression, - output, - parameter_type_ids, - ) + .write_expression(ir_module, ir_function, pointer_expression, block, function) .unwrap(); - let (value_id, _) = self - .write_expression( - ir_module, - function, - value_expression, - output, - parameter_type_ids, - ) + let (value_id, value_ty) = self + .write_expression(ir_module, ir_function, value_expression, block, function) .unwrap(); - output.push(self.instruction_store(pointer_id, value_id)); + let value_id = match value_expression { + crate::Expression::LocalVariable(_) => { + let load_id = self.generate_id(); + let value_ty_id = self + .get_type_id(&ir_module.types, LookupType::Handle(value_ty.unwrap())); + block.body.push(self.instruction_load( + value_ty_id, + load_id, + value_id, + None, + )); + load_id + } + _ => value_id, + }; + + block + .body + .push(self.instruction_store(pointer_id, value_id)); } - _ => unimplemented!(), + crate::Statement::Empty => {} + _ => unimplemented!("{:?}", statement), } } @@ -1342,19 +1458,46 @@ impl Writer { annotation.to_words(&mut self.logical_layout.annotations); } - for (handle, function) in ir_module.functions.iter() { - let mut function_instructions: Vec = vec![]; - let id = self.generate_id(); + for (handle, ir_function) in ir_module.functions.iter() { + let mut function = Function::new(); + + for (_, variable) in ir_function.local_variables.iter() { + let id = self.generate_id(); + + let init_word = match variable.init { + Some(exp) => match &ir_function.expressions[exp] { + crate::Expression::Constant(handle) => { + Some(self.get_constant_id(*handle, ir_module)) + } + _ => unreachable!(), + }, + None => None, + }; + + let pointer_id = self.get_pointer_id( + &ir_module.types, + variable.ty, + spirv::StorageClass::Function, + ); + function.variables.push(LocalVariable { + id, + name: variable.name.clone(), + instruction: self.instruction_variable( + pointer_id, + id, + spirv::StorageClass::Function, + init_word, + ), + }); + } let return_type_id = - self.get_function_return_type(function.return_type, &ir_module.types); - let mut parameter_type_ids = Vec::with_capacity(function.parameter_types.len()); + self.get_function_return_type(ir_function.return_type, &ir_module.types); + let mut parameter_type_ids = Vec::with_capacity(ir_function.parameter_types.len()); - let mut function_parameter_ids = vec![]; let mut function_parameter_pointer_ids = vec![]; - let mut function_parameter_instructions = vec![]; - for parameter_type in function.parameter_types.iter() { + for parameter_type in ir_function.parameter_types.iter() { let id = self.generate_id(); let pointer_id = self.get_pointer_id( &ir_module.types, @@ -1362,12 +1505,12 @@ impl Writer { spirv::StorageClass::Function, ); - function_parameter_ids.push(id); function_parameter_pointer_ids.push(pointer_id); - function_parameter_instructions - .push(self.instruction_function_parameter(pointer_id, id)); parameter_type_ids .push(self.get_type_id(&ir_module.types, LookupType::Handle(*parameter_type))); + function + .parameters + .push(self.instruction_function_parameter(pointer_id, id)); } let lookup_function_type = LookupFunctionType { @@ -1375,38 +1518,35 @@ impl Writer { parameter_type_ids, }; - let type_function_id = + let id = self.generate_id(); + let function_type = self.get_function_type(lookup_function_type, function_parameter_pointer_ids); - - function_instructions.push(self.instruction_function( + function.signature = Some(self.instruction_function( return_type_id, id, spirv::FunctionControl::empty(), - type_function_id, + function_type, )); - function_instructions.append(&mut function_parameter_instructions); self.lookup_function.insert(handle, id); + let mut block = Block::new(); let id = self.generate_id(); - function_instructions.push(self.instruction_label(id)); - - for block in function.body.iter() { - let mut output: Vec = vec![]; - self.write_function_block( + block.label = Some(self.instruction_label(id)); + for statement in ir_function.body.iter() { + self.write_function_statement( ir_module, - function, - &block, - &mut output, - &function_parameter_ids, + ir_function, + &statement, + &mut block, + &mut function, ); - function_instructions.append(&mut output); } + function.blocks.push(block); - function_instructions.push(self.instruction_function_end()); - for instruction in function_instructions.iter() { - instruction.to_words(&mut self.logical_layout.function_definitions); - } + function.to_words(&mut self.logical_layout.function_definitions); + self.instruction_function_end() + .to_words(&mut self.logical_layout.function_definitions); } for entry_point in ir_module.entry_points.iter() {