diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index 9dc021275d..75b5d0c0ae 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -213,11 +213,12 @@ impl<'a> ConstantEvaluator<'a> { Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { Ok(self.register_evaluated_expr(expr.clone(), span)) } - Expression::Compose { ref components, .. } => { - for component in components { - self.check(*component)?; - } - Ok(self.register_evaluated_expr(expr.clone(), span)) + Expression::Compose { ty, ref components } => { + let components = components + .iter() + .map(|component| self.check_and_get(*component)) + .collect::, _>>()?; + Ok(self.register_evaluated_expr(Expression::Compose { ty, components }, span)) } Expression::Splat { value, .. } => { self.check(value)?; @@ -1343,4 +1344,87 @@ mod tests { Expression::Literal(Literal::F32(5.)) ); } + + #[test] + fn compose_of_constants() { + let mut types = UniqueArena::new(); + let mut constants = Arena::new(); + let mut const_expressions = Arena::new(); + + let i32_ty = types.insert( + Type { + name: None, + inner: TypeInner::Scalar { + kind: ScalarKind::Sint, + width: 4, + }, + }, + Default::default(), + ); + + let vec2_i32_ty = types.insert( + Type { + name: None, + inner: TypeInner::Vector { + size: VectorSize::Bi, + kind: ScalarKind::Sint, + width: 4, + }, + }, + Default::default(), + ); + + let h = constants.append( + Constant { + name: None, + r#override: crate::Override::None, + ty: i32_ty, + init: const_expressions + .append(Expression::Literal(Literal::I32(4)), Default::default()), + }, + Default::default(), + ); + + let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); + + let mut solver = ConstantEvaluator { + types: &mut types, + constants: &constants, + expressions: &mut const_expressions, + function_local_data: None, + }; + + let solved_compose = solver + .try_eval_and_append( + &Expression::Compose { + ty: vec2_i32_ty, + components: vec![h_expr, h_expr], + }, + Default::default(), + ) + .unwrap(); + let solved_negate = solver + .try_eval_and_append( + &Expression::Unary { + op: UnaryOperator::Negate, + expr: solved_compose, + }, + Default::default(), + ) + .unwrap(); + + let pass = match const_expressions[solved_negate] { + Expression::Compose { ty, ref components } => { + ty == vec2_i32_ty + && components.iter().all(|&component| { + let component = &const_expressions[component]; + matches!(*component, Expression::Literal(Literal::I32(-4))) + }) + } + _ => false, + }; + if !pass { + panic!("unexpected evaluation result") + } + } } diff --git a/tests/out/glsl/const-exprs.main.Compute.glsl b/tests/out/glsl/const-exprs.main.Compute.glsl index fde4c96156..2d60807e90 100644 --- a/tests/out/glsl/const-exprs.main.Compute.glsl +++ b/tests/out/glsl/const-exprs.main.Compute.glsl @@ -57,7 +57,7 @@ void splat_of_constant() { } void compose_of_constant() { - _group_0_binding_0_cs = -(ivec4(FOUR, FOUR, FOUR, FOUR)); + _group_0_binding_0_cs = ivec4(-4, -4, -4, -4); return; } diff --git a/tests/out/hlsl/const-exprs.hlsl b/tests/out/hlsl/const-exprs.hlsl index 6e325a9bf1..572e95e819 100644 --- a/tests/out/hlsl/const-exprs.hlsl +++ b/tests/out/hlsl/const-exprs.hlsl @@ -55,7 +55,7 @@ void splat_of_constant() void compose_of_constant() { - out_.Store4(0, asuint(-(int4(FOUR, FOUR, FOUR, FOUR)))); + out_.Store4(0, asuint(int4(-4, -4, -4, -4))); return; } diff --git a/tests/out/msl/const-exprs.msl b/tests/out/msl/const-exprs.msl index 1d9c8e9caf..e9db1d0fbb 100644 --- a/tests/out/msl/const-exprs.msl +++ b/tests/out/msl/const-exprs.msl @@ -63,7 +63,7 @@ void splat_of_constant( void compose_of_constant( device metal::int4& out ) { - out = -(metal::int4(FOUR, FOUR, FOUR, FOUR)); + out = metal::int4(-4, -4, -4, -4); return; } diff --git a/tests/out/spv/const-exprs.spvasm b/tests/out/spv/const-exprs.spvasm index 9e0d6c14cf..357ecf9196 100644 --- a/tests/out/spv/const-exprs.spvasm +++ b/tests/out/spv/const-exprs.spvasm @@ -45,7 +45,8 @@ OpMemberDecorate %11 0 Offset 0 %49 = OpConstantNull %4 %51 = OpConstantNull %4 %66 = OpConstantComposite %3 %6 %6 %6 %6 -%72 = OpConstantComposite %3 %6 %6 %6 %6 +%72 = OpConstant %4 -4 +%73 = OpConstantComposite %3 %72 %72 %72 %72 %14 = OpFunction %2 None %15 %13 = OpLabel %19 = OpAccessChain %16 %7 %17 @@ -109,10 +110,9 @@ OpFunctionEnd %70 = OpFunction %2 None %15 %69 = OpLabel %71 = OpAccessChain %16 %7 %17 -OpBranch %73 -%73 = OpLabel -%74 = OpSNegate %3 %72 -OpStore %71 %74 +OpBranch %74 +%74 = OpLabel +OpStore %71 %73 OpReturn OpFunctionEnd %76 = OpFunction %2 None %15 diff --git a/tests/out/wgsl/const-exprs.wgsl b/tests/out/wgsl/const-exprs.wgsl index 47744ede0d..77d24bbdfa 100644 --- a/tests/out/wgsl/const-exprs.wgsl +++ b/tests/out/wgsl/const-exprs.wgsl @@ -51,7 +51,7 @@ fn splat_of_constant() { } fn compose_of_constant() { - out = -(vec4(FOUR, FOUR, FOUR, FOUR)); + out = vec4(-4, -4, -4, -4); return; } diff --git a/tests/out/wgsl/module-scope.wgsl b/tests/out/wgsl/module-scope.wgsl index c30052dd1d..b746ff37ca 100644 --- a/tests/out/wgsl/module-scope.wgsl +++ b/tests/out/wgsl/module-scope.wgsl @@ -14,7 +14,7 @@ fn statement() { } fn returns() -> S { - return S(Value); + return S(1); } fn call() {