diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index d2d64afd42..182b16e7bb 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -119,6 +119,8 @@ pub enum Error<'a> { UnknownLocalFunction(&'a str), #[error("builtin {0:?} is not implemented")] UnimplementedBuiltin(crate::BuiltIn), + #[error("expression {0} doesn't match its given type {1:?}")] + ConstTypeMismatch(&'a str, Handle), #[error("other error")] Other, } @@ -1984,12 +1986,34 @@ impl Parser { match word { "const" => { emitter.start(context.expressions); - let (name, _ty, _access) = - self.parse_variable_ident_decl(lexer, context.types, context.constants)?; + let name = lexer.next_ident()?; + let given_ty = if lexer.skip(Token::Separator(':')) { + let (ty, _access) = + self.parse_type_decl(lexer, None, context.types, context.constants)?; + Some(ty) + } else { + None + }; lexer.expect(Token::Operation('='))?; let expr_id = self .parse_general_expression(lexer, context.as_expression(block, &mut emitter))?; lexer.expect(Token::Separator(';'))?; + if let Some(ty) = given_ty { + // prepare the typifier, but work around mutable borrowing... + let _ = context + .as_expression(block, &mut emitter) + .resolve_type(expr_id)?; + let expr_inner = context.typifier.get(expr_id, context.types); + let given_inner = &context.types[ty].inner; + if given_inner != expr_inner { + log::error!( + "Given type {:?} doesn't match expected {:?}", + given_inner, + expr_inner + ); + return Err(Error::ConstTypeMismatch(name, ty)); + } + } block.extend(emitter.finish(context.expressions)); context.lookup_ident.insert(name, expr_id); } diff --git a/src/front/wgsl/tests.rs b/src/front/wgsl/tests.rs index e24b4f88d5..847e3dea7e 100644 --- a/src/front/wgsl/tests.rs +++ b/src/front/wgsl/tests.rs @@ -25,6 +25,23 @@ fn parse_types() { parse_str("var t: [[access(read)]] texture_storage_3d;").unwrap(); } +#[test] +fn parse_type_inference() { + parse_str( + " + fn foo() { + const a = 2u; + const b: u32 = a; + }", + ) + .unwrap(); + assert!(parse_str( + " + fn foo() { const c : i32 = 2.0; }", + ) + .is_err()); +} + #[test] fn parse_type_cast() { parse_str( diff --git a/tests/in/quad.wgsl b/tests/in/quad.wgsl index 514f11e9be..6283687ef7 100644 --- a/tests/in/quad.wgsl +++ b/tests/in/quad.wgsl @@ -20,10 +20,12 @@ fn main([[location(0)]] pos : vec2, [[location(1)]] uv : vec2) -> Vert [[stage(fragment)]] fn main([[location(0)]] uv : vec2) -> [[location(0)]] vec4 { - const color: vec4 = textureSample(u_texture, u_sampler, uv); + const color = textureSample(u_texture, u_sampler, uv); if (color.a == 0.0) { discard; } - const premultiplied: vec4 = color.a * color; + // forcing the expression here to be emitted in order to check the + // uniformity of the control flow a bit more strongly. + const premultiplied = color.a * color; return premultiplied; } diff --git a/tests/in/shadow.wgsl b/tests/in/shadow.wgsl index e6c2408a57..d68b93ef5c 100644 --- a/tests/in/shadow.wgsl +++ b/tests/in/shadow.wgsl @@ -29,9 +29,9 @@ fn fetch_shadow(light_id: u32, homogeneous_coords: vec4) -> f32 { if (homogeneous_coords.w <= 0.0) { return 1.0; } - const flip_correction: vec2 = vec2(0.5, -0.5); - const proj_correction: f32 = 1.0 / homogeneous_coords.w; - const light_local: vec2 = homogeneous_coords.xy * flip_correction * proj_correction + vec2(0.5, 0.5); + const flip_correction = vec2(0.5, -0.5); + const proj_correction = 1.0 / homogeneous_coords.w; + const light_local = homogeneous_coords.xy * flip_correction * proj_correction + vec2(0.5, 0.5); return textureSampleCompare(t_shadow, sampler_shadow, light_local, i32(light_id), homogeneous_coords.z * proj_correction); } @@ -51,10 +51,10 @@ fn fs_main( if (i >= min(u_globals.num_lights.x, c_max_lights)) { break; } - const light: Light = s_lights.data[i]; - const shadow: f32 = fetch_shadow(i, light.proj * position); - const light_dir: vec3 = normalize(light.pos.xyz - position.xyz); - const diffuse: f32 = max(0.0, dot(normal, light_dir)); + const light = s_lights.data[i]; + const shadow = fetch_shadow(i, light.proj * position); + const light_dir = normalize(light.pos.xyz - position.xyz); + const diffuse = max(0.0, dot(normal, light_dir)); color = color + shadow * diffuse * light.color.xyz; continuing { i = i + 1u; diff --git a/tests/in/skybox.wgsl b/tests/in/skybox.wgsl index bdb9ea8209..cd6d46721d 100644 --- a/tests/in/skybox.wgsl +++ b/tests/in/skybox.wgsl @@ -16,15 +16,15 @@ fn vs_main([[builtin(vertex_index)]] vertex_index: u32) -> VertexOutput { // hacky way to draw a large triangle var tmp1: i32 = i32(vertex_index) / 2; var tmp2: i32 = i32(vertex_index) & 1; - const pos: vec4 = vec4( + const pos = vec4( f32(tmp1) * 4.0 - 1.0, f32(tmp2) * 4.0 - 1.0, 0.0, 1.0 ); - const inv_model_view: mat3x3 = transpose(mat3x3(r_data.view.x.xyz, r_data.view.y.xyz, r_data.view.z.xyz)); - var unprojected: vec4 = r_data.proj_inv * pos; //TODO: const + const inv_model_view = transpose(mat3x3(r_data.view.x.xyz, r_data.view.y.xyz, r_data.view.z.xyz)); + const unprojected = r_data.proj_inv * pos; var out: VertexOutput; out.uv = inv_model_view * unprojected.xyz; out.position = pos; diff --git a/tests/out/skybox-Vertex.glsl.snap b/tests/out/skybox-Vertex.glsl.snap index 50667e689e..4b71a76fad 100644 --- a/tests/out/skybox-Vertex.glsl.snap +++ b/tests/out/skybox-Vertex.glsl.snap @@ -22,14 +22,12 @@ void main() { uint vertex_index = uint(gl_VertexID); int tmp1_; int tmp2_; - vec4 unprojected; VertexOutput out1; tmp1_ = (int(vertex_index) / 2); tmp2_ = (int(vertex_index) & 1); vec4 _expr24 = vec4(((float(tmp1_) * 4.0) - 1.0), ((float(tmp2_) * 4.0) - 1.0), 0.0, 1.0); - unprojected = (_group_0_binding_0.proj_inv * _expr24); - vec4 _expr54 = unprojected; - out1.uv = (transpose(mat3x3(vec3(_group_0_binding_0.view[0][0], _group_0_binding_0.view[0][1], _group_0_binding_0.view[0][2]), vec3(_group_0_binding_0.view[1][0], _group_0_binding_0.view[1][1], _group_0_binding_0.view[1][2]), vec3(_group_0_binding_0.view[2][0], _group_0_binding_0.view[2][1], _group_0_binding_0.view[2][2]))) * vec3(_expr54[0], _expr54[1], _expr54[2])); + vec4 _expr50 = (_group_0_binding_0.proj_inv * _expr24); + out1.uv = (transpose(mat3x3(vec3(_group_0_binding_0.view[0][0], _group_0_binding_0.view[0][1], _group_0_binding_0.view[0][2]), vec3(_group_0_binding_0.view[1][0], _group_0_binding_0.view[1][1], _group_0_binding_0.view[1][2]), vec3(_group_0_binding_0.view[2][0], _group_0_binding_0.view[2][1], _group_0_binding_0.view[2][2]))) * vec3(_expr50[0], _expr50[1], _expr50[2])); out1.position = _expr24; gl_Position = out1.position; _out_location_0 = out1.uv; diff --git a/tests/out/skybox.msl.snap b/tests/out/skybox.msl.snap index effce33ea9..4345f58138 100644 --- a/tests/out/skybox.msl.snap +++ b/tests/out/skybox.msl.snap @@ -48,14 +48,12 @@ vertex vs_mainOutput vs_main( ) { type4 tmp1_; type4 tmp2_; - type unprojected; VertexOutput out; tmp1_ = (static_cast(vertex_index) / const_2i); tmp2_ = (static_cast(vertex_index) & const_1i); type _expr24 = metal::float4(((static_cast(tmp1_) * const_4f) - const_1f), ((static_cast(tmp2_) * const_4f) - const_1f), const_0f, const_1f); - unprojected = (r_data.proj_inv * _expr24); - type _expr54 = unprojected; - out.uv = (metal::transpose(metal::float3x3(metal::float3(r_data.view[0].x, r_data.view[0].y, r_data.view[0].z), metal::float3(r_data.view[1].x, r_data.view[1].y, r_data.view[1].z), metal::float3(r_data.view[2].x, r_data.view[2].y, r_data.view[2].z))) * metal::float3(_expr54.x, _expr54.y, _expr54.z)); + metal::float4 _expr50 = (r_data.proj_inv * _expr24); + out.uv = (metal::transpose(metal::float3x3(metal::float3(r_data.view[0].x, r_data.view[0].y, r_data.view[0].z), metal::float3(r_data.view[1].x, r_data.view[1].y, r_data.view[1].z), metal::float3(r_data.view[2].x, r_data.view[2].y, r_data.view[2].z))) * metal::float3(_expr50.x, _expr50.y, _expr50.z)); out.position = _expr24; const auto _tmp = out; return vs_mainOutput { _tmp.position, _tmp.uv }; diff --git a/tests/out/skybox.spvasm.snap b/tests/out/skybox.spvasm.snap index 19e3b10700..8b197e2db0 100644 --- a/tests/out/skybox.spvasm.snap +++ b/tests/out/skybox.spvasm.snap @@ -5,13 +5,13 @@ expression: dis ; SPIR-V ; Version: 1.0 ; Generator: rspirv -; Bound: 113 +; Bound: 111 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint Vertex %39 "vs_main" %32 %35 %37 -OpEntryPoint Fragment %106 "fs_main" %102 %105 -OpExecutionMode %106 OriginUpperLeft +OpEntryPoint Vertex %37 "vs_main" %30 %33 %35 +OpEntryPoint Fragment %104 "fs_main" %100 %103 +OpExecutionMode %104 OriginUpperLeft OpSource GLSL 450 OpName %11 "Data" OpMemberName %11 0 "proj_inv" @@ -21,17 +21,16 @@ OpName %15 "r_texture" OpName %18 "r_sampler" OpName %21 "tmp1" OpName %23 "tmp2" -OpName %24 "unprojected" -OpName %26 "out" -OpName %27 "VertexOutput" -OpName %32 "vertex_index" -OpName %35 "position" -OpName %37 "uv" -OpName %39 "vs_main" -OpName %39 "vs_main" -OpName %102 "uv" -OpName %106 "fs_main" -OpName %106 "fs_main" +OpName %24 "out" +OpName %25 "VertexOutput" +OpName %30 "vertex_index" +OpName %33 "position" +OpName %35 "uv" +OpName %37 "vs_main" +OpName %37 "vs_main" +OpName %100 "uv" +OpName %104 "fs_main" +OpName %104 "fs_main" OpDecorate %11 Block OpMemberDecorate %11 0 Offset 0 OpMemberDecorate %11 0 ColMajor @@ -45,11 +44,11 @@ OpDecorate %15 DescriptorSet 0 OpDecorate %15 Binding 1 OpDecorate %18 DescriptorSet 0 OpDecorate %18 Binding 2 -OpDecorate %32 BuiltIn VertexIndex -OpDecorate %35 BuiltIn Position -OpDecorate %37 Location 0 -OpDecorate %102 Location 0 -OpDecorate %105 Location 0 +OpDecorate %30 BuiltIn VertexIndex +OpDecorate %33 BuiltIn Position +OpDecorate %35 Location 0 +OpDecorate %100 Location 0 +OpDecorate %103 Location 0 %2 = OpTypeVoid %4 = OpTypeInt 32 1 %3 = OpConstant %4 2 @@ -70,108 +69,105 @@ OpDecorate %105 Location 0 %20 = OpTypePointer UniformConstant %19 %18 = OpVariable %20 UniformConstant %22 = OpTypePointer Function %4 -%25 = OpTypePointer Function %13 -%28 = OpTypeVector %7 3 -%27 = OpTypeStruct %13 %28 -%29 = OpTypePointer Function %27 -%31 = OpTypeInt 32 0 -%33 = OpTypePointer Input %31 -%32 = OpVariable %33 Input -%36 = OpTypePointer Output %13 +%26 = OpTypeVector %7 3 +%25 = OpTypeStruct %13 %26 +%27 = OpTypePointer Function %25 +%29 = OpTypeInt 32 0 +%31 = OpTypePointer Input %29 +%30 = OpVariable %31 Input +%34 = OpTypePointer Output %13 +%33 = OpVariable %34 Output +%36 = OpTypePointer Output %26 %35 = OpVariable %36 Output -%38 = OpTypePointer Output %28 -%37 = OpVariable %38 Output -%40 = OpTypeFunction %2 -%55 = OpTypePointer Uniform %12 -%56 = OpConstant %4 1 -%64 = OpConstant %4 1 -%72 = OpConstant %4 1 -%80 = OpTypeMatrix %28 3 -%83 = OpConstant %4 0 -%87 = OpTypePointer Function %28 -%94 = OpConstant %4 1 -%96 = OpConstant %4 0 -%103 = OpTypePointer Input %28 -%102 = OpVariable %103 Input -%105 = OpVariable %36 Output -%110 = OpTypeSampledImage %16 -%39 = OpFunction %2 None %40 -%30 = OpLabel -%23 = OpVariable %22 Function -%26 = OpVariable %29 Function +%38 = OpTypeFunction %2 +%53 = OpTypePointer Uniform %12 +%54 = OpConstant %4 1 +%62 = OpConstant %4 1 +%70 = OpConstant %4 1 +%78 = OpTypeMatrix %26 3 +%81 = OpConstant %4 0 +%85 = OpTypePointer Function %26 +%91 = OpConstant %4 1 +%93 = OpTypePointer Function %13 +%94 = OpConstant %4 0 +%101 = OpTypePointer Input %26 +%100 = OpVariable %101 Input +%103 = OpVariable %34 Output +%108 = OpTypeSampledImage %16 +%37 = OpFunction %2 None %38 +%28 = OpLabel %21 = OpVariable %22 Function -%24 = OpVariable %25 Function -%34 = OpLoad %31 %32 -OpBranch %41 -%41 = OpLabel -%42 = OpBitcast %4 %34 -%43 = OpSDiv %4 %42 %3 -OpStore %21 %43 -%44 = OpBitcast %4 %34 -%45 = OpBitwiseAnd %4 %44 %5 -OpStore %23 %45 -%46 = OpLoad %4 %21 -%47 = OpConvertSToF %7 %46 -%48 = OpFMul %7 %47 %6 -%49 = OpFSub %7 %48 %8 -%50 = OpLoad %4 %23 -%51 = OpConvertSToF %7 %50 -%52 = OpFMul %7 %51 %6 -%53 = OpFSub %7 %52 %8 -%54 = OpCompositeConstruct %13 %49 %53 %9 %8 -%57 = OpAccessChain %55 %10 %56 -%58 = OpLoad %12 %57 -%59 = OpCompositeExtract %13 %58 0 -%60 = OpCompositeExtract %7 %59 0 -%61 = OpCompositeExtract %7 %59 1 -%62 = OpCompositeExtract %7 %59 2 -%63 = OpCompositeConstruct %28 %60 %61 %62 -%65 = OpAccessChain %55 %10 %64 -%66 = OpLoad %12 %65 -%67 = OpCompositeExtract %13 %66 1 -%68 = OpCompositeExtract %7 %67 0 -%69 = OpCompositeExtract %7 %67 1 -%70 = OpCompositeExtract %7 %67 2 -%71 = OpCompositeConstruct %28 %68 %69 %70 -%73 = OpAccessChain %55 %10 %72 -%74 = OpLoad %12 %73 -%75 = OpCompositeExtract %13 %74 2 -%76 = OpCompositeExtract %7 %75 0 -%77 = OpCompositeExtract %7 %75 1 -%78 = OpCompositeExtract %7 %75 2 -%79 = OpCompositeConstruct %28 %76 %77 %78 -%81 = OpCompositeConstruct %80 %63 %71 %79 -%82 = OpTranspose %80 %81 -%84 = OpAccessChain %55 %10 %83 -%85 = OpLoad %12 %84 -%86 = OpMatrixTimesVector %13 %85 %54 -OpStore %24 %86 -%88 = OpLoad %13 %24 -%89 = OpCompositeExtract %7 %88 0 -%90 = OpCompositeExtract %7 %88 1 -%91 = OpCompositeExtract %7 %88 2 -%92 = OpCompositeConstruct %28 %89 %90 %91 -%93 = OpMatrixTimesVector %28 %82 %92 -%95 = OpAccessChain %87 %26 %94 -OpStore %95 %93 -%97 = OpAccessChain %25 %26 %96 -OpStore %97 %54 -%98 = OpLoad %27 %26 -%99 = OpCompositeExtract %13 %98 0 -OpStore %35 %99 -%100 = OpCompositeExtract %28 %98 1 -OpStore %37 %100 +%23 = OpVariable %22 Function +%24 = OpVariable %27 Function +%32 = OpLoad %29 %30 +OpBranch %39 +%39 = OpLabel +%40 = OpBitcast %4 %32 +%41 = OpSDiv %4 %40 %3 +OpStore %21 %41 +%42 = OpBitcast %4 %32 +%43 = OpBitwiseAnd %4 %42 %5 +OpStore %23 %43 +%44 = OpLoad %4 %21 +%45 = OpConvertSToF %7 %44 +%46 = OpFMul %7 %45 %6 +%47 = OpFSub %7 %46 %8 +%48 = OpLoad %4 %23 +%49 = OpConvertSToF %7 %48 +%50 = OpFMul %7 %49 %6 +%51 = OpFSub %7 %50 %8 +%52 = OpCompositeConstruct %13 %47 %51 %9 %8 +%55 = OpAccessChain %53 %10 %54 +%56 = OpLoad %12 %55 +%57 = OpCompositeExtract %13 %56 0 +%58 = OpCompositeExtract %7 %57 0 +%59 = OpCompositeExtract %7 %57 1 +%60 = OpCompositeExtract %7 %57 2 +%61 = OpCompositeConstruct %26 %58 %59 %60 +%63 = OpAccessChain %53 %10 %62 +%64 = OpLoad %12 %63 +%65 = OpCompositeExtract %13 %64 1 +%66 = OpCompositeExtract %7 %65 0 +%67 = OpCompositeExtract %7 %65 1 +%68 = OpCompositeExtract %7 %65 2 +%69 = OpCompositeConstruct %26 %66 %67 %68 +%71 = OpAccessChain %53 %10 %70 +%72 = OpLoad %12 %71 +%73 = OpCompositeExtract %13 %72 2 +%74 = OpCompositeExtract %7 %73 0 +%75 = OpCompositeExtract %7 %73 1 +%76 = OpCompositeExtract %7 %73 2 +%77 = OpCompositeConstruct %26 %74 %75 %76 +%79 = OpCompositeConstruct %78 %61 %69 %77 +%80 = OpTranspose %78 %79 +%82 = OpAccessChain %53 %10 %81 +%83 = OpLoad %12 %82 +%84 = OpMatrixTimesVector %13 %83 %52 +%86 = OpCompositeExtract %7 %84 0 +%87 = OpCompositeExtract %7 %84 1 +%88 = OpCompositeExtract %7 %84 2 +%89 = OpCompositeConstruct %26 %86 %87 %88 +%90 = OpMatrixTimesVector %26 %80 %89 +%92 = OpAccessChain %85 %24 %91 +OpStore %92 %90 +%95 = OpAccessChain %93 %24 %94 +OpStore %95 %52 +%96 = OpLoad %25 %24 +%97 = OpCompositeExtract %13 %96 0 +OpStore %33 %97 +%98 = OpCompositeExtract %26 %96 1 +OpStore %35 %98 OpReturn OpFunctionEnd -%106 = OpFunction %2 None %40 -%101 = OpLabel -%104 = OpLoad %28 %102 -%107 = OpLoad %16 %15 -%108 = OpLoad %19 %18 -OpBranch %109 -%109 = OpLabel -%111 = OpSampledImage %110 %107 %108 -%112 = OpImageSampleImplicitLod %13 %111 %104 -OpStore %105 %112 +%104 = OpFunction %2 None %38 +%99 = OpLabel +%102 = OpLoad %26 %100 +%105 = OpLoad %16 %15 +%106 = OpLoad %19 %18 +OpBranch %107 +%107 = OpLabel +%109 = OpSampledImage %108 %105 %106 +%110 = OpImageSampleImplicitLod %13 %109 %102 +OpStore %103 %110 OpReturn OpFunctionEnd