[msl] write down expressions

This commit is contained in:
Dzmitry Malyshau
2020-02-26 15:20:01 -05:00
parent 61e0a980c2
commit 20be6876d7
3 changed files with 138 additions and 9 deletions

View File

@@ -319,10 +319,122 @@ fn vector_size_string(size: crate::VectorSize) -> &'static str {
}
}
const NAME_INPUT: &'static str = "input";
const NAME_OUTPUT: &'static str = "output";
const NAME_INPUT: &str = "input";
const NAME_OUTPUT: &str = "output";
const COMPONENTS: &[char] = &['x', 'y', 'z', 'w'];
impl<W: Write> Writer<W> {
fn put_expression<'a>(
&mut self,
expr_token: crate::Token<crate::Expression>,
expressions: &crate::Storage<crate::Expression>,
module: &crate::Module,
) -> Result<crate::Type, Error> {
let expression = &expressions[expr_token];
log::trace!("expression {:?}", expression);
match *expression {
crate::Expression::AccessIndex { base, index } => {
match self.put_expression(base, expressions, module)? {
crate::Type::Struct(token) => {
let member = &module.complex_types.structs[token].members[index as usize];
let name = member.name.or_index(MemberIndex(index as usize));
write!(self.out, ".{}", name)?;
Ok(member.ty.clone())
}
crate::Type::Matrix { rows, kind, width, .. } => {
write!(self.out, ".{}", COMPONENTS[index as usize])?;
Ok(crate::Type::Vector { size: rows, kind, width })
}
crate::Type::Vector { kind, width, .. } => {
write!(self.out, ".{}", COMPONENTS[index as usize])?;
Ok(crate::Type::Scalar { kind, width })
}
_ => {
write!(self.out, "[{}]", index)?;
Ok(crate::Type::Void) //TODO
}
}
}
crate::Expression::Constant(ref constant) => {
let kind = match *constant {
crate::Constant::Sint(value) => {
write!(self.out, "{}", value)?;
crate::ScalarKind::Sint
}
crate::Constant::Uint(value) => {
write!(self.out, "{}", value)?;
crate::ScalarKind::Uint
}
crate::Constant::Float(value) => {
write!(self.out, "{}", value)?;
crate::ScalarKind::Float
}
};
let width = 32; //TODO: not sure how to get that...
Ok(crate::Type::Scalar { kind, width })
}
crate::Expression::Compose { ref ty, ref components } => {
match *ty {
crate::Type::Vector { size, kind, .. } => {
write!(self.out, "{}{}(", scalar_kind_string(kind), vector_size_string(size))?;
for (i, &token) in components.iter().enumerate() {
if i != 0 {
write!(self.out, ",")?;
}
self.put_expression(token, expressions, module)?;
}
write!(self.out, ")")?;
}
_ => panic!("Unsupported compose {:?}", ty),
}
Ok(ty.clone())
}
crate::Expression::GlobalVariable(token) => {
let var = &module.global_variables[token];
match var.class {
spirv::StorageClass::Output => {
self.out.write_str(NAME_OUTPUT)?;
if let crate::Type::Pointer(pt) = var.ty {
if let crate::Type::Struct(_) = module.complex_types.pointers[pt].base {
return Ok(module.complex_types.pointers[pt].base.clone());
}
}
self.out.write_str(".")?;
}
spirv::StorageClass::Input => {
write!(self.out, "{}.", NAME_INPUT)?;
}
_ => ()
}
let name = var.name.or_index(token);
write!(self.out, "{}", name)?;
Ok(var.ty.clone())
}
crate::Expression::Load { pointer } => {
//write!(self.out, "*")?;
match self.put_expression(pointer, expressions, module)? {
crate::Type::Pointer(token) => {
Ok(module.complex_types.pointers[token].base.clone())
}
other => panic!("Unexpected load pointer {:?}", other),
}
}
crate::Expression::Mul(left, right) => {
write!(self.out, "(")?;
let ty_left = self.put_expression(left, expressions, module)?;
write!(self.out, " * ")?;
let ty_right = self.put_expression(right, expressions, module)?;
write!(self.out, ")")?;
Ok(match (ty_left, ty_right) {
(crate::Type::Vector { size, kind, width }, crate::Type::Scalar { .. }) =>
crate::Type::Vector { size, kind, width },
other => panic!("Unable to infer Mul for {:?}", other),
})
}
ref other => panic!("Unsupported {:?}", other),
}
}
pub fn write(&mut self, module: &crate::Module, options: Options) -> Result<(), Error> {
writeln!(self.out, "#include <metal_stdlib>")?;
writeln!(self.out, "#include <simd/simd.h>")?;
@@ -375,6 +487,8 @@ impl<W: Write> Writer<W> {
let mut uniforms_used = FastHashSet::default();
writeln!(self.out, "")?;
for (fun_token, fun) in module.functions.iter() {
let fun_name = fun.name.or_index(fun_token);
// find the entry point(s) and inputs/outputs
let mut exec_model = None;
let mut var_inputs = FastHashSet::default();
let mut var_outputs = FastHashSet::default();
@@ -393,6 +507,7 @@ impl<W: Write> Writer<W> {
}
let input_name = fun.name.or_index(InputStructIndex(fun_token));
let output_name = fun.name.or_index(OutputStructIndex(fun_token));
// make dedicated input/output structs
if let Some(em) = exec_model {
writeln!(self.out, "struct {} {{", input_name)?;
let (em_str, in_mode, out_mode) = match em {
@@ -440,12 +555,7 @@ impl<W: Write> Writer<W> {
writeln!(self.out, ";")?;
}
writeln!(self.out, "}};")?;
write!(self.out, "{} ", em_str)?;
}
let fun_name = fun.name.or_index(fun_token);
if exec_model.is_some() {
writeln!(self.out, "{} {}(", output_name, fun_name)?;
writeln!(self.out, "{} {} {}(", em_str, output_name, fun_name)?;
writeln!(self.out, "\t{} {} [[stage_in]],", input_name, NAME_INPUT)?;
} else {
let fun_tv = TypedVar(&fun.return_type, &fun_name, &module.complex_types);
@@ -482,8 +592,25 @@ impl<W: Write> Writer<W> {
}
}
writeln!(self.out, ") {{")?;
// write down function body
writeln!(self.out, "\t{} {};", output_name, NAME_OUTPUT)?;
writeln!(self.out, "\treturn {};", NAME_OUTPUT)?;
for statement in fun.body.iter() {
log::trace!("statement {:?}", statement);
match *statement {
crate::Statement::Store { pointer, value } => {
//write!(self.out, "\t*")?;
write!(self.out, "\t")?;
self.put_expression(pointer, &fun.expressions, module)?;
write!(self.out, " = ")?;
self.put_expression(value, &fun.expressions, module)?;
writeln!(self.out, ";")?;
}
crate::Statement::Return { value: None } => {
writeln!(self.out, "\treturn {};", NAME_OUTPUT)?;
}
_ => panic!("Unsupported {:?}", statement),
}
}
writeln!(self.out, "}}")?;
}

View File

@@ -457,6 +457,7 @@ impl<I: Iterator<Item = u32>> Parser<I> {
components.push(lexp.token);
}
let expr = crate::Expression::Compose {
ty: self.lookup_type.lookup(result_type_id)?.value.clone(),
components,
};
self.lookup_expression.insert(id, LookupExpression {

View File

@@ -111,6 +111,7 @@ pub enum Expression {
},
Constant(Constant),
Compose {
ty: Type,
components: Vec<Token<Expression>>,
},
FunctionParameter(u32),