[spv-in] support for void function calls

This commit is contained in:
Dzmitry Malyshau
2021-02-06 02:20:54 -05:00
committed by Dzmitry Malyshau
parent 11a3ed9837
commit d0c84a5ffa
4 changed files with 125 additions and 87 deletions

View File

@@ -126,6 +126,16 @@ impl<T> Arena<T> {
})
}
/// Returns a iterator over the items stored in this arena,
/// returning both the item's handle and a mutable reference to it.
pub fn iter_mut(&mut self) -> impl DoubleEndedIterator<Item = (Handle<T>, &mut T)> {
self.data.iter_mut().enumerate().map(|(i, v)| {
let position = i + 1;
let index = unsafe { Index::new_unchecked(position as u32) };
(Handle::new(index), v)
})
}
/// Adds a new value to the arena, returning a typed handle.
///
/// The value is not linked to any SPIR-V module.

View File

@@ -57,23 +57,23 @@ impl<I: Iterator<Item = u32>> super::Parser<I> {
) -> Result<(), Error> {
self.switch(ModuleState::Function, inst.op)?;
inst.expect(5)?;
let result_type = self.next()?;
let result_type_id = self.next()?;
let fun_id = self.next()?;
let _fun_control = self.next()?;
let fun_type = self.next()?;
let fun_type_id = self.next()?;
let mut fun = {
let ft = self.lookup_function_type.lookup(fun_type)?;
if ft.return_type_id != result_type {
return Err(Error::WrongFunctionResultType(result_type));
let ft = self.lookup_function_type.lookup(fun_type_id)?;
if ft.return_type_id != result_type_id {
return Err(Error::WrongFunctionResultType(result_type_id));
}
crate::Function {
name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name),
arguments: Vec::with_capacity(ft.parameter_type_ids.len()),
return_type: if self.lookup_void_type.contains(&result_type) {
return_type: if self.lookup_void_type == Some(result_type_id) {
None
} else {
Some(self.lookup_type.lookup(result_type)?.handle)
Some(self.lookup_type.lookup(result_type_id)?.handle)
},
global_usage: Vec::new(),
local_variables: Arena::new(),
@@ -101,7 +101,7 @@ impl<I: Iterator<Item = u32>> super::Parser<I> {
if type_id
!= self
.lookup_function_type
.lookup(fun_type)?
.lookup(fun_type_id)?
.parameter_type_ids[i]
{
return Err(Error::WrongFunctionArgumentType(type_id));
@@ -116,7 +116,6 @@ impl<I: Iterator<Item = u32>> super::Parser<I> {
// Read body
let mut flow_graph = FlowGraph::new();
let base_deferred_call_index = self.deferred_function_calls.len();
// Scan the blocks and add them as nodes
loop {
@@ -156,39 +155,29 @@ impl<I: Iterator<Item = u32>> super::Parser<I> {
// done
fun.fill_global_use(&module.global_variables);
let source = match self.lookup_entry_point.remove(&fun_id) {
let dump_suffix = match self.lookup_entry_point.remove(&fun_id) {
Some(ep) => {
let dump_name = format!("flow.{:?}-{}.dot", ep.stage, ep.name);
module.entry_points.insert(
(ep.stage, ep.name.clone()),
(ep.stage, ep.name),
crate::EntryPoint {
early_depth_test: ep.early_depth_test,
workgroup_size: ep.workgroup_size,
function: fun,
},
);
DeferredSource::EntryPoint(ep.stage, ep.name)
dump_name
}
None => {
let handle = module.functions.append(fun);
self.lookup_function.insert(fun_id, handle);
DeferredSource::Function(handle)
format!("flow.Fun-{}.dot", handle.index())
}
};
for dfc in self.deferred_function_calls[base_deferred_call_index..].iter_mut() {
dfc.source = source.clone();
}
if let Some(ref prefix) = self.options.flow_graph_dump_prefix {
let dump = flow_graph.to_graphviz().unwrap_or_default();
let suffix = match source {
DeferredSource::Undefined => unreachable!(),
DeferredSource::EntryPoint(stage, ref name) => {
format!("flow.{:?}-{}.dot", stage, name)
}
DeferredSource::Function(handle) => format!("flow.Fun-{}.dot", handle.index()),
};
let _ = std::fs::write(prefix.join(suffix), dump);
let _ = std::fs::write(prefix.join(dump_suffix), dump);
}
self.lookup_expression.clear();

View File

@@ -268,18 +268,6 @@ struct LookupSampledImage {
image: Handle<crate::Expression>,
sampler: Handle<crate::Expression>,
}
#[derive(Clone, Debug)]
enum DeferredSource {
Undefined,
EntryPoint(crate::ShaderStage, String),
Function(Handle<crate::Function>),
}
struct DeferredFunctionCall {
source: DeferredSource,
expr_handle: Handle<crate::Expression>,
dst_id: spirv::Word,
arguments: Vec<Handle<crate::Expression>>,
}
#[derive(Clone, Debug)]
pub struct Assignment {
@@ -302,7 +290,7 @@ pub struct Parser<I> {
lookup_member_type_id: FastHashMap<(Handle<crate::Type>, MemberIndex), spirv::Word>,
handle_sampling: FastHashMap<Handle<crate::Type>, SamplingFlags>,
lookup_type: FastHashMap<spirv::Word, LookupType>,
lookup_void_type: FastHashSet<spirv::Word>,
lookup_void_type: Option<spirv::Word>,
lookup_storage_buffer_types: FastHashSet<Handle<crate::Type>>,
// Lookup for samplers and sampled images, storing flags on how they are used.
lookup_constant: FastHashMap<spirv::Word, LookupConstant>,
@@ -312,7 +300,9 @@ pub struct Parser<I> {
lookup_function_type: FastHashMap<spirv::Word, LookupFunctionType>,
lookup_function: FastHashMap<spirv::Word, Handle<crate::Function>>,
lookup_entry_point: FastHashMap<spirv::Word, EntryPoint>,
deferred_function_calls: Vec<DeferredFunctionCall>,
//Note: the key here is fully artificial, has nothing to do with the module
deferred_function_calls: FastHashMap<Handle<crate::Function>, spirv::Word>,
dummy_functions: Arena<crate::Function>,
options: Options,
}
@@ -328,7 +318,7 @@ impl<I: Iterator<Item = u32>> Parser<I> {
handle_sampling: FastHashMap::default(),
lookup_member_type_id: FastHashMap::default(),
lookup_type: FastHashMap::default(),
lookup_void_type: FastHashSet::default(),
lookup_void_type: None,
lookup_storage_buffer_types: FastHashSet::default(),
lookup_constant: FastHashMap::default(),
lookup_variable: FastHashMap::default(),
@@ -337,7 +327,8 @@ impl<I: Iterator<Item = u32>> Parser<I> {
lookup_function_type: FastHashMap::default(),
lookup_function: FastHashMap::default(),
lookup_entry_point: FastHashMap::default(),
deferred_function_calls: Vec::new(),
deferred_function_calls: FastHashMap::default(),
dummy_functions: Arena::new(),
options: options.clone(),
}
}
@@ -510,7 +501,7 @@ impl<I: Iterator<Item = u32>> Parser<I> {
const_arena: &Arena<crate::Constant>,
global_arena: &Arena<crate::GlobalVariable>,
) -> Result<ControlFlowNode, Error> {
let mut assignments = Vec::new();
let mut block = Vec::new();
let mut phis = Vec::new();
let mut merge = None;
let terminator = loop {
@@ -782,8 +773,8 @@ impl<I: Iterator<Item = u32>> Parser<I> {
}
let base_expr = self.lookup_expression.lookup(pointer_id)?;
let value_expr = self.lookup_expression.lookup(value_id)?;
assignments.push(Assignment {
to: base_expr.handle,
block.push(crate::Statement::Store {
pointer: base_expr.handle,
value: value_expr.handle,
});
}
@@ -1315,22 +1306,29 @@ impl<I: Iterator<Item = u32>> Parser<I> {
let arg_id = self.next()?;
arguments.push(self.lookup_expression.lookup(arg_id)?.handle);
}
// will be replaced by the actual expression
let expr = crate::Expression::FunctionArgument(!0);
let expr_handle = expressions.append(expr);
self.deferred_function_calls.push(DeferredFunctionCall {
source: DeferredSource::Undefined,
expr_handle,
dst_id: func_id,
arguments,
});
self.lookup_expression.insert(
result_id,
LookupExpression {
handle: expr_handle,
type_id: result_type_id,
},
);
// We just need an unique handle here, nothing more.
let function = self.dummy_functions.append(crate::Function::default());
self.deferred_function_calls.insert(function, func_id);
if self.lookup_void_type == Some(result_type_id) {
block.push(crate::Statement::Call {
function,
arguments,
});
} else {
let expr_handle = expressions.append(crate::Expression::Call {
function,
arguments,
});
self.lookup_expression.insert(
result_id,
LookupExpression {
handle: expr_handle,
type_id: result_type_id,
},
);
}
}
Op::ExtInst => {
use crate::MathFunction as Mf;
@@ -1551,14 +1549,6 @@ impl<I: Iterator<Item = u32>> Parser<I> {
}
};
let mut block = Vec::new();
for assignment in assignments.iter() {
block.push(crate::Statement::Store {
pointer: assignment.to,
value: assignment.value,
});
}
Ok(ControlFlowNode {
id: block_id,
ty: None,
@@ -1610,6 +1600,67 @@ impl<I: Iterator<Item = u32>> Parser<I> {
}
}
fn patch_function_call_statements(
&self,
statements: &mut [crate::Statement],
) -> Result<(), Error> {
use crate::Statement as S;
for statement in statements.iter_mut() {
match *statement {
S::Block(ref mut block) => {
self.patch_function_call_statements(block)?;
}
S::If {
condition: _,
ref mut accept,
ref mut reject,
} => {
self.patch_function_call_statements(accept)?;
self.patch_function_call_statements(reject)?;
}
S::Switch {
selector: _,
ref mut cases,
ref mut default,
} => {
for case in cases.iter_mut() {
self.patch_function_call_statements(&mut case.body)?;
}
self.patch_function_call_statements(default)?;
}
S::Loop {
ref mut body,
ref mut continuing,
} => {
self.patch_function_call_statements(body)?;
self.patch_function_call_statements(continuing)?;
}
S::Break | S::Continue | S::Return { .. } | S::Kill | S::Store { .. } => {}
S::Call {
ref mut function, ..
} => {
let fun_id = self.deferred_function_calls[function];
*function = *self.lookup_function.lookup(fun_id)?;
}
}
}
Ok(())
}
fn patch_function_calls(&self, fun: &mut crate::Function) -> Result<(), Error> {
for (_, expr) in fun.expressions.iter_mut() {
if let crate::Expression::Call {
ref mut function, ..
} = *expr
{
let fun_id = self.deferred_function_calls[function];
*function = *self.lookup_function.lookup(fun_id)?;
}
}
self.patch_function_call_statements(&mut fun.body)?;
Ok(())
}
pub fn parse(mut self) -> Result<crate::Module, Error> {
let mut module = {
if self.next()? != spirv::MAGIC_NUMBER {
@@ -1680,23 +1731,11 @@ impl<I: Iterator<Item = u32>> Parser<I> {
}
}
for dfc in self.deferred_function_calls.drain(..) {
let dst_handle = *self.lookup_function.lookup(dfc.dst_id)?;
let fun = match dfc.source {
DeferredSource::Undefined => unreachable!(),
DeferredSource::Function(fun_handle) => module.functions.get_mut(fun_handle),
DeferredSource::EntryPoint(stage, name) => {
&mut module
.entry_points
.get_mut(&(stage, name))
.unwrap()
.function
}
};
*fun.expressions.get_mut(dfc.expr_handle) = crate::Expression::Call {
function: dst_handle,
arguments: dfc.arguments,
};
for (_, func) in module.functions.iter_mut() {
self.patch_function_calls(func)?;
}
for (_, ep) in module.entry_points.iter_mut() {
self.patch_function_calls(&mut ep.function)?;
}
if !self.future_decor.is_empty() {
@@ -1912,7 +1951,7 @@ impl<I: Iterator<Item = u32>> Parser<I> {
self.switch(ModuleState::Type, inst.op)?;
inst.expect(2)?;
let id = self.next()?;
self.lookup_void_type.insert(id);
self.lookup_void_type = Some(id);
Ok(())
}

View File

@@ -821,7 +821,7 @@ pub struct FunctionArgument {
}
/// A function defined in the module.
#[derive(Debug)]
#[derive(Debug, Default)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub struct Function {