[msl] re-use expression results based on the analysis

This commit is contained in:
Dzmitry Malyshau
2021-02-12 01:39:08 -05:00
parent eae40383d0
commit c1fc5d6424
11 changed files with 230 additions and 52 deletions

View File

@@ -147,7 +147,8 @@ fn main() {
};
// validate the IR
naga::proc::Validator::new()
#[cfg_attr(not(feature = "msl-out"), allow(unused_variables))]
let analysis = naga::proc::Validator::new()
.validate(&module)
.unwrap_pretty();
@@ -190,7 +191,7 @@ fn main() {
spirv_cross_compatibility: false,
binding_map,
};
let (msl, _) = msl::write_string(&module, &options).unwrap();
let (msl, _) = msl::write_string(&module, &analysis, &options).unwrap();
fs::write(&args[2], msl).unwrap();
}
#[cfg(feature = "spv-out")]

View File

@@ -1908,7 +1908,7 @@ struct TextureMappingVisitor<'a> {
}
impl<'a> Visitor for TextureMappingVisitor<'a> {
fn visit_expr(&mut self, expr: &crate::Expression) {
fn visit_expr(&mut self, _: Handle<crate::Expression>, expr: &crate::Expression) {
// We only care about `ImageSample` and `ImageLoad`
//
// Both `image` and `sampler` are `Expression::GlobalVariable` otherwise the module is

View File

@@ -14,7 +14,11 @@ the output struct. If there is a structure in the outputs, and it contains any b
we move them up to the root output structure that we define ourselves.
!*/
use crate::{arena::Handle, proc::ResolveError, FastHashMap};
use crate::{
arena::Handle,
proc::{analyzer::Analysis, ResolveError},
FastHashMap,
};
use std::{
io::{Error as IoError, Write},
string::FromUtf8Error,
@@ -66,6 +70,7 @@ pub enum Error {
UnexpectedSampleLevel(crate::SampleLevel),
UnsupportedCall(String),
UnsupportedDynamicArrayLength,
FeatureNotImplemented(String),
/// The source IR is not valid.
Validation,
}
@@ -222,10 +227,11 @@ pub struct TranslationInfo {
pub fn write_string(
module: &crate::Module,
analysis: &Analysis,
options: &Options,
) -> Result<(String, TranslationInfo), Error> {
let mut w = writer::Writer::new(Vec::new());
let info = w.write(module, options)?;
let info = w.write(module, analysis, options)?;
let string = String::from_utf8(w.finish())?;
Ok((string, info))
}

View File

@@ -1,18 +1,23 @@
use super::{keywords::RESERVED, Error, LocationMode, Options, TranslationInfo};
use crate::{
arena::Handle,
proc::{EntryPointIndex, NameKey, Namer, ResolveContext, Typifier},
proc::{
analyzer::{Analysis, FunctionInfo},
EntryPointIndex, Interface, NameKey, Namer, ResolveContext, Typifier, Visitor,
},
FastHashMap,
};
use bit_set::BitSet;
use std::{
fmt::{Display, Error as FmtError, Formatter},
io::Write,
iter,
iter, mem,
};
const NAMESPACE: &str = "metal";
const INDENT: &str = " ";
#[derive(Clone)]
struct Level(usize);
impl Level {
fn next(&self) -> Self {
@@ -64,8 +69,10 @@ impl<'a> TypedGlobalVariable<'a> {
pub struct Writer<W> {
out: W,
names: FastHashMap<NameKey, String>,
named_expressions: BitSet,
typifier: Typifier,
namer: Namer,
temp_bake_handles: Vec<Handle<crate::Expression>>,
}
fn scalar_kind_string(kind: crate::ScalarKind) -> &'static str {
@@ -113,6 +120,39 @@ impl crate::StorageClass {
}
}
struct BakeExpressionVisitor<'a> {
named_expressions: &'a mut BitSet,
bake_handles: &'a mut Vec<Handle<crate::Expression>>,
fun_info: &'a FunctionInfo,
exclude: Option<Handle<crate::Expression>>,
}
impl Visitor for BakeExpressionVisitor<'_> {
fn visit_expr(&mut self, handle: Handle<crate::Expression>, expr: &crate::Expression) {
use crate::Expression as E;
// filter out the expressions that don't need to bake
let min_ref_count = match *expr {
// The following expressions can be inlined nicely.
E::AccessIndex { .. }
| E::Constant(_)
| E::FunctionArgument(_)
| E::GlobalVariable(_)
| E::LocalVariable(_) => !0,
// Image sampling and function calling are nice to isolate
// into separate statements even when done only once.
E::ImageSample { .. } | E::ImageLoad { .. } | E::Call { .. } => 1,
// Bake only expressions referenced more than once.
_ => 2,
};
let modifier = if self.exclude == Some(handle) { 1 } else { 0 };
if self.fun_info[handle].ref_count - modifier >= min_ref_count
&& self.named_expressions.insert(handle.index())
{
self.bake_handles.push(handle);
}
}
}
enum FunctionOrigin {
Handle(Handle<crate::Function>),
EntryPoint(EntryPointIndex),
@@ -124,14 +164,22 @@ struct ExpressionContext<'a> {
module: &'a crate::Module,
}
struct StatementContext<'a> {
expression: ExpressionContext<'a>,
fun_info: &'a FunctionInfo,
return_value: Option<&'a str>,
}
impl<W: Write> Writer<W> {
/// Creates a new `Writer` instance.
pub fn new(out: W) -> Self {
Writer {
out,
names: FastHashMap::default(),
named_expressions: BitSet::new(),
typifier: Typifier::new(),
namer: Namer::default(),
temp_bake_handles: Vec::new(),
}
}
@@ -216,6 +264,11 @@ impl<W: Write> Writer<W> {
expr_handle: Handle<crate::Expression>,
context: &ExpressionContext,
) -> Result<(), Error> {
if self.named_expressions.contains(expr_handle.index()) {
write!(self.out, "expr{}", expr_handle.index())?;
return Ok(());
}
let expression = &context.function.expressions[expr_handle];
log::trace!("expression {:?} = {:?}", expr_handle, expression);
match *expression {
@@ -684,12 +737,76 @@ impl<W: Write> Writer<W> {
Ok(())
}
// Write down any required intermediate results
fn prepare_expression(
&mut self,
level: Level,
root_handle: Handle<crate::Expression>,
context: &StatementContext,
exclude_root: bool,
) -> Result<(), Error> {
// set up the search
let mut interface = Interface {
expressions: &context.expression.function.expressions,
local_variables: &context.expression.function.local_variables,
visitor: BakeExpressionVisitor {
named_expressions: &mut self.named_expressions,
bake_handles: &mut self.temp_bake_handles,
fun_info: context.fun_info,
exclude: if exclude_root {
Some(root_handle)
} else {
None
},
},
};
// populate the bake handles
interface.traverse_expr(root_handle);
// bake
let mut temp_bake_handles = mem::replace(&mut self.temp_bake_handles, Vec::new());
for handle in temp_bake_handles.drain(..).rev() {
write!(self.out, "{}", level)?;
match self.typifier.get_handle(handle) {
Ok(ty_handle) => {
let ty_name = &self.names[&NameKey::Type(ty_handle)];
write!(self.out, "{}", ty_name)?;
}
Err(&crate::TypeInner::Scalar { kind, .. }) => {
write!(self.out, "{}", scalar_kind_string(kind))?;
}
Err(&crate::TypeInner::Vector { size, kind, .. }) => {
write!(
self.out,
"{}::{}{}",
NAMESPACE,
scalar_kind_string(kind),
vector_size_string(size)
)?;
}
Err(other) => {
log::error!("Type {:?} isn't a known local", other);
return Err(Error::FeatureNotImplemented("weird local type".to_string()));
}
}
//TODO: figure out the naming scheme that wouldn't collide with user names.
write!(self.out, " expr{} = ", handle.index())?;
// Make sure to temporarily unblock the expression before writing it down.
self.named_expressions.remove(handle.index());
self.put_expression(handle, &context.expression)?;
self.named_expressions.insert(handle.index());
writeln!(self.out, ";")?;
}
self.temp_bake_handles = temp_bake_handles;
Ok(())
}
fn put_block(
&mut self,
level: Level,
statements: &[crate::Statement],
context: &ExpressionContext,
return_value: Option<&str>,
context: &StatementContext,
) -> Result<(), Error> {
for statement in statements {
log::trace!("statement[{}] {:?}", level.0, statement);
@@ -697,7 +814,7 @@ impl<W: Write> Writer<W> {
crate::Statement::Block(ref block) => {
if !block.is_empty() {
writeln!(self.out, "{}{{", level)?;
self.put_block(level.next(), block, context, return_value)?;
self.put_block(level.next(), block, context)?;
writeln!(self.out, "{}}}", level)?;
}
}
@@ -706,13 +823,14 @@ impl<W: Write> Writer<W> {
ref accept,
ref reject,
} => {
self.prepare_expression(level.clone(), condition, context, false)?;
write!(self.out, "{}if (", level)?;
self.put_expression(condition, context)?;
self.put_expression(condition, &context.expression)?;
writeln!(self.out, ") {{")?;
self.put_block(level.next(), accept, context, return_value)?;
self.put_block(level.next(), accept, context)?;
if !reject.is_empty() {
writeln!(self.out, "{}}} else {{", level)?;
self.put_block(level.next(), reject, context, return_value)?;
self.put_block(level.next(), reject, context)?;
}
writeln!(self.out, "{}}}", level)?;
}
@@ -721,20 +839,21 @@ impl<W: Write> Writer<W> {
ref cases,
ref default,
} => {
self.prepare_expression(level.clone(), selector, context, false)?;
write!(self.out, "{}switch(", level)?;
self.put_expression(selector, context)?;
self.put_expression(selector, &context.expression)?;
writeln!(self.out, ") {{")?;
let lcase = level.next();
for case in cases.iter() {
writeln!(self.out, "{}case {}: {{", lcase, case.value)?;
self.put_block(lcase.next(), &case.body, context, return_value)?;
self.put_block(lcase.next(), &case.body, context)?;
if case.fall_through {
writeln!(self.out, "{}break;", lcase.next())?;
}
writeln!(self.out, "{}}}", lcase)?;
}
writeln!(self.out, "{}default: {{", lcase)?;
self.put_block(lcase.next(), default, context, return_value)?;
self.put_block(lcase.next(), default, context)?;
writeln!(self.out, "{}}}", lcase)?;
writeln!(self.out, "{}}}", level)?;
}
@@ -748,13 +867,13 @@ impl<W: Write> Writer<W> {
writeln!(self.out, "{}while(true) {{", level)?;
let lif = level.next();
writeln!(self.out, "{}if (!{}) {{", lif, gate_name)?;
self.put_block(lif.next(), continuing, context, return_value)?;
self.put_block(lif.next(), continuing, context)?;
writeln!(self.out, "{}}}", lif)?;
writeln!(self.out, "{}{} = false;", lif, gate_name)?;
} else {
writeln!(self.out, "{}while(true) {{", level)?;
}
self.put_block(level.next(), body, context, return_value)?;
self.put_block(level.next(), body, context)?;
writeln!(self.out, "{}}}", level)?;
}
crate::Statement::Break => {
@@ -766,8 +885,9 @@ impl<W: Write> Writer<W> {
crate::Statement::Return {
value: Some(expr_handle),
} => {
self.prepare_expression(level.clone(), expr_handle, context, true)?;
write!(self.out, "{}return ", level)?;
self.put_expression(expr_handle, context)?;
self.put_expression(expr_handle, &context.expression)?;
writeln!(self.out, ";")?;
}
crate::Statement::Return { value: None } => {
@@ -775,25 +895,28 @@ impl<W: Write> Writer<W> {
self.out,
"{}return {};",
level,
return_value.unwrap_or_default(),
context.return_value.unwrap_or_default(),
)?;
}
crate::Statement::Kill => {
writeln!(self.out, "{}discard_fragment();", level)?;
}
crate::Statement::Store { pointer, value } => {
//write!(self.out, "{}*", INDENT)?;
self.prepare_expression(level.clone(), value, context, true)?;
write!(self.out, "{}", level)?;
self.put_expression(pointer, context)?;
self.put_expression(pointer, &context.expression)?;
write!(self.out, " = ")?;
self.put_expression(value, context)?;
self.put_expression(value, &context.expression)?;
writeln!(self.out, ";")?;
}
crate::Statement::Call {
function,
ref arguments,
} => {
self.put_local_call(function, arguments, context)?;
for &arg in arguments {
self.prepare_expression(level.clone(), arg, context, false)?;
}
self.put_local_call(function, arguments, &context.expression)?;
writeln!(self.out, ";")?;
}
}
@@ -804,6 +927,7 @@ impl<W: Write> Writer<W> {
pub fn write(
&mut self,
module: &crate::Module,
analysis: &Analysis,
options: &Options,
) -> Result<TranslationInfo, Error> {
self.names.clear();
@@ -814,7 +938,7 @@ impl<W: Write> Writer<W> {
writeln!(self.out)?;
self.write_type_defs(module)?;
self.write_functions(module, options)
self.write_functions(module, analysis, options)
}
fn write_type_defs(&mut self, module: &crate::Module) -> Result<(), Error> {
@@ -956,6 +1080,7 @@ impl<W: Write> Writer<W> {
fn write_functions(
&mut self,
module: &crate::Module,
analysis: &Analysis,
options: &Options,
) -> Result<TranslationInfo, Error> {
let mut pass_through_globals = Vec::new();
@@ -1022,12 +1147,17 @@ impl<W: Write> Writer<W> {
writeln!(self.out, ";")?;
}
let context = ExpressionContext {
function: fun,
origin: FunctionOrigin::Handle(fun_handle),
module,
let context = StatementContext {
expression: ExpressionContext {
function: fun,
origin: FunctionOrigin::Handle(fun_handle),
module,
},
fun_info: &analysis[fun_handle],
return_value: None,
};
self.put_block(Level(1), &fun.body, &context, None)?;
self.named_expressions.clear();
self.put_block(Level(1), &fun.body, &context)?;
writeln!(self.out, "}}")?;
writeln!(self.out)?;
}
@@ -1035,7 +1165,7 @@ impl<W: Write> Writer<W> {
let mut info = TranslationInfo {
entry_point_names: Vec::with_capacity(module.entry_points.len()),
};
for (ep_index, (&(stage, _), ep)) in module.entry_points.iter().enumerate() {
for (ep_index, (&(stage, ref ep_name), ep)) in module.entry_points.iter().enumerate() {
let fun = &ep.function;
self.typifier.resolve_all(
&fun.expressions,
@@ -1222,12 +1352,17 @@ impl<W: Write> Writer<W> {
writeln!(self.out, ";")?;
}
let context = ExpressionContext {
function: fun,
origin: FunctionOrigin::EntryPoint(ep_index as _),
module,
let context = StatementContext {
expression: ExpressionContext {
function: fun,
origin: FunctionOrigin::EntryPoint(ep_index as _),
module,
},
fun_info: analysis.get_entry_point(stage, ep_name),
return_value,
};
self.put_block(Level(1), &fun.body, &context, return_value)?;
self.named_expressions.clear();
self.put_block(Level(1), &fun.body, &context)?;
writeln!(self.out, "}}")?;
let is_last = ep_index == module.entry_points.len() - 1;
if !is_last {

View File

@@ -9,13 +9,13 @@ struct GlobalUseVisitor<'a> {
}
impl Visitor for GlobalUseVisitor<'_> {
fn visit_expr(&mut self, expr: &crate::Expression) {
fn visit_expr(&mut self, _: Handle<crate::Expression>, expr: &crate::Expression) {
if let crate::Expression::GlobalVariable(handle) = expr {
self.usage[handle.index()] |= crate::GlobalUse::READ;
}
}
fn visit_lhs_expr(&mut self, expr: &crate::Expression) {
fn visit_lhs_expr(&mut self, _: Handle<crate::Expression>, expr: &crate::Expression) {
if let crate::Expression::GlobalVariable(handle) = expr {
self.usage[handle.index()] |= crate::GlobalUse::WRITE;
}

View File

@@ -640,6 +640,8 @@ pub enum ImageQuery {
}
/// An expression that can be evaluated to obtain a value.
///
/// This is a Single Static Assignment (SSA) scheme similar to SPIR-V.
#[derive(Clone, Debug)]
#[cfg_attr(test, derive(PartialEq))]
#[cfg_attr(feature = "serialize", derive(Serialize))]

View File

@@ -7,6 +7,7 @@ Figures out the following properties:
!*/
use crate::arena::{Arena, Handle};
use std::ops;
bitflags::bitflags! {
#[derive(Default)]
@@ -35,7 +36,14 @@ pub struct ExpressionInfo {
pub struct FunctionInfo {
pub control_flags: ControlFlags,
pub sampling_set: crate::FastHashSet<SamplingKey>,
pub expressions: Box<[ExpressionInfo]>,
expressions: Box<[ExpressionInfo]>,
}
impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
type Output = ExpressionInfo;
fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
&self.expressions[handle.index()]
}
}
#[derive(Clone, Debug, thiserror::Error)]
@@ -355,6 +363,22 @@ impl Analysis {
Ok(this)
}
pub fn get_entry_point(&self, stage: crate::ShaderStage, name: &str) -> &FunctionInfo {
let (_, info) = self
.entry_points
.iter()
.find(|(key, _)| key.0 == stage && key.1 == name)
.unwrap();
info
}
}
impl ops::Index<Handle<crate::Function>> for Analysis {
type Output = FunctionInfo;
fn index(&self, handle: Handle<crate::Function>) -> &FunctionInfo {
&self.functions[handle.index()]
}
}
#[test]

View File

@@ -7,8 +7,8 @@ pub struct Interface<'a, T> {
}
pub trait Visitor {
fn visit_expr(&mut self, _: &crate::Expression) {}
fn visit_lhs_expr(&mut self, _: &crate::Expression) {}
fn visit_expr(&mut self, _: Handle<crate::Expression>, _: &crate::Expression) {}
fn visit_lhs_expr(&mut self, _: Handle<crate::Expression>, _: &crate::Expression) {}
fn visit_fun(&mut self, _: Handle<crate::Function>) {}
}
@@ -16,12 +16,12 @@ impl<'a, T> Interface<'a, T>
where
T: Visitor,
{
fn traverse_expr(&mut self, handle: Handle<crate::Expression>) {
pub fn traverse_expr(&mut self, handle: Handle<crate::Expression>) {
use crate::Expression as E;
let expr = &self.expressions[handle];
self.visitor.visit_expr(expr);
self.visitor.visit_expr(handle, expr);
match *expr {
E::Access { base, index } => {
@@ -200,7 +200,7 @@ where
_ => break,
}
}
self.visitor.visit_lhs_expr(&self.expressions[left]);
self.visitor.visit_lhs_expr(left, &self.expressions[left]);
self.traverse_expr(value);
}
S::Call {

View File

@@ -52,7 +52,8 @@ type7 fetch_shadow(
if ((homogeneous_coords.w <= 0.0)) {
return 1.0;
}
return t_shadow.sample_compare(sampler_shadow, (((metal::float2(homogeneous_coords.x, homogeneous_coords.y) * metal::float2(0.5, -0.5)) * (1.0 / homogeneous_coords.w)) + metal::float2(0.5, 0.5)), static_cast<int>(light_id), (homogeneous_coords.z * (1.0 / homogeneous_coords.w)));
float expr15 = (1.0 / homogeneous_coords.w);
return t_shadow.sample_compare(sampler_shadow, (((metal::float2(homogeneous_coords.x, homogeneous_coords.y) * metal::float2(0.5, -0.5)) * expr15) + metal::float2(0.5, 0.5)), static_cast<int>(light_id), (homogeneous_coords.z * expr15));
}
struct fs_mainInput {
@@ -83,7 +84,9 @@ fragment fs_mainOutput fs_main(
if ((i >= metal::min(u_globals.num_lights.x, 10u))) {
break;
}
color1 = (color1 + ((fetch_shadow(i, (s_lights.data[i].proj * input.in_position_fs), t_shadow, sampler_shadow) * metal::max(0.0, metal::dot(metal::normalize(input.in_normal_fs), metal::normalize((metal::float3(s_lights.data[i].pos.x, s_lights.data[i].pos.y, s_lights.data[i].pos.z) - metal::float3(input.in_position_fs.x, input.in_position_fs.y, input.in_position_fs.z)))))) * metal::float3(s_lights.data[i].color.x, s_lights.data[i].color.y, s_lights.data[i].color.z)));
Light expr18 = s_lights.data[i];
type7 expr21 = fetch_shadow(i, (expr18.proj * input.in_position_fs), t_shadow, sampler_shadow);
color1 = (color1 + ((expr21 * metal::max(0.0, metal::dot(metal::normalize(input.in_normal_fs), metal::normalize((metal::float3(expr18.pos.x, expr18.pos.y, expr18.pos.z) - metal::float3(input.in_position_fs.x, input.in_position_fs.y, input.in_position_fs.z)))))) * metal::float3(expr18.color.x, expr18.color.y, expr18.color.z)));
}
output.out_color_fs = metal::float4(color1, 1.0);
return output;

View File

@@ -47,9 +47,10 @@ vertex vs_mainOutput vs_main(
type unprojected;
tmp1_ = (static_cast<int>(in_vertex_index) / 2);
tmp2_ = (static_cast<int>(in_vertex_index) & 1);
unprojected = (r_data.proj_inv * metal::float4(((static_cast<float>(tmp1_) * 4.0) - 1.0), ((static_cast<float>(tmp2_) * 4.0) - 1.0), 0.0, 1.0));
type expr24 = metal::float4(((static_cast<float>(tmp1_) * 4.0) - 1.0), ((static_cast<float>(tmp2_) * 4.0) - 1.0), 0.0, 1.0);
unprojected = (r_data.proj_inv * expr24);
output.out_uv = (metal::transpose(metal::float3x3(metal::float3(r_data.view[0].x, r_data.view[0].y, r_data.view[0].z), metal::float3(r_data.view[1].x, r_data.view[1].y, r_data.view[1].z), metal::float3(r_data.view[2].x, r_data.view[2].y, r_data.view[2].z))) * metal::float3(unprojected.x, unprojected.y, unprojected.z));
output.out_position = metal::float4(((static_cast<float>(tmp1_) * 4.0) - 1.0), ((static_cast<float>(tmp2_) * 4.0) - 1.0), 0.0, 1.0);
output.out_position = expr24;
return output;
}

View File

@@ -75,7 +75,12 @@ fn check_output_spv(module: &naga::Module, name: &str, params: &Parameters) {
}
#[cfg(feature = "msl-out")]
fn check_output_msl(module: &naga::Module, name: &str, params: &Parameters) {
fn check_output_msl(
module: &naga::Module,
analysis: &naga::proc::analyzer::Analysis,
name: &str,
params: &Parameters,
) {
use naga::back::msl;
let mut binding_map = msl::BindingMap::default();
@@ -104,7 +109,7 @@ fn check_output_msl(module: &naga::Module, name: &str, params: &Parameters) {
binding_map,
};
let (msl, _) = msl::write_string(&module, &options).unwrap();
let (msl, _) = msl::write_string(module, analysis, &options).unwrap();
with_snapshot_settings(|| {
insta::assert_snapshot!(format!("{}.msl", name), msl);
@@ -143,7 +148,8 @@ fn convert_wgsl(name: &str, language: Language) {
.expect("Couldn't find wgsl file"),
)
.unwrap();
naga::proc::Validator::new().validate(&module).unwrap();
#[cfg_attr(not(feature = "msl-out"), allow(unused_variables))]
let analysis = naga::proc::Validator::new().validate(&module).unwrap();
#[cfg(feature = "spv-out")]
{
@@ -154,7 +160,7 @@ fn convert_wgsl(name: &str, language: Language) {
#[cfg(feature = "msl-out")]
{
if language.contains(Language::METAL) {
check_output_msl(&module, name, &params);
check_output_msl(&module, &analysis, name, &params);
}
}
#[cfg(feature = "glsl-out")]