From 80a9254dbedb32934dc92665efce379c6d57b59a Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 19 Mar 2020 09:40:34 -0400 Subject: [PATCH] Interface processor --- src/back/msl.rs | 115 +++++++++++++++++++++------------ src/front/spirv.rs | 10 +-- src/front/wgsl.rs | 49 ++++++++++---- src/lib.rs | 4 +- src/proc/interface.rs | 145 ++++++++++++++++++++++++++++++++++++++++++ src/proc/mod.rs | 2 + 6 files changed, 267 insertions(+), 58 deletions(-) create mode 100644 src/proc/interface.rs diff --git a/src/back/msl.rs b/src/back/msl.rs index a2a5df7267..ea08a4149c 100644 --- a/src/back/msl.rs +++ b/src/back/msl.rs @@ -25,6 +25,10 @@ use crate::{ FastHashMap, FastHashSet }; +/// Expect all the global variables to have a pointer type, +/// like in SPIR-V. +const GLOBAL_POINTERS: bool = false; + #[derive(Clone, Debug, PartialEq)] pub struct BindTarget { pub buffer: Option, @@ -224,19 +228,24 @@ impl Display for TypedGlobalVariable<'_> { fn fmt(&self, formatter: &mut Formatter<'_>) -> Result<(), FmtError> { let var = &self.module.global_variables[self.handle]; let name = var.name.or_index(self.handle); - let ty = &self.module.types[var.ty]; - match ty.inner { - crate::TypeInner::Pointer { base, class } => { - let ty_handle= match class { - spirv::StorageClass::Input | - spirv::StorageClass::Output | - spirv::StorageClass::UniformConstant => base, - _ => var.ty - }; - let ty_name = self.module.types[ty_handle].name.or_index(ty_handle); - write!(formatter, "{} {}", ty_name, name) + if GLOBAL_POINTERS { + let ty = &self.module.types[var.ty]; + match ty.inner { + crate::TypeInner::Pointer { base, class } => { + let ty_handle = match class { + spirv::StorageClass::Input | + spirv::StorageClass::Output | + spirv::StorageClass::UniformConstant => base, + _ => var.ty + }; + let ty_name = self.module.types[ty_handle].name.or_index(ty_handle); + write!(formatter, "{} {}", ty_name, name) + } + _ => panic!("Unexpected global type {:?} = {:?}", var.ty, ty), } - _ => panic!("Unexpected global type {:?}", var.ty), + } else { + let ty_name = self.module.types[var.ty].name.or_index(var.ty); + write!(formatter, "{} {}", ty_name, name) } } } @@ -247,6 +256,7 @@ impl Display for ResolvedBinding { ResolvedBinding::BuiltIn(built_in) => { let name = match built_in { spirv::BuiltIn::ClipDistance => "clip_distance", + spirv::BuiltIn::GlobalInvocationId => "thread_position_in_grid", spirv::BuiltIn::PointSize => "point_size", spirv::BuiltIn::Position => "position", _ => panic!("Built in {:?} is not implemented", built_in), @@ -416,10 +426,16 @@ impl Writer { match var.class { spirv::StorageClass::Output => { self.out.write_str(NAME_OUTPUT)?; - if let crate::TypeInner::Pointer { base, .. } = *inner { - let base_inner = &module.types[base].inner; - if let crate::TypeInner::Struct { .. } = *base_inner { - return Ok(MaybeOwned::Borrowed(base_inner)); + if GLOBAL_POINTERS { + if let crate::TypeInner::Pointer { base, .. } = *inner { + let base_inner = &module.types[base].inner; + if let crate::TypeInner::Struct { .. } = *base_inner { + return Ok(MaybeOwned::Borrowed(base_inner)); + } + } + } else { + if let crate::TypeInner::Struct { .. } = *inner { + return Ok(MaybeOwned::Borrowed(inner)); } } self.out.write_str(".")?; @@ -506,6 +522,23 @@ impl Writer { } crate::Expression::Call { ref name, ref arguments } => { match name.as_str() { + "cos" | + "fclamp" | + "normalize" | + "sin" => { + write!(self.out, "{}(", name)?; + let result = self.put_expression(arguments[0], function, module)?; + write!(self.out, ")")?; + Ok(result) + } + "atan2" => { + write!(self.out, "{}(", name)?; + let result = self.put_expression(arguments[0], function, module)?; + write!(self.out, ", ")?; + self.put_expression(arguments[1], function, module)?; + write!(self.out, ")")?; + Ok(result) + } "distance" => { write!(self.out, "distance(")?; let result = match *self.put_expression(arguments[0], function, module)?.borrow() { @@ -526,18 +559,7 @@ impl Writer { write!(self.out, ")")?; Ok(MaybeOwned::Owned(result)) } - "normalize" => { - write!(self.out, "normalize(")?; - let result = self.put_expression(arguments[0], function, module)?; - write!(self.out, ")")?; - Ok(result) - } - "fclamp" => { - write!(self.out, "fclamp(")?; - let result = self.put_expression(arguments[0], function, module)?; - write!(self.out, ")")?; - Ok(result) - } + _ => panic!("Unsupported call to '{}'", name), } } @@ -757,20 +779,24 @@ impl Writer { for &handle in var_outputs.iter() { let var = &module.global_variables[handle]; // if it's a struct, lift all the built-in contents up to the root - if let crate::TypeInner::Pointer { base, .. } = module.types[var.ty].inner { - if let crate::TypeInner::Struct { ref members } = module.types[base].inner { - for (index, member) in members.iter().enumerate() { - let name = member.name.or_index(MemberIndex(index)); - let ty_name = module.types[member.ty].name.or_index(member.ty); - let binding = member.binding - .as_ref() - .ok_or(Error::MissingBinding(handle))?; - let resolved = options.resolve_binding(binding, out_mode)?; - writeln!(self.out, "\t{} {} [[{}]];", ty_name, name, resolved)?; - } - continue + let mut ty_handle = var.ty; + if GLOBAL_POINTERS { + if let crate::TypeInner::Pointer { base, .. } = module.types[var.ty].inner { + ty_handle = base; } } + if let crate::TypeInner::Struct { ref members } = module.types[ty_handle].inner { + for (index, member) in members.iter().enumerate() { + let name = member.name.or_index(MemberIndex(index)); + let ty_name = module.types[member.ty].name.or_index(member.ty); + let binding = member.binding + .as_ref() + .ok_or(Error::MissingBinding(handle))?; + let resolved = options.resolve_binding(binding, out_mode)?; + writeln!(self.out, "\t{} {} [[{}]];", ty_name, name, resolved)?; + } + continue + } let tyvar = TypedGlobalVariable { module, handle }; write!(self.out, "\t{}", tyvar)?; if let Some(ref binding) = var.binding { @@ -831,6 +857,15 @@ impl Writer { if exec_model.is_some() { writeln!(self.out, "\t{} {};", output_name, NAME_OUTPUT)?; } + for (local_handle, local) in fun.local_variables.iter() { + let ty_name = module.types[local.ty].name.or_index(local.ty); + write!(self.out, "\t{} {}", ty_name, local.name.or_index(local_handle))?; + if let Some(value) = local.init { + write!(self.out, " = ")?; + self.put_expression(value, fun, module)?; + } + writeln!(self.out, ";")?; + } for statement in fun.body.iter() { self.put_statement(Level(1), statement, fun, exec_model.is_some(), module)?; } diff --git a/src/front/spirv.rs b/src/front/spirv.rs index ab2fd54973..7ee620bf81 100644 --- a/src/front/spirv.rs +++ b/src/front/spirv.rs @@ -736,16 +736,16 @@ impl> Parser { exec_model: raw.exec_model, name: raw.name, function: *self.lookup_function.lookup(raw.function_id)?, - inputs: Vec::new(), - outputs: Vec::new(), + inputs: FastHashSet::default(), + outputs: FastHashSet::default(), }; for var_id in raw.variable_ids { let handle = self.lookup_variable.lookup(var_id)?.handle; match module.global_variables[handle].class { - spirv::StorageClass::Input => ep.inputs.push(handle), - spirv::StorageClass::Output => ep.outputs.push(handle), + spirv::StorageClass::Input => ep.inputs.insert(handle), + spirv::StorageClass::Output => ep.outputs.insert(handle), other => return Err(Error::InvalidVariableClass(other)), - } + }; } module.entry_points.push(ep); } diff --git a/src/front/wgsl.rs b/src/front/wgsl.rs index 303ff1bdf2..9c1bf02bf8 100644 --- a/src/front/wgsl.rs +++ b/src/front/wgsl.rs @@ -1,6 +1,6 @@ use crate::{ arena::{Arena, Handle}, - proc::{Typifier, ResolveError}, + proc::{Interface, Typifier, ResolveError}, FastHashMap, }; @@ -1039,21 +1039,37 @@ impl Parser { self.scopes.push(Scope::Statement); let statement = match word { "var" => { + enum Init { + Empty, + Uniform(Handle), + Variable(Handle), + } let (name, ty) = self.parse_variable_ident_decl(lexer, context.types)?; let init = if lexer.skip(Token::Operation('=')) { - Some(self.parse_general_expression(lexer, context.as_expression())?) + let value = self.parse_general_expression(lexer, context.as_expression())?; + if let crate::Expression::Constant(_) = context.expressions[value] { + Init::Uniform(value) + } else { + Init::Variable(value) + } } else { - None + Init::Empty }; lexer.expect(Token::Separator(';'))?; let var_id = context.variables.append(crate::LocalVariable { name: Some(name.to_owned()), ty, - init, + init: match init { + Init::Uniform(value) => Some(value), + _ => None, + } }); let expr_id = context.expressions.append(crate::Expression::LocalVariable(var_id)); context.lookup_ident.insert(name, expr_id); - crate::Statement::Empty + match init { + Init::Variable(value) => crate::Statement::Store { pointer: expr_id, value }, + _ => crate::Statement::Empty, + } } "return" => { let value = if lexer.peek() != Token::Separator(';') { @@ -1292,7 +1308,17 @@ impl Parser { let (name, class, ty) = self.parse_variable_decl(lexer, &mut module.types, &mut module.constants)?; let var_handle = module.global_variables.append(crate::GlobalVariable { name: Some(name.to_owned()), - class: class.unwrap_or(spirv::StorageClass::Private), + class: match class { + Some(c) => c, + None => match binding { + Some(crate::Binding::BuiltIn(builtin)) => match builtin { + spirv::BuiltIn::GlobalInvocationId => spirv::StorageClass::Input, + spirv::BuiltIn::Position => spirv::StorageClass::Output, + _ => unimplemented!(), + }, + _ => spirv::StorageClass::Private, + }, + }, binding: binding.take(), ty, }); @@ -1314,17 +1340,18 @@ impl Parser { lexer.expect(Token::Operation('='))?; let fun_ident = lexer.next_ident()?; lexer.expect(Token::Separator(';'))?; - let function = module.functions + let (fun_handle, function) = module.functions .iter() .find(|(_, fun)| fun.name.as_ref().map(|s| s.as_str()) == Some(fun_ident)) - .map(|(handle, _)| handle) .ok_or(Error::UnknownFunction(fun_ident))?; + + let io = Interface::new(function, &module.global_variables); module.entry_points.push(crate::EntryPoint { exec_model, name: export_name.unwrap_or(fun_ident).to_owned(), - inputs: Vec::new(), //TODO - outputs: Vec::new(), //TODO - function, + inputs: io.inputs, + outputs: io.outputs, + function: fun_handle, }); } Token::End => return Ok(false), diff --git a/src/lib.rs b/src/lib.rs index c1cf027b54..f7b091958d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -280,8 +280,8 @@ pub struct Function { pub struct EntryPoint { pub exec_model: spirv::ExecutionModel, pub name: String, - pub inputs: Vec>, - pub outputs: Vec>, + pub inputs: FastHashSet>, + pub outputs: FastHashSet>, pub function: Handle, } diff --git a/src/proc/interface.rs b/src/proc/interface.rs new file mode 100644 index 0000000000..a68baee8c1 --- /dev/null +++ b/src/proc/interface.rs @@ -0,0 +1,145 @@ +use crate::{ + arena::{Arena, Handle}, + FastHashSet, +}; + +pub struct Interface<'a> { + expressions: &'a Arena, + globals: &'a Arena, + pub inputs: FastHashSet>, + pub outputs: FastHashSet>, +} + +impl<'a> Interface<'a> { + fn add_inputs(&mut self, handle: Handle) { + use crate::Expression as E; + match self.expressions[handle] { + E::Access { base, index } => { + self.add_inputs(base); + self.add_inputs(index); + } + E::AccessIndex { base, .. } => { + self.add_inputs(base); + } + E::Constant(_) => {} + E::Compose { ref components, .. } => { + for &comp in components { + self.add_inputs(comp); + } + } + E::FunctionParameter(_) => {}, + E::GlobalVariable(handle) => { + if self.globals[handle].class == spirv::StorageClass::Input { + self.inputs.insert(handle); + } + } + E::LocalVariable(_) => {} + E::Load { pointer } => { + self.add_inputs(pointer); + } + E::ImageSample { image, sampler, coordinate } => { + self.add_inputs(image); + self.add_inputs(sampler); + self.add_inputs(coordinate); + } + E::Unary { expr, .. } => { + self.add_inputs(expr); + } + E::Binary { left, right, .. } => { + self.add_inputs(left); + self.add_inputs(right); + } + E::Intrinsic { argument, .. } => { + self.add_inputs(argument); + } + E::DotProduct(left, right) => { + self.add_inputs(left); + self.add_inputs(right); + } + E::CrossProduct(left, right) => { + self.add_inputs(left); + self.add_inputs(right); + } + E::Derivative { expr, .. } => { + self.add_inputs(expr); + } + E::Call { ref arguments, .. } => { + for &argument in arguments { + self.add_inputs(argument); + } + } + } + } + + fn collect(&mut self, block: &[crate::Statement]) { + for statement in block { + use crate::Statement as S; + match *statement { + S::Empty | + S::Break | + S::Continue | + S::Kill => (), + S::Block(ref b) => { + self.collect(b); + } + S::If { condition, ref accept, ref reject } => { + self.add_inputs(condition); + self.collect(accept); + self.collect(reject); + } + S::Switch { selector, ref cases, ref default } => { + self.add_inputs(selector); + for &(ref case, _) in cases.values() { + self.collect(case); + } + self.collect(default); + } + S::Loop { ref body, ref continuing } => { + self.collect(body); + self.collect(continuing); + } + S::Return { value } => { + if let Some(expr) = value { + self.add_inputs(expr); + } + } + S::Store { pointer, value } => { + let mut left = pointer; + loop { + match self.expressions[left] { + crate::Expression::Access { base, index } => { + self.add_inputs(index); + left = base; + } + crate::Expression::AccessIndex { base, .. } => { + left = base; + } + crate::Expression::GlobalVariable(handle) => { + if self.globals[handle].class == spirv::StorageClass::Output { + self.outputs.insert(handle); + } + break; + } + _ => break, + } + } + self.add_inputs(value); + } + } + } + } + + pub fn new( + fun: &'a crate::Function, + globals: &'a Arena, + ) -> Self { + let mut io = Interface { + expressions: &fun.expressions, + globals, + inputs: FastHashSet::default(), + outputs: FastHashSet::default(), + }; + io.collect(&fun.body); + io + } +} diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 38ac9304be..1237df94d1 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -1,3 +1,5 @@ +mod interface; mod typifier; +pub use interface::Interface; pub use typifier::{ResolveError, Typifier};