mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
[naga wgsl-in] Implement any() and all() during const evaluation (#7166)
This commit is contained in:
@@ -4,8 +4,8 @@ use arrayvec::ArrayVec;
|
||||
|
||||
use crate::{
|
||||
arena::{Arena, Handle, HandleVec, UniqueArena},
|
||||
ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type,
|
||||
TypeInner, UnaryOperator,
|
||||
ArraySize, BinaryOperator, Constant, Expression, Literal, Override, RelationalFunction,
|
||||
ScalarKind, Span, Type, TypeInner, UnaryOperator,
|
||||
};
|
||||
|
||||
/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
|
||||
@@ -547,6 +547,8 @@ pub enum ConstantEvaluatorError {
|
||||
InvalidMathArg,
|
||||
#[error("{0:?} built-in function expects {1:?} arguments but {2:?} were supplied")]
|
||||
InvalidMathArgCount(crate::MathFunction, usize, usize),
|
||||
#[error("Cannot apply relational function to type")]
|
||||
InvalidRelationalArg(RelationalFunction),
|
||||
#[error("value of `low` is greater than `high` for clamp built-in function")]
|
||||
InvalidClamp,
|
||||
#[error("Splat is defined only on scalar values")]
|
||||
@@ -931,9 +933,10 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
Expression::Select { .. } => Err(ConstantEvaluatorError::NotImplemented(
|
||||
"select built-in function".into(),
|
||||
)),
|
||||
Expression::Relational { fun, .. } => Err(ConstantEvaluatorError::NotImplemented(
|
||||
format!("{fun:?} built-in function"),
|
||||
)),
|
||||
Expression::Relational { fun, argument } => {
|
||||
let argument = self.check_and_get(argument)?;
|
||||
self.relational(fun, argument, span)
|
||||
}
|
||||
Expression::ArrayLength(expr) => match self.behavior {
|
||||
Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
|
||||
Behavior::Glsl(_) => {
|
||||
@@ -2103,6 +2106,41 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
Ok(Expression::Compose { ty, components })
|
||||
}
|
||||
|
||||
fn relational(
|
||||
&mut self,
|
||||
fun: RelationalFunction,
|
||||
arg: Handle<Expression>,
|
||||
span: Span,
|
||||
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
|
||||
let arg = self.eval_zero_value_and_splat(arg, span)?;
|
||||
match fun {
|
||||
RelationalFunction::All | RelationalFunction::Any => match self.expressions[arg] {
|
||||
Expression::Literal(Literal::Bool(_)) => Ok(arg),
|
||||
Expression::Compose { ty, ref components }
|
||||
if matches!(self.types[ty].inner, TypeInner::Vector { .. }) =>
|
||||
{
|
||||
let components =
|
||||
crate::proc::flatten_compose(ty, components, self.expressions, self.types)
|
||||
.map(|component| match self.expressions[component] {
|
||||
Expression::Literal(Literal::Bool(val)) => Ok(val),
|
||||
_ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
|
||||
})
|
||||
.collect::<Result<ArrayVec<bool, { crate::VectorSize::MAX }>, _>>()?;
|
||||
let result = match fun {
|
||||
RelationalFunction::All => components.iter().all(|c| *c),
|
||||
RelationalFunction::Any => components.iter().any(|c| *c),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.register_evaluated_expr(Expression::Literal(Literal::Bool(result)), span)
|
||||
}
|
||||
_ => Err(ConstantEvaluatorError::InvalidRelationalArg(fun)),
|
||||
},
|
||||
_ => Err(ConstantEvaluatorError::NotImplemented(format!(
|
||||
"{fun:?} built-in function"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Deep copy `expr` from `expressions` into `self.expressions`.
|
||||
///
|
||||
/// Return the root of the new copy.
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
const TWO: u32 = 2u;
|
||||
const THREE: i32 = 3i;
|
||||
const TRUE = true;
|
||||
const FALSE = false;
|
||||
|
||||
@compute @workgroup_size(TWO, THREE, TWO - 1u)
|
||||
fn main() {
|
||||
@@ -94,3 +96,16 @@ fn compose_vector_zero_val_binop() {
|
||||
var b = vec3(vec2i(), 0) + vec3(0, 1, 2);
|
||||
var c = vec3(vec2i(), 2) + vec3(1, vec2i());
|
||||
}
|
||||
|
||||
fn relational() {
|
||||
// Test scalar and vector forms of any() and all(), with a mixture of
|
||||
// consts, literals, zero-values, composes, and splats.
|
||||
var scalar_any_false = any(false);
|
||||
var scalar_any_true = any(true);
|
||||
var scalar_all_false = all(false);
|
||||
var scalar_all_true = all(true);
|
||||
var vec_any_false = any(vec4<bool>());
|
||||
var vec_any_true = any(vec4(bool(), true, vec2(FALSE)));
|
||||
var vec_all_false = all(vec4(vec3(vec2<bool>(), TRUE), false));
|
||||
var vec_all_true = all(vec4(true));
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ layout(local_size_x = 2, local_size_y = 3, local_size_z = 1) in;
|
||||
|
||||
const uint TWO = 2u;
|
||||
const int THREE = 3;
|
||||
const bool TRUE = true;
|
||||
const bool FALSE = false;
|
||||
const int FOUR = 4;
|
||||
const int FOUR_ALIAS = 4;
|
||||
const int TEST_CONSTANT_ADDITION = 8;
|
||||
@@ -93,6 +95,18 @@ void compose_vector_zero_val_binop() {
|
||||
return;
|
||||
}
|
||||
|
||||
void relational() {
|
||||
bool scalar_any_false = false;
|
||||
bool scalar_any_true = true;
|
||||
bool scalar_all_false = false;
|
||||
bool scalar_all_true = true;
|
||||
bool vec_any_false = false;
|
||||
bool vec_any_true = true;
|
||||
bool vec_all_false = false;
|
||||
bool vec_all_true = true;
|
||||
return;
|
||||
}
|
||||
|
||||
void main() {
|
||||
swizzle_of_compose();
|
||||
index_of_compose();
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
static const uint TWO = 2u;
|
||||
static const int THREE = int(3);
|
||||
static const bool TRUE = true;
|
||||
static const bool FALSE = false;
|
||||
static const int FOUR = int(4);
|
||||
static const int FOUR_ALIAS = int(4);
|
||||
static const int TEST_CONSTANT_ADDITION = int(8);
|
||||
@@ -102,6 +104,20 @@ void compose_vector_zero_val_binop()
|
||||
return;
|
||||
}
|
||||
|
||||
void relational()
|
||||
{
|
||||
bool scalar_any_false = false;
|
||||
bool scalar_any_true = true;
|
||||
bool scalar_all_false = false;
|
||||
bool scalar_all_true = true;
|
||||
bool vec_any_false = false;
|
||||
bool vec_any_true = true;
|
||||
bool vec_all_false = false;
|
||||
bool vec_all_true = true;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
[numthreads(2, 3, 1)]
|
||||
void main()
|
||||
{
|
||||
|
||||
@@ -6,6 +6,8 @@ using metal::uint;
|
||||
|
||||
constant uint TWO = 2u;
|
||||
constant int THREE = 3;
|
||||
constant bool TRUE = true;
|
||||
constant bool FALSE = false;
|
||||
constant int FOUR = 4;
|
||||
constant int FOUR_ALIAS = 4;
|
||||
constant int TEST_CONSTANT_ADDITION = 8;
|
||||
@@ -101,6 +103,19 @@ void compose_vector_zero_val_binop(
|
||||
return;
|
||||
}
|
||||
|
||||
void relational(
|
||||
) {
|
||||
bool scalar_any_false = false;
|
||||
bool scalar_any_true = true;
|
||||
bool scalar_all_false = false;
|
||||
bool scalar_all_true = true;
|
||||
bool vec_any_false = false;
|
||||
bool vec_any_true = true;
|
||||
bool vec_all_false = false;
|
||||
bool vec_all_true = true;
|
||||
return;
|
||||
}
|
||||
|
||||
kernel void main_(
|
||||
) {
|
||||
swizzle_of_compose();
|
||||
|
||||
@@ -1,66 +1,67 @@
|
||||
; SPIR-V
|
||||
; Version: 1.1
|
||||
; Generator: rspirv
|
||||
; Bound: 120
|
||||
; Bound: 132
|
||||
OpCapability Shader
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %111 "main"
|
||||
OpExecutionMode %111 LocalSize 2 3 1
|
||||
OpEntryPoint GLCompute %123 "main"
|
||||
OpExecutionMode %123 LocalSize 2 3 1
|
||||
%2 = OpTypeVoid
|
||||
%3 = OpTypeInt 32 0
|
||||
%4 = OpTypeInt 32 1
|
||||
%5 = OpTypeVector %4 4
|
||||
%6 = OpTypeFloat 32
|
||||
%7 = OpTypeVector %6 4
|
||||
%8 = OpTypeVector %6 2
|
||||
%10 = OpTypeBool
|
||||
%9 = OpTypeVector %10 2
|
||||
%5 = OpTypeBool
|
||||
%6 = OpTypeVector %4 4
|
||||
%7 = OpTypeFloat 32
|
||||
%8 = OpTypeVector %7 4
|
||||
%9 = OpTypeVector %7 2
|
||||
%10 = OpTypeVector %5 2
|
||||
%11 = OpTypeVector %4 3
|
||||
%12 = OpConstant %3 2
|
||||
%13 = OpConstant %4 3
|
||||
%14 = OpConstant %4 4
|
||||
%15 = OpConstant %4 8
|
||||
%16 = OpConstant %6 3.141
|
||||
%17 = OpConstant %6 6.282
|
||||
%18 = OpConstant %6 0.44444445
|
||||
%19 = OpConstant %6 0.0
|
||||
%20 = OpConstantComposite %7 %18 %19 %19 %19
|
||||
%21 = OpConstant %4 0
|
||||
%22 = OpConstant %4 1
|
||||
%23 = OpConstant %4 2
|
||||
%24 = OpConstant %6 4.0
|
||||
%25 = OpConstant %6 5.0
|
||||
%26 = OpConstantComposite %8 %24 %25
|
||||
%27 = OpConstantTrue %10
|
||||
%28 = OpConstantFalse %10
|
||||
%29 = OpConstantComposite %9 %27 %28
|
||||
%14 = OpConstantTrue %5
|
||||
%15 = OpConstantFalse %5
|
||||
%16 = OpConstant %4 4
|
||||
%17 = OpConstant %4 8
|
||||
%18 = OpConstant %7 3.141
|
||||
%19 = OpConstant %7 6.282
|
||||
%20 = OpConstant %7 0.44444445
|
||||
%21 = OpConstant %7 0.0
|
||||
%22 = OpConstantComposite %8 %20 %21 %21 %21
|
||||
%23 = OpConstant %4 0
|
||||
%24 = OpConstant %4 1
|
||||
%25 = OpConstant %4 2
|
||||
%26 = OpConstant %7 4.0
|
||||
%27 = OpConstant %7 5.0
|
||||
%28 = OpConstantComposite %9 %26 %27
|
||||
%29 = OpConstantComposite %10 %14 %15
|
||||
%32 = OpTypeFunction %2
|
||||
%33 = OpConstantComposite %5 %14 %13 %23 %22
|
||||
%35 = OpTypePointer Function %5
|
||||
%33 = OpConstantComposite %6 %16 %13 %25 %24
|
||||
%35 = OpTypePointer Function %6
|
||||
%40 = OpTypePointer Function %4
|
||||
%44 = OpConstant %4 6
|
||||
%49 = OpConstant %4 30
|
||||
%50 = OpConstant %4 70
|
||||
%53 = OpConstantNull %4
|
||||
%55 = OpConstantNull %4
|
||||
%58 = OpConstantNull %5
|
||||
%58 = OpConstantNull %6
|
||||
%69 = OpConstant %4 -4
|
||||
%70 = OpConstantComposite %5 %69 %69 %69 %69
|
||||
%79 = OpConstant %6 1.0
|
||||
%80 = OpConstant %6 2.0
|
||||
%81 = OpConstantComposite %7 %80 %79 %79 %79
|
||||
%83 = OpTypePointer Function %7
|
||||
%70 = OpConstantComposite %6 %69 %69 %69 %69
|
||||
%79 = OpConstant %7 1.0
|
||||
%80 = OpConstant %7 2.0
|
||||
%81 = OpConstantComposite %8 %80 %79 %79 %79
|
||||
%83 = OpTypePointer Function %8
|
||||
%88 = OpTypeFunction %3 %4
|
||||
%89 = OpConstant %3 10
|
||||
%90 = OpConstant %3 20
|
||||
%91 = OpConstant %3 30
|
||||
%92 = OpConstant %3 0
|
||||
%99 = OpConstantNull %3
|
||||
%102 = OpConstantComposite %11 %22 %22 %22
|
||||
%103 = OpConstantComposite %11 %21 %22 %23
|
||||
%104 = OpConstantComposite %11 %22 %21 %23
|
||||
%102 = OpConstantComposite %11 %24 %24 %24
|
||||
%103 = OpConstantComposite %11 %23 %24 %25
|
||||
%104 = OpConstantComposite %11 %24 %23 %25
|
||||
%106 = OpTypePointer Function %11
|
||||
%113 = OpTypePointer Function %5
|
||||
%31 = OpFunction %2 None %32
|
||||
%30 = OpLabel
|
||||
%34 = OpVariable %35 Function %33
|
||||
@@ -70,7 +71,7 @@ OpReturn
|
||||
OpFunctionEnd
|
||||
%38 = OpFunction %2 None %32
|
||||
%37 = OpLabel
|
||||
%39 = OpVariable %40 Function %23
|
||||
%39 = OpVariable %40 Function %25
|
||||
OpBranch %41
|
||||
%41 = OpLabel
|
||||
OpReturn
|
||||
@@ -99,7 +100,7 @@ OpStore %54 %61
|
||||
%63 = OpLoad %4 %52
|
||||
%64 = OpLoad %4 %54
|
||||
%65 = OpLoad %4 %56
|
||||
%66 = OpCompositeConstruct %5 %62 %63 %64 %65
|
||||
%66 = OpCompositeConstruct %6 %62 %63 %64 %65
|
||||
OpStore %57 %66
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
@@ -153,14 +154,28 @@ OpReturn
|
||||
OpFunctionEnd
|
||||
%111 = OpFunction %2 None %32
|
||||
%110 = OpLabel
|
||||
OpBranch %112
|
||||
%112 = OpLabel
|
||||
%113 = OpFunctionCall %2 %31
|
||||
%114 = OpFunctionCall %2 %38
|
||||
%115 = OpFunctionCall %2 %43
|
||||
%116 = OpFunctionCall %2 %48
|
||||
%117 = OpFunctionCall %2 %68
|
||||
%118 = OpFunctionCall %2 %74
|
||||
%119 = OpFunctionCall %2 %78
|
||||
%119 = OpVariable %113 Function %15
|
||||
%116 = OpVariable %113 Function %14
|
||||
%112 = OpVariable %113 Function %15
|
||||
%120 = OpVariable %113 Function %14
|
||||
%117 = OpVariable %113 Function %15
|
||||
%114 = OpVariable %113 Function %14
|
||||
%118 = OpVariable %113 Function %14
|
||||
%115 = OpVariable %113 Function %15
|
||||
OpBranch %121
|
||||
%121 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%123 = OpFunction %2 None %32
|
||||
%122 = OpLabel
|
||||
OpBranch %124
|
||||
%124 = OpLabel
|
||||
%125 = OpFunctionCall %2 %31
|
||||
%126 = OpFunctionCall %2 %38
|
||||
%127 = OpFunctionCall %2 %43
|
||||
%128 = OpFunctionCall %2 %48
|
||||
%129 = OpFunctionCall %2 %68
|
||||
%130 = OpFunctionCall %2 %74
|
||||
%131 = OpFunctionCall %2 %78
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
@@ -1,5 +1,7 @@
|
||||
const TWO: u32 = 2u;
|
||||
const THREE: i32 = 3i;
|
||||
const TRUE: bool = true;
|
||||
const FALSE: bool = false;
|
||||
const FOUR: i32 = 4i;
|
||||
const FOUR_ALIAS: i32 = 4i;
|
||||
const TEST_CONSTANT_ADDITION: i32 = 8i;
|
||||
@@ -93,6 +95,19 @@ fn compose_vector_zero_val_binop() {
|
||||
return;
|
||||
}
|
||||
|
||||
fn relational() {
|
||||
var scalar_any_false: bool = false;
|
||||
var scalar_any_true: bool = true;
|
||||
var scalar_all_false: bool = false;
|
||||
var scalar_all_true: bool = true;
|
||||
var vec_any_false: bool = false;
|
||||
var vec_any_true: bool = true;
|
||||
var vec_all_false: bool = false;
|
||||
var vec_all_true: bool = true;
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(2, 3, 1)
|
||||
fn main() {
|
||||
swizzle_of_compose();
|
||||
|
||||
Reference in New Issue
Block a user