diff --git a/Cargo.toml b/Cargo.toml index ee33f41c7a..3939aa9a4c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ glsl = { version = "4.1", optional = true } pomelo = { version = "0.1.4", optional = true } thiserror = "1.0" serde = { version = "1.0", features = ["derive"], optional = true } +petgraph = { version ="0.5", optional = true } [features] default = [] @@ -28,6 +29,7 @@ glsl-validate = [] glsl-out = [] serialize = ["serde"] deserialize = ["serde"] +spirv-in = ["petgraph", "spirv"] [dev-dependencies] env_logger = "0.6" diff --git a/src/front/mod.rs b/src/front/mod.rs index 7e169017a0..ec4c36cf01 100644 --- a/src/front/mod.rs +++ b/src/front/mod.rs @@ -4,7 +4,7 @@ pub mod glsl; #[cfg(feature = "glsl-new")] pub mod glsl_new; -#[cfg(feature = "spirv")] +#[cfg(feature = "spirv-in")] pub mod spv; pub mod wgsl; diff --git a/src/front/spv/convert.rs b/src/front/spv/convert.rs new file mode 100644 index 0000000000..6329975e43 --- /dev/null +++ b/src/front/spv/convert.rs @@ -0,0 +1,75 @@ +use super::error::Error; +use num_traits::cast::FromPrimitive; +use std::convert::TryInto; + +pub fn map_binary_operator(word: spirv::Op) -> Result { + use crate::BinaryOperator; + use spirv::Op; + + match word { + // Arithmetic Instructions +, -, *, /, % + Op::IAdd | Op::FAdd => Ok(BinaryOperator::Add), + Op::ISub | Op::FSub => Ok(BinaryOperator::Subtract), + Op::IMul | Op::FMul => Ok(BinaryOperator::Multiply), + Op::UDiv | Op::SDiv | Op::FDiv => Ok(BinaryOperator::Divide), + Op::UMod | Op::SMod | Op::FMod => Ok(BinaryOperator::Modulo), + // Relational and Logical Instructions + Op::IEqual | Op::FOrdEqual | Op::FUnordEqual => Ok(BinaryOperator::Equal), + Op::INotEqual | Op::FOrdNotEqual | Op::FUnordNotEqual => Ok(BinaryOperator::NotEqual), + Op::ULessThan | Op::SLessThan | Op::FOrdLessThan | Op::FUnordLessThan => { + Ok(BinaryOperator::Less) + } + Op::ULessThanEqual + | Op::SLessThanEqual + | Op::FOrdLessThanEqual + | Op::FUnordLessThanEqual => Ok(BinaryOperator::LessEqual), + Op::UGreaterThan | Op::SGreaterThan | Op::FOrdGreaterThan | Op::FUnordGreaterThan => { + Ok(BinaryOperator::Greater) + } + Op::UGreaterThanEqual + | Op::SGreaterThanEqual + | Op::FOrdGreaterThanEqual + | Op::FUnordGreaterThanEqual => Ok(BinaryOperator::GreaterEqual), + _ => Err(Error::UnknownInstruction(word as u16)), + } +} + +pub fn map_vector_size(word: spirv::Word) -> Result { + match word { + 2 => Ok(crate::VectorSize::Bi), + 3 => Ok(crate::VectorSize::Tri), + 4 => Ok(crate::VectorSize::Quad), + _ => Err(Error::InvalidVectorSize(word)), + } +} + +pub fn map_storage_class(word: spirv::Word) -> Result { + use spirv::StorageClass as Sc; + match Sc::from_u32(word) { + Some(Sc::UniformConstant) => Ok(crate::StorageClass::Constant), + Some(Sc::Function) => Ok(crate::StorageClass::Function), + Some(Sc::Input) => Ok(crate::StorageClass::Input), + Some(Sc::Output) => Ok(crate::StorageClass::Output), + Some(Sc::Private) => Ok(crate::StorageClass::Private), + Some(Sc::StorageBuffer) => Ok(crate::StorageClass::StorageBuffer), + Some(Sc::Uniform) => Ok(crate::StorageClass::Uniform), + Some(Sc::Workgroup) => Ok(crate::StorageClass::WorkGroup), + _ => Err(Error::UnsupportedStorageClass(word)), + } +} + +pub fn map_image_dim(word: spirv::Word) -> Result { + match spirv::Dim::from_u32(word) { + Some(spirv::Dim::Dim1D) => Ok(crate::ImageDimension::D1), + Some(spirv::Dim::Dim2D) => Ok(crate::ImageDimension::D2), + Some(spirv::Dim::Dim3D) => Ok(crate::ImageDimension::D3), + Some(spirv::Dim::DimCube) => Ok(crate::ImageDimension::Cube), + _ => Err(Error::UnsupportedImageDim(word)), + } +} + +pub fn map_width(word: spirv::Word) -> Result { + (word >> 3) // bits to bytes + .try_into() + .map_err(|_| Error::InvalidTypeWidth(word)) +} diff --git a/src/front/spv/error.rs b/src/front/spv/error.rs new file mode 100644 index 0000000000..f4384cd892 --- /dev/null +++ b/src/front/spv/error.rs @@ -0,0 +1,52 @@ +use super::ModuleState; +use crate::arena::Handle; + +#[derive(Debug)] +pub enum Error { + InvalidHeader, + InvalidWordCount, + UnknownInstruction(u16), + UnknownCapability(spirv::Word), + UnsupportedInstruction(ModuleState, spirv::Op), + UnsupportedCapability(spirv::Capability), + UnsupportedExtension(String), + UnsupportedExtSet(String), + UnsupportedExtInstSet(spirv::Word), + UnsupportedExtInst(spirv::Word), + UnsupportedType(Handle), + UnsupportedExecutionModel(spirv::Word), + UnsupportedStorageClass(spirv::Word), + UnsupportedImageDim(spirv::Word), + UnsupportedBuiltIn(spirv::Word), + UnsupportedControlFlow(spirv::Word), + UnsupportedBinaryOperator(spirv::Word), + InvalidParameter(spirv::Op), + InvalidOperandCount(spirv::Op, u16), + InvalidOperand, + InvalidId(spirv::Word), + InvalidDecoration(spirv::Word), + InvalidTypeWidth(spirv::Word), + InvalidSign(spirv::Word), + InvalidInnerType(spirv::Word), + InvalidVectorSize(spirv::Word), + InvalidVariableClass(spirv::StorageClass), + InvalidAccessType(spirv::Word), + InvalidAccess(Handle), + InvalidAccessIndex(spirv::Word), + InvalidLoadType(spirv::Word), + InvalidStoreType(spirv::Word), + InvalidBinding(spirv::Word), + InvalidImageExpression(Handle), + InvalidSamplerExpression(Handle), + InvalidSampleImage(Handle), + InvalidSampleSampler(Handle), + InvalidSampleCoordinates(Handle), + InvalidDepthReference(Handle), + InconsistentComparisonSampling(Handle), + WrongFunctionResultType(spirv::Word), + WrongFunctionParameterType(spirv::Word), + MissingDecoration(spirv::Decoration), + BadString, + IncompleteData, + InvalidTerminator, +} diff --git a/src/front/spv/flow.rs b/src/front/spv/flow.rs new file mode 100644 index 0000000000..a2f35b8464 --- /dev/null +++ b/src/front/spv/flow.rs @@ -0,0 +1,404 @@ +#![allow(dead_code)] + +use super::error::Error; +///! see https://en.wikipedia.org/wiki/Control-flow_graph +///! see https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_structuredcontrolflow_a_structured_control_flow +use super::function::{BlockId, MergeInstruction, Terminator}; + +use crate::FastHashMap; + +use petgraph::{ + graph::{node_index, NodeIndex}, + visit::EdgeRef, + Directed, Direction, +}; + +use std::fmt::Write; + +/// Index of a block node in the `ControlFlowGraph`. +type BlockNodeIndex = NodeIndex; + +/// Internal representation of a CFG constisting of function's basic blocks. +type ControlFlowGraph = petgraph::Graph; + +/// Control flow graph (CFG) containing relationships between blocks. +pub struct FlowGraph { + /// + flow: ControlFlowGraph, + + /// Block ID to Node index mapping. Internal helper to speed up the classification. + block_to_node: FastHashMap, +} + +impl FlowGraph { + /// Creates empty flow graph. + pub fn new() -> Self { + Self { + flow: ControlFlowGraph::default(), + block_to_node: FastHashMap::default(), + } + } + + /// Add a control flow node. + pub fn add_node(&mut self, node: ControlFlowNode) { + let block_id = node.id; + let node_index = self.flow.add_node(node); + self.block_to_node.insert(block_id, node_index); + } + + /// + /// 1. Creates edges in the CFG. + /// 2. Classifies types of blocks and edges in the CFG. + pub fn classify(&mut self) { + let block_to_node = &mut self.block_to_node; + + // 1. + // Add all edges + // Classify Nodes as one of [Header, Loop, Kill, Return] + for source_node_index in self.flow.node_indices() { + // Merge edges + if let Some(merge) = self.flow[source_node_index].merge { + let merge_block_index = block_to_node[&merge.merge_block_id]; + + self.flow[source_node_index].ty = Some(ControlFlowNodeType::Header); + self.flow[merge_block_index].ty = Some(ControlFlowNodeType::Merge); + self.flow.add_edge( + source_node_index, + merge_block_index, + ControlFlowEdgeType::ForwardMerge, + ); + + if let Some(continue_block_id) = merge.continue_block_id { + let continue_block_index = block_to_node[&continue_block_id]; + + self.flow[source_node_index].ty = Some(ControlFlowNodeType::Loop); + self.flow.add_edge( + source_node_index, + continue_block_index, + ControlFlowEdgeType::ForwardContinue, + ); + } + } + + // Branch Edges + match self.flow[source_node_index].terminator { + Terminator::Branch { target_id } => { + let target_node_index = block_to_node[&target_id]; + self.flow.add_edge( + source_node_index, + target_node_index, + ControlFlowEdgeType::Forward, + ); + } + Terminator::BranchConditional { + true_id, false_id, .. + } => { + let true_node_index = block_to_node[&true_id]; + let false_node_index = block_to_node[&false_id]; + + self.flow.add_edge( + source_node_index, + true_node_index, + ControlFlowEdgeType::IfTrue, + ); + self.flow.add_edge( + source_node_index, + false_node_index, + ControlFlowEdgeType::IfFalse, + ); + } + Terminator::Switch { .. } => { + // TODO + } + Terminator::Return { .. } => { + self.flow[source_node_index].ty = Some(ControlFlowNodeType::Return) + } + Terminator::Kill => { + self.flow[source_node_index].ty = Some(ControlFlowNodeType::Kill) + } + _ => {} + }; + } + + // Classify Nodes/Edges as one of [Break, Continue, Back] + for edge_index in self.flow.edge_indices() { + let (node_source_index, node_target_index) = + self.flow.edge_endpoints(edge_index).unwrap(); + + // Back + if self.flow[node_target_index].ty == Some(ControlFlowNodeType::Loop) + && self.flow[node_source_index].id > self.flow[node_target_index].id + { + self.flow[node_source_index].ty = Some(ControlFlowNodeType::Back); + self.flow[edge_index] = ControlFlowEdgeType::Back; + } + + let mut target_incoming_edges = self + .flow + .neighbors_directed(node_target_index, Direction::Incoming) + .detach(); + while let Some((incoming_edge, incoming_source)) = + target_incoming_edges.next(&self.flow) + { + // Loop continue + if self.flow[incoming_edge] == ControlFlowEdgeType::ForwardContinue { + self.flow[node_source_index].ty = Some(ControlFlowNodeType::Continue); + self.flow[edge_index] = ControlFlowEdgeType::LoopContinue; + } + // Loop break + if self.flow[incoming_source].ty == Some(ControlFlowNodeType::Loop) + && self.flow[incoming_edge] == ControlFlowEdgeType::ForwardMerge + { + self.flow[node_source_index].ty = Some(ControlFlowNodeType::Break); + self.flow[edge_index] = ControlFlowEdgeType::LoopBreak; + } + } + } + } + + /// TODO + /// Removes OpPhi instructions from the control flow graph and turns them into ordinary variables. + /// + /// Phi instructions are not supported inside Naga nor do they exist as instructions on CPUs. It is neccessary + /// to remove them and turn into ordinary variables before converting to Naga's IR and shader code. + pub fn remove_phi_instructions() { + unimplemented!(); + } + + /// Traverses the flow graph and returns a list of Naga's statements. + pub fn to_naga(&self) -> Result { + self.naga_traverse(node_index(0)) + } + + fn naga_traverse(&self, node_index: BlockNodeIndex) -> Result { + let node = &self.flow[node_index]; + + match node.ty { + Some(ControlFlowNodeType::Header) => { + match node.terminator { + Terminator::BranchConditional { + condition, + true_id, + false_id, + } => { + let mut result = node.block.clone(); + result.push(crate::Statement::If { + condition, + accept: self.naga_traverse(self.block_to_node[&true_id])?, + reject: self.naga_traverse(self.block_to_node[&false_id])?, + }); + Ok(result) + } + Terminator::Switch { .. } => { + // TODO + Ok(node.block.clone()) + } + _ => Err(Error::InvalidTerminator), + } + } + Some(ControlFlowNodeType::Loop) => { + let continuing: crate::Block = { + let continue_edge = self + .flow + .edges_directed(node_index, Direction::Outgoing) + .find(|&ty| *ty.weight() == ControlFlowEdgeType::ForwardContinue) + .unwrap(); + + self.flow[continue_edge.target()].block.clone() + }; + + let mut body: crate::Block = node.block.clone(); + match node.terminator { + Terminator::BranchConditional { + condition, + true_id, + false_id, + } => body.push(crate::Statement::If { + condition, + accept: self.naga_traverse(self.block_to_node[&true_id])?, + reject: self.naga_traverse(self.block_to_node[&false_id])?, + }), + Terminator::Branch { target_id } => { + body.extend(self.naga_traverse(self.block_to_node[&target_id])?) + } + _ => return Err(Error::InvalidTerminator), + }; + + Ok(vec![crate::Statement::Loop { body, continuing }]) + } + Some(ControlFlowNodeType::Break) => { + let mut result = node.block.clone(); + match node.terminator { + Terminator::BranchConditional { + condition, + true_id, + false_id, + } => result.push(crate::Statement::If { + condition, + accept: self.naga_traverse(self.block_to_node[&true_id])?, + reject: self.naga_traverse(self.block_to_node[&false_id])?, + }), + _ => return Err(Error::InvalidTerminator), + }; + Ok(result) + } + Some(ControlFlowNodeType::Continue) => { + let mut result = node.block.clone(); + result.push(crate::Statement::Continue); + Ok(result) + } + Some(ControlFlowNodeType::Back) | Some(ControlFlowNodeType::Merge) => { + Ok(node.block.clone()) + } + Some(ControlFlowNodeType::Kill) => { + let mut result = node.block.clone(); + result.push(crate::Statement::Kill); + Ok(result) + } + Some(ControlFlowNodeType::Return) => { + let value = match node.terminator { + Terminator::Return { value } => value, + _ => return Err(Error::InvalidTerminator), + }; + let mut result = node.block.clone(); + result.push(crate::Statement::Return { value }); + Ok(result) + } + None => match node.terminator { + Terminator::Branch { target_id } => { + let mut result = node.block.clone(); + result.extend(self.naga_traverse(self.block_to_node[&target_id])?); + Ok(result) + } + _ => Ok(node.block.clone()), + }, + } + } + + /// Get the entire graph in a graphviz dot format for visualization. Useful for debugging purposes. + pub fn to_graphviz(&self) -> Result { + let mut output = String::new(); + + output += "digraph ControlFlowGraph {"; + + for node_index in self.flow.node_indices() { + let node = &self.flow[node_index]; + writeln!( + output, + "{} [ label = \"%{} {:?}\" ]", + node_index.index(), + node.id, + node.ty + )?; + } + + for edge in self.flow.raw_edges() { + let source = edge.source(); + let target = edge.target(); + + let style = match edge.weight { + ControlFlowEdgeType::IfTrue => "color=blue", + ControlFlowEdgeType::IfFalse => "color=red", + ControlFlowEdgeType::ForwardMerge => "style=dotted", + _ => "", + }; + + writeln!( + &mut output, + "{} -> {} [ {} ]", + source.index(), + target.index(), + style + )?; + } + + output += "}\n"; + + Ok(output) + } +} + +/// Type of an edge(flow) in the `ControlFlowGraph`. +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +pub enum ControlFlowEdgeType { + /// Default + Forward, + + /// Forward edge to a merge block. + ForwardMerge, + + /// Forward edge to a OpLoopMerge continue's instruction. + ForwardContinue, + + /// A back-edge: An edge from a node to one of its ancestors in a depth-first + /// search from the entry block. + /// Can only be to a ControlFlowNodeType::Loop. + Back, + + /// An edge from a node to the merge block of the nearest enclosing loop, where + /// there is no intervening switch. + /// The source block is a "break block" as defined by SPIR-V. + LoopBreak, + + /// An edge from a node in a loop body to the associated continue target, where + /// there are no other intervening loops or switches. + /// The source block is a "continue block" as defined by SPIR-V. + LoopContinue, + + /// An edge from a node with OpBranchConditional to the block of true operand. + IfTrue, + + /// An edge from a node with OpBranchConditional to the block of false operand. + IfFalse, + + /// An edge from a node to the merge block of the nearest enclosing switch, + /// where there is no intervening loop. + SwitchBreak, + + /// An edge from one switch case to the next sibling switch case. + CaseFallThrough, +} +/// Type of a node(block) in the `ControlFlowGraph`. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum ControlFlowNodeType { + /// A block whose merge instruction is an OpSelectionMerge. + Header, + + /// A header block whose merge instruction is an OpLoopMerge. + Loop, + + /// A block declared by the Merge Block operand of a merge instruction. + Merge, + + /// A block containing a branch to the Merge Block of a loop header’s merge instruction. + Break, + + /// A block containing a branch to an OpLoopMerge instruction’s Continue Target. + Continue, + + /// A block containing an OpBranch to a Loop block. + Back, + + /// A block containing an OpKill instruction. + Kill, + + /// A block containing an OpReturn or OpReturnValue branch. + Return, +} +/// ControlFlowGraph's node representing a block in the control flow. +pub struct ControlFlowNode { + /// SPIR-V ID. + pub id: BlockId, + + /// Type of the node. See *ControlFlowNodeType*. + pub ty: Option, + + /// Naga's statements inside this block. + pub block: crate::Block, + + /// Termination instruction of the block. + pub terminator: Terminator, + + /// Merge Instruction + pub merge: Option, +} diff --git a/src/front/spv/function.rs b/src/front/spv/function.rs new file mode 100644 index 0000000000..4012242edf --- /dev/null +++ b/src/front/spv/function.rs @@ -0,0 +1,163 @@ +use crate::arena::Handle; + +use super::flow::*; +use super::*; + +pub type BlockId = u32; + +#[derive(Copy, Clone, Debug)] +pub struct MergeInstruction { + pub merge_block_id: BlockId, + pub continue_block_id: Option, +} +/// Terminator instruction of a SPIR-V's block. +#[derive(Clone, Debug)] +#[allow(dead_code)] +pub enum Terminator { + /// + Return { + value: Option>, + }, + /// + Branch { target_id: BlockId }, + /// + BranchConditional { + condition: Handle, + true_id: BlockId, + false_id: BlockId, + }, + /// + /// switch(SELECTOR) { + /// case TARGET_LITERAL#: { + /// TARGET_BLOCK# + /// } + /// default: { + /// DEFAULT + /// } + /// } + Switch { + /// + selector: Handle, + /// Default block of the switch case. + default: BlockId, + /// Tuples of (literal, target block) + targets: Vec<(i32, BlockId)>, + }, + /// Fragment shader discard + Kill, + /// + Unreachable, +} + +pub fn parse_function>( + parser: &mut super::Parser, + inst: Instruction, + module: &mut crate::Module, +) -> Result<(), Error> { + parser.switch(ModuleState::Function, inst.op)?; + inst.expect(5)?; + let result_type = parser.next()?; + let fun_id = parser.next()?; + let _fun_control = parser.next()?; + let fun_type = parser.next()?; + let mut fun = { + let ft = parser.lookup_function_type.lookup(fun_type)?; + if ft.return_type_id != result_type { + return Err(Error::WrongFunctionResultType(result_type)); + } + crate::Function { + name: parser.future_decor.remove(&fun_id).and_then(|dec| dec.name), + parameter_types: Vec::with_capacity(ft.parameter_type_ids.len()), + return_type: if parser.lookup_void_type.contains(&result_type) { + None + } else { + Some(parser.lookup_type.lookup(result_type)?.handle) + }, + global_usage: Vec::new(), + local_variables: Arena::new(), + expressions: parser.make_expression_storage(), + body: Vec::new(), + } + }; + + // read parameters + for i in 0..fun.parameter_types.capacity() { + match parser.next_inst()? { + Instruction { + op: spirv::Op::FunctionParameter, + wc: 3, + } => { + let type_id = parser.next()?; + let _id = parser.next()?; + //Note: we redo the lookup in order to work around `parser` borrowing + if type_id + != parser + .lookup_function_type + .lookup(fun_type)? + .parameter_type_ids[i] + { + return Err(Error::WrongFunctionParameterType(type_id)); + } + let ty = parser.lookup_type.lookup(type_id)?.handle; + fun.parameter_types.push(ty); + } + Instruction { op, .. } => return Err(Error::InvalidParameter(op)), + } + } + + // Read body + let mut local_function_calls = FastHashMap::default(); + let mut flow_graph = FlowGraph::new(); + + // Scan the blocks and add them as nodes + loop { + let fun_inst = parser.next_inst()?; + log::debug!("\t\t{:?}", fun_inst.op); + match fun_inst.op { + spirv::Op::Label => { + // Read the label ID + fun_inst.expect(2)?; + let block_id = parser.next()?; + + let node = parser.next_block( + block_id, + &mut fun.expressions, + &mut fun.local_variables, + &module.types, + &module.constants, + &module.global_variables, + &mut local_function_calls, + )?; + + flow_graph.add_node(node); + } + spirv::Op::FunctionEnd => { + fun_inst.expect(1)?; + break; + } + _ => { + return Err(Error::UnsupportedInstruction(parser.state, fun_inst.op)); + } + } + } + + flow_graph.classify(); + fun.body = flow_graph.to_naga()?; + + // done + fun.global_usage = + crate::GlobalUse::scan(&fun.expressions, &fun.body, &module.global_variables); + let handle = module.functions.append(fun); + for (expr_handle, dst_id) in local_function_calls { + parser.deferred_function_calls.push(DeferredFunctionCall { + source_handle: handle, + expr_handle, + dst_id, + }); + } + + parser.lookup_function.insert(fun_id, handle); + parser.lookup_expression.clear(); + parser.lookup_sampled_image.clear(); + Ok(()) +} diff --git a/src/front/spv.rs b/src/front/spv/mod.rs similarity index 89% rename from src/front/spv.rs rename to src/front/spv/mod.rs index e98c14ef33..701c150db5 100644 --- a/src/front/spv.rs +++ b/src/front/spv/mod.rs @@ -9,6 +9,17 @@ extra info, such as the related SPIR-V type ID. TODO: would be nice to find ways that avoid looking up as much !*/ +#![allow(dead_code)] + +mod convert; +mod error; +mod flow; +mod function; + +use convert::*; +use error::Error; +use flow::*; +use function::*; use crate::{ arena::{Arena, Handle}, @@ -22,55 +33,8 @@ pub const SUPPORTED_CAPABILITIES: &[spirv::Capability] = &[spirv::Capability::Sh pub const SUPPORTED_EXTENSIONS: &[&str] = &[]; pub const SUPPORTED_EXT_SETS: &[&str] = &["GLSL.std.450"]; -#[derive(Debug)] -pub enum Error { - InvalidHeader, - InvalidWordCount, - UnknownInstruction(u16), - UnknownCapability(spirv::Word), - UnsupportedInstruction(ModuleState, spirv::Op), - UnsupportedCapability(spirv::Capability), - UnsupportedExtension(String), - UnsupportedExtSet(String), - UnsupportedExtInstSet(spirv::Word), - UnsupportedExtInst(spirv::Word), - UnsupportedType(Handle), - UnsupportedExecutionModel(spirv::Word), - UnsupportedStorageClass(spirv::Word), - UnsupportedImageDim(spirv::Word), - UnsupportedBuiltIn(spirv::Word), - UnsupportedControlFlow(spirv::Word), - InvalidParameter(spirv::Op), - InvalidOperandCount(spirv::Op, u16), - InvalidOperand, - InvalidId(spirv::Word), - InvalidDecoration(spirv::Word), - InvalidTypeWidth(spirv::Word), - InvalidSign(spirv::Word), - InvalidInnerType(spirv::Word), - InvalidVectorSize(spirv::Word), - InvalidVariableClass(spirv::StorageClass), - InvalidAccessType(spirv::Word), - InvalidAccess(Handle), - InvalidAccessIndex(spirv::Word), - InvalidLoadType(spirv::Word), - InvalidStoreType(spirv::Word), - InvalidBinding(spirv::Word), - InvalidImageExpression(Handle), - InvalidSamplerExpression(Handle), - InvalidSampleImage(Handle), - InvalidSampleSampler(Handle), - InvalidSampleCoordinates(Handle), - InvalidDepthReference(Handle), - InconsistentComparisonSampling(Handle), - WrongFunctionResultType(spirv::Word), - WrongFunctionParameterType(spirv::Word), - MissingDecoration(spirv::Decoration), - BadString, - IncompleteData, -} - -struct Instruction { +#[derive(Copy, Clone)] +pub struct Instruction { op: spirv::Op, wc: u16, } @@ -122,46 +86,6 @@ impl LookupHelper for FastHashMap { } } -fn map_vector_size(word: spirv::Word) -> Result { - match word { - 2 => Ok(crate::VectorSize::Bi), - 3 => Ok(crate::VectorSize::Tri), - 4 => Ok(crate::VectorSize::Quad), - _ => Err(Error::InvalidVectorSize(word)), - } -} - -fn map_storage_class(word: spirv::Word) -> Result { - use spirv::StorageClass as Sc; - match Sc::from_u32(word) { - Some(Sc::UniformConstant) => Ok(crate::StorageClass::Constant), - Some(Sc::Function) => Ok(crate::StorageClass::Function), - Some(Sc::Input) => Ok(crate::StorageClass::Input), - Some(Sc::Output) => Ok(crate::StorageClass::Output), - Some(Sc::Private) => Ok(crate::StorageClass::Private), - Some(Sc::StorageBuffer) => Ok(crate::StorageClass::StorageBuffer), - Some(Sc::Uniform) => Ok(crate::StorageClass::Uniform), - Some(Sc::Workgroup) => Ok(crate::StorageClass::WorkGroup), - _ => Err(Error::UnsupportedStorageClass(word)), - } -} - -fn map_image_dim(word: spirv::Word) -> Result { - match spirv::Dim::from_u32(word) { - Some(spirv::Dim::Dim1D) => Ok(crate::ImageDimension::D1), - Some(spirv::Dim::Dim2D) => Ok(crate::ImageDimension::D2), - Some(spirv::Dim::Dim3D) => Ok(crate::ImageDimension::D3), - Some(spirv::Dim::DimCube) => Ok(crate::ImageDimension::Cube), - _ => Err(Error::UnsupportedImageDim(word)), - } -} - -fn map_width(word: spirv::Word) -> Result { - (word >> 3) // bits to bytes - .try_into() - .map_err(|_| Error::InvalidTypeWidth(word)) -} - //TODO: this method may need to be gone, depending on whether // WGSL allows treating images and samplers as expressions and pass them around. fn reach_global_type( @@ -331,33 +255,18 @@ struct LookupSampledImage { image: Handle, sampler: Handle, } - struct DeferredFunctionCall { source_handle: Handle, expr_handle: Handle, dst_id: spirv::Word, } -enum Terminator { - Return { - value: Option>, - }, - Branch { - label_id: spirv::Word, - condition: Option>, - }, -} - -struct Assignment { +#[derive(Clone, Debug)] +pub struct Assignment { to: Handle, value: Handle, } -struct ControlFlowNode { - assignments: Vec, - terminator: Terminator, -} - pub struct Parser { data: I, state: ModuleState, @@ -550,8 +459,10 @@ impl> Parser { Ok(()) } + #[allow(clippy::too_many_arguments)] fn next_block( &mut self, + block_id: spirv::Word, expressions: &mut Arena, local_arena: &mut Arena, type_arena: &Arena, @@ -560,10 +471,12 @@ impl> Parser { local_function_calls: &mut FastHashMap, spirv::Word>, ) -> Result { let mut assignments = Vec::new(); + let mut merge = None; let terminator = loop { use spirv::Op; let inst = self.next_inst()?; log::debug!("\t\t{:?} [{}]", inst.op, inst.wc); + match inst.op { Op::Variable => { inst.expect_at_least(4)?; @@ -679,13 +592,11 @@ impl> Parser { }; } - self.lookup_expression.insert( - result_id, - LookupExpression { - handle: acex.base_handle, - type_id: result_type_id, - }, - ); + let lookup_expression = LookupExpression { + handle: acex.base_handle, + type_id: result_type_id, + }; + self.lookup_expression.insert(result_id, lookup_expression); } Op::CompositeExtract => { inst.expect_at_least(4)?; @@ -808,26 +719,11 @@ impl> Parser { value: value_expr.handle, }); } - Op::Return => { - inst.expect(1)?; - break Terminator::Return { value: None }; - } - Op::Branch => { - inst.expect(2)?; - let label_id = self.next()?; - break Terminator::Branch { - label_id, - condition: None, - }; - } - Op::FSub => { + // Arithmetic Instructions +, -, *, /, % + _ if inst.op >= Op::IAdd && inst.op <= Op::FMod => { inst.expect(5)?; self.parse_expr_binary_op(expressions, crate::BinaryOperator::Subtract)?; } - Op::FMul => { - inst.expect(5)?; - self.parse_expr_binary_op(expressions, crate::BinaryOperator::Multiply)?; - } Op::VectorTimesScalar => { inst.expect(5)?; let result_type_id = self.next()?; @@ -1134,12 +1030,99 @@ impl> Parser { }, ); } + // Relational and Logical Instructions + op if inst.op >= Op::IEqual && inst.op <= Op::FUnordGreaterThanEqual => { + inst.expect(5)?; + self.parse_expr_binary_op(expressions, map_binary_operator(op)?)?; + } + Op::Return => { + inst.expect(1)?; + break Terminator::Return { value: None }; + } + Op::Branch => { + inst.expect(2)?; + let target_id = self.next()?; + break Terminator::Branch { target_id }; + } + Op::BranchConditional => { + inst.expect_at_least(4)?; + + let condition_id = self.next()?; + let condition = self.lookup_expression.lookup(condition_id)?.handle; + + let true_id = self.next()?; + let false_id = self.next()?; + + break Terminator::BranchConditional { + condition, + true_id, + false_id, + }; + } + Op::Switch => { + inst.expect_at_least(3)?; + + let selector = self.next()?; + let selector = self.lookup_expression[&selector].handle; + let default = self.next()?; + + let mut targets = Vec::new(); + for _ in 0..inst.wc - 3 { + let literal = self.next()?; + let target = self.next()?; + targets.push((literal as i32, target)); + } + + break Terminator::Switch { + selector, + default, + targets, + }; + } + Op::SelectionMerge => { + inst.expect(3)?; + let merge_block_id = self.next()?; + // TODO: Selection Control Mask + let _selection_control = self.next()?; + let continue_block_id = None; + merge = Some(MergeInstruction { + merge_block_id, + continue_block_id, + }); + } + Op::LoopMerge => { + inst.expect_at_least(4)?; + let merge_block_id = self.next()?; + let continue_block_id = Some(self.next()?); + + // TODO: Loop Control Parameters + for _ in 0..inst.wc - 3 { + self.next()?; + } + + merge = Some(MergeInstruction { + merge_block_id, + continue_block_id, + }); + } _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), } }; + + let mut block = Vec::new(); + for assignment in assignments.iter() { + block.push(crate::Statement::Store { + pointer: assignment.to, + value: assignment.value, + }); + } + Ok(ControlFlowNode { - assignments, + id: block_id, + ty: None, + block, terminator, + merge, }) } @@ -1231,7 +1214,7 @@ impl> Parser { Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module), Op::ConstantComposite => self.parse_composite_constant(inst, &mut module), Op::Variable => self.parse_global_variable(inst, &mut module), - Op::Function => self.parse_function(inst, &mut module), + Op::Function => parse_function(&mut self, inst, &mut module), _ => Err(Error::UnsupportedInstruction(self.state, inst.op)), //TODO }?; } @@ -1871,7 +1854,7 @@ impl> Parser { let id = self.next()?; let type_lookup = self.lookup_type.lookup(type_id)?; let ty = type_lookup.handle; - let inner = match module.types[type_lookup.handle].inner { + let inner = match module.types[ty].inner { crate::TypeInner::Scalar { kind: crate::ScalarKind::Uint, width, @@ -1897,7 +1880,7 @@ impl> Parser { inst.expect(4)?; self.next()? } - Ordering::Equal => !0, + Ordering::Equal => 0, }; crate::ConstantInner::Sint(((u64::from(high) << 32) | u64::from(low)) as i64) } @@ -2019,123 +2002,6 @@ impl> Parser { ); Ok(()) } - - fn parse_function( - &mut self, - inst: Instruction, - module: &mut crate::Module, - ) -> Result<(), Error> { - self.switch(ModuleState::Function, inst.op)?; - inst.expect(5)?; - let result_type = self.next()?; - let fun_id = self.next()?; - let _fun_control = self.next()?; - let fun_type = self.next()?; - let mut fun = { - let ft = self.lookup_function_type.lookup(fun_type)?; - if ft.return_type_id != result_type { - return Err(Error::WrongFunctionResultType(result_type)); - } - crate::Function { - name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name), - parameter_types: Vec::with_capacity(ft.parameter_type_ids.len()), - return_type: if self.lookup_void_type.contains(&result_type) { - None - } else { - Some(self.lookup_type.lookup(result_type)?.handle) - }, - global_usage: Vec::new(), - local_variables: Arena::new(), - expressions: self.make_expression_storage(), - body: Vec::new(), - } - }; - // read parameters - for i in 0..fun.parameter_types.capacity() { - match self.next_inst()? { - Instruction { - op: spirv::Op::FunctionParameter, - wc: 3, - } => { - let type_id = self.next()?; - let _id = self.next()?; - //Note: we redo the lookup in order to work around `self` borrowing - if type_id - != self - .lookup_function_type - .lookup(fun_type)? - .parameter_type_ids[i] - { - return Err(Error::WrongFunctionParameterType(type_id)); - } - let ty = self.lookup_type.lookup(type_id)?.handle; - fun.parameter_types.push(ty); - } - Instruction { op, .. } => return Err(Error::InvalidParameter(op)), - } - } - // read body - let mut local_function_calls = FastHashMap::default(); - let mut control_flow_graph = FastHashMap::default(); - loop { - let fun_inst = self.next_inst()?; - log::debug!("\t\t{:?}", fun_inst.op); - match fun_inst.op { - spirv::Op::Label => { - fun_inst.expect(2)?; - let label_id = self.next()?; - let node = self.next_block( - &mut fun.expressions, - &mut fun.local_variables, - &module.types, - &module.constants, - &module.global_variables, - &mut local_function_calls, - )?; - // temp until the CFG is fully processed - for assign in node.assignments.iter() { - fun.body.push(crate::Statement::Store { - pointer: assign.to, - value: assign.value, - }); - } - match node.terminator { - Terminator::Return { value } => { - fun.body.push(crate::Statement::Return { value }); - } - Terminator::Branch { - label_id, - condition, - } => { - let _ = (label_id, condition); //TODO - } - } - control_flow_graph.insert(label_id, node); - } - spirv::Op::FunctionEnd => { - fun_inst.expect(1)?; - break; - } - _ => return Err(Error::UnsupportedInstruction(self.state, fun_inst.op)), - } - } - // done - fun.global_usage = - crate::GlobalUse::scan(&fun.expressions, &fun.body, &module.global_variables); - let handle = module.functions.append(fun); - for (expr_handle, dst_id) in local_function_calls { - self.deferred_function_calls.push(DeferredFunctionCall { - source_handle: handle, - expr_handle, - dst_id, - }); - } - - self.lookup_function.insert(fun_id, handle); - self.lookup_expression.clear(); - self.lookup_sampled_image.clear(); - Ok(()) - } } pub fn parse_u8_slice(data: &[u8]) -> Result {