Let ConstantEvaluator see through Constant exprs in Splat exprs.

This commit is contained in:
Jim Blandy
2023-09-28 16:19:25 -07:00
committed by Teodor Tanasoaia
parent 3e4d565576
commit c16a298cac
6 changed files with 113 additions and 32 deletions

View File

@@ -220,9 +220,9 @@ impl<'a> ConstantEvaluator<'a> {
.collect::<Result<Vec<_>, _>>()?;
Ok(self.register_evaluated_expr(Expression::Compose { ty, components }, span))
}
Expression::Splat { value, .. } => {
self.check(value)?;
Ok(self.register_evaluated_expr(expr.clone(), span))
Expression::Splat { size, value } => {
let value = self.check_and_get(value)?;
Ok(self.register_evaluated_expr(Expression::Splat { size, value }, span))
}
Expression::AccessIndex { base, index } => {
let base = self.check_and_get(base)?;
@@ -1427,4 +1427,87 @@ mod tests {
panic!("unexpected evaluation result")
}
}
#[test]
fn splat_of_constant() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
let mut const_expressions = Arena::new();
let i32_ty = types.insert(
Type {
name: None,
inner: TypeInner::Scalar {
kind: ScalarKind::Sint,
width: 4,
},
},
Default::default(),
);
let vec2_i32_ty = types.insert(
Type {
name: None,
inner: TypeInner::Vector {
size: VectorSize::Bi,
kind: ScalarKind::Sint,
width: 4,
},
},
Default::default(),
);
let h = constants.append(
Constant {
name: None,
r#override: crate::Override::None,
ty: i32_ty,
init: const_expressions
.append(Expression::Literal(Literal::I32(4)), Default::default()),
},
Default::default(),
);
let h_expr = const_expressions.append(Expression::Constant(h), Default::default());
let mut solver = ConstantEvaluator {
types: &mut types,
constants: &constants,
expressions: &mut const_expressions,
function_local_data: None,
};
let solved_compose = solver
.try_eval_and_append(
&Expression::Splat {
size: VectorSize::Bi,
value: h_expr,
},
Default::default(),
)
.unwrap();
let solved_negate = solver
.try_eval_and_append(
&Expression::Unary {
op: UnaryOperator::Negate,
expr: solved_compose,
},
Default::default(),
)
.unwrap();
let pass = match const_expressions[solved_negate] {
Expression::Compose { ty, ref components } => {
ty == vec2_i32_ty
&& components.iter().all(|&component| {
let component = &const_expressions[component];
matches!(*component, Expression::Literal(Literal::I32(-4)))
})
}
_ => false,
};
if !pass {
panic!("unexpected evaluation result")
}
}
}

View File

@@ -52,7 +52,7 @@ void non_constant_initializers() {
}
void splat_of_constant() {
_group_0_binding_0_cs = -(ivec4(FOUR));
_group_0_binding_0_cs = ivec4(-4, -4, -4, -4);
return;
}

View File

@@ -49,7 +49,7 @@ void non_constant_initializers()
void splat_of_constant()
{
out_.Store4(0, asuint(-((FOUR).xxxx)));
out_.Store4(0, asuint(int4(-4, -4, -4, -4)));
return;
}

View File

@@ -56,7 +56,7 @@ void non_constant_initializers(
void splat_of_constant(
device metal::int4& out
) {
out = -(metal::int4(FOUR));
out = metal::int4(-4, -4, -4, -4);
return;
}

View File

@@ -1,13 +1,13 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 86
; Bound: 84
OpCapability Shader
OpExtension "SPV_KHR_storage_buffer_storage_class"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %76 "main"
OpExecutionMode %76 LocalSize 1 1 1
OpEntryPoint GLCompute %74 "main"
OpExecutionMode %74 LocalSize 1 1 1
OpDecorate %7 DescriptorSet 0
OpDecorate %7 Binding 0
OpDecorate %8 Block
@@ -44,9 +44,8 @@ OpMemberDecorate %11 0 Offset 0
%47 = OpTypePointer Function %4
%49 = OpConstantNull %4
%51 = OpConstantNull %4
%66 = OpConstantComposite %3 %6 %6 %6 %6
%72 = OpConstant %4 -4
%73 = OpConstantComposite %3 %72 %72 %72 %72
%66 = OpConstant %4 -4
%67 = OpConstantComposite %3 %66 %66 %66 %66
%14 = OpFunction %2 None %15
%13 = OpLabel
%19 = OpAccessChain %16 %7 %17
@@ -101,31 +100,30 @@ OpFunctionEnd
%64 = OpFunction %2 None %15
%63 = OpLabel
%65 = OpAccessChain %16 %7 %17
OpBranch %67
%67 = OpLabel
%68 = OpSNegate %3 %66
OpStore %65 %68
OpBranch %68
%68 = OpLabel
OpStore %65 %67
OpReturn
OpFunctionEnd
%70 = OpFunction %2 None %15
%69 = OpLabel
%71 = OpAccessChain %16 %7 %17
OpBranch %74
%74 = OpLabel
OpStore %71 %73
OpBranch %72
%72 = OpLabel
OpStore %71 %67
OpReturn
OpFunctionEnd
%76 = OpFunction %2 None %15
%75 = OpLabel
%77 = OpAccessChain %16 %7 %17
%78 = OpAccessChain %29 %10 %17
OpBranch %79
%79 = OpLabel
%80 = OpFunctionCall %2 %14
%81 = OpFunctionCall %2 %28
%82 = OpFunctionCall %2 %35
%83 = OpFunctionCall %2 %42
%84 = OpFunctionCall %2 %64
%85 = OpFunctionCall %2 %70
%74 = OpFunction %2 None %15
%73 = OpLabel
%75 = OpAccessChain %16 %7 %17
%76 = OpAccessChain %29 %10 %17
OpBranch %77
%77 = OpLabel
%78 = OpFunctionCall %2 %14
%79 = OpFunctionCall %2 %28
%80 = OpFunctionCall %2 %35
%81 = OpFunctionCall %2 %42
%82 = OpFunctionCall %2 %64
%83 = OpFunctionCall %2 %70
OpReturn
OpFunctionEnd

View File

@@ -46,7 +46,7 @@ fn non_constant_initializers() {
}
fn splat_of_constant() {
out = -(vec4(FOUR));
out = vec4<i32>(-4, -4, -4, -4);
return;
}