mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
Distinguish between expression uniformity and function uniformity
This commit is contained in:
committed by
Dzmitry Malyshau
parent
74d1fdbf2a
commit
dd273e254a
@@ -784,7 +784,7 @@ impl<W: Write> Writer<W> {
|
||||
)?;
|
||||
}
|
||||
crate::Statement::Kill => {
|
||||
writeln!(self.out, "{}discard_fragment();", level)?;
|
||||
writeln!(self.out, "{}{}::discard_fragment();", level, NAMESPACE)?;
|
||||
}
|
||||
crate::Statement::Store { pointer, value } => {
|
||||
write!(self.out, "{}", level)?;
|
||||
|
||||
@@ -73,6 +73,19 @@ impl Uniformity {
|
||||
fn disruptor(&self) -> Option<UniformityDisruptor> {
|
||||
self.non_uniform_result.map(UniformityDisruptor::Expression)
|
||||
}
|
||||
|
||||
/// When considering the uniformity at the functional level,
|
||||
/// we don't care about all of the expression results.
|
||||
/// We only care about the `return` expressions.
|
||||
fn to_function(&self) -> FunctionUniformity {
|
||||
FunctionUniformity {
|
||||
result: Uniformity {
|
||||
non_uniform_result: None,
|
||||
require_uniform: self.require_uniform,
|
||||
},
|
||||
exit: ExitFlags::empty(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bitflags::bitflags! {
|
||||
@@ -87,6 +100,43 @@ bitflags::bitflags! {
|
||||
}
|
||||
}
|
||||
|
||||
/// Uniformity characteristics of a function.
|
||||
#[cfg_attr(test, derive(Debug, PartialEq))]
|
||||
struct FunctionUniformity {
|
||||
result: Uniformity,
|
||||
exit: ExitFlags,
|
||||
}
|
||||
|
||||
impl ops::BitOr for FunctionUniformity {
|
||||
type Output = Self;
|
||||
fn bitor(self, other: Self) -> Self {
|
||||
FunctionUniformity {
|
||||
result: self.result | other.result,
|
||||
exit: self.exit | other.exit,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FunctionUniformity {
|
||||
fn new() -> Self {
|
||||
FunctionUniformity {
|
||||
result: Uniformity::default(),
|
||||
exit: ExitFlags::empty(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a disruptor based on the stored exit flags, if any.
|
||||
fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
|
||||
if self.exit.contains(ExitFlags::MAY_RETURN) {
|
||||
Some(UniformityDisruptor::Return)
|
||||
} else if self.exit.contains(ExitFlags::MAY_KILL) {
|
||||
Some(UniformityDisruptor::Discard)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bitflags::bitflags! {
|
||||
/// Indicates how a global variable is used.
|
||||
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
|
||||
@@ -182,18 +232,6 @@ pub enum UniformityDisruptor {
|
||||
Discard,
|
||||
}
|
||||
|
||||
impl UniformityDisruptor {
|
||||
fn from_exit(flags: ExitFlags) -> Option<Self> {
|
||||
if flags.contains(ExitFlags::MAY_RETURN) {
|
||||
Some(Self::Return)
|
||||
} else if flags.contains(ExitFlags::MAY_KILL) {
|
||||
Some(Self::Discard)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, thiserror::Error)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub enum AnalysisError {
|
||||
@@ -248,14 +286,21 @@ impl FunctionInfo {
|
||||
}
|
||||
|
||||
/// Inherit information from a called function.
|
||||
fn process_call(&mut self, info: &Self) -> Uniformity {
|
||||
fn process_call(&mut self, info: &Self) -> FunctionUniformity {
|
||||
for key in info.sampling_set.iter() {
|
||||
self.sampling_set.insert(key.clone());
|
||||
}
|
||||
for (mine, other) in self.global_uses.iter_mut().zip(info.global_uses.iter()) {
|
||||
*mine |= *other;
|
||||
}
|
||||
info.uniformity.clone()
|
||||
FunctionUniformity {
|
||||
result: info.uniformity.clone(),
|
||||
exit: if info.may_kill {
|
||||
ExitFlags::MAY_KILL
|
||||
} else {
|
||||
ExitFlags::empty()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the expression info and stores it in `self.expressions`.
|
||||
@@ -412,7 +457,7 @@ impl FunctionInfo {
|
||||
self.add_ref(arg) | arg1_flags | arg2_flags
|
||||
}
|
||||
E::As { expr, .. } => self.add_ref(expr),
|
||||
E::Call(function) => self.process_call(&other_functions[function.index()]),
|
||||
E::Call(function) => self.process_call(&other_functions[function.index()]).result,
|
||||
E::ArrayLength(expr) => self.add_ref_impl(expr, GlobalUse::QUERY),
|
||||
};
|
||||
|
||||
@@ -424,8 +469,11 @@ impl FunctionInfo {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Computes the uniformity and the exit flags on the block
|
||||
/// (as a sequence of statements), and returns them.
|
||||
/// Analyzes the uniformity requirements of a block (as a sequence of statements).
|
||||
/// Returns the uniformity characteristics at the *function* level, i.e.
|
||||
/// whether or not the function requires to be called in uniform control flow,
|
||||
/// and whether the produced result is not disrupting the control flow.
|
||||
///
|
||||
/// The parent control flow is uniform if `disruptor.is_none()`.
|
||||
///
|
||||
/// Returns a `NonUniformControlFlow` error if any of the expressions in the block
|
||||
@@ -436,14 +484,28 @@ impl FunctionInfo {
|
||||
statements: &[crate::Statement],
|
||||
other_functions: &[FunctionInfo],
|
||||
mut disruptor: Option<UniformityDisruptor>,
|
||||
) -> Result<(Uniformity, ExitFlags), AnalysisError> {
|
||||
) -> Result<FunctionUniformity, AnalysisError> {
|
||||
use crate::Statement as S;
|
||||
let mut block_uniformity = Uniformity::default();
|
||||
let mut block_exit = ExitFlags::empty();
|
||||
|
||||
let mut combined_uniformity = FunctionUniformity::new();
|
||||
for statement in statements {
|
||||
let (cur_uniformity, cur_exit) = match *statement {
|
||||
S::Emit(_) | S::Break | S::Continue => (Uniformity::default(), ExitFlags::empty()),
|
||||
S::Kill => (Uniformity::default(), ExitFlags::MAY_KILL),
|
||||
let uniformity = match *statement {
|
||||
S::Emit(ref range) => {
|
||||
for expr in range.clone() {
|
||||
if let (Some(expr), Some(cause)) = (
|
||||
self.expressions[expr.index()].uniformity.require_uniform,
|
||||
disruptor,
|
||||
) {
|
||||
return Err(AnalysisError::NonUniformControlFlow(expr, cause));
|
||||
}
|
||||
}
|
||||
FunctionUniformity::new()
|
||||
}
|
||||
S::Break | S::Continue => FunctionUniformity::new(),
|
||||
S::Kill => FunctionUniformity {
|
||||
result: Uniformity::default(),
|
||||
exit: ExitFlags::MAY_KILL,
|
||||
},
|
||||
S::Block(ref b) => self.process_block(b, other_functions, disruptor)?,
|
||||
S::If {
|
||||
condition,
|
||||
@@ -452,14 +514,11 @@ impl FunctionInfo {
|
||||
} => {
|
||||
let condition_uniformity = self.add_ref(condition);
|
||||
let branch_disruptor = disruptor.or(condition_uniformity.disruptor());
|
||||
let (accept_uniformity, accept_exit) =
|
||||
let accept_uniformity =
|
||||
self.process_block(accept, other_functions, branch_disruptor)?;
|
||||
let (reject_uniformity, reject_exit) =
|
||||
let reject_uniformity =
|
||||
self.process_block(reject, other_functions, branch_disruptor)?;
|
||||
(
|
||||
condition_uniformity | accept_uniformity | reject_uniformity,
|
||||
accept_exit | reject_exit,
|
||||
)
|
||||
condition_uniformity.to_function() | accept_uniformity | reject_uniformity
|
||||
}
|
||||
S::Switch {
|
||||
selector,
|
||||
@@ -468,52 +527,46 @@ impl FunctionInfo {
|
||||
} => {
|
||||
let selector_uniformity = self.add_ref(selector);
|
||||
let branch_disruptor = disruptor.or(selector_uniformity.disruptor());
|
||||
let mut uniformity = selector_uniformity;
|
||||
let mut exit = ExitFlags::empty();
|
||||
let mut case_disruptor = disruptor;
|
||||
let mut uniformity = FunctionUniformity::new();
|
||||
let mut case_disruptor = branch_disruptor;
|
||||
for case in cases.iter() {
|
||||
let (case_uniformity, case_exit) =
|
||||
let case_uniformity =
|
||||
self.process_block(&case.body, other_functions, case_disruptor)?;
|
||||
uniformity |= case_uniformity;
|
||||
exit |= case_exit;
|
||||
case_disruptor = if case.fall_through {
|
||||
case_disruptor.or(UniformityDisruptor::from_exit(case_exit))
|
||||
case_disruptor.or(case_uniformity.exit_disruptor())
|
||||
} else {
|
||||
branch_disruptor
|
||||
};
|
||||
uniformity = uniformity | case_uniformity;
|
||||
}
|
||||
let (default_uniformity, default_exit) =
|
||||
// using the disruptor inherited from the last fall-through chain
|
||||
let default_exit =
|
||||
self.process_block(default, other_functions, case_disruptor)?;
|
||||
(uniformity | default_uniformity, exit | default_exit)
|
||||
uniformity | default_exit
|
||||
}
|
||||
S::Loop {
|
||||
ref body,
|
||||
ref continuing,
|
||||
} => {
|
||||
let (body_uniformity, body_exit) =
|
||||
self.process_block(body, other_functions, disruptor)?;
|
||||
let continuing_disruptor = disruptor
|
||||
.or(body_uniformity.disruptor())
|
||||
.or(UniformityDisruptor::from_exit(body_exit));
|
||||
let (continuing_uniformity, continuing_exit) =
|
||||
let body_uniformity = self.process_block(body, other_functions, disruptor)?;
|
||||
let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
|
||||
let continuing_uniformity =
|
||||
self.process_block(continuing, other_functions, continuing_disruptor)?;
|
||||
(
|
||||
body_uniformity | continuing_uniformity,
|
||||
body_exit | continuing_exit,
|
||||
)
|
||||
body_uniformity | continuing_uniformity
|
||||
}
|
||||
S::Return { value } => {
|
||||
let uniformity = match value {
|
||||
Some(expr) => self.add_ref(expr),
|
||||
None => Uniformity::default(),
|
||||
};
|
||||
S::Return { value } => FunctionUniformity {
|
||||
result: if let Some(expr) = value {
|
||||
self.add_ref(expr)
|
||||
} else {
|
||||
Uniformity::default()
|
||||
},
|
||||
//TODO: if we are in the uniform control flow, should this still be an exit flag?
|
||||
(uniformity, ExitFlags::MAY_RETURN)
|
||||
}
|
||||
exit: ExitFlags::MAY_RETURN,
|
||||
},
|
||||
S::Store { pointer, value } => {
|
||||
let uniformity =
|
||||
self.add_ref_impl(pointer, GlobalUse::WRITE) | self.add_ref(value);
|
||||
(uniformity, ExitFlags::empty())
|
||||
uniformity.to_function()
|
||||
}
|
||||
S::ImageStore {
|
||||
image,
|
||||
@@ -529,7 +582,7 @@ impl FunctionInfo {
|
||||
| self.add_ref_impl(image, GlobalUse::WRITE)
|
||||
| self.add_ref(coordinate)
|
||||
| self.add_ref(value);
|
||||
(uniformity, ExitFlags::empty())
|
||||
uniformity.to_function()
|
||||
}
|
||||
S::Call {
|
||||
function,
|
||||
@@ -537,27 +590,19 @@ impl FunctionInfo {
|
||||
result: _,
|
||||
} => {
|
||||
let info = &other_functions[function.index()];
|
||||
let mut uniformity = self.process_call(info);
|
||||
let call_uniformity = self.process_call(info);
|
||||
let mut uniformity = call_uniformity.result.clone();
|
||||
for &argument in arguments {
|
||||
uniformity |= self.add_ref(argument);
|
||||
}
|
||||
let exit = if info.may_kill {
|
||||
ExitFlags::MAY_KILL
|
||||
} else {
|
||||
ExitFlags::empty()
|
||||
};
|
||||
(uniformity, exit)
|
||||
call_uniformity | uniformity.to_function()
|
||||
}
|
||||
};
|
||||
|
||||
if let (Some(expr), Some(cause)) = (cur_uniformity.require_uniform, disruptor) {
|
||||
return Err(AnalysisError::NonUniformControlFlow(expr, cause));
|
||||
}
|
||||
disruptor = disruptor.or(UniformityDisruptor::from_exit(cur_exit));
|
||||
block_uniformity |= cur_uniformity;
|
||||
block_exit |= cur_exit;
|
||||
disruptor = disruptor.or(uniformity.exit_disruptor());
|
||||
combined_uniformity = combined_uniformity | uniformity;
|
||||
}
|
||||
Ok((block_uniformity, block_exit))
|
||||
Ok(combined_uniformity)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -589,9 +634,9 @@ impl Analysis {
|
||||
info.process_expression(handle, &fun.expressions, global_var_arena, &self.functions)?;
|
||||
}
|
||||
|
||||
let (uniformity, exit) = info.process_block(&fun.body, &self.functions, None)?;
|
||||
info.uniformity = uniformity;
|
||||
info.may_kill = exit.contains(ExitFlags::MAY_KILL);
|
||||
let uniformity = info.process_block(&fun.body, &self.functions, None)?;
|
||||
info.uniformity = uniformity.result;
|
||||
info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
|
||||
|
||||
Ok(info)
|
||||
}
|
||||
@@ -676,8 +721,11 @@ fn uniform_control_flow() {
|
||||
axis: crate::DerivativeAxis::X,
|
||||
expr: constant_expr,
|
||||
});
|
||||
let emit_range_constant_derivative = expressions.range_from(0);
|
||||
let non_uniform_global_expr = expressions.append(E::GlobalVariable(non_uniform_global));
|
||||
let uniform_global_expr = expressions.append(E::GlobalVariable(uniform_global));
|
||||
let emit_range_globals = expressions.range_from(2);
|
||||
|
||||
// checks the QUERY flag
|
||||
let query_expr = expressions.append(E::ArrayLength(uniform_global_expr));
|
||||
// checks the transitive WRITE flag
|
||||
@@ -685,6 +733,7 @@ fn uniform_control_flow() {
|
||||
base: non_uniform_global_expr,
|
||||
index: 1,
|
||||
});
|
||||
let emit_range_query_access_globals = expressions.range_from(2);
|
||||
|
||||
let mut info = FunctionInfo {
|
||||
uniformity: Uniformity::default(),
|
||||
@@ -704,68 +753,87 @@ fn uniform_control_flow() {
|
||||
assert_eq!(info[non_uniform_global], GlobalUse::empty());
|
||||
assert_eq!(info[uniform_global], GlobalUse::QUERY);
|
||||
|
||||
let stmt_emit1 = S::Emit(emit_range_globals.clone());
|
||||
let stmt_if_uniform = S::If {
|
||||
condition: uniform_global_expr,
|
||||
accept: Vec::new(),
|
||||
reject: vec![S::Store {
|
||||
pointer: constant_expr,
|
||||
value: derivative_expr,
|
||||
}],
|
||||
reject: vec![
|
||||
S::Emit(emit_range_constant_derivative.clone()),
|
||||
S::Store {
|
||||
pointer: constant_expr,
|
||||
value: derivative_expr,
|
||||
},
|
||||
],
|
||||
};
|
||||
assert_eq!(
|
||||
info.process_block(&[stmt_if_uniform], &[], None),
|
||||
Ok((
|
||||
Uniformity::require_uniform(derivative_expr),
|
||||
ExitFlags::empty()
|
||||
)),
|
||||
info.process_block(&[stmt_emit1, stmt_if_uniform], &[], None),
|
||||
Ok(FunctionUniformity {
|
||||
result: Uniformity::require_uniform(derivative_expr),
|
||||
exit: ExitFlags::empty(),
|
||||
}),
|
||||
);
|
||||
assert_eq!(info[constant_expr].ref_count, 2);
|
||||
assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
|
||||
|
||||
let stmt_emit2 = S::Emit(emit_range_globals.clone());
|
||||
let stmt_if_non_uniform = S::If {
|
||||
condition: non_uniform_global_expr,
|
||||
accept: vec![S::Store {
|
||||
pointer: constant_expr,
|
||||
value: derivative_expr,
|
||||
}],
|
||||
accept: vec![
|
||||
S::Emit(emit_range_constant_derivative.clone()),
|
||||
S::Store {
|
||||
pointer: constant_expr,
|
||||
value: derivative_expr,
|
||||
},
|
||||
],
|
||||
reject: Vec::new(),
|
||||
};
|
||||
assert_eq!(
|
||||
info.process_block(&[stmt_if_non_uniform], &[], None),
|
||||
info.process_block(&[stmt_emit2, stmt_if_non_uniform], &[], None),
|
||||
Err(AnalysisError::NonUniformControlFlow(
|
||||
derivative_expr,
|
||||
UniformityDisruptor::Expression(non_uniform_global_expr)
|
||||
)),
|
||||
);
|
||||
assert_eq!(info[derivative_expr].ref_count, 2);
|
||||
assert_eq!(info[derivative_expr].ref_count, 1);
|
||||
assert_eq!(info[non_uniform_global], GlobalUse::READ);
|
||||
|
||||
let stmt_emit3 = S::Emit(emit_range_globals);
|
||||
let stmt_return_non_uniform = S::Return {
|
||||
value: Some(non_uniform_global_expr),
|
||||
};
|
||||
assert_eq!(
|
||||
info.process_block(
|
||||
&[stmt_return_non_uniform],
|
||||
&[stmt_emit3, stmt_return_non_uniform],
|
||||
&[],
|
||||
Some(UniformityDisruptor::Return)
|
||||
),
|
||||
Ok((
|
||||
Uniformity::non_uniform_result(non_uniform_global_expr),
|
||||
ExitFlags::MAY_RETURN
|
||||
)),
|
||||
Ok(FunctionUniformity {
|
||||
result: Uniformity::non_uniform_result(non_uniform_global_expr),
|
||||
exit: ExitFlags::MAY_RETURN,
|
||||
}),
|
||||
);
|
||||
assert_eq!(info[non_uniform_global_expr].ref_count, 3);
|
||||
|
||||
// Check that uniformity requirements reach through a pointer
|
||||
let stmt_emit4 = S::Emit(emit_range_query_access_globals);
|
||||
let stmt_assign = S::Store {
|
||||
pointer: access_expr,
|
||||
value: query_expr,
|
||||
};
|
||||
let stmt_return_pointer = S::Return {
|
||||
value: Some(access_expr),
|
||||
};
|
||||
let stmt_kill = S::Kill;
|
||||
assert_eq!(
|
||||
info.process_block(&[stmt_assign], &[], Some(UniformityDisruptor::Discard)),
|
||||
Ok((
|
||||
Uniformity::non_uniform_result(non_uniform_global_expr),
|
||||
ExitFlags::empty()
|
||||
)),
|
||||
info.process_block(
|
||||
&[stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer],
|
||||
&[],
|
||||
Some(UniformityDisruptor::Discard)
|
||||
),
|
||||
Ok(FunctionUniformity {
|
||||
result: Uniformity::non_uniform_result(non_uniform_global_expr),
|
||||
exit: ExitFlags::all(),
|
||||
}),
|
||||
);
|
||||
assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user