Distinguish between expression uniformity and function uniformity

This commit is contained in:
Dzmitry Malyshau
2021-03-11 01:04:55 -05:00
committed by Dzmitry Malyshau
parent 74d1fdbf2a
commit dd273e254a
8 changed files with 218 additions and 109 deletions

View File

@@ -784,7 +784,7 @@ impl<W: Write> Writer<W> {
)?;
}
crate::Statement::Kill => {
writeln!(self.out, "{}discard_fragment();", level)?;
writeln!(self.out, "{}{}::discard_fragment();", level, NAMESPACE)?;
}
crate::Statement::Store { pointer, value } => {
write!(self.out, "{}", level)?;

View File

@@ -73,6 +73,19 @@ impl Uniformity {
fn disruptor(&self) -> Option<UniformityDisruptor> {
self.non_uniform_result.map(UniformityDisruptor::Expression)
}
/// When considering the uniformity at the functional level,
/// we don't care about all of the expression results.
/// We only care about the `return` expressions.
fn to_function(&self) -> FunctionUniformity {
FunctionUniformity {
result: Uniformity {
non_uniform_result: None,
require_uniform: self.require_uniform,
},
exit: ExitFlags::empty(),
}
}
}
bitflags::bitflags! {
@@ -87,6 +100,43 @@ bitflags::bitflags! {
}
}
/// Uniformity characteristics of a function.
#[cfg_attr(test, derive(Debug, PartialEq))]
struct FunctionUniformity {
result: Uniformity,
exit: ExitFlags,
}
impl ops::BitOr for FunctionUniformity {
type Output = Self;
fn bitor(self, other: Self) -> Self {
FunctionUniformity {
result: self.result | other.result,
exit: self.exit | other.exit,
}
}
}
impl FunctionUniformity {
fn new() -> Self {
FunctionUniformity {
result: Uniformity::default(),
exit: ExitFlags::empty(),
}
}
/// Returns a disruptor based on the stored exit flags, if any.
fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
if self.exit.contains(ExitFlags::MAY_RETURN) {
Some(UniformityDisruptor::Return)
} else if self.exit.contains(ExitFlags::MAY_KILL) {
Some(UniformityDisruptor::Discard)
} else {
None
}
}
}
bitflags::bitflags! {
/// Indicates how a global variable is used.
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
@@ -182,18 +232,6 @@ pub enum UniformityDisruptor {
Discard,
}
impl UniformityDisruptor {
fn from_exit(flags: ExitFlags) -> Option<Self> {
if flags.contains(ExitFlags::MAY_RETURN) {
Some(Self::Return)
} else if flags.contains(ExitFlags::MAY_KILL) {
Some(Self::Discard)
} else {
None
}
}
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum AnalysisError {
@@ -248,14 +286,21 @@ impl FunctionInfo {
}
/// Inherit information from a called function.
fn process_call(&mut self, info: &Self) -> Uniformity {
fn process_call(&mut self, info: &Self) -> FunctionUniformity {
for key in info.sampling_set.iter() {
self.sampling_set.insert(key.clone());
}
for (mine, other) in self.global_uses.iter_mut().zip(info.global_uses.iter()) {
*mine |= *other;
}
info.uniformity.clone()
FunctionUniformity {
result: info.uniformity.clone(),
exit: if info.may_kill {
ExitFlags::MAY_KILL
} else {
ExitFlags::empty()
},
}
}
/// Computes the expression info and stores it in `self.expressions`.
@@ -412,7 +457,7 @@ impl FunctionInfo {
self.add_ref(arg) | arg1_flags | arg2_flags
}
E::As { expr, .. } => self.add_ref(expr),
E::Call(function) => self.process_call(&other_functions[function.index()]),
E::Call(function) => self.process_call(&other_functions[function.index()]).result,
E::ArrayLength(expr) => self.add_ref_impl(expr, GlobalUse::QUERY),
};
@@ -424,8 +469,11 @@ impl FunctionInfo {
Ok(())
}
/// Computes the uniformity and the exit flags on the block
/// (as a sequence of statements), and returns them.
/// Analyzes the uniformity requirements of a block (as a sequence of statements).
/// Returns the uniformity characteristics at the *function* level, i.e.
/// whether or not the function requires to be called in uniform control flow,
/// and whether the produced result is not disrupting the control flow.
///
/// The parent control flow is uniform if `disruptor.is_none()`.
///
/// Returns a `NonUniformControlFlow` error if any of the expressions in the block
@@ -436,14 +484,28 @@ impl FunctionInfo {
statements: &[crate::Statement],
other_functions: &[FunctionInfo],
mut disruptor: Option<UniformityDisruptor>,
) -> Result<(Uniformity, ExitFlags), AnalysisError> {
) -> Result<FunctionUniformity, AnalysisError> {
use crate::Statement as S;
let mut block_uniformity = Uniformity::default();
let mut block_exit = ExitFlags::empty();
let mut combined_uniformity = FunctionUniformity::new();
for statement in statements {
let (cur_uniformity, cur_exit) = match *statement {
S::Emit(_) | S::Break | S::Continue => (Uniformity::default(), ExitFlags::empty()),
S::Kill => (Uniformity::default(), ExitFlags::MAY_KILL),
let uniformity = match *statement {
S::Emit(ref range) => {
for expr in range.clone() {
if let (Some(expr), Some(cause)) = (
self.expressions[expr.index()].uniformity.require_uniform,
disruptor,
) {
return Err(AnalysisError::NonUniformControlFlow(expr, cause));
}
}
FunctionUniformity::new()
}
S::Break | S::Continue => FunctionUniformity::new(),
S::Kill => FunctionUniformity {
result: Uniformity::default(),
exit: ExitFlags::MAY_KILL,
},
S::Block(ref b) => self.process_block(b, other_functions, disruptor)?,
S::If {
condition,
@@ -452,14 +514,11 @@ impl FunctionInfo {
} => {
let condition_uniformity = self.add_ref(condition);
let branch_disruptor = disruptor.or(condition_uniformity.disruptor());
let (accept_uniformity, accept_exit) =
let accept_uniformity =
self.process_block(accept, other_functions, branch_disruptor)?;
let (reject_uniformity, reject_exit) =
let reject_uniformity =
self.process_block(reject, other_functions, branch_disruptor)?;
(
condition_uniformity | accept_uniformity | reject_uniformity,
accept_exit | reject_exit,
)
condition_uniformity.to_function() | accept_uniformity | reject_uniformity
}
S::Switch {
selector,
@@ -468,52 +527,46 @@ impl FunctionInfo {
} => {
let selector_uniformity = self.add_ref(selector);
let branch_disruptor = disruptor.or(selector_uniformity.disruptor());
let mut uniformity = selector_uniformity;
let mut exit = ExitFlags::empty();
let mut case_disruptor = disruptor;
let mut uniformity = FunctionUniformity::new();
let mut case_disruptor = branch_disruptor;
for case in cases.iter() {
let (case_uniformity, case_exit) =
let case_uniformity =
self.process_block(&case.body, other_functions, case_disruptor)?;
uniformity |= case_uniformity;
exit |= case_exit;
case_disruptor = if case.fall_through {
case_disruptor.or(UniformityDisruptor::from_exit(case_exit))
case_disruptor.or(case_uniformity.exit_disruptor())
} else {
branch_disruptor
};
uniformity = uniformity | case_uniformity;
}
let (default_uniformity, default_exit) =
// using the disruptor inherited from the last fall-through chain
let default_exit =
self.process_block(default, other_functions, case_disruptor)?;
(uniformity | default_uniformity, exit | default_exit)
uniformity | default_exit
}
S::Loop {
ref body,
ref continuing,
} => {
let (body_uniformity, body_exit) =
self.process_block(body, other_functions, disruptor)?;
let continuing_disruptor = disruptor
.or(body_uniformity.disruptor())
.or(UniformityDisruptor::from_exit(body_exit));
let (continuing_uniformity, continuing_exit) =
let body_uniformity = self.process_block(body, other_functions, disruptor)?;
let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
let continuing_uniformity =
self.process_block(continuing, other_functions, continuing_disruptor)?;
(
body_uniformity | continuing_uniformity,
body_exit | continuing_exit,
)
body_uniformity | continuing_uniformity
}
S::Return { value } => {
let uniformity = match value {
Some(expr) => self.add_ref(expr),
None => Uniformity::default(),
};
S::Return { value } => FunctionUniformity {
result: if let Some(expr) = value {
self.add_ref(expr)
} else {
Uniformity::default()
},
//TODO: if we are in the uniform control flow, should this still be an exit flag?
(uniformity, ExitFlags::MAY_RETURN)
}
exit: ExitFlags::MAY_RETURN,
},
S::Store { pointer, value } => {
let uniformity =
self.add_ref_impl(pointer, GlobalUse::WRITE) | self.add_ref(value);
(uniformity, ExitFlags::empty())
uniformity.to_function()
}
S::ImageStore {
image,
@@ -529,7 +582,7 @@ impl FunctionInfo {
| self.add_ref_impl(image, GlobalUse::WRITE)
| self.add_ref(coordinate)
| self.add_ref(value);
(uniformity, ExitFlags::empty())
uniformity.to_function()
}
S::Call {
function,
@@ -537,27 +590,19 @@ impl FunctionInfo {
result: _,
} => {
let info = &other_functions[function.index()];
let mut uniformity = self.process_call(info);
let call_uniformity = self.process_call(info);
let mut uniformity = call_uniformity.result.clone();
for &argument in arguments {
uniformity |= self.add_ref(argument);
}
let exit = if info.may_kill {
ExitFlags::MAY_KILL
} else {
ExitFlags::empty()
};
(uniformity, exit)
call_uniformity | uniformity.to_function()
}
};
if let (Some(expr), Some(cause)) = (cur_uniformity.require_uniform, disruptor) {
return Err(AnalysisError::NonUniformControlFlow(expr, cause));
}
disruptor = disruptor.or(UniformityDisruptor::from_exit(cur_exit));
block_uniformity |= cur_uniformity;
block_exit |= cur_exit;
disruptor = disruptor.or(uniformity.exit_disruptor());
combined_uniformity = combined_uniformity | uniformity;
}
Ok((block_uniformity, block_exit))
Ok(combined_uniformity)
}
}
@@ -589,9 +634,9 @@ impl Analysis {
info.process_expression(handle, &fun.expressions, global_var_arena, &self.functions)?;
}
let (uniformity, exit) = info.process_block(&fun.body, &self.functions, None)?;
info.uniformity = uniformity;
info.may_kill = exit.contains(ExitFlags::MAY_KILL);
let uniformity = info.process_block(&fun.body, &self.functions, None)?;
info.uniformity = uniformity.result;
info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
Ok(info)
}
@@ -676,8 +721,11 @@ fn uniform_control_flow() {
axis: crate::DerivativeAxis::X,
expr: constant_expr,
});
let emit_range_constant_derivative = expressions.range_from(0);
let non_uniform_global_expr = expressions.append(E::GlobalVariable(non_uniform_global));
let uniform_global_expr = expressions.append(E::GlobalVariable(uniform_global));
let emit_range_globals = expressions.range_from(2);
// checks the QUERY flag
let query_expr = expressions.append(E::ArrayLength(uniform_global_expr));
// checks the transitive WRITE flag
@@ -685,6 +733,7 @@ fn uniform_control_flow() {
base: non_uniform_global_expr,
index: 1,
});
let emit_range_query_access_globals = expressions.range_from(2);
let mut info = FunctionInfo {
uniformity: Uniformity::default(),
@@ -704,68 +753,87 @@ fn uniform_control_flow() {
assert_eq!(info[non_uniform_global], GlobalUse::empty());
assert_eq!(info[uniform_global], GlobalUse::QUERY);
let stmt_emit1 = S::Emit(emit_range_globals.clone());
let stmt_if_uniform = S::If {
condition: uniform_global_expr,
accept: Vec::new(),
reject: vec![S::Store {
pointer: constant_expr,
value: derivative_expr,
}],
reject: vec![
S::Emit(emit_range_constant_derivative.clone()),
S::Store {
pointer: constant_expr,
value: derivative_expr,
},
],
};
assert_eq!(
info.process_block(&[stmt_if_uniform], &[], None),
Ok((
Uniformity::require_uniform(derivative_expr),
ExitFlags::empty()
)),
info.process_block(&[stmt_emit1, stmt_if_uniform], &[], None),
Ok(FunctionUniformity {
result: Uniformity::require_uniform(derivative_expr),
exit: ExitFlags::empty(),
}),
);
assert_eq!(info[constant_expr].ref_count, 2);
assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
let stmt_emit2 = S::Emit(emit_range_globals.clone());
let stmt_if_non_uniform = S::If {
condition: non_uniform_global_expr,
accept: vec![S::Store {
pointer: constant_expr,
value: derivative_expr,
}],
accept: vec![
S::Emit(emit_range_constant_derivative.clone()),
S::Store {
pointer: constant_expr,
value: derivative_expr,
},
],
reject: Vec::new(),
};
assert_eq!(
info.process_block(&[stmt_if_non_uniform], &[], None),
info.process_block(&[stmt_emit2, stmt_if_non_uniform], &[], None),
Err(AnalysisError::NonUniformControlFlow(
derivative_expr,
UniformityDisruptor::Expression(non_uniform_global_expr)
)),
);
assert_eq!(info[derivative_expr].ref_count, 2);
assert_eq!(info[derivative_expr].ref_count, 1);
assert_eq!(info[non_uniform_global], GlobalUse::READ);
let stmt_emit3 = S::Emit(emit_range_globals);
let stmt_return_non_uniform = S::Return {
value: Some(non_uniform_global_expr),
};
assert_eq!(
info.process_block(
&[stmt_return_non_uniform],
&[stmt_emit3, stmt_return_non_uniform],
&[],
Some(UniformityDisruptor::Return)
),
Ok((
Uniformity::non_uniform_result(non_uniform_global_expr),
ExitFlags::MAY_RETURN
)),
Ok(FunctionUniformity {
result: Uniformity::non_uniform_result(non_uniform_global_expr),
exit: ExitFlags::MAY_RETURN,
}),
);
assert_eq!(info[non_uniform_global_expr].ref_count, 3);
// Check that uniformity requirements reach through a pointer
let stmt_emit4 = S::Emit(emit_range_query_access_globals);
let stmt_assign = S::Store {
pointer: access_expr,
value: query_expr,
};
let stmt_return_pointer = S::Return {
value: Some(access_expr),
};
let stmt_kill = S::Kill;
assert_eq!(
info.process_block(&[stmt_assign], &[], Some(UniformityDisruptor::Discard)),
Ok((
Uniformity::non_uniform_result(non_uniform_global_expr),
ExitFlags::empty()
)),
info.process_block(
&[stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer],
&[],
Some(UniformityDisruptor::Discard)
),
Ok(FunctionUniformity {
result: Uniformity::non_uniform_result(non_uniform_global_expr),
exit: ExitFlags::all(),
}),
);
assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
}

View File

@@ -19,5 +19,9 @@ fn main() {
[[stage(fragment)]]
fn main() {
o_color = textureSample(u_texture, u_sampler, v_uv);
const color: vec4<f32> = textureSample(u_texture, u_sampler, v_uv);
if (color.a == 0.0) {
discard;
}
o_color = color;
}

View File

@@ -6,7 +6,7 @@ expression: output
functions: [
(
uniformity: (
non_uniform_result: Some(4),
non_uniform_result: Some(6),
require_uniform: None,
),
may_kill: false,
@@ -234,7 +234,7 @@ expression: output
entry_points: [
(
uniformity: (
non_uniform_result: Some(4),
non_uniform_result: Some(6),
require_uniform: None,
),
may_kill: false,
@@ -338,7 +338,7 @@ expression: output
),
(
uniformity: (
non_uniform_result: Some(4),
non_uniform_result: Some(6),
require_uniform: None,
),
ref_count: 1,

View File

@@ -13,6 +13,11 @@ uniform highp sampler2D _group_0_binding_0;
out vec4 _location_0;
void main() {
if((texture(_group_0_binding_0, vec2(_location_0_vs))[3] == 0.0)) {
discard;
}
_location_0 = texture(_group_0_binding_0, vec2(_location_0_vs));
return;
}

View File

@@ -83,17 +83,38 @@ digraph Module {
ep1_e1 -> ep1_e9 [ label="sampler" ]
ep1_e2 -> ep1_e9 [ label="image" ]
ep1_e8 -> ep1_e9 [ label="coordinate" ]
ep1_e10 [ color="#8dd3c7" label="[11] AccessIndex[3]" ]
ep1_e9 -> ep1_e10 [ label="base" ]
ep1_e11 [ color="#ffffb3" label="[12] Constant" ]
ep1_e12 [ color="#fdb462" label="[13] Equal" ]
ep1_e11 -> ep1_e12 [ label="right" ]
ep1_e10 -> ep1_e12 [ label="left" ]
ep1_s0 [ shape=square label="Root" ]
ep1_s1 [ shape=square label="Emit" ]
ep1_s2 [ shape=square label="Store" ]
ep1_s3 [ shape=square label="Return" ]
ep1_s2 [ shape=square label="Emit" ]
ep1_s3 [ shape=square label="Emit" ]
ep1_s4 [ shape=square label="If" ]
ep1_s5 [ shape=square label="Node" ]
ep1_s6 [ shape=square label="Kill" ]
ep1_s7 [ shape=square label="Node" ]
ep1_s8 [ shape=square label="Store" ]
ep1_s9 [ shape=square label="Return" ]
ep1_s0 -> ep1_s1 [ arrowhead=tee label="" ]
ep1_s1 -> ep1_s2 [ arrowhead=tee label="" ]
ep1_s2 -> ep1_s3 [ arrowhead=tee label="" ]
ep1_e9 -> ep1_s2 [ label="value" ]
ep1_s3 -> ep1_s4 [ arrowhead=tee label="" ]
ep1_s5 -> ep1_s6 [ arrowhead=tee label="" ]
ep1_s4 -> ep1_s5 [ arrowhead=tee label="accept" ]
ep1_s4 -> ep1_s7 [ arrowhead=tee label="reject" ]
ep1_s7 -> ep1_s8 [ arrowhead=tee label="" ]
ep1_s8 -> ep1_s9 [ arrowhead=tee label="" ]
ep1_e12 -> ep1_s4 [ label="condition" ]
ep1_e9 -> ep1_s8 [ label="value" ]
ep1_s1 -> ep1_e8 [ style=dotted ]
ep1_s1 -> ep1_e9 [ style=dotted ]
ep1_s2 -> ep1_e7 [ style=dotted ]
ep1_s2 -> ep1_e10 [ style=dotted ]
ep1_s3 -> ep1_e12 [ style=dotted ]
ep1_s8 -> ep1_e7 [ style=dotted ]
}
}

View File

@@ -52,6 +52,9 @@ fragment main2Output main2(
) {
main2Output output;
metal::float4 _expr9 = u_texture.sample(u_sampler, input.v_uv1);
if ((_expr9.w == const_0f)) {
metal::discard_fragment();
}
output.o_color = _expr9;
return output;
}

View File

@@ -5,7 +5,7 @@ expression: dis
; SPIR-V
; Version: 1.0
; Generator: rspirv
; Bound: 41
; Bound: 46
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
@@ -60,6 +60,7 @@ OpDecorate %23 Location 0
%23 = OpVariable %15 Output
%25 = OpTypeFunction %2
%38 = OpTypeSampledImage %18
%42 = OpTypeBool
%24 = OpFunction %2 None %25
%26 = OpLabel
OpBranch %27
@@ -81,6 +82,13 @@ OpBranch %36
%37 = OpLoad %8 %16
%39 = OpSampledImage %38 %34 %35
%40 = OpImageSampleImplicitLod %14 %39 %37
%41 = OpCompositeExtract %4 %40 3
%43 = OpFOrdEqual %42 %41 %5
OpSelectionMerge %44 None
OpBranchConditional %43 %45 %44
%45 = OpLabel
OpKill
%44 = OpLabel
OpStore %23 %40
OpReturn
OpFunctionEnd