diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index dd8b6d5c5f..1963c99c23 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -784,7 +784,7 @@ impl Writer { )?; } crate::Statement::Kill => { - writeln!(self.out, "{}discard_fragment();", level)?; + writeln!(self.out, "{}{}::discard_fragment();", level, NAMESPACE)?; } crate::Statement::Store { pointer, value } => { write!(self.out, "{}", level)?; diff --git a/src/proc/analyzer.rs b/src/proc/analyzer.rs index e63041ae67..1d2833ba5e 100644 --- a/src/proc/analyzer.rs +++ b/src/proc/analyzer.rs @@ -73,6 +73,19 @@ impl Uniformity { fn disruptor(&self) -> Option { self.non_uniform_result.map(UniformityDisruptor::Expression) } + + /// When considering the uniformity at the functional level, + /// we don't care about all of the expression results. + /// We only care about the `return` expressions. + fn to_function(&self) -> FunctionUniformity { + FunctionUniformity { + result: Uniformity { + non_uniform_result: None, + require_uniform: self.require_uniform, + }, + exit: ExitFlags::empty(), + } + } } bitflags::bitflags! { @@ -87,6 +100,43 @@ bitflags::bitflags! { } } +/// Uniformity characteristics of a function. +#[cfg_attr(test, derive(Debug, PartialEq))] +struct FunctionUniformity { + result: Uniformity, + exit: ExitFlags, +} + +impl ops::BitOr for FunctionUniformity { + type Output = Self; + fn bitor(self, other: Self) -> Self { + FunctionUniformity { + result: self.result | other.result, + exit: self.exit | other.exit, + } + } +} + +impl FunctionUniformity { + fn new() -> Self { + FunctionUniformity { + result: Uniformity::default(), + exit: ExitFlags::empty(), + } + } + + /// Returns a disruptor based on the stored exit flags, if any. + fn exit_disruptor(&self) -> Option { + if self.exit.contains(ExitFlags::MAY_RETURN) { + Some(UniformityDisruptor::Return) + } else if self.exit.contains(ExitFlags::MAY_KILL) { + Some(UniformityDisruptor::Discard) + } else { + None + } + } +} + bitflags::bitflags! { /// Indicates how a global variable is used. #[cfg_attr(feature = "serialize", derive(serde::Serialize))] @@ -182,18 +232,6 @@ pub enum UniformityDisruptor { Discard, } -impl UniformityDisruptor { - fn from_exit(flags: ExitFlags) -> Option { - if flags.contains(ExitFlags::MAY_RETURN) { - Some(Self::Return) - } else if flags.contains(ExitFlags::MAY_KILL) { - Some(Self::Discard) - } else { - None - } - } -} - #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum AnalysisError { @@ -248,14 +286,21 @@ impl FunctionInfo { } /// Inherit information from a called function. - fn process_call(&mut self, info: &Self) -> Uniformity { + fn process_call(&mut self, info: &Self) -> FunctionUniformity { for key in info.sampling_set.iter() { self.sampling_set.insert(key.clone()); } for (mine, other) in self.global_uses.iter_mut().zip(info.global_uses.iter()) { *mine |= *other; } - info.uniformity.clone() + FunctionUniformity { + result: info.uniformity.clone(), + exit: if info.may_kill { + ExitFlags::MAY_KILL + } else { + ExitFlags::empty() + }, + } } /// Computes the expression info and stores it in `self.expressions`. @@ -412,7 +457,7 @@ impl FunctionInfo { self.add_ref(arg) | arg1_flags | arg2_flags } E::As { expr, .. } => self.add_ref(expr), - E::Call(function) => self.process_call(&other_functions[function.index()]), + E::Call(function) => self.process_call(&other_functions[function.index()]).result, E::ArrayLength(expr) => self.add_ref_impl(expr, GlobalUse::QUERY), }; @@ -424,8 +469,11 @@ impl FunctionInfo { Ok(()) } - /// Computes the uniformity and the exit flags on the block - /// (as a sequence of statements), and returns them. + /// Analyzes the uniformity requirements of a block (as a sequence of statements). + /// Returns the uniformity characteristics at the *function* level, i.e. + /// whether or not the function requires to be called in uniform control flow, + /// and whether the produced result is not disrupting the control flow. + /// /// The parent control flow is uniform if `disruptor.is_none()`. /// /// Returns a `NonUniformControlFlow` error if any of the expressions in the block @@ -436,14 +484,28 @@ impl FunctionInfo { statements: &[crate::Statement], other_functions: &[FunctionInfo], mut disruptor: Option, - ) -> Result<(Uniformity, ExitFlags), AnalysisError> { + ) -> Result { use crate::Statement as S; - let mut block_uniformity = Uniformity::default(); - let mut block_exit = ExitFlags::empty(); + + let mut combined_uniformity = FunctionUniformity::new(); for statement in statements { - let (cur_uniformity, cur_exit) = match *statement { - S::Emit(_) | S::Break | S::Continue => (Uniformity::default(), ExitFlags::empty()), - S::Kill => (Uniformity::default(), ExitFlags::MAY_KILL), + let uniformity = match *statement { + S::Emit(ref range) => { + for expr in range.clone() { + if let (Some(expr), Some(cause)) = ( + self.expressions[expr.index()].uniformity.require_uniform, + disruptor, + ) { + return Err(AnalysisError::NonUniformControlFlow(expr, cause)); + } + } + FunctionUniformity::new() + } + S::Break | S::Continue => FunctionUniformity::new(), + S::Kill => FunctionUniformity { + result: Uniformity::default(), + exit: ExitFlags::MAY_KILL, + }, S::Block(ref b) => self.process_block(b, other_functions, disruptor)?, S::If { condition, @@ -452,14 +514,11 @@ impl FunctionInfo { } => { let condition_uniformity = self.add_ref(condition); let branch_disruptor = disruptor.or(condition_uniformity.disruptor()); - let (accept_uniformity, accept_exit) = + let accept_uniformity = self.process_block(accept, other_functions, branch_disruptor)?; - let (reject_uniformity, reject_exit) = + let reject_uniformity = self.process_block(reject, other_functions, branch_disruptor)?; - ( - condition_uniformity | accept_uniformity | reject_uniformity, - accept_exit | reject_exit, - ) + condition_uniformity.to_function() | accept_uniformity | reject_uniformity } S::Switch { selector, @@ -468,52 +527,46 @@ impl FunctionInfo { } => { let selector_uniformity = self.add_ref(selector); let branch_disruptor = disruptor.or(selector_uniformity.disruptor()); - let mut uniformity = selector_uniformity; - let mut exit = ExitFlags::empty(); - let mut case_disruptor = disruptor; + let mut uniformity = FunctionUniformity::new(); + let mut case_disruptor = branch_disruptor; for case in cases.iter() { - let (case_uniformity, case_exit) = + let case_uniformity = self.process_block(&case.body, other_functions, case_disruptor)?; - uniformity |= case_uniformity; - exit |= case_exit; case_disruptor = if case.fall_through { - case_disruptor.or(UniformityDisruptor::from_exit(case_exit)) + case_disruptor.or(case_uniformity.exit_disruptor()) } else { branch_disruptor }; + uniformity = uniformity | case_uniformity; } - let (default_uniformity, default_exit) = + // using the disruptor inherited from the last fall-through chain + let default_exit = self.process_block(default, other_functions, case_disruptor)?; - (uniformity | default_uniformity, exit | default_exit) + uniformity | default_exit } S::Loop { ref body, ref continuing, } => { - let (body_uniformity, body_exit) = - self.process_block(body, other_functions, disruptor)?; - let continuing_disruptor = disruptor - .or(body_uniformity.disruptor()) - .or(UniformityDisruptor::from_exit(body_exit)); - let (continuing_uniformity, continuing_exit) = + let body_uniformity = self.process_block(body, other_functions, disruptor)?; + let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor()); + let continuing_uniformity = self.process_block(continuing, other_functions, continuing_disruptor)?; - ( - body_uniformity | continuing_uniformity, - body_exit | continuing_exit, - ) + body_uniformity | continuing_uniformity } - S::Return { value } => { - let uniformity = match value { - Some(expr) => self.add_ref(expr), - None => Uniformity::default(), - }; + S::Return { value } => FunctionUniformity { + result: if let Some(expr) = value { + self.add_ref(expr) + } else { + Uniformity::default() + }, //TODO: if we are in the uniform control flow, should this still be an exit flag? - (uniformity, ExitFlags::MAY_RETURN) - } + exit: ExitFlags::MAY_RETURN, + }, S::Store { pointer, value } => { let uniformity = self.add_ref_impl(pointer, GlobalUse::WRITE) | self.add_ref(value); - (uniformity, ExitFlags::empty()) + uniformity.to_function() } S::ImageStore { image, @@ -529,7 +582,7 @@ impl FunctionInfo { | self.add_ref_impl(image, GlobalUse::WRITE) | self.add_ref(coordinate) | self.add_ref(value); - (uniformity, ExitFlags::empty()) + uniformity.to_function() } S::Call { function, @@ -537,27 +590,19 @@ impl FunctionInfo { result: _, } => { let info = &other_functions[function.index()]; - let mut uniformity = self.process_call(info); + let call_uniformity = self.process_call(info); + let mut uniformity = call_uniformity.result.clone(); for &argument in arguments { uniformity |= self.add_ref(argument); } - let exit = if info.may_kill { - ExitFlags::MAY_KILL - } else { - ExitFlags::empty() - }; - (uniformity, exit) + call_uniformity | uniformity.to_function() } }; - if let (Some(expr), Some(cause)) = (cur_uniformity.require_uniform, disruptor) { - return Err(AnalysisError::NonUniformControlFlow(expr, cause)); - } - disruptor = disruptor.or(UniformityDisruptor::from_exit(cur_exit)); - block_uniformity |= cur_uniformity; - block_exit |= cur_exit; + disruptor = disruptor.or(uniformity.exit_disruptor()); + combined_uniformity = combined_uniformity | uniformity; } - Ok((block_uniformity, block_exit)) + Ok(combined_uniformity) } } @@ -589,9 +634,9 @@ impl Analysis { info.process_expression(handle, &fun.expressions, global_var_arena, &self.functions)?; } - let (uniformity, exit) = info.process_block(&fun.body, &self.functions, None)?; - info.uniformity = uniformity; - info.may_kill = exit.contains(ExitFlags::MAY_KILL); + let uniformity = info.process_block(&fun.body, &self.functions, None)?; + info.uniformity = uniformity.result; + info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL); Ok(info) } @@ -676,8 +721,11 @@ fn uniform_control_flow() { axis: crate::DerivativeAxis::X, expr: constant_expr, }); + let emit_range_constant_derivative = expressions.range_from(0); let non_uniform_global_expr = expressions.append(E::GlobalVariable(non_uniform_global)); let uniform_global_expr = expressions.append(E::GlobalVariable(uniform_global)); + let emit_range_globals = expressions.range_from(2); + // checks the QUERY flag let query_expr = expressions.append(E::ArrayLength(uniform_global_expr)); // checks the transitive WRITE flag @@ -685,6 +733,7 @@ fn uniform_control_flow() { base: non_uniform_global_expr, index: 1, }); + let emit_range_query_access_globals = expressions.range_from(2); let mut info = FunctionInfo { uniformity: Uniformity::default(), @@ -704,68 +753,87 @@ fn uniform_control_flow() { assert_eq!(info[non_uniform_global], GlobalUse::empty()); assert_eq!(info[uniform_global], GlobalUse::QUERY); + let stmt_emit1 = S::Emit(emit_range_globals.clone()); let stmt_if_uniform = S::If { condition: uniform_global_expr, accept: Vec::new(), - reject: vec![S::Store { - pointer: constant_expr, - value: derivative_expr, - }], + reject: vec![ + S::Emit(emit_range_constant_derivative.clone()), + S::Store { + pointer: constant_expr, + value: derivative_expr, + }, + ], }; assert_eq!( - info.process_block(&[stmt_if_uniform], &[], None), - Ok(( - Uniformity::require_uniform(derivative_expr), - ExitFlags::empty() - )), + info.process_block(&[stmt_emit1, stmt_if_uniform], &[], None), + Ok(FunctionUniformity { + result: Uniformity::require_uniform(derivative_expr), + exit: ExitFlags::empty(), + }), ); assert_eq!(info[constant_expr].ref_count, 2); assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY); + let stmt_emit2 = S::Emit(emit_range_globals.clone()); let stmt_if_non_uniform = S::If { condition: non_uniform_global_expr, - accept: vec![S::Store { - pointer: constant_expr, - value: derivative_expr, - }], + accept: vec![ + S::Emit(emit_range_constant_derivative.clone()), + S::Store { + pointer: constant_expr, + value: derivative_expr, + }, + ], reject: Vec::new(), }; assert_eq!( - info.process_block(&[stmt_if_non_uniform], &[], None), + info.process_block(&[stmt_emit2, stmt_if_non_uniform], &[], None), Err(AnalysisError::NonUniformControlFlow( derivative_expr, UniformityDisruptor::Expression(non_uniform_global_expr) )), ); - assert_eq!(info[derivative_expr].ref_count, 2); + assert_eq!(info[derivative_expr].ref_count, 1); assert_eq!(info[non_uniform_global], GlobalUse::READ); + let stmt_emit3 = S::Emit(emit_range_globals); let stmt_return_non_uniform = S::Return { value: Some(non_uniform_global_expr), }; assert_eq!( info.process_block( - &[stmt_return_non_uniform], + &[stmt_emit3, stmt_return_non_uniform], &[], Some(UniformityDisruptor::Return) ), - Ok(( - Uniformity::non_uniform_result(non_uniform_global_expr), - ExitFlags::MAY_RETURN - )), + Ok(FunctionUniformity { + result: Uniformity::non_uniform_result(non_uniform_global_expr), + exit: ExitFlags::MAY_RETURN, + }), ); assert_eq!(info[non_uniform_global_expr].ref_count, 3); + // Check that uniformity requirements reach through a pointer + let stmt_emit4 = S::Emit(emit_range_query_access_globals); let stmt_assign = S::Store { pointer: access_expr, value: query_expr, }; + let stmt_return_pointer = S::Return { + value: Some(access_expr), + }; + let stmt_kill = S::Kill; assert_eq!( - info.process_block(&[stmt_assign], &[], Some(UniformityDisruptor::Discard)), - Ok(( - Uniformity::non_uniform_result(non_uniform_global_expr), - ExitFlags::empty() - )), + info.process_block( + &[stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer], + &[], + Some(UniformityDisruptor::Discard) + ), + Ok(FunctionUniformity { + result: Uniformity::non_uniform_result(non_uniform_global_expr), + exit: ExitFlags::all(), + }), ); assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE); } diff --git a/tests/in/quad.wgsl b/tests/in/quad.wgsl index bc619d5fc3..033eaa3fda 100644 --- a/tests/in/quad.wgsl +++ b/tests/in/quad.wgsl @@ -19,5 +19,9 @@ fn main() { [[stage(fragment)]] fn main() { - o_color = textureSample(u_texture, u_sampler, v_uv); + const color: vec4 = textureSample(u_texture, u_sampler, v_uv); + if (color.a == 0.0) { + discard; + } + o_color = color; } diff --git a/tests/out/collatz.info.ron.snap b/tests/out/collatz.info.ron.snap index b64e3fa7c8..4d8696219a 100644 --- a/tests/out/collatz.info.ron.snap +++ b/tests/out/collatz.info.ron.snap @@ -6,7 +6,7 @@ expression: output functions: [ ( uniformity: ( - non_uniform_result: Some(4), + non_uniform_result: Some(6), require_uniform: None, ), may_kill: false, @@ -234,7 +234,7 @@ expression: output entry_points: [ ( uniformity: ( - non_uniform_result: Some(4), + non_uniform_result: Some(6), require_uniform: None, ), may_kill: false, @@ -338,7 +338,7 @@ expression: output ), ( uniformity: ( - non_uniform_result: Some(4), + non_uniform_result: Some(6), require_uniform: None, ), ref_count: 1, diff --git a/tests/out/quad-Fragment.glsl.snap b/tests/out/quad-Fragment.glsl.snap index 90c6723433..9f879ec214 100644 --- a/tests/out/quad-Fragment.glsl.snap +++ b/tests/out/quad-Fragment.glsl.snap @@ -13,6 +13,11 @@ uniform highp sampler2D _group_0_binding_0; out vec4 _location_0; void main() { + if((texture(_group_0_binding_0, vec2(_location_0_vs))[3] == 0.0)) { + discard; + } _location_0 = texture(_group_0_binding_0, vec2(_location_0_vs)); return; } + + diff --git a/tests/out/quad.dot.snap b/tests/out/quad.dot.snap index f00e3c15eb..4f52819190 100644 --- a/tests/out/quad.dot.snap +++ b/tests/out/quad.dot.snap @@ -83,17 +83,38 @@ digraph Module { ep1_e1 -> ep1_e9 [ label="sampler" ] ep1_e2 -> ep1_e9 [ label="image" ] ep1_e8 -> ep1_e9 [ label="coordinate" ] + ep1_e10 [ color="#8dd3c7" label="[11] AccessIndex[3]" ] + ep1_e9 -> ep1_e10 [ label="base" ] + ep1_e11 [ color="#ffffb3" label="[12] Constant" ] + ep1_e12 [ color="#fdb462" label="[13] Equal" ] + ep1_e11 -> ep1_e12 [ label="right" ] + ep1_e10 -> ep1_e12 [ label="left" ] ep1_s0 [ shape=square label="Root" ] ep1_s1 [ shape=square label="Emit" ] - ep1_s2 [ shape=square label="Store" ] - ep1_s3 [ shape=square label="Return" ] + ep1_s2 [ shape=square label="Emit" ] + ep1_s3 [ shape=square label="Emit" ] + ep1_s4 [ shape=square label="If" ] + ep1_s5 [ shape=square label="Node" ] + ep1_s6 [ shape=square label="Kill" ] + ep1_s7 [ shape=square label="Node" ] + ep1_s8 [ shape=square label="Store" ] + ep1_s9 [ shape=square label="Return" ] ep1_s0 -> ep1_s1 [ arrowhead=tee label="" ] ep1_s1 -> ep1_s2 [ arrowhead=tee label="" ] ep1_s2 -> ep1_s3 [ arrowhead=tee label="" ] - ep1_e9 -> ep1_s2 [ label="value" ] + ep1_s3 -> ep1_s4 [ arrowhead=tee label="" ] + ep1_s5 -> ep1_s6 [ arrowhead=tee label="" ] + ep1_s4 -> ep1_s5 [ arrowhead=tee label="accept" ] + ep1_s4 -> ep1_s7 [ arrowhead=tee label="reject" ] + ep1_s7 -> ep1_s8 [ arrowhead=tee label="" ] + ep1_s8 -> ep1_s9 [ arrowhead=tee label="" ] + ep1_e12 -> ep1_s4 [ label="condition" ] + ep1_e9 -> ep1_s8 [ label="value" ] ep1_s1 -> ep1_e8 [ style=dotted ] ep1_s1 -> ep1_e9 [ style=dotted ] - ep1_s2 -> ep1_e7 [ style=dotted ] + ep1_s2 -> ep1_e10 [ style=dotted ] + ep1_s3 -> ep1_e12 [ style=dotted ] + ep1_s8 -> ep1_e7 [ style=dotted ] } } diff --git a/tests/out/quad.msl.snap b/tests/out/quad.msl.snap index e9b69c0cd7..f695543781 100644 --- a/tests/out/quad.msl.snap +++ b/tests/out/quad.msl.snap @@ -52,6 +52,9 @@ fragment main2Output main2( ) { main2Output output; metal::float4 _expr9 = u_texture.sample(u_sampler, input.v_uv1); + if ((_expr9.w == const_0f)) { + metal::discard_fragment(); + } output.o_color = _expr9; return output; } diff --git a/tests/out/quad.spvasm.snap b/tests/out/quad.spvasm.snap index 99f99eee97..669073886c 100644 --- a/tests/out/quad.spvasm.snap +++ b/tests/out/quad.spvasm.snap @@ -5,7 +5,7 @@ expression: dis ; SPIR-V ; Version: 1.0 ; Generator: rspirv -; Bound: 41 +; Bound: 46 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -60,6 +60,7 @@ OpDecorate %23 Location 0 %23 = OpVariable %15 Output %25 = OpTypeFunction %2 %38 = OpTypeSampledImage %18 +%42 = OpTypeBool %24 = OpFunction %2 None %25 %26 = OpLabel OpBranch %27 @@ -81,6 +82,13 @@ OpBranch %36 %37 = OpLoad %8 %16 %39 = OpSampledImage %38 %34 %35 %40 = OpImageSampleImplicitLod %14 %39 %37 +%41 = OpCompositeExtract %4 %40 3 +%43 = OpFOrdEqual %42 %41 %5 +OpSelectionMerge %44 None +OpBranchConditional %43 %45 %44 +%45 = OpLabel +OpKill +%44 = OpLabel OpStore %23 %40 OpReturn OpFunctionEnd