[msl] improve function I/O declarations

This commit is contained in:
Dzmitry Malyshau
2020-02-25 14:26:23 -05:00
parent 96a12a0c9a
commit 34c59d0371
2 changed files with 169 additions and 38 deletions

View File

@@ -1,4 +1,10 @@
use std::fmt::{Display, Error as FmtError, Formatter, Write};
use std::{
fmt::{
Display, Error as FmtError, Formatter, Write,
},
};
use crate::FastHashSet;
pub struct Options {
@@ -6,7 +12,10 @@ pub struct Options {
#[derive(Debug)]
pub enum Error {
Format(FmtError)
Format(FmtError),
UnsupportedExecutionModel(spirv::ExecutionModel),
MixedExecutionModels(crate::Token<crate::Function>),
BadName(String),
}
impl From<FmtError> for Error {
@@ -17,6 +26,7 @@ impl From<FmtError> for Error {
trait Indexed {
const CLASS: &'static str;
const PREFIX: bool = false;
fn id(&self) -> usize;
}
@@ -38,37 +48,75 @@ impl Indexed for MemberIndex {
const CLASS: &'static str = "field";
fn id(&self) -> usize { self.0 }
}
struct ParameterIndex(usize);
impl Indexed for ParameterIndex {
const CLASS: &'static str = "param";
fn id(&self) -> usize { self.0 }
}
enum Name<'a, I> {
Custom(&'a str),
Index(I),
struct InputStructIndex(crate::Token<crate::Function>);
impl Indexed for InputStructIndex {
const CLASS: &'static str = "Input";
const PREFIX: bool = true;
fn id(&self) -> usize { self.0.index() }
}
struct OutputStructIndex(crate::Token<crate::Function>);
impl Indexed for OutputStructIndex {
const CLASS: &'static str = "Output";
const PREFIX: bool = true;
fn id(&self) -> usize { self.0.index() }
}
impl<I: Indexed> Display for Name<'_, I> {
enum NameSource<'a> {
Custom { name: &'a str, prefix: bool },
Index(usize),
}
struct Name<'a> {
class: &'static str,
source: NameSource<'a>,
}
const RESERVED_NAMES: &[&str] = &[
"main",
];
impl Display for Name<'_> {
fn fmt(&self, formatter: &mut Formatter<'_>) -> Result<(), FmtError> {
match *self {
Name::Custom(name) => formatter.write_str(name),
Name::Index(ref index) => write!(formatter, "{}{}", I::CLASS, index.id()),
match self.source {
NameSource::Custom { name, prefix: false } if RESERVED_NAMES.contains(&name) => {
write!(formatter, "{}_", name)
}
NameSource::Custom { name, prefix: false } => formatter.write_str(name),
NameSource::Custom { name, prefix: true } => {
let (head, tail) = name.split_at(1);
write!(formatter, "{}{}{}", self.class, head.to_uppercase(), tail)
}
NameSource::Index(index) => write!(formatter, "{}{}", self.class, index),
}
}
}
impl<I: Indexed> From<I> for Name<'_> {
fn from(index: I) -> Self {
Name {
class: I::CLASS,
source: NameSource::Index(index.id()),
}
}
}
trait AsName {
fn or_index<I>(&self, index: I) -> Name<I>;
fn or_index<I: Indexed>(&self, index: I) -> Name;
}
impl AsName for Option<String> {
fn or_index<I>(&self, index: I) -> Name<I> {
match *self {
Some(ref name) => Name::Custom(name),
None => Name::Index(index),
fn or_index<I: Indexed>(&self, index: I) -> Name {
Name {
class: I::CLASS,
source: match *self {
Some(ref name) if !name.is_empty() => NameSource::Custom { name, prefix: I::PREFIX },
_ => NameSource::Index(index.id()),
},
}
}
}
@@ -103,6 +151,26 @@ impl<T: Display> Display for TypedVar<'_, T> {
}
}
struct TypedGlobalVariable<'a> {
module: &'a crate::Module,
token: crate::Token<crate::GlobalVariable>,
}
impl Display for TypedGlobalVariable<'_> {
fn fmt(&self, formatter: &mut Formatter<'_>) -> Result<(), FmtError> {
let var = &self.module.global_variables[self.token];
let name = var.name.or_index(self.token);
let tv = match var.ty {
crate::Type::Pointer { ref base, .. } => {
TypedVar(base, &name, &self.module.struct_declarations)
}
_ => panic!("Unexpected global type {:?}", var.ty),
};
write!(formatter, "{}", tv)
}
}
pub struct Writer<W> {
out: W,
}
@@ -123,6 +191,9 @@ fn vector_size_string(size: crate::VectorSize) -> &'static str {
}
}
const NAME_INPUT: &'static str = "input";
const NAME_OUTPUT: &'static str = "output";
impl<W: Write> Writer<W> {
pub fn write(&mut self, module: &crate::Module) -> Result<(), Error> {
writeln!(self.out, "#include <metal_stdlib>")?;
@@ -140,29 +211,90 @@ impl<W: Write> Writer<W> {
writeln!(self.out, "}};")?;
}
let mut globals_used = Vec::new();
let mut uniforms_used = FastHashSet::default();
writeln!(self.out, "")?;
for (fun_token, fun) in module.functions.iter() {
let fun_name = fun.name.or_index(fun_token);
let fun_tv = TypedVar(&fun.return_type, &fun_name, &module.struct_declarations);
writeln!(self.out, "{}(", fun_tv)?;
for (index, ty) in fun.parameter_types.iter().enumerate() {
let name = Name::Index(ParameterIndex(index));
let tv = TypedVar(ty, &name, &module.struct_declarations);
writeln!(self.out, "\t{},", tv)?;
}
for (_, expr) in fun.expressions.iter() {
if let crate::Expression::GlobalVariable(token) = *expr {
if !globals_used.contains(&token) {
globals_used.push(token);
let var = &module.global_variables[token];
let name = var.name.or_index(token);
let tv = TypedVar(&var.ty, &name, &module.struct_declarations);
writeln!(self.out, "\t{},", tv)?;
let mut exec_model = None;
let mut var_inputs = FastHashSet::default();
let mut var_outputs = FastHashSet::default();
for ep in module.entry_points.iter() {
if ep.function == fun_token {
var_inputs.extend(ep.inputs.iter().cloned());
var_outputs.extend(ep.outputs.iter().cloned());
if exec_model.is_some() {
if exec_model != Some(ep.exec_model) {
return Err(Error::MixedExecutionModels(fun_token));
}
} else {
exec_model = Some(ep.exec_model);
}
}
}
let input_name = fun.name.or_index(InputStructIndex(fun_token));
let output_name = fun.name.or_index(OutputStructIndex(fun_token));
if let Some(em) = exec_model {
writeln!(self.out, "struct {} {{", input_name)?;
for &token in var_inputs.iter() {
let var = TypedGlobalVariable { module, token };
writeln!(self.out, "\t{};", var)?;
}
writeln!(self.out, "}};")?;
writeln!(self.out, "struct {} {{", output_name)?;
for &token in var_outputs.iter() {
let var = TypedGlobalVariable { module, token };
writeln!(self.out, "\t{};", var)?;
}
writeln!(self.out, "}};")?;
let em_str = match em {
spirv::ExecutionModel::Vertex => "vertex",
spirv::ExecutionModel::Fragment => "fragment",
spirv::ExecutionModel::GLCompute => "compute",
_ => return Err(Error::UnsupportedExecutionModel(em)),
};
write!(self.out, "{} ", em_str)?;
}
let fun_name = fun.name.or_index(fun_token);
if exec_model.is_some() {
writeln!(self.out, "{} {}(", output_name, fun_name)?;
writeln!(self.out, "\t{} {} [[stage_in]],", input_name, NAME_INPUT)?;
} else {
let fun_tv = TypedVar(&fun.return_type, &fun_name, &module.struct_declarations);
writeln!(self.out, "{}(", fun_tv)?;
for (index, ty) in fun.parameter_types.iter().enumerate() {
let name = Name::from(ParameterIndex(index));
let tv = TypedVar(ty, &name, &module.struct_declarations);
writeln!(self.out, "\t{},", tv)?;
}
}
for (_, expr) in fun.expressions.iter() {
if let crate::Expression::GlobalVariable(token) = *expr {
let var = &module.global_variables[token];
if var.class == spirv::StorageClass::Uniform && !uniforms_used.contains(&token) {
uniforms_used.insert(token);
let var = TypedGlobalVariable { module, token };
writeln!(self.out, "\t{},", var)?;
}
}
}
// add an extra parameter to make Metal happy about the comma
match exec_model {
Some(spirv::ExecutionModel::Vertex) => {
writeln!(self.out, "\tunsigned _dummy [[vertex_id]]")?;
}
Some(spirv::ExecutionModel::Fragment) => {
writeln!(self.out, "\tbool _dummy [[front_facing]]")?;
}
Some(spirv::ExecutionModel::GLCompute) => {
writeln!(self.out, "\tunsigned _dummy [[threads_per_grid]]")?;
}
_ => {
writeln!(self.out, "\tint _dummy")?;
}
}
writeln!(self.out, ") {{")?;
writeln!(self.out, "\t{} {};", output_name, NAME_OUTPUT)?;
writeln!(self.out, "\treturn {};", NAME_OUTPUT)?;
writeln!(self.out, "}}")?;
}

View File

@@ -4,17 +4,16 @@ pub mod back;
pub mod front;
mod storage;
use crate::storage::{Storage, Token};
use std::{
collections::HashMap,
collections::{HashMap, HashSet},
hash::BuildHasherDefault,
};
type FastHashMap<K, T> = HashMap<K, T, BuildHasherDefault<fxhash::FxHasher>>;
type FastHashSet<K> = HashSet<K, BuildHasherDefault<fxhash::FxHasher>>;
#[derive(Debug)]
pub struct Header {