diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index 75b5d0c0ae..e22b1275eb 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -220,9 +220,9 @@ impl<'a> ConstantEvaluator<'a> { .collect::, _>>()?; Ok(self.register_evaluated_expr(Expression::Compose { ty, components }, span)) } - Expression::Splat { value, .. } => { - self.check(value)?; - Ok(self.register_evaluated_expr(expr.clone(), span)) + Expression::Splat { size, value } => { + let value = self.check_and_get(value)?; + Ok(self.register_evaluated_expr(Expression::Splat { size, value }, span)) } Expression::AccessIndex { base, index } => { let base = self.check_and_get(base)?; @@ -1427,4 +1427,87 @@ mod tests { panic!("unexpected evaluation result") } } + + #[test] + fn splat_of_constant() { + let mut types = UniqueArena::new(); + let mut constants = Arena::new(); + let mut const_expressions = Arena::new(); + + let i32_ty = types.insert( + Type { + name: None, + inner: TypeInner::Scalar { + kind: ScalarKind::Sint, + width: 4, + }, + }, + Default::default(), + ); + + let vec2_i32_ty = types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size: VectorSize::Bi, + kind: ScalarKind::Sint, + width: 4, + }, + }, + Default::default(), + ); + + let h = constants.append( + Constant { + name: None, + r#override: crate::Override::None, + ty: i32_ty, + init: const_expressions + .append(Expression::Literal(Literal::I32(4)), Default::default()), + }, + Default::default(), + ); + + let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); + + let mut solver = ConstantEvaluator { + types: &mut types, + constants: &constants, + expressions: &mut const_expressions, + function_local_data: None, + }; + + let solved_compose = solver + .try_eval_and_append( + &Expression::Splat { + size: VectorSize::Bi, + value: h_expr, + }, + Default::default(), + ) + .unwrap(); + let solved_negate = solver + .try_eval_and_append( + &Expression::Unary { + op: UnaryOperator::Negate, + expr: solved_compose, + }, + Default::default(), + ) + .unwrap(); + + let pass = match const_expressions[solved_negate] { + Expression::Compose { ty, ref components } => { + ty == vec2_i32_ty + && components.iter().all(|&component| { + let component = &const_expressions[component]; + matches!(*component, Expression::Literal(Literal::I32(-4))) + }) + } + _ => false, + }; + if !pass { + panic!("unexpected evaluation result") + } + } } diff --git a/tests/out/glsl/const-exprs.main.Compute.glsl b/tests/out/glsl/const-exprs.main.Compute.glsl index 2d60807e90..86263b45cf 100644 --- a/tests/out/glsl/const-exprs.main.Compute.glsl +++ b/tests/out/glsl/const-exprs.main.Compute.glsl @@ -52,7 +52,7 @@ void non_constant_initializers() { } void splat_of_constant() { - _group_0_binding_0_cs = -(ivec4(FOUR)); + _group_0_binding_0_cs = ivec4(-4, -4, -4, -4); return; } diff --git a/tests/out/hlsl/const-exprs.hlsl b/tests/out/hlsl/const-exprs.hlsl index 572e95e819..87b2a48d72 100644 --- a/tests/out/hlsl/const-exprs.hlsl +++ b/tests/out/hlsl/const-exprs.hlsl @@ -49,7 +49,7 @@ void non_constant_initializers() void splat_of_constant() { - out_.Store4(0, asuint(-((FOUR).xxxx))); + out_.Store4(0, asuint(int4(-4, -4, -4, -4))); return; } diff --git a/tests/out/msl/const-exprs.msl b/tests/out/msl/const-exprs.msl index e9db1d0fbb..8cf66d6179 100644 --- a/tests/out/msl/const-exprs.msl +++ b/tests/out/msl/const-exprs.msl @@ -56,7 +56,7 @@ void non_constant_initializers( void splat_of_constant( device metal::int4& out ) { - out = -(metal::int4(FOUR)); + out = metal::int4(-4, -4, -4, -4); return; } diff --git a/tests/out/spv/const-exprs.spvasm b/tests/out/spv/const-exprs.spvasm index 357ecf9196..86cb57a918 100644 --- a/tests/out/spv/const-exprs.spvasm +++ b/tests/out/spv/const-exprs.spvasm @@ -1,13 +1,13 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 86 +; Bound: 84 OpCapability Shader OpExtension "SPV_KHR_storage_buffer_storage_class" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %76 "main" -OpExecutionMode %76 LocalSize 1 1 1 +OpEntryPoint GLCompute %74 "main" +OpExecutionMode %74 LocalSize 1 1 1 OpDecorate %7 DescriptorSet 0 OpDecorate %7 Binding 0 OpDecorate %8 Block @@ -44,9 +44,8 @@ OpMemberDecorate %11 0 Offset 0 %47 = OpTypePointer Function %4 %49 = OpConstantNull %4 %51 = OpConstantNull %4 -%66 = OpConstantComposite %3 %6 %6 %6 %6 -%72 = OpConstant %4 -4 -%73 = OpConstantComposite %3 %72 %72 %72 %72 +%66 = OpConstant %4 -4 +%67 = OpConstantComposite %3 %66 %66 %66 %66 %14 = OpFunction %2 None %15 %13 = OpLabel %19 = OpAccessChain %16 %7 %17 @@ -101,31 +100,30 @@ OpFunctionEnd %64 = OpFunction %2 None %15 %63 = OpLabel %65 = OpAccessChain %16 %7 %17 -OpBranch %67 -%67 = OpLabel -%68 = OpSNegate %3 %66 -OpStore %65 %68 +OpBranch %68 +%68 = OpLabel +OpStore %65 %67 OpReturn OpFunctionEnd %70 = OpFunction %2 None %15 %69 = OpLabel %71 = OpAccessChain %16 %7 %17 -OpBranch %74 -%74 = OpLabel -OpStore %71 %73 +OpBranch %72 +%72 = OpLabel +OpStore %71 %67 OpReturn OpFunctionEnd -%76 = OpFunction %2 None %15 -%75 = OpLabel -%77 = OpAccessChain %16 %7 %17 -%78 = OpAccessChain %29 %10 %17 -OpBranch %79 -%79 = OpLabel -%80 = OpFunctionCall %2 %14 -%81 = OpFunctionCall %2 %28 -%82 = OpFunctionCall %2 %35 -%83 = OpFunctionCall %2 %42 -%84 = OpFunctionCall %2 %64 -%85 = OpFunctionCall %2 %70 +%74 = OpFunction %2 None %15 +%73 = OpLabel +%75 = OpAccessChain %16 %7 %17 +%76 = OpAccessChain %29 %10 %17 +OpBranch %77 +%77 = OpLabel +%78 = OpFunctionCall %2 %14 +%79 = OpFunctionCall %2 %28 +%80 = OpFunctionCall %2 %35 +%81 = OpFunctionCall %2 %42 +%82 = OpFunctionCall %2 %64 +%83 = OpFunctionCall %2 %70 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/const-exprs.wgsl b/tests/out/wgsl/const-exprs.wgsl index 77d24bbdfa..ded8e5042e 100644 --- a/tests/out/wgsl/const-exprs.wgsl +++ b/tests/out/wgsl/const-exprs.wgsl @@ -46,7 +46,7 @@ fn non_constant_initializers() { } fn splat_of_constant() { - out = -(vec4(FOUR)); + out = vec4(-4, -4, -4, -4); return; }