[naga wgsl-in] Implement any() and all() during const evaluation (#7166)

This commit is contained in:
Jamie Nicol
2025-02-17 19:13:49 +00:00
committed by GitHub
parent c03176f3eb
commit d625d083c3
7 changed files with 180 additions and 52 deletions

View File

@@ -4,8 +4,8 @@ use arrayvec::ArrayVec;
use crate::{
arena::{Arena, Handle, HandleVec, UniqueArena},
ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type,
TypeInner, UnaryOperator,
ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
ScalarKind, Span, Type, TypeInner, UnaryOperator,
};
/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
@@ -547,6 +547,8 @@ pub enum ConstantEvaluatorError {
InvalidMathArg,
#[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
InvalidMathArgCount(crate::MathFunction, usize, usize),
#[error("Cannot apply relational function to type")]
InvalidRelationalArg(RelationalFunction),
#[error("value of `low` is greater than `high` for clamp built-in function")]
InvalidClamp,
#[error("Splat is defined only on scalar values")]
@@ -931,9 +933,10 @@ impl<'a> ConstantEvaluator<'a> {
Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
"select built-in function".into(),
)),
Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented(
format!("{fun:?} built-in function"),
)),
Expression::Relational { fun, argument } => {
let argument = self.check_and_get(argument)?;
self.relational(fun, argument, span)
}
Expression::ArrayLength(expr) => match self.behavior {
Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
Behavior::Glsl(_) => {
@@ -2103,6 +2106,41 @@ impl<'a> ConstantEvaluator<'a> {
Ok(Expression::Compose { ty, components })
}
fn relational(
&mut self,
fun: RelationalFunction,
arg: Handle<Expression>,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
let arg = self.eval_zero_value_and_splat(arg, span)?;
match fun {
RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
Expression::Literal(Literal::Bool(_)) => Ok(arg),
Expression::Compose { ty, ref components }
if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
{
let components =
crate::proc::flatten_compose(ty, components, self.expressions, self.types)
.map(|component| match self.expressions[component] {
Expression::Literal(Literal::Bool(val)) => Ok(val),
_ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
})
.collect::<Result<ArrayVec<bool, { crate::VectorSize::MAX }>, _>>()?;
let result = match fun {
RelationalFunction::All => components.iter().all(|c| *c),
RelationalFunction::Any => components.iter().any(|c| *c),
_ => unreachable!(),
};
self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
}
_ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
},
_ => Err(ConstantEvaluatorError::NotImplemented(format!(
"{fun:?} built-in function"
))),
}
}
/// Deep copy `expr` from `expressions` into `self.expressions`.
///
/// Return the root of the new copy.

View File

@@ -1,5 +1,7 @@
const TWO: u32 = 2u;
const THREE: i32 = 3i;
const TRUE = true;
const FALSE = false;
@compute @workgroup_size(TWO, THREE, TWO - 1u)
fn main() {
@@ -94,3 +96,16 @@ fn compose_vector_zero_val_binop() {
var b = vec3(vec2i(), 0) + vec3(0, 1, 2);
var c = vec3(vec2i(), 2) + vec3(1, vec2i());
}
fn relational() {
// Test scalar and vector forms of any() and all(), with a mixture of
// consts, literals, zero-values, composes, and splats.
var scalar_any_false = any(false);
var scalar_any_true = any(true);
var scalar_all_false = all(false);
var scalar_all_true = all(true);
var vec_any_false = any(vec4<bool>());
var vec_any_true = any(vec4(bool(), true, vec2(FALSE)));
var vec_all_false = all(vec4(vec3(vec2<bool>(), TRUE), false));
var vec_all_true = all(vec4(true));
}

View File

@@ -7,6 +7,8 @@ layout(local_size_x = 2, local_size_y = 3, local_size_z = 1) in;
const uint TWO = 2u;
const int THREE = 3;
const bool TRUE = true;
const bool FALSE = false;
const int FOUR = 4;
const int FOUR_ALIAS = 4;
const int TEST_CONSTANT_ADDITION = 8;
@@ -93,6 +95,18 @@ void compose_vector_zero_val_binop() {
return;
}
void relational() {
bool scalar_any_false = false;
bool scalar_any_true = true;
bool scalar_all_false = false;
bool scalar_all_true = true;
bool vec_any_false = false;
bool vec_any_true = true;
bool vec_all_false = false;
bool vec_all_true = true;
return;
}
void main() {
swizzle_of_compose();
index_of_compose();

View File

@@ -1,5 +1,7 @@
static const uint TWO = 2u;
static const int THREE = int(3);
static const bool TRUE = true;
static const bool FALSE = false;
static const int FOUR = int(4);
static const int FOUR_ALIAS = int(4);
static const int TEST_CONSTANT_ADDITION = int(8);
@@ -102,6 +104,20 @@ void compose_vector_zero_val_binop()
return;
}
void relational()
{
bool scalar_any_false = false;
bool scalar_any_true = true;
bool scalar_all_false = false;
bool scalar_all_true = true;
bool vec_any_false = false;
bool vec_any_true = true;
bool vec_all_false = false;
bool vec_all_true = true;
return;
}
[numthreads(2, 3, 1)]
void main()
{

View File

@@ -6,6 +6,8 @@ using metal::uint;
constant uint TWO = 2u;
constant int THREE = 3;
constant bool TRUE = true;
constant bool FALSE = false;
constant int FOUR = 4;
constant int FOUR_ALIAS = 4;
constant int TEST_CONSTANT_ADDITION = 8;
@@ -101,6 +103,19 @@ void compose_vector_zero_val_binop(
return;
}
void relational(
) {
bool scalar_any_false = false;
bool scalar_any_true = true;
bool scalar_all_false = false;
bool scalar_all_true = true;
bool vec_any_false = false;
bool vec_any_true = true;
bool vec_all_false = false;
bool vec_all_true = true;
return;
}
kernel void main_(
) {
swizzle_of_compose();

View File

@@ -1,66 +1,67 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 120
; Bound: 132
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %111 "main"
OpExecutionMode %111 LocalSize 2 3 1
OpEntryPoint GLCompute %123 "main"
OpExecutionMode %123 LocalSize 2 3 1
%2 = OpTypeVoid
%3 = OpTypeInt 32 0
%4 = OpTypeInt 32 1
%5 = OpTypeVector %4 4
%6 = OpTypeFloat 32
%7 = OpTypeVector %6 4
%8 = OpTypeVector %6 2
%10 = OpTypeBool
%9 = OpTypeVector %10 2
%5 = OpTypeBool
%6 = OpTypeVector %4 4
%7 = OpTypeFloat 32
%8 = OpTypeVector %7 4
%9 = OpTypeVector %7 2
%10 = OpTypeVector %5 2
%11 = OpTypeVector %4 3
%12 = OpConstant %3 2
%13 = OpConstant %4 3
%14 = OpConstant %4 4
%15 = OpConstant %4 8
%16 = OpConstant %6 3.141
%17 = OpConstant %6 6.282
%18 = OpConstant %6 0.44444445
%19 = OpConstant %6 0.0
%20 = OpConstantComposite %7 %18 %19 %19 %19
%21 = OpConstant %4 0
%22 = OpConstant %4 1
%23 = OpConstant %4 2
%24 = OpConstant %6 4.0
%25 = OpConstant %6 5.0
%26 = OpConstantComposite %8 %24 %25
%27 = OpConstantTrue %10
%28 = OpConstantFalse %10
%29 = OpConstantComposite %9 %27 %28
%14 = OpConstantTrue %5
%15 = OpConstantFalse %5
%16 = OpConstant %4 4
%17 = OpConstant %4 8
%18 = OpConstant %7 3.141
%19 = OpConstant %7 6.282
%20 = OpConstant %7 0.44444445
%21 = OpConstant %7 0.0
%22 = OpConstantComposite %8 %20 %21 %21 %21
%23 = OpConstant %4 0
%24 = OpConstant %4 1
%25 = OpConstant %4 2
%26 = OpConstant %7 4.0
%27 = OpConstant %7 5.0
%28 = OpConstantComposite %9 %26 %27
%29 = OpConstantComposite %10 %14 %15
%32 = OpTypeFunction %2
%33 = OpConstantComposite %5 %14 %13 %23 %22
%35 = OpTypePointer Function %5
%33 = OpConstantComposite %6 %16 %13 %25 %24
%35 = OpTypePointer Function %6
%40 = OpTypePointer Function %4
%44 = OpConstant %4 6
%49 = OpConstant %4 30
%50 = OpConstant %4 70
%53 = OpConstantNull %4
%55 = OpConstantNull %4
%58 = OpConstantNull %5
%58 = OpConstantNull %6
%69 = OpConstant %4 -4
%70 = OpConstantComposite %5 %69 %69 %69 %69
%79 = OpConstant %6 1.0
%80 = OpConstant %6 2.0
%81 = OpConstantComposite %7 %80 %79 %79 %79
%83 = OpTypePointer Function %7
%70 = OpConstantComposite %6 %69 %69 %69 %69
%79 = OpConstant %7 1.0
%80 = OpConstant %7 2.0
%81 = OpConstantComposite %8 %80 %79 %79 %79
%83 = OpTypePointer Function %8
%88 = OpTypeFunction %3 %4
%89 = OpConstant %3 10
%90 = OpConstant %3 20
%91 = OpConstant %3 30
%92 = OpConstant %3 0
%99 = OpConstantNull %3
%102 = OpConstantComposite %11 %22 %22 %22
%103 = OpConstantComposite %11 %21 %22 %23
%104 = OpConstantComposite %11 %22 %21 %23
%102 = OpConstantComposite %11 %24 %24 %24
%103 = OpConstantComposite %11 %23 %24 %25
%104 = OpConstantComposite %11 %24 %23 %25
%106 = OpTypePointer Function %11
%113 = OpTypePointer Function %5
%31 = OpFunction %2 None %32
%30 = OpLabel
%34 = OpVariable %35 Function %33
@@ -70,7 +71,7 @@ OpReturn
OpFunctionEnd
%38 = OpFunction %2 None %32
%37 = OpLabel
%39 = OpVariable %40 Function %23
%39 = OpVariable %40 Function %25
OpBranch %41
%41 = OpLabel
OpReturn
@@ -99,7 +100,7 @@ OpStore %54 %61
%63 = OpLoad %4 %52
%64 = OpLoad %4 %54
%65 = OpLoad %4 %56
%66 = OpCompositeConstruct %5 %62 %63 %64 %65
%66 = OpCompositeConstruct %6 %62 %63 %64 %65
OpStore %57 %66
OpReturn
OpFunctionEnd
@@ -153,14 +154,28 @@ OpReturn
OpFunctionEnd
%111 = OpFunction %2 None %32
%110 = OpLabel
OpBranch %112
%112 = OpLabel
%113 = OpFunctionCall %2 %31
%114 = OpFunctionCall %2 %38
%115 = OpFunctionCall %2 %43
%116 = OpFunctionCall %2 %48
%117 = OpFunctionCall %2 %68
%118 = OpFunctionCall %2 %74
%119 = OpFunctionCall %2 %78
%119 = OpVariable %113 Function %15
%116 = OpVariable %113 Function %14
%112 = OpVariable %113 Function %15
%120 = OpVariable %113 Function %14
%117 = OpVariable %113 Function %15
%114 = OpVariable %113 Function %14
%118 = OpVariable %113 Function %14
%115 = OpVariable %113 Function %15
OpBranch %121
%121 = OpLabel
OpReturn
OpFunctionEnd
%123 = OpFunction %2 None %32
%122 = OpLabel
OpBranch %124
%124 = OpLabel
%125 = OpFunctionCall %2 %31
%126 = OpFunctionCall %2 %38
%127 = OpFunctionCall %2 %43
%128 = OpFunctionCall %2 %48
%129 = OpFunctionCall %2 %68
%130 = OpFunctionCall %2 %74
%131 = OpFunctionCall %2 %78
OpReturn
OpFunctionEnd

View File

@@ -1,5 +1,7 @@
const TWO: u32 = 2u;
const THREE: i32 = 3i;
const TRUE: bool = true;
const FALSE: bool = false;
const FOUR: i32 = 4i;
const FOUR_ALIAS: i32 = 4i;
const TEST_CONSTANT_ADDITION: i32 = 8i;
@@ -93,6 +95,19 @@ fn compose_vector_zero_val_binop() {
return;
}
fn relational() {
var scalar_any_false: bool = false;
var scalar_any_true: bool = true;
var scalar_all_false: bool = false;
var scalar_all_true: bool = true;
var vec_any_false: bool = false;
var vec_any_true: bool = true;
var vec_all_false: bool = false;
var vec_all_true: bool = true;
return;
}
@compute @workgroup_size(2, 3, 1)
fn main() {
swizzle_of_compose();