From c2f70aaf809b6ba5733dbf07878304d003d9f9fa Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Fri, 19 Mar 2021 00:40:48 -0400 Subject: [PATCH] Validate binary expressions --- src/front/wgsl/mod.rs | 16 ++- src/proc/validator.rs | 190 +++++++++++++++++++++++++++- tests/in/boids.wgsl | 2 +- tests/in/texture-array.wgsl | 2 +- tests/out/boids.msl.snap | 2 +- tests/out/boids.spvasm.snap | 96 +++++++------- tests/out/texture-array.spvasm.snap | 74 +++++------ 7 files changed, 287 insertions(+), 95 deletions(-) diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 7586dc1aff..e54fac72a2 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -2575,7 +2575,7 @@ impl Parser { lexer.expect(Token::Separator(';'))?; } (Token::Word("const"), _) => { - let (name, _ty, _access) = self.parse_variable_ident_decl( + let (name, explicit_ty, _access) = self.parse_variable_ident_decl( lexer, &mut module.types, &mut module.constants, @@ -2589,6 +2589,20 @@ impl Parser { &mut module.types, &mut module.constants, )?; + let con = &module.constants[const_handle]; + let type_match = match con.inner { + crate::ConstantInner::Scalar { width, value } => { + module.types[explicit_ty].inner + == crate::TypeInner::Scalar { + kind: value.scalar_kind(), + width, + } + } + crate::ConstantInner::Composite { ty, components: _ } => ty == explicit_ty, + }; + if !type_match { + return Err(Error::ConstTypeMismatch(name, explicit_ty)); + } //TODO: check `ty` against `const_handle`. lexer.expect(Token::Separator(';'))?; lookup_global_expression.insert(name, crate::Expression::Constant(const_handle)); diff --git a/src/proc/validator.rs b/src/proc/validator.rs index 49466d83b0..d18240c872 100644 --- a/src/proc/validator.rs +++ b/src/proc/validator.rs @@ -181,6 +181,14 @@ pub enum ExpressionError { InvalidComposeCount { given: u32, expected: u32 }, #[error("Composing {0}'s component {1:?} is not expected")] InvalidComponentType(u32, Handle), + #[error("Operation {0:?} can't work with {1:?}")] + InvalidUnaryOperandType(crate::UnaryOperator, Handle), + #[error("Operation {0:?} can't work with {1:?} and {2:?}")] + InvalidBinaryOperandTypes( + crate::BinaryOperator, + Handle, + Handle, + ), } #[derive(Clone, Debug, Error)] @@ -887,7 +895,7 @@ impl Validator { stage: Option, module: &crate::Module, ) -> Result<(), ExpressionError> { - use crate::{Expression as E, TypeInner as Ti}; + use crate::{Expression as E, ScalarKind as Sk, TypeInner as Ti}; let resolver = ExpressionTypeResolver { root, @@ -910,11 +918,11 @@ impl Validator { match *resolver.resolve(index)? { //TODO: only allow one of these Ti::Scalar { - kind: crate::ScalarKind::Sint, + kind: Sk::Sint, width: _, } | Ti::Scalar { - kind: crate::ScalarKind::Uint, + kind: Sk::Uint, width: _, } => {} ref other => { @@ -998,7 +1006,7 @@ impl Validator { } => { let inner = Ti::Vector { size: rows, - kind: crate::ScalarKind::Float, + kind: Sk::Float, width, }; if columns as usize != components.len() { @@ -1109,8 +1117,178 @@ impl Validator { index, } => {} E::ImageQuery { image, query } => {} - E::Unary { op, expr } => {} - E::Binary { op, left, right } => {} + E::Unary { op, expr } => { + use crate::UnaryOperator as Uo; + let inner = resolver.resolve(expr)?; + match (op, inner.scalar_kind()) { + (_, Some(Sk::Sint)) + | (_, Some(Sk::Bool)) + | (Uo::Negate, Some(Sk::Float)) + | (Uo::Not, Some(Sk::Uint)) => {} + other => { + log::error!("Op {:?} kind {:?}", op, other); + return Err(ExpressionError::InvalidUnaryOperandType(op, expr)); + } + } + } + E::Binary { op, left, right } => { + use crate::BinaryOperator as Bo; + let left_inner = resolver.resolve(left)?; + let right_inner = resolver.resolve(right)?; + let good = match op { + Bo::Add | Bo::Subtract | Bo::Divide | Bo::Modulo => match *left_inner { + Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind { + Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Bool => false, + }, + _ => false, + }, + Bo::Multiply => { + let kind_match = match left_inner.scalar_kind() { + Some(Sk::Uint) | Some(Sk::Sint) | Some(Sk::Float) => true, + Some(Sk::Bool) | None => false, + }; + //TODO: should we be more restrictive here? I.e. expect scalar only to the left. + let types_match = match (left_inner, right_inner) { + (&Ti::Scalar { kind: kind1, .. }, &Ti::Scalar { kind: kind2, .. }) + | (&Ti::Vector { kind: kind1, .. }, &Ti::Scalar { kind: kind2, .. }) + | (&Ti::Scalar { kind: kind1, .. }, &Ti::Vector { kind: kind2, .. }) => { + kind1 == kind2 + } + ( + &Ti::Scalar { + kind: Sk::Float, .. + }, + &Ti::Matrix { .. }, + ) + | ( + &Ti::Matrix { .. }, + &Ti::Scalar { + kind: Sk::Float, .. + }, + ) => true, + ( + &Ti::Vector { + kind: kind1, + size: size1, + .. + }, + &Ti::Vector { + kind: kind2, + size: size2, + .. + }, + ) => kind1 == kind2 && size1 == size2, + ( + &Ti::Matrix { columns, .. }, + &Ti::Vector { + kind: Sk::Float, + size, + .. + }, + ) => columns == size, + ( + &Ti::Vector { + kind: Sk::Float, + size, + .. + }, + &Ti::Matrix { rows, .. }, + ) => size == rows, + (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => { + columns == rows + } + _ => false, + }; + let left_width = match *left_inner { + Ti::Scalar { width, .. } + | Ti::Vector { width, .. } + | Ti::Matrix { width, .. } => width, + _ => 0, + }; + let right_width = match *right_inner { + Ti::Scalar { width, .. } + | Ti::Vector { width, .. } + | Ti::Matrix { width, .. } => width, + _ => 0, + }; + kind_match && types_match && left_width == right_width + } + Bo::Equal | Bo::NotEqual => match *left_inner { + Ti::Scalar { .. } + | Ti::Vector { .. } + | Ti::Matrix { .. } + | Ti::Array { + size: crate::ArraySize::Constant(_), + .. + } + | Ti::Pointer { .. } + | Ti::ValuePointer { .. } + | Ti::Struct { .. } => left_inner == right_inner, + Ti::Array { .. } | Ti::Image { .. } | Ti::Sampler { .. } => false, + }, + Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => { + match *left_inner { + Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind { + Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Bool => false, + }, + ref other => { + log::error!("Op {:?} left type {:?}", op, other); + false + } + } + } + Bo::LogicalAnd | Bo::LogicalOr => match *left_inner { + Ti::Scalar { kind: Sk::Bool, .. } | Ti::Vector { kind: Sk::Bool, .. } => { + left_inner == right_inner + } + ref other => { + log::error!("Op {:?} left type {:?}", op, other); + false + } + }, + Bo::And | Bo::ExclusiveOr | Bo::InclusiveOr => match *left_inner { + Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind { + Sk::Sint | Sk::Uint => left_inner == right_inner, + Sk::Bool | Sk::Float => false, + }, + ref other => { + log::error!("Op {:?} left type {:?}", op, other); + false + } + }, + Bo::ShiftLeft | Bo::ShiftRight => { + let (base_size, base_kind) = match *left_inner { + Ti::Scalar { kind, .. } => (Ok(None), kind), + Ti::Vector { size, kind, .. } => (Ok(Some(size)), kind), + ref other => { + log::error!("Op {:?} base type {:?}", op, other); + (Err(()), Sk::Bool) + } + }; + let shift_size = match *right_inner { + Ti::Scalar { kind: Sk::Uint, .. } => Ok(None), + Ti::Vector { + size, + kind: Sk::Uint, + .. + } => Ok(Some(size)), + ref other => { + log::error!("Op {:?} shift type {:?}", op, other); + Err(()) + } + }; + match base_kind { + Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size, + Sk::Float | Sk::Bool => false, + } + } + }; + if !good { + return Err(ExpressionError::InvalidBinaryOperandTypes(op, left, right)); + } + } E::Select { condition, accept, diff --git a/tests/in/boids.wgsl b/tests/in/boids.wgsl index 3b404f9f06..f41a6ffa0e 100644 --- a/tests/in/boids.wgsl +++ b/tests/in/boids.wgsl @@ -1,4 +1,4 @@ -const NUM_PARTICLES: u32 = 1500; +const NUM_PARTICLES: u32 = 1500u; [[block]] struct Particle { diff --git a/tests/in/texture-array.wgsl b/tests/in/texture-array.wgsl index ca14d8c23a..5f2e09b35f 100644 --- a/tests/in/texture-array.wgsl +++ b/tests/in/texture-array.wgsl @@ -10,7 +10,7 @@ var pc: PushConstants; [[stage(fragment)]] fn main([[location(0)]] tex_coord: vec2) -> [[location(1)]] vec4 { - if (pc.index == 0) { + if (pc.index == 0u) { return textureSample(texture0, sampler, tex_coord); } else { return textureSample(texture1, sampler, tex_coord); diff --git a/tests/out/boids.msl.snap b/tests/out/boids.msl.snap index 7fe3c2c824..816ad79a49 100644 --- a/tests/out/boids.msl.snap +++ b/tests/out/boids.msl.snap @@ -27,7 +27,7 @@ struct Particles { }; typedef metal::uint3 type4; typedef int type5; -constexpr constant int NUM_PARTICLES = 1500; +constexpr constant unsigned NUM_PARTICLES = 1500u; constexpr constant float const_0f = 0.0; constexpr constant int const_0i = 0; constexpr constant unsigned const_0u = 0u; diff --git a/tests/out/boids.spvasm.snap b/tests/out/boids.spvasm.snap index 80e3f0860d..1564f5c013 100644 --- a/tests/out/boids.spvasm.snap +++ b/tests/out/boids.spvasm.snap @@ -64,15 +64,15 @@ OpDecorate %24 DescriptorSet 0 OpDecorate %24 Binding 2 OpDecorate %40 BuiltIn GlobalInvocationId %2 = OpTypeVoid -%4 = OpTypeInt 32 1 +%4 = OpTypeInt 32 0 %3 = OpConstant %4 1500 %6 = OpTypeFloat 32 %5 = OpConstant %6 0.0 -%7 = OpConstant %4 0 -%9 = OpTypeInt 32 0 -%8 = OpConstant %9 0 -%10 = OpConstant %4 1 -%11 = OpConstant %9 1 +%8 = OpTypeInt 32 1 +%7 = OpConstant %8 0 +%9 = OpConstant %4 0 +%10 = OpConstant %8 1 +%11 = OpConstant %4 1 %12 = OpConstant %6 1.0 %13 = OpConstant %6 0.1 %14 = OpConstant %6 -1.0 @@ -87,9 +87,9 @@ OpDecorate %40 BuiltIn GlobalInvocationId %18 = OpVariable %23 Uniform %24 = OpVariable %23 Uniform %26 = OpTypePointer Function %22 -%32 = OpTypePointer Function %4 -%37 = OpTypePointer Function %9 -%39 = OpTypeVector %9 3 +%32 = OpTypePointer Function %8 +%37 = OpTypePointer Function %4 +%39 = OpTypeVector %4 3 %41 = OpTypePointer Input %39 %40 = OpVariable %41 Input %44 = OpTypeFunction %2 @@ -97,34 +97,34 @@ OpDecorate %40 BuiltIn GlobalInvocationId %51 = OpTypePointer Uniform %20 %52 = OpTypePointer Uniform %21 %53 = OpTypePointer Uniform %22 -%54 = OpConstant %4 0 -%55 = OpConstant %4 0 -%58 = OpConstant %4 1 -%59 = OpConstant %4 0 -%78 = OpConstant %4 0 -%79 = OpConstant %4 0 -%83 = OpConstant %4 1 -%84 = OpConstant %4 0 +%54 = OpConstant %8 0 +%55 = OpConstant %8 0 +%58 = OpConstant %8 1 +%59 = OpConstant %8 0 +%78 = OpConstant %8 0 +%79 = OpConstant %8 0 +%83 = OpConstant %8 1 +%84 = OpConstant %8 0 %90 = OpTypePointer Uniform %6 -%91 = OpConstant %4 1 -%105 = OpConstant %4 2 -%119 = OpConstant %4 3 -%154 = OpConstant %4 4 -%160 = OpConstant %4 5 -%166 = OpConstant %4 6 -%179 = OpConstant %4 0 +%91 = OpConstant %8 1 +%105 = OpConstant %8 2 +%119 = OpConstant %8 3 +%154 = OpConstant %8 4 +%160 = OpConstant %8 5 +%166 = OpConstant %8 6 +%179 = OpConstant %8 0 %189 = OpTypePointer Function %6 -%190 = OpConstant %4 0 -%197 = OpConstant %4 0 -%204 = OpConstant %4 1 -%211 = OpConstant %4 1 -%214 = OpConstant %4 0 -%215 = OpConstant %4 0 -%218 = OpConstant %4 1 -%219 = OpConstant %4 0 +%190 = OpConstant %8 0 +%197 = OpConstant %8 0 +%204 = OpConstant %8 1 +%211 = OpConstant %8 1 +%214 = OpConstant %8 0 +%215 = OpConstant %8 0 +%218 = OpConstant %8 1 +%219 = OpConstant %8 0 %43 = OpFunction %2 None %44 %38 = OpLabel -%36 = OpVariable %37 Function %8 +%36 = OpVariable %37 Function %9 %33 = OpVariable %32 Function %7 %29 = OpVariable %26 Function %25 = OpVariable %26 Function @@ -137,7 +137,7 @@ OpDecorate %40 BuiltIn GlobalInvocationId %42 = OpLoad %39 %40 OpBranch %45 %45 = OpLabel -%46 = OpCompositeExtract %9 %42 0 +%46 = OpCompositeExtract %4 %42 0 %48 = OpUGreaterThanEqual %47 %46 %3 OpSelectionMerge %49 None OpBranchConditional %48 %50 %49 @@ -161,25 +161,25 @@ OpBranch %65 OpLoopMerge %66 %68 None OpBranch %67 %67 = OpLabel -%69 = OpLoad %9 %36 +%69 = OpLoad %4 %36 %70 = OpUGreaterThanEqual %47 %69 %3 OpSelectionMerge %71 None OpBranchConditional %70 %72 %71 %72 = OpLabel OpBranch %66 %71 = OpLabel -%73 = OpLoad %9 %36 +%73 = OpLoad %4 %36 %74 = OpIEqual %47 %73 %46 OpSelectionMerge %75 None OpBranchConditional %74 %76 %75 %76 = OpLabel OpBranch %68 %75 = OpLabel -%77 = OpLoad %9 %36 +%77 = OpLoad %4 %36 %80 = OpAccessChain %53 %18 %79 %77 %78 %81 = OpLoad %22 %80 OpStore %34 %81 -%82 = OpLoad %9 %36 +%82 = OpLoad %4 %36 %85 = OpAccessChain %53 %18 %84 %82 %83 %86 = OpLoad %22 %85 OpStore %35 %86 @@ -196,8 +196,8 @@ OpBranchConditional %94 %96 %95 %98 = OpLoad %22 %34 %99 = OpFAdd %22 %97 %98 OpStore %28 %99 -%100 = OpLoad %4 %31 -%101 = OpIAdd %4 %100 %10 +%100 = OpLoad %8 %31 +%101 = OpIAdd %8 %100 %10 OpStore %31 %101 OpBranch %95 %95 = OpLabel @@ -231,25 +231,25 @@ OpBranchConditional %122 %124 %123 %126 = OpLoad %22 %35 %127 = OpFAdd %22 %125 %126 OpStore %29 %127 -%128 = OpLoad %4 %33 -%129 = OpIAdd %4 %128 %10 +%128 = OpLoad %8 %33 +%129 = OpIAdd %8 %128 %10 OpStore %33 %129 OpBranch %123 %123 = OpLabel OpBranch %68 %68 = OpLabel -%130 = OpLoad %9 %36 -%131 = OpIAdd %9 %130 %11 +%130 = OpLoad %4 %36 +%131 = OpIAdd %4 %130 %11 OpStore %36 %131 OpBranch %65 %66 = OpLabel -%132 = OpLoad %4 %31 +%132 = OpLoad %8 %31 %133 = OpSGreaterThan %47 %132 %7 OpSelectionMerge %134 None OpBranchConditional %133 %135 %134 %135 = OpLabel %136 = OpLoad %22 %28 -%137 = OpLoad %4 %31 +%137 = OpLoad %8 %31 %138 = OpConvertSToF %6 %137 %139 = OpFDiv %6 %12 %138 %140 = OpVectorTimesScalar %22 %136 %139 @@ -258,13 +258,13 @@ OpBranchConditional %133 %135 %134 OpStore %28 %142 OpBranch %134 %134 = OpLabel -%143 = OpLoad %4 %33 +%143 = OpLoad %8 %33 %144 = OpSGreaterThan %47 %143 %7 OpSelectionMerge %145 None OpBranchConditional %144 %146 %145 %146 = OpLabel %147 = OpLoad %22 %29 -%148 = OpLoad %4 %33 +%148 = OpLoad %8 %33 %149 = OpConvertSToF %6 %148 %150 = OpFDiv %6 %12 %149 %151 = OpVectorTimesScalar %22 %147 %150 diff --git a/tests/out/texture-array.spvasm.snap b/tests/out/texture-array.spvasm.snap index aa9942f94b..ae41626ab2 100644 --- a/tests/out/texture-array.spvasm.snap +++ b/tests/out/texture-array.spvasm.snap @@ -9,8 +9,8 @@ expression: dis OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint Fragment %25 "main" %19 %23 -OpExecutionMode %25 OriginUpperLeft +OpEntryPoint Fragment %24 "main" %18 %22 +OpExecutionMode %24 OriginUpperLeft OpSource GLSL 450 OpName %5 "texture0" OpName %9 "texture1" @@ -18,9 +18,9 @@ OpName %10 "sampler" OpName %14 "PushConstants" OpMemberName %14 0 "index" OpName %13 "pc" -OpName %19 "tex_coord" -OpName %25 "main" -OpName %25 "main" +OpName %18 "tex_coord" +OpName %24 "main" +OpName %24 "main" OpDecorate %5 DescriptorSet 0 OpDecorate %5 Binding 0 OpDecorate %9 DescriptorSet 0 @@ -29,10 +29,10 @@ OpDecorate %10 DescriptorSet 0 OpDecorate %10 Binding 2 OpDecorate %14 Block OpMemberDecorate %14 0 Offset 0 -OpDecorate %19 Location 0 -OpDecorate %23 Location 1 +OpDecorate %18 Location 0 +OpDecorate %22 Location 1 %2 = OpTypeVoid -%4 = OpTypeInt 32 1 +%4 = OpTypeInt 32 0 %3 = OpConstant %4 0 %7 = OpTypeFloat 32 %6 = OpTypeImage %7 2D 0 0 0 1 Unknown @@ -42,43 +42,43 @@ OpDecorate %23 Location 1 %11 = OpTypeSampler %12 = OpTypePointer UniformConstant %11 %10 = OpVariable %12 UniformConstant -%15 = OpTypeInt 32 0 -%14 = OpTypeStruct %15 -%16 = OpTypePointer PushConstant %14 -%13 = OpVariable %16 PushConstant -%18 = OpTypeVector %7 2 -%20 = OpTypePointer Input %18 -%19 = OpVariable %20 Input -%22 = OpTypeVector %7 4 -%24 = OpTypePointer Output %22 -%23 = OpVariable %24 Output -%26 = OpTypeFunction %2 -%31 = OpTypePointer PushConstant %15 -%32 = OpConstant %4 0 +%14 = OpTypeStruct %4 +%15 = OpTypePointer PushConstant %14 +%13 = OpVariable %15 PushConstant +%17 = OpTypeVector %7 2 +%19 = OpTypePointer Input %17 +%18 = OpVariable %19 Input +%21 = OpTypeVector %7 4 +%23 = OpTypePointer Output %21 +%22 = OpVariable %23 Output +%25 = OpTypeFunction %2 +%30 = OpTypePointer PushConstant %4 +%31 = OpTypeInt 32 1 +%32 = OpConstant %31 0 %35 = OpTypeBool %40 = OpTypeSampledImage %6 -%25 = OpFunction %2 None %26 -%17 = OpLabel -%21 = OpLoad %18 %19 -%27 = OpLoad %6 %5 -%28 = OpLoad %6 %9 -%29 = OpLoad %11 %10 -OpBranch %30 -%30 = OpLabel -%33 = OpAccessChain %31 %13 %32 -%34 = OpLoad %15 %33 +%24 = OpFunction %2 None %25 +%16 = OpLabel +%20 = OpLoad %17 %18 +%26 = OpLoad %6 %5 +%27 = OpLoad %6 %9 +%28 = OpLoad %11 %10 +OpBranch %29 +%29 = OpLabel +%33 = OpAccessChain %30 %13 %32 +%34 = OpLoad %4 %33 %36 = OpIEqual %35 %34 %3 OpSelectionMerge %37 None OpBranchConditional %36 %38 %39 %38 = OpLabel -%41 = OpSampledImage %40 %27 %29 -%42 = OpImageSampleImplicitLod %22 %41 %21 -OpStore %23 %42 +%41 = OpSampledImage %40 %26 %28 +%42 = OpImageSampleImplicitLod %21 %41 %20 +OpStore %22 %42 OpReturn %39 = OpLabel -%43 = OpSampledImage %40 %28 %29 -%44 = OpImageSampleImplicitLod %22 %43 %21 -OpStore %23 %44 +%43 = OpSampledImage %40 %27 %28 +%44 = OpImageSampleImplicitLod %21 %43 %20 +OpStore %22 %44 OpReturn %37 = OpLabel OpReturn