diff --git a/src/back/msl.rs b/src/back/msl.rs index 75b4a38417..6a9bff10a5 100644 --- a/src/back/msl.rs +++ b/src/back/msl.rs @@ -1,4 +1,10 @@ -use std::fmt::{Display, Error as FmtError, Formatter, Write}; +use std::{ + fmt::{ + Display, Error as FmtError, Formatter, Write, + }, +}; + +use crate::FastHashSet; pub struct Options { @@ -6,7 +12,10 @@ pub struct Options { #[derive(Debug)] pub enum Error { - Format(FmtError) + Format(FmtError), + UnsupportedExecutionModel(spirv::ExecutionModel), + MixedExecutionModels(crate::Token), + BadName(String), } impl From for Error { @@ -17,6 +26,7 @@ impl From for Error { trait Indexed { const CLASS: &'static str; + const PREFIX: bool = false; fn id(&self) -> usize; } @@ -38,37 +48,75 @@ impl Indexed for MemberIndex { const CLASS: &'static str = "field"; fn id(&self) -> usize { self.0 } } - struct ParameterIndex(usize); impl Indexed for ParameterIndex { const CLASS: &'static str = "param"; fn id(&self) -> usize { self.0 } } - - -enum Name<'a, I> { - Custom(&'a str), - Index(I), +struct InputStructIndex(crate::Token); +impl Indexed for InputStructIndex { + const CLASS: &'static str = "Input"; + const PREFIX: bool = true; + fn id(&self) -> usize { self.0.index() } +} +struct OutputStructIndex(crate::Token); +impl Indexed for OutputStructIndex { + const CLASS: &'static str = "Output"; + const PREFIX: bool = true; + fn id(&self) -> usize { self.0.index() } } -impl Display for Name<'_, I> { +enum NameSource<'a> { + Custom { name: &'a str, prefix: bool }, + Index(usize), +} + +struct Name<'a> { + class: &'static str, + source: NameSource<'a>, +} + +const RESERVED_NAMES: &[&str] = &[ + "main", +]; + +impl Display for Name<'_> { fn fmt(&self, formatter: &mut Formatter<'_>) -> Result<(), FmtError> { - match *self { - Name::Custom(name) => formatter.write_str(name), - Name::Index(ref index) => write!(formatter, "{}{}", I::CLASS, index.id()), + match self.source { + NameSource::Custom { name, prefix: false } if RESERVED_NAMES.contains(&name) => { + write!(formatter, "{}_", name) + } + NameSource::Custom { name, prefix: false } => formatter.write_str(name), + NameSource::Custom { name, prefix: true } => { + let (head, tail) = name.split_at(1); + write!(formatter, "{}{}{}", self.class, head.to_uppercase(), tail) + } + NameSource::Index(index) => write!(formatter, "{}{}", self.class, index), + } + } +} + +impl From for Name<'_> { + fn from(index: I) -> Self { + Name { + class: I::CLASS, + source: NameSource::Index(index.id()), } } } trait AsName { - fn or_index(&self, index: I) -> Name; + fn or_index(&self, index: I) -> Name; } impl AsName for Option { - fn or_index(&self, index: I) -> Name { - match *self { - Some(ref name) => Name::Custom(name), - None => Name::Index(index), + fn or_index(&self, index: I) -> Name { + Name { + class: I::CLASS, + source: match *self { + Some(ref name) if !name.is_empty() => NameSource::Custom { name, prefix: I::PREFIX }, + _ => NameSource::Index(index.id()), + }, } } } @@ -103,6 +151,26 @@ impl Display for TypedVar<'_, T> { } } +struct TypedGlobalVariable<'a> { + module: &'a crate::Module, + token: crate::Token, +} + +impl Display for TypedGlobalVariable<'_> { + fn fmt(&self, formatter: &mut Formatter<'_>) -> Result<(), FmtError> { + let var = &self.module.global_variables[self.token]; + let name = var.name.or_index(self.token); + let tv = match var.ty { + crate::Type::Pointer { ref base, .. } => { + TypedVar(base, &name, &self.module.struct_declarations) + } + _ => panic!("Unexpected global type {:?}", var.ty), + }; + write!(formatter, "{}", tv) + } +} + + pub struct Writer { out: W, } @@ -123,6 +191,9 @@ fn vector_size_string(size: crate::VectorSize) -> &'static str { } } +const NAME_INPUT: &'static str = "input"; +const NAME_OUTPUT: &'static str = "output"; + impl Writer { pub fn write(&mut self, module: &crate::Module) -> Result<(), Error> { writeln!(self.out, "#include ")?; @@ -140,29 +211,90 @@ impl Writer { writeln!(self.out, "}};")?; } - let mut globals_used = Vec::new(); + let mut uniforms_used = FastHashSet::default(); writeln!(self.out, "")?; for (fun_token, fun) in module.functions.iter() { - let fun_name = fun.name.or_index(fun_token); - let fun_tv = TypedVar(&fun.return_type, &fun_name, &module.struct_declarations); - writeln!(self.out, "{}(", fun_tv)?; - for (index, ty) in fun.parameter_types.iter().enumerate() { - let name = Name::Index(ParameterIndex(index)); - let tv = TypedVar(ty, &name, &module.struct_declarations); - writeln!(self.out, "\t{},", tv)?; - } - for (_, expr) in fun.expressions.iter() { - if let crate::Expression::GlobalVariable(token) = *expr { - if !globals_used.contains(&token) { - globals_used.push(token); - let var = &module.global_variables[token]; - let name = var.name.or_index(token); - let tv = TypedVar(&var.ty, &name, &module.struct_declarations); - writeln!(self.out, "\t{},", tv)?; + let mut exec_model = None; + let mut var_inputs = FastHashSet::default(); + let mut var_outputs = FastHashSet::default(); + for ep in module.entry_points.iter() { + if ep.function == fun_token { + var_inputs.extend(ep.inputs.iter().cloned()); + var_outputs.extend(ep.outputs.iter().cloned()); + if exec_model.is_some() { + if exec_model != Some(ep.exec_model) { + return Err(Error::MixedExecutionModels(fun_token)); + } + } else { + exec_model = Some(ep.exec_model); } } } + let input_name = fun.name.or_index(InputStructIndex(fun_token)); + let output_name = fun.name.or_index(OutputStructIndex(fun_token)); + if let Some(em) = exec_model { + writeln!(self.out, "struct {} {{", input_name)?; + for &token in var_inputs.iter() { + let var = TypedGlobalVariable { module, token }; + writeln!(self.out, "\t{};", var)?; + } + writeln!(self.out, "}};")?; + writeln!(self.out, "struct {} {{", output_name)?; + for &token in var_outputs.iter() { + let var = TypedGlobalVariable { module, token }; + writeln!(self.out, "\t{};", var)?; + } + writeln!(self.out, "}};")?; + let em_str = match em { + spirv::ExecutionModel::Vertex => "vertex", + spirv::ExecutionModel::Fragment => "fragment", + spirv::ExecutionModel::GLCompute => "compute", + _ => return Err(Error::UnsupportedExecutionModel(em)), + }; + write!(self.out, "{} ", em_str)?; + } + + let fun_name = fun.name.or_index(fun_token); + if exec_model.is_some() { + writeln!(self.out, "{} {}(", output_name, fun_name)?; + writeln!(self.out, "\t{} {} [[stage_in]],", input_name, NAME_INPUT)?; + } else { + let fun_tv = TypedVar(&fun.return_type, &fun_name, &module.struct_declarations); + writeln!(self.out, "{}(", fun_tv)?; + for (index, ty) in fun.parameter_types.iter().enumerate() { + let name = Name::from(ParameterIndex(index)); + let tv = TypedVar(ty, &name, &module.struct_declarations); + writeln!(self.out, "\t{},", tv)?; + } + } + for (_, expr) in fun.expressions.iter() { + if let crate::Expression::GlobalVariable(token) = *expr { + let var = &module.global_variables[token]; + if var.class == spirv::StorageClass::Uniform && !uniforms_used.contains(&token) { + uniforms_used.insert(token); + let var = TypedGlobalVariable { module, token }; + writeln!(self.out, "\t{},", var)?; + } + } + } + // add an extra parameter to make Metal happy about the comma + match exec_model { + Some(spirv::ExecutionModel::Vertex) => { + writeln!(self.out, "\tunsigned _dummy [[vertex_id]]")?; + } + Some(spirv::ExecutionModel::Fragment) => { + writeln!(self.out, "\tbool _dummy [[front_facing]]")?; + } + Some(spirv::ExecutionModel::GLCompute) => { + writeln!(self.out, "\tunsigned _dummy [[threads_per_grid]]")?; + } + _ => { + writeln!(self.out, "\tint _dummy")?; + } + } writeln!(self.out, ") {{")?; + writeln!(self.out, "\t{} {};", output_name, NAME_OUTPUT)?; + writeln!(self.out, "\treturn {};", NAME_OUTPUT)?; writeln!(self.out, "}}")?; } diff --git a/src/lib.rs b/src/lib.rs index bfb0e63a1d..8858dbdb27 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,17 +4,16 @@ pub mod back; pub mod front; mod storage; - use crate::storage::{Storage, Token}; use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, hash::BuildHasherDefault, }; + type FastHashMap = HashMap>; - - +type FastHashSet = HashSet>; #[derive(Debug)] pub struct Header {