Distinguish between expression uniformity and function uniformity

This commit is contained in:
Dzmitry Malyshau
2021-03-11 01:04:55 -05:00
committed by Dzmitry Malyshau
parent 74d1fdbf2a
commit dd273e254a
8 changed files with 218 additions and 109 deletions

View File

@@ -784,7 +784,7 @@ impl<W: Write> Writer<W> {
)?;
}
crate::Statement::Kill => {
writeln!(self.out, "{}discard_fragment();", level)?;
writeln!(self.out, "{}{}::discard_fragment();", level, NAMESPACE)?;
}
crate::Statement::Store { pointer, value } => {
write!(self.out, "{}", level)?;

View File

@@ -73,6 +73,19 @@ impl Uniformity {
fn disruptor(&self) -> Option<UniformityDisruptor> {
self.non_uniform_result.map(UniformityDisruptor::Expression)
}
/// When considering the uniformity at the functional level,
/// we don't care about all of the expression results.
/// We only care about the `return` expressions.
fn to_function(&self) -> FunctionUniformity {
FunctionUniformity {
result: Uniformity {
non_uniform_result: None,
require_uniform: self.require_uniform,
},
exit: ExitFlags::empty(),
}
}
}
bitflags::bitflags! {
@@ -87,6 +100,43 @@ bitflags::bitflags! {
}
}
/// Uniformity characteristics of a function.
#[cfg_attr(test, derive(Debug, PartialEq))]
struct FunctionUniformity {
result: Uniformity,
exit: ExitFlags,
}
impl ops::BitOr for FunctionUniformity {
type Output = Self;
fn bitor(self, other: Self) -> Self {
FunctionUniformity {
result: self.result | other.result,
exit: self.exit | other.exit,
}
}
}
impl FunctionUniformity {
fn new() -> Self {
FunctionUniformity {
result: Uniformity::default(),
exit: ExitFlags::empty(),
}
}
/// Returns a disruptor based on the stored exit flags, if any.
fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
if self.exit.contains(ExitFlags::MAY_RETURN) {
Some(UniformityDisruptor::Return)
} else if self.exit.contains(ExitFlags::MAY_KILL) {
Some(UniformityDisruptor::Discard)
} else {
None
}
}
}
bitflags::bitflags! {
/// Indicates how a global variable is used.
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
@@ -182,18 +232,6 @@ pub enum UniformityDisruptor {
Discard,
}
impl UniformityDisruptor {
fn from_exit(flags: ExitFlags) -> Option<Self> {
if flags.contains(ExitFlags::MAY_RETURN) {
Some(Self::Return)
} else if flags.contains(ExitFlags::MAY_KILL) {
Some(Self::Discard)
} else {
None
}
}
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum AnalysisError {
@@ -248,14 +286,21 @@ impl FunctionInfo {
}
/// Inherit information from a called function.
fn process_call(&mut self, info: &Self) -> Uniformity {
fn process_call(&mut self, info: &Self) -> FunctionUniformity {
for key in info.sampling_set.iter() {
self.sampling_set.insert(key.clone());
}
for (mine, other) in self.global_uses.iter_mut().zip(info.global_uses.iter()) {
*mine |= *other;
}
info.uniformity.clone()
FunctionUniformity {
result: info.uniformity.clone(),
exit: if info.may_kill {
ExitFlags::MAY_KILL
} else {
ExitFlags::empty()
},
}
}
/// Computes the expression info and stores it in `self.expressions`.
@@ -412,7 +457,7 @@ impl FunctionInfo {
self.add_ref(arg) | arg1_flags | arg2_flags
}
E::As { expr, .. } => self.add_ref(expr),
E::Call(function) => self.process_call(&other_functions[function.index()]),
E::Call(function) => self.process_call(&other_functions[function.index()]).result,
E::ArrayLength(expr) => self.add_ref_impl(expr, GlobalUse::QUERY),
};
@@ -424,8 +469,11 @@ impl FunctionInfo {
Ok(())
}
/// Computes the uniformity and the exit flags on the block
/// (as a sequence of statements), and returns them.
/// Analyzes the uniformity requirements of a block (as a sequence of statements).
/// Returns the uniformity characteristics at the *function* level, i.e.
/// whether or not the function requires to be called in uniform control flow,
/// and whether the produced result is not disrupting the control flow.
///
/// The parent control flow is uniform if `disruptor.is_none()`.
///
/// Returns a `NonUniformControlFlow` error if any of the expressions in the block
@@ -436,14 +484,28 @@ impl FunctionInfo {
statements: &[crate::Statement],
other_functions: &[FunctionInfo],
mut disruptor: Option<UniformityDisruptor>,
) -> Result<(Uniformity, ExitFlags), AnalysisError> {
) -> Result<FunctionUniformity, AnalysisError> {
use crate::Statement as S;
let mut block_uniformity = Uniformity::default();
let mut block_exit = ExitFlags::empty();
let mut combined_uniformity = FunctionUniformity::new();
for statement in statements {
let (cur_uniformity, cur_exit) = match *statement {
S::Emit(_) | S::Break | S::Continue => (Uniformity::default(), ExitFlags::empty()),
S::Kill => (Uniformity::default(), ExitFlags::MAY_KILL),
let uniformity = match *statement {
S::Emit(ref range) => {
for expr in range.clone() {
if let (Some(expr), Some(cause)) = (
self.expressions[expr.index()].uniformity.require_uniform,
disruptor,
) {
return Err(AnalysisError::NonUniformControlFlow(expr, cause));
}
}
FunctionUniformity::new()
}
S::Break | S::Continue => FunctionUniformity::new(),
S::Kill => FunctionUniformity {
result: Uniformity::default(),
exit: ExitFlags::MAY_KILL,
},
S::Block(ref b) => self.process_block(b, other_functions, disruptor)?,
S::If {
condition,
@@ -452,14 +514,11 @@ impl FunctionInfo {
} => {
let condition_uniformity = self.add_ref(condition);
let branch_disruptor = disruptor.or(condition_uniformity.disruptor());
let (accept_uniformity, accept_exit) =
let accept_uniformity =
self.process_block(accept, other_functions, branch_disruptor)?;
let (reject_uniformity, reject_exit) =
let reject_uniformity =
self.process_block(reject, other_functions, branch_disruptor)?;
(
condition_uniformity | accept_uniformity | reject_uniformity,
accept_exit | reject_exit,
)
condition_uniformity.to_function() | accept_uniformity | reject_uniformity
}
S::Switch {
selector,
@@ -468,52 +527,46 @@ impl FunctionInfo {
} => {
let selector_uniformity = self.add_ref(selector);
let branch_disruptor = disruptor.or(selector_uniformity.disruptor());
let mut uniformity = selector_uniformity;
let mut exit = ExitFlags::empty();
let mut case_disruptor = disruptor;
let mut uniformity = FunctionUniformity::new();
let mut case_disruptor = branch_disruptor;
for case in cases.iter() {
let (case_uniformity, case_exit) =
let case_uniformity =
self.process_block(&case.body, other_functions, case_disruptor)?;
uniformity |= case_uniformity;
exit |= case_exit;
case_disruptor = if case.fall_through {
case_disruptor.or(UniformityDisruptor::from_exit(case_exit))
case_disruptor.or(case_uniformity.exit_disruptor())
} else {
branch_disruptor
};
uniformity = uniformity | case_uniformity;
}
let (default_uniformity, default_exit) =
// using the disruptor inherited from the last fall-through chain
let default_exit =
self.process_block(default, other_functions, case_disruptor)?;
(uniformity | default_uniformity, exit | default_exit)
uniformity | default_exit
}
S::Loop {
ref body,
ref continuing,
} => {
let (body_uniformity, body_exit) =
self.process_block(body, other_functions, disruptor)?;
let continuing_disruptor = disruptor
.or(body_uniformity.disruptor())
.or(UniformityDisruptor::from_exit(body_exit));
let (continuing_uniformity, continuing_exit) =
let body_uniformity = self.process_block(body, other_functions, disruptor)?;
let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
let continuing_uniformity =
self.process_block(continuing, other_functions, continuing_disruptor)?;
(
body_uniformity | continuing_uniformity,
body_exit | continuing_exit,
)
body_uniformity | continuing_uniformity
}
S::Return { value } => {
let uniformity = match value {
Some(expr) => self.add_ref(expr),
None => Uniformity::default(),
};
S::Return { value } => FunctionUniformity {
result: if let Some(expr) = value {
self.add_ref(expr)
} else {
Uniformity::default()
},
//TODO: if we are in the uniform control flow, should this still be an exit flag?
(uniformity, ExitFlags::MAY_RETURN)
}
exit: ExitFlags::MAY_RETURN,
},
S::Store { pointer, value } => {
let uniformity =
self.add_ref_impl(pointer, GlobalUse::WRITE) | self.add_ref(value);
(uniformity, ExitFlags::empty())
uniformity.to_function()
}
S::ImageStore {
image,
@@ -529,7 +582,7 @@ impl FunctionInfo {
| self.add_ref_impl(image, GlobalUse::WRITE)
| self.add_ref(coordinate)
| self.add_ref(value);
(uniformity, ExitFlags::empty())
uniformity.to_function()
}
S::Call {
function,
@@ -537,27 +590,19 @@ impl FunctionInfo {
result: _,
} => {
let info = &other_functions[function.index()];
let mut uniformity = self.process_call(info);
let call_uniformity = self.process_call(info);
let mut uniformity = call_uniformity.result.clone();
for &argument in arguments {
uniformity |= self.add_ref(argument);
}
let exit = if info.may_kill {
ExitFlags::MAY_KILL
} else {
ExitFlags::empty()
};
(uniformity, exit)
call_uniformity | uniformity.to_function()
}
};
if let (Some(expr), Some(cause)) = (cur_uniformity.require_uniform, disruptor) {
return Err(AnalysisError::NonUniformControlFlow(expr, cause));
}
disruptor = disruptor.or(UniformityDisruptor::from_exit(cur_exit));
block_uniformity |= cur_uniformity;
block_exit |= cur_exit;
disruptor = disruptor.or(uniformity.exit_disruptor());
combined_uniformity = combined_uniformity | uniformity;
}
Ok((block_uniformity, block_exit))
Ok(combined_uniformity)
}
}
@@ -589,9 +634,9 @@ impl Analysis {
info.process_expression(handle, &fun.expressions, global_var_arena, &self.functions)?;
}
let (uniformity, exit) = info.process_block(&fun.body, &self.functions, None)?;
info.uniformity = uniformity;
info.may_kill = exit.contains(ExitFlags::MAY_KILL);
let uniformity = info.process_block(&fun.body, &self.functions, None)?;
info.uniformity = uniformity.result;
info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
Ok(info)
}
@@ -676,8 +721,11 @@ fn uniform_control_flow() {
axis: crate::DerivativeAxis::X,
expr: constant_expr,
});
let emit_range_constant_derivative = expressions.range_from(0);
let non_uniform_global_expr = expressions.append(E::GlobalVariable(non_uniform_global));
let uniform_global_expr = expressions.append(E::GlobalVariable(uniform_global));
let emit_range_globals = expressions.range_from(2);
// checks the QUERY flag
let query_expr = expressions.append(E::ArrayLength(uniform_global_expr));
// checks the transitive WRITE flag
@@ -685,6 +733,7 @@ fn uniform_control_flow() {
base: non_uniform_global_expr,
index: 1,
});
let emit_range_query_access_globals = expressions.range_from(2);
let mut info = FunctionInfo {
uniformity: Uniformity::default(),
@@ -704,68 +753,87 @@ fn uniform_control_flow() {
assert_eq!(info[non_uniform_global], GlobalUse::empty());
assert_eq!(info[uniform_global], GlobalUse::QUERY);
let stmt_emit1 = S::Emit(emit_range_globals.clone());
let stmt_if_uniform = S::If {
condition: uniform_global_expr,
accept: Vec::new(),
reject: vec![S::Store {
pointer: constant_expr,
value: derivative_expr,
}],
reject: vec![
S::Emit(emit_range_constant_derivative.clone()),
S::Store {
pointer: constant_expr,
value: derivative_expr,
},
],
};
assert_eq!(
info.process_block(&[stmt_if_uniform], &[], None),
Ok((
Uniformity::require_uniform(derivative_expr),
ExitFlags::empty()
)),
info.process_block(&[stmt_emit1, stmt_if_uniform], &[], None),
Ok(FunctionUniformity {
result: Uniformity::require_uniform(derivative_expr),
exit: ExitFlags::empty(),
}),
);
assert_eq!(info[constant_expr].ref_count, 2);
assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
let stmt_emit2 = S::Emit(emit_range_globals.clone());
let stmt_if_non_uniform = S::If {
condition: non_uniform_global_expr,
accept: vec![S::Store {
pointer: constant_expr,
value: derivative_expr,
}],
accept: vec![
S::Emit(emit_range_constant_derivative.clone()),
S::Store {
pointer: constant_expr,
value: derivative_expr,
},
],
reject: Vec::new(),
};
assert_eq!(
info.process_block(&[stmt_if_non_uniform], &[], None),
info.process_block(&[stmt_emit2, stmt_if_non_uniform], &[], None),
Err(AnalysisError::NonUniformControlFlow(
derivative_expr,
UniformityDisruptor::Expression(non_uniform_global_expr)
)),
);
assert_eq!(info[derivative_expr].ref_count, 2);
assert_eq!(info[derivative_expr].ref_count, 1);
assert_eq!(info[non_uniform_global], GlobalUse::READ);
let stmt_emit3 = S::Emit(emit_range_globals);
let stmt_return_non_uniform = S::Return {
value: Some(non_uniform_global_expr),
};
assert_eq!(
info.process_block(
&[stmt_return_non_uniform],
&[stmt_emit3, stmt_return_non_uniform],
&[],
Some(UniformityDisruptor::Return)
),
Ok((
Uniformity::non_uniform_result(non_uniform_global_expr),
ExitFlags::MAY_RETURN
)),
Ok(FunctionUniformity {
result: Uniformity::non_uniform_result(non_uniform_global_expr),
exit: ExitFlags::MAY_RETURN,
}),
);
assert_eq!(info[non_uniform_global_expr].ref_count, 3);
// Check that uniformity requirements reach through a pointer
let stmt_emit4 = S::Emit(emit_range_query_access_globals);
let stmt_assign = S::Store {
pointer: access_expr,
value: query_expr,
};
let stmt_return_pointer = S::Return {
value: Some(access_expr),
};
let stmt_kill = S::Kill;
assert_eq!(
info.process_block(&[stmt_assign], &[], Some(UniformityDisruptor::Discard)),
Ok((
Uniformity::non_uniform_result(non_uniform_global_expr),
ExitFlags::empty()
)),
info.process_block(
&[stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer],
&[],
Some(UniformityDisruptor::Discard)
),
Ok(FunctionUniformity {
result: Uniformity::non_uniform_result(non_uniform_global_expr),
exit: ExitFlags::all(),
}),
);
assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
}