[wgsl] type inference for constants

This commit is contained in:
Dzmitry Malyshau
2021-03-15 01:24:40 -04:00
committed by Dzmitry Malyshau
parent 4fb82bd955
commit 20979b4800
8 changed files with 176 additions and 141 deletions

View File

@@ -119,6 +119,8 @@ pub enum Error<'a> {
UnknownLocalFunction(&'a str),
#[error("builtin {0:?} is not implemented")]
UnimplementedBuiltin(crate::BuiltIn),
#[error("expression {0} doesn't match its given type {1:?}")]
ConstTypeMismatch(&'a str, Handle<crate::Type>),
#[error("other error")]
Other,
}
@@ -1984,12 +1986,34 @@ impl Parser {
match word {
"const" => {
emitter.start(context.expressions);
let (name, _ty, _access) =
self.parse_variable_ident_decl(lexer, context.types, context.constants)?;
let name = lexer.next_ident()?;
let given_ty = if lexer.skip(Token::Separator(':')) {
let (ty, _access) =
self.parse_type_decl(lexer, None, context.types, context.constants)?;
Some(ty)
} else {
None
};
lexer.expect(Token::Operation('='))?;
let expr_id = self
.parse_general_expression(lexer, context.as_expression(block, &mut emitter))?;
lexer.expect(Token::Separator(';'))?;
if let Some(ty) = given_ty {
// prepare the typifier, but work around mutable borrowing...
let _ = context
.as_expression(block, &mut emitter)
.resolve_type(expr_id)?;
let expr_inner = context.typifier.get(expr_id, context.types);
let given_inner = &context.types[ty].inner;
if given_inner != expr_inner {
log::error!(
"Given type {:?} doesn't match expected {:?}",
given_inner,
expr_inner
);
return Err(Error::ConstTypeMismatch(name, ty));
}
}
block.extend(emitter.finish(context.expressions));
context.lookup_ident.insert(name, expr_id);
}

View File

@@ -25,6 +25,23 @@ fn parse_types() {
parse_str("var t: [[access(read)]] texture_storage_3d<r32float>;").unwrap();
}
#[test]
fn parse_type_inference() {
parse_str(
"
fn foo() {
const a = 2u;
const b: u32 = a;
}",
)
.unwrap();
assert!(parse_str(
"
fn foo() { const c : i32 = 2.0; }",
)
.is_err());
}
#[test]
fn parse_type_cast() {
parse_str(

View File

@@ -20,10 +20,12 @@ fn main([[location(0)]] pos : vec2<f32>, [[location(1)]] uv : vec2<f32>) -> Vert
[[stage(fragment)]]
fn main([[location(0)]] uv : vec2<f32>) -> [[location(0)]] vec4<f32> {
const color: vec4<f32> = textureSample(u_texture, u_sampler, uv);
const color = textureSample(u_texture, u_sampler, uv);
if (color.a == 0.0) {
discard;
}
const premultiplied: vec4<f32> = color.a * color;
// forcing the expression here to be emitted in order to check the
// uniformity of the control flow a bit more strongly.
const premultiplied = color.a * color;
return premultiplied;
}

View File

@@ -29,9 +29,9 @@ fn fetch_shadow(light_id: u32, homogeneous_coords: vec4<f32>) -> f32 {
if (homogeneous_coords.w <= 0.0) {
return 1.0;
}
const flip_correction: vec2<f32> = vec2<f32>(0.5, -0.5);
const proj_correction: f32 = 1.0 / homogeneous_coords.w;
const light_local: vec2<f32> = homogeneous_coords.xy * flip_correction * proj_correction + vec2<f32>(0.5, 0.5);
const flip_correction = vec2<f32>(0.5, -0.5);
const proj_correction = 1.0 / homogeneous_coords.w;
const light_local = homogeneous_coords.xy * flip_correction * proj_correction + vec2<f32>(0.5, 0.5);
return textureSampleCompare(t_shadow, sampler_shadow, light_local, i32(light_id), homogeneous_coords.z * proj_correction);
}
@@ -51,10 +51,10 @@ fn fs_main(
if (i >= min(u_globals.num_lights.x, c_max_lights)) {
break;
}
const light: Light = s_lights.data[i];
const shadow: f32 = fetch_shadow(i, light.proj * position);
const light_dir: vec3<f32> = normalize(light.pos.xyz - position.xyz);
const diffuse: f32 = max(0.0, dot(normal, light_dir));
const light = s_lights.data[i];
const shadow = fetch_shadow(i, light.proj * position);
const light_dir = normalize(light.pos.xyz - position.xyz);
const diffuse = max(0.0, dot(normal, light_dir));
color = color + shadow * diffuse * light.color.xyz;
continuing {
i = i + 1u;

View File

@@ -16,15 +16,15 @@ fn vs_main([[builtin(vertex_index)]] vertex_index: u32) -> VertexOutput {
// hacky way to draw a large triangle
var tmp1: i32 = i32(vertex_index) / 2;
var tmp2: i32 = i32(vertex_index) & 1;
const pos: vec4<f32> = vec4<f32>(
const pos = vec4<f32>(
f32(tmp1) * 4.0 - 1.0,
f32(tmp2) * 4.0 - 1.0,
0.0,
1.0
);
const inv_model_view: mat3x3<f32> = transpose(mat3x3<f32>(r_data.view.x.xyz, r_data.view.y.xyz, r_data.view.z.xyz));
var unprojected: vec4<f32> = r_data.proj_inv * pos; //TODO: const
const inv_model_view = transpose(mat3x3<f32>(r_data.view.x.xyz, r_data.view.y.xyz, r_data.view.z.xyz));
const unprojected = r_data.proj_inv * pos;
var out: VertexOutput;
out.uv = inv_model_view * unprojected.xyz;
out.position = pos;

View File

@@ -22,14 +22,12 @@ void main() {
uint vertex_index = uint(gl_VertexID);
int tmp1_;
int tmp2_;
vec4 unprojected;
VertexOutput out1;
tmp1_ = (int(vertex_index) / 2);
tmp2_ = (int(vertex_index) & 1);
vec4 _expr24 = vec4(((float(tmp1_) * 4.0) - 1.0), ((float(tmp2_) * 4.0) - 1.0), 0.0, 1.0);
unprojected = (_group_0_binding_0.proj_inv * _expr24);
vec4 _expr54 = unprojected;
out1.uv = (transpose(mat3x3(vec3(_group_0_binding_0.view[0][0], _group_0_binding_0.view[0][1], _group_0_binding_0.view[0][2]), vec3(_group_0_binding_0.view[1][0], _group_0_binding_0.view[1][1], _group_0_binding_0.view[1][2]), vec3(_group_0_binding_0.view[2][0], _group_0_binding_0.view[2][1], _group_0_binding_0.view[2][2]))) * vec3(_expr54[0], _expr54[1], _expr54[2]));
vec4 _expr50 = (_group_0_binding_0.proj_inv * _expr24);
out1.uv = (transpose(mat3x3(vec3(_group_0_binding_0.view[0][0], _group_0_binding_0.view[0][1], _group_0_binding_0.view[0][2]), vec3(_group_0_binding_0.view[1][0], _group_0_binding_0.view[1][1], _group_0_binding_0.view[1][2]), vec3(_group_0_binding_0.view[2][0], _group_0_binding_0.view[2][1], _group_0_binding_0.view[2][2]))) * vec3(_expr50[0], _expr50[1], _expr50[2]));
out1.position = _expr24;
gl_Position = out1.position;
_out_location_0 = out1.uv;

View File

@@ -48,14 +48,12 @@ vertex vs_mainOutput vs_main(
) {
type4 tmp1_;
type4 tmp2_;
type unprojected;
VertexOutput out;
tmp1_ = (static_cast<int>(vertex_index) / const_2i);
tmp2_ = (static_cast<int>(vertex_index) & const_1i);
type _expr24 = metal::float4(((static_cast<float>(tmp1_) * const_4f) - const_1f), ((static_cast<float>(tmp2_) * const_4f) - const_1f), const_0f, const_1f);
unprojected = (r_data.proj_inv * _expr24);
type _expr54 = unprojected;
out.uv = (metal::transpose(metal::float3x3(metal::float3(r_data.view[0].x, r_data.view[0].y, r_data.view[0].z), metal::float3(r_data.view[1].x, r_data.view[1].y, r_data.view[1].z), metal::float3(r_data.view[2].x, r_data.view[2].y, r_data.view[2].z))) * metal::float3(_expr54.x, _expr54.y, _expr54.z));
metal::float4 _expr50 = (r_data.proj_inv * _expr24);
out.uv = (metal::transpose(metal::float3x3(metal::float3(r_data.view[0].x, r_data.view[0].y, r_data.view[0].z), metal::float3(r_data.view[1].x, r_data.view[1].y, r_data.view[1].z), metal::float3(r_data.view[2].x, r_data.view[2].y, r_data.view[2].z))) * metal::float3(_expr50.x, _expr50.y, _expr50.z));
out.position = _expr24;
const auto _tmp = out;
return vs_mainOutput { _tmp.position, _tmp.uv };

View File

@@ -5,13 +5,13 @@ expression: dis
; SPIR-V
; Version: 1.0
; Generator: rspirv
; Bound: 113
; Bound: 111
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Vertex %39 "vs_main" %32 %35 %37
OpEntryPoint Fragment %106 "fs_main" %102 %105
OpExecutionMode %106 OriginUpperLeft
OpEntryPoint Vertex %37 "vs_main" %30 %33 %35
OpEntryPoint Fragment %104 "fs_main" %100 %103
OpExecutionMode %104 OriginUpperLeft
OpSource GLSL 450
OpName %11 "Data"
OpMemberName %11 0 "proj_inv"
@@ -21,17 +21,16 @@ OpName %15 "r_texture"
OpName %18 "r_sampler"
OpName %21 "tmp1"
OpName %23 "tmp2"
OpName %24 "unprojected"
OpName %26 "out"
OpName %27 "VertexOutput"
OpName %32 "vertex_index"
OpName %35 "position"
OpName %37 "uv"
OpName %39 "vs_main"
OpName %39 "vs_main"
OpName %102 "uv"
OpName %106 "fs_main"
OpName %106 "fs_main"
OpName %24 "out"
OpName %25 "VertexOutput"
OpName %30 "vertex_index"
OpName %33 "position"
OpName %35 "uv"
OpName %37 "vs_main"
OpName %37 "vs_main"
OpName %100 "uv"
OpName %104 "fs_main"
OpName %104 "fs_main"
OpDecorate %11 Block
OpMemberDecorate %11 0 Offset 0
OpMemberDecorate %11 0 ColMajor
@@ -45,11 +44,11 @@ OpDecorate %15 DescriptorSet 0
OpDecorate %15 Binding 1
OpDecorate %18 DescriptorSet 0
OpDecorate %18 Binding 2
OpDecorate %32 BuiltIn VertexIndex
OpDecorate %35 BuiltIn Position
OpDecorate %37 Location 0
OpDecorate %102 Location 0
OpDecorate %105 Location 0
OpDecorate %30 BuiltIn VertexIndex
OpDecorate %33 BuiltIn Position
OpDecorate %35 Location 0
OpDecorate %100 Location 0
OpDecorate %103 Location 0
%2 = OpTypeVoid
%4 = OpTypeInt 32 1
%3 = OpConstant %4 2
@@ -70,108 +69,105 @@ OpDecorate %105 Location 0
%20 = OpTypePointer UniformConstant %19
%18 = OpVariable %20 UniformConstant
%22 = OpTypePointer Function %4
%25 = OpTypePointer Function %13
%28 = OpTypeVector %7 3
%27 = OpTypeStruct %13 %28
%29 = OpTypePointer Function %27
%31 = OpTypeInt 32 0
%33 = OpTypePointer Input %31
%32 = OpVariable %33 Input
%36 = OpTypePointer Output %13
%26 = OpTypeVector %7 3
%25 = OpTypeStruct %13 %26
%27 = OpTypePointer Function %25
%29 = OpTypeInt 32 0
%31 = OpTypePointer Input %29
%30 = OpVariable %31 Input
%34 = OpTypePointer Output %13
%33 = OpVariable %34 Output
%36 = OpTypePointer Output %26
%35 = OpVariable %36 Output
%38 = OpTypePointer Output %28
%37 = OpVariable %38 Output
%40 = OpTypeFunction %2
%55 = OpTypePointer Uniform %12
%56 = OpConstant %4 1
%64 = OpConstant %4 1
%72 = OpConstant %4 1
%80 = OpTypeMatrix %28 3
%83 = OpConstant %4 0
%87 = OpTypePointer Function %28
%94 = OpConstant %4 1
%96 = OpConstant %4 0
%103 = OpTypePointer Input %28
%102 = OpVariable %103 Input
%105 = OpVariable %36 Output
%110 = OpTypeSampledImage %16
%39 = OpFunction %2 None %40
%30 = OpLabel
%23 = OpVariable %22 Function
%26 = OpVariable %29 Function
%38 = OpTypeFunction %2
%53 = OpTypePointer Uniform %12
%54 = OpConstant %4 1
%62 = OpConstant %4 1
%70 = OpConstant %4 1
%78 = OpTypeMatrix %26 3
%81 = OpConstant %4 0
%85 = OpTypePointer Function %26
%91 = OpConstant %4 1
%93 = OpTypePointer Function %13
%94 = OpConstant %4 0
%101 = OpTypePointer Input %26
%100 = OpVariable %101 Input
%103 = OpVariable %34 Output
%108 = OpTypeSampledImage %16
%37 = OpFunction %2 None %38
%28 = OpLabel
%21 = OpVariable %22 Function
%24 = OpVariable %25 Function
%34 = OpLoad %31 %32
OpBranch %41
%41 = OpLabel
%42 = OpBitcast %4 %34
%43 = OpSDiv %4 %42 %3
OpStore %21 %43
%44 = OpBitcast %4 %34
%45 = OpBitwiseAnd %4 %44 %5
OpStore %23 %45
%46 = OpLoad %4 %21
%47 = OpConvertSToF %7 %46
%48 = OpFMul %7 %47 %6
%49 = OpFSub %7 %48 %8
%50 = OpLoad %4 %23
%51 = OpConvertSToF %7 %50
%52 = OpFMul %7 %51 %6
%53 = OpFSub %7 %52 %8
%54 = OpCompositeConstruct %13 %49 %53 %9 %8
%57 = OpAccessChain %55 %10 %56
%58 = OpLoad %12 %57
%59 = OpCompositeExtract %13 %58 0
%60 = OpCompositeExtract %7 %59 0
%61 = OpCompositeExtract %7 %59 1
%62 = OpCompositeExtract %7 %59 2
%63 = OpCompositeConstruct %28 %60 %61 %62
%65 = OpAccessChain %55 %10 %64
%66 = OpLoad %12 %65
%67 = OpCompositeExtract %13 %66 1
%68 = OpCompositeExtract %7 %67 0
%69 = OpCompositeExtract %7 %67 1
%70 = OpCompositeExtract %7 %67 2
%71 = OpCompositeConstruct %28 %68 %69 %70
%73 = OpAccessChain %55 %10 %72
%74 = OpLoad %12 %73
%75 = OpCompositeExtract %13 %74 2
%76 = OpCompositeExtract %7 %75 0
%77 = OpCompositeExtract %7 %75 1
%78 = OpCompositeExtract %7 %75 2
%79 = OpCompositeConstruct %28 %76 %77 %78
%81 = OpCompositeConstruct %80 %63 %71 %79
%82 = OpTranspose %80 %81
%84 = OpAccessChain %55 %10 %83
%85 = OpLoad %12 %84
%86 = OpMatrixTimesVector %13 %85 %54
OpStore %24 %86
%88 = OpLoad %13 %24
%89 = OpCompositeExtract %7 %88 0
%90 = OpCompositeExtract %7 %88 1
%91 = OpCompositeExtract %7 %88 2
%92 = OpCompositeConstruct %28 %89 %90 %91
%93 = OpMatrixTimesVector %28 %82 %92
%95 = OpAccessChain %87 %26 %94
OpStore %95 %93
%97 = OpAccessChain %25 %26 %96
OpStore %97 %54
%98 = OpLoad %27 %26
%99 = OpCompositeExtract %13 %98 0
OpStore %35 %99
%100 = OpCompositeExtract %28 %98 1
OpStore %37 %100
%23 = OpVariable %22 Function
%24 = OpVariable %27 Function
%32 = OpLoad %29 %30
OpBranch %39
%39 = OpLabel
%40 = OpBitcast %4 %32
%41 = OpSDiv %4 %40 %3
OpStore %21 %41
%42 = OpBitcast %4 %32
%43 = OpBitwiseAnd %4 %42 %5
OpStore %23 %43
%44 = OpLoad %4 %21
%45 = OpConvertSToF %7 %44
%46 = OpFMul %7 %45 %6
%47 = OpFSub %7 %46 %8
%48 = OpLoad %4 %23
%49 = OpConvertSToF %7 %48
%50 = OpFMul %7 %49 %6
%51 = OpFSub %7 %50 %8
%52 = OpCompositeConstruct %13 %47 %51 %9 %8
%55 = OpAccessChain %53 %10 %54
%56 = OpLoad %12 %55
%57 = OpCompositeExtract %13 %56 0
%58 = OpCompositeExtract %7 %57 0
%59 = OpCompositeExtract %7 %57 1
%60 = OpCompositeExtract %7 %57 2
%61 = OpCompositeConstruct %26 %58 %59 %60
%63 = OpAccessChain %53 %10 %62
%64 = OpLoad %12 %63
%65 = OpCompositeExtract %13 %64 1
%66 = OpCompositeExtract %7 %65 0
%67 = OpCompositeExtract %7 %65 1
%68 = OpCompositeExtract %7 %65 2
%69 = OpCompositeConstruct %26 %66 %67 %68
%71 = OpAccessChain %53 %10 %70
%72 = OpLoad %12 %71
%73 = OpCompositeExtract %13 %72 2
%74 = OpCompositeExtract %7 %73 0
%75 = OpCompositeExtract %7 %73 1
%76 = OpCompositeExtract %7 %73 2
%77 = OpCompositeConstruct %26 %74 %75 %76
%79 = OpCompositeConstruct %78 %61 %69 %77
%80 = OpTranspose %78 %79
%82 = OpAccessChain %53 %10 %81
%83 = OpLoad %12 %82
%84 = OpMatrixTimesVector %13 %83 %52
%86 = OpCompositeExtract %7 %84 0
%87 = OpCompositeExtract %7 %84 1
%88 = OpCompositeExtract %7 %84 2
%89 = OpCompositeConstruct %26 %86 %87 %88
%90 = OpMatrixTimesVector %26 %80 %89
%92 = OpAccessChain %85 %24 %91
OpStore %92 %90
%95 = OpAccessChain %93 %24 %94
OpStore %95 %52
%96 = OpLoad %25 %24
%97 = OpCompositeExtract %13 %96 0
OpStore %33 %97
%98 = OpCompositeExtract %26 %96 1
OpStore %35 %98
OpReturn
OpFunctionEnd
%106 = OpFunction %2 None %40
%101 = OpLabel
%104 = OpLoad %28 %102
%107 = OpLoad %16 %15
%108 = OpLoad %19 %18
OpBranch %109
%109 = OpLabel
%111 = OpSampledImage %110 %107 %108
%112 = OpImageSampleImplicitLod %13 %111 %104
OpStore %105 %112
%104 = OpFunction %2 None %38
%99 = OpLabel
%102 = OpLoad %26 %100
%105 = OpLoad %16 %15
%106 = OpLoad %19 %18
OpBranch %107
%107 = OpLabel
%109 = OpSampledImage %108 %105 %106
%110 = OpImageSampleImplicitLod %13 %109 %102
OpStore %103 %110
OpReturn
OpFunctionEnd