diff --git a/src/valid/function.rs b/src/valid/function.rs index c2f3cf97cb..0254775ec5 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -843,16 +843,6 @@ impl super::Validator { #[cfg(feature = "validate")] for (index, argument) in fun.arguments.iter().enumerate() { - if !self.types[argument.ty.index()] - .flags - .contains(super::TypeFlags::ARGUMENT) - { - return Err(FunctionError::InvalidArgumentType { - index, - name: argument.name.clone().unwrap_or_default(), - } - .with_span_handle(argument.ty, &module.types)); - } match module.types[argument.ty].inner.pointer_class() { Some(crate::StorageClass::Private) | Some(crate::StorageClass::Function) @@ -867,6 +857,17 @@ impl super::Validator { .with_span_handle(argument.ty, &module.types)) } } + // Check for the least informative error last. + if !self.types[argument.ty.index()] + .flags + .contains(super::TypeFlags::ARGUMENT) + { + return Err(FunctionError::InvalidArgumentType { + index, + name: argument.name.clone().unwrap_or_default(), + } + .with_span_handle(argument.ty, &module.types)); + } } self.valid_expression_set.clear(); diff --git a/src/valid/type.rs b/src/valid/type.rs index c052fa9b21..923e1b24ff 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -83,6 +83,11 @@ pub enum TypeError { UnresolvedBase(Handle), #[error("Invalid type for pointer target {0:?}")] InvalidPointerBase(Handle), + #[error("Unsized types like {base:?} must be in the `Storage` storage class, not `{class:?}`")] + InvalidPointerToUnsized { + base: Handle, + class: crate::StorageClass, + }, #[error("Expected data type, found {0:?}")] InvalidData(Handle), #[error("Base type {0:?} for the array is invalid")] @@ -255,7 +260,9 @@ impl super::Validator { width as u32, ) } - Ti::Pointer { base, class: _ } => { + Ti::Pointer { base, class } => { + use crate::StorageClass as Sc; + if base >= handle { return Err(TypeError::UnresolvedBase(base)); } @@ -264,21 +271,45 @@ impl super::Validator { if !base_info.flags.contains(TypeFlags::DATA) { return Err(TypeError::InvalidPointerBase(base)); } - // Pointers to dynamically-sized arrays are needed, to serve as - // the type of an `AccessIndex` expression referring to a - // dynamically sized array appearing as the final member of a - // top-level `Struct`. But such pointers cannot be passed to - // functions, stored in variables, etc. So, we mark them as not - // `DATA`. - let data_flag = if base_info.flags.contains(TypeFlags::SIZED) { - TypeFlags::DATA | TypeFlags::ARGUMENT - } else if let crate::TypeInner::Struct { .. } = types[base].inner { - TypeFlags::DATA | TypeFlags::ARGUMENT - } else { - TypeFlags::empty() + + // Runtime-sized values can only live in the `Storage` storage + // class, so it's useless to have a pointer to such a type in + // any other class. + // + // Detecting this problem here prevents the definition of + // functions like: + // + // fn f(p: ptr) -> ... { ... } + // + // which would otherwise be permitted, but uncallable. (They + // may also present difficulties in code generation). + if !base_info.flags.contains(TypeFlags::SIZED) { + match class { + Sc::Storage { .. } => {} + _ => { + return Err(TypeError::InvalidPointerToUnsized { base, class }); + } + } + } + + // Pointers passed as arguments to user-defined functions must + // be in the `Function`, `Private`, or `Workgroup` storage + // class. We only mark pointers in those classes as `ARGUMENT`. + // + // `Validator::validate_function` actually checks the storage + // class of pointer arguments explicitly before checking the + // `ARGUMENT` flag, to give better error messages. But it seems + // best to set `ARGUMENT` accurately anyway. + let argument_flag = match class { + Sc::Function | Sc::Private | Sc::WorkGroup => TypeFlags::ARGUMENT, + Sc::Uniform | Sc::Storage { .. } | Sc::Handle | Sc::PushConstant => { + TypeFlags::empty() + } }; - TypeInfo::new(data_flag | TypeFlags::SIZED | TypeFlags::COPY, 0) + // Pointers cannot be stored in variables, structure members, or + // array elements, so we do not mark them as `DATA`. + TypeInfo::new(argument_flag | TypeFlags::SIZED | TypeFlags::COPY, 0) } Ti::ValuePointer { size: _, diff --git a/tests/in/pointers.wgsl b/tests/in/pointers.wgsl index 0612b8a471..d21080d30b 100644 --- a/tests/in/pointers.wgsl +++ b/tests/in/pointers.wgsl @@ -9,12 +9,6 @@ struct DynamicArray { array: array; }; -fn index_dynamic_array(p: ptr, i: i32, v: u32) -> u32 { - let old = (*p).array[i]; - (*p).array[i] = v; - return old; -} - [[group(0), binding(0)]] var dynamic_array: DynamicArray; diff --git a/tests/out/spv/pointers.spvasm b/tests/out/spv/pointers.spvasm index d49a053a0c..698829fd52 100644 --- a/tests/out/spv/pointers.spvasm +++ b/tests/out/spv/pointers.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.2 ; Generator: rspirv -; Bound: 65 +; Bound: 41 OpCapability Shader OpCapability Linkage OpExtension "SPV_KHR_storage_buffer_storage_class" @@ -10,24 +10,20 @@ OpMemoryModel Logical GLSL450 OpSource GLSL 450 OpMemberName %8 0 "array" OpName %8 "DynamicArray" -OpName %12 "dynamic_array" -OpName %13 "v" -OpName %16 "f" -OpName %23 "p" -OpName %24 "i" -OpName %25 "v" -OpName %26 "index_dynamic_array" -OpName %46 "i" -OpName %47 "v" -OpName %48 "index_unsized" -OpName %57 "i" -OpName %58 "v" -OpName %59 "index_dynamic_array" +OpName %11 "dynamic_array" +OpName %12 "v" +OpName %15 "f" +OpName %22 "i" +OpName %23 "v" +OpName %24 "index_unsized" +OpName %33 "i" +OpName %34 "v" +OpName %35 "index_dynamic_array" OpDecorate %7 ArrayStride 4 OpDecorate %8 Block OpMemberDecorate %8 0 Offset 0 -OpDecorate %12 DescriptorSet 0 -OpDecorate %12 Binding 0 +OpDecorate %11 DescriptorSet 0 +OpDecorate %11 Binding 0 %2 = OpTypeVoid %4 = OpTypeInt 32 1 %3 = OpConstant %4 10 @@ -35,81 +31,47 @@ OpDecorate %12 Binding 0 %6 = OpTypeInt 32 0 %7 = OpTypeRuntimeArray %6 %8 = OpTypeStruct %7 -%9 = OpTypePointer Workgroup %8 -%10 = OpTypePointer StorageBuffer %8 -%11 = OpTypePointer StorageBuffer %7 -%12 = OpVariable %10 StorageBuffer -%14 = OpTypePointer Function %5 -%17 = OpTypeFunction %2 -%19 = OpTypePointer Function %4 -%20 = OpConstant %6 0 -%27 = OpTypeFunction %6 %9 %4 %6 -%29 = OpTypePointer Workgroup %7 -%30 = OpTypePointer Workgroup %6 -%33 = OpTypeBool -%35 = OpConstantNull %6 -%49 = OpTypeFunction %2 %4 %6 -%51 = OpTypePointer StorageBuffer %6 -%16 = OpFunction %2 None %17 -%15 = OpLabel -%13 = OpVariable %14 Function -OpBranch %18 -%18 = OpLabel -%21 = OpAccessChain %19 %13 %20 -OpStore %21 %3 +%9 = OpTypePointer StorageBuffer %8 +%10 = OpTypePointer StorageBuffer %7 +%11 = OpVariable %9 StorageBuffer +%13 = OpTypePointer Function %5 +%16 = OpTypeFunction %2 +%18 = OpTypePointer Function %4 +%19 = OpConstant %6 0 +%25 = OpTypeFunction %2 %4 %6 +%27 = OpTypePointer StorageBuffer %6 +%15 = OpFunction %2 None %16 +%14 = OpLabel +%12 = OpVariable %13 Function +OpBranch %17 +%17 = OpLabel +%20 = OpAccessChain %18 %12 %19 +OpStore %20 %3 OpReturn OpFunctionEnd -%26 = OpFunction %6 None %27 -%23 = OpFunctionParameter %9 -%24 = OpFunctionParameter %4 -%25 = OpFunctionParameter %6 -%22 = OpLabel -OpBranch %28 -%28 = OpLabel -%31 = OpArrayLength %6 %23 0 -%32 = OpULessThan %33 %24 %31 -OpSelectionMerge %36 None -OpBranchConditional %32 %37 %36 -%37 = OpLabel -%34 = OpAccessChain %30 %23 %20 %24 -%38 = OpLoad %6 %34 +%24 = OpFunction %2 None %25 +%22 = OpFunctionParameter %4 +%23 = OpFunctionParameter %6 +%21 = OpLabel +OpBranch %26 +%26 = OpLabel +%28 = OpAccessChain %27 %11 %19 %22 +%29 = OpLoad %6 %28 +%30 = OpIAdd %6 %29 %23 +%31 = OpAccessChain %27 %11 %19 %22 +OpStore %31 %30 +OpReturn +OpFunctionEnd +%35 = OpFunction %2 None %25 +%33 = OpFunctionParameter %4 +%34 = OpFunctionParameter %6 +%32 = OpLabel OpBranch %36 %36 = OpLabel -%39 = OpPhi %6 %35 %28 %38 %37 -%40 = OpArrayLength %6 %23 0 -%41 = OpULessThan %33 %24 %40 -OpSelectionMerge %43 None -OpBranchConditional %41 %44 %43 -%44 = OpLabel -%42 = OpAccessChain %30 %23 %20 %24 -OpStore %42 %25 -OpBranch %43 -%43 = OpLabel -OpReturnValue %39 -OpFunctionEnd -%48 = OpFunction %2 None %49 -%46 = OpFunctionParameter %4 -%47 = OpFunctionParameter %6 -%45 = OpLabel -OpBranch %50 -%50 = OpLabel -%52 = OpAccessChain %51 %12 %20 %46 -%53 = OpLoad %6 %52 -%54 = OpIAdd %6 %53 %47 -%55 = OpAccessChain %51 %12 %20 %46 -OpStore %55 %54 -OpReturn -OpFunctionEnd -%59 = OpFunction %2 None %49 -%57 = OpFunctionParameter %4 -%58 = OpFunctionParameter %6 -%56 = OpLabel -OpBranch %60 -%60 = OpLabel -%61 = OpAccessChain %51 %12 %20 %57 -%62 = OpLoad %6 %61 -%63 = OpIAdd %6 %62 %58 -%64 = OpAccessChain %51 %12 %20 %57 -OpStore %64 %63 +%37 = OpAccessChain %27 %11 %19 %33 +%38 = OpLoad %6 %37 +%39 = OpIAdd %6 %38 %34 +%40 = OpAccessChain %27 %11 %19 %33 +OpStore %40 %39 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/pointers.wgsl b/tests/out/wgsl/pointers.wgsl index 153e5a6ace..5c3d0e1f8a 100644 --- a/tests/out/wgsl/pointers.wgsl +++ b/tests/out/wgsl/pointers.wgsl @@ -14,22 +14,16 @@ fn f() { return; } -fn index_dynamic_array(p: ptr, i: i32, v_1: u32) -> u32 { - let old: u32 = (*p).array_[i]; - (*p).array_[i] = v_1; - return old; -} - -fn index_unsized(i_1: i32, v_2: u32) { - let val: u32 = dynamic_array.array_[i_1]; - dynamic_array.array_[i_1] = (val + v_2); +fn index_unsized(i: i32, v_1: u32) { + let val: u32 = dynamic_array.array_[i]; + dynamic_array.array_[i] = (val + v_1); return; } -fn index_dynamic_array_1(i_2: i32, v_3: u32) { - let p_1: ptr, read_write> = (&dynamic_array.array_); - let val_1: u32 = (*p_1)[i_2]; - (*p_1)[i_2] = (val_1 + v_3); +fn index_dynamic_array(i_1: i32, v_2: u32) { + let p: ptr, read_write> = (&dynamic_array.array_); + let val_1: u32 = (*p)[i_1]; + (*p)[i_1] = (val_1 + v_2); return; } diff --git a/tests/wgsl-errors.rs b/tests/wgsl-errors.rs index c85c113db4..0109245c0b 100644 --- a/tests/wgsl-errors.rs +++ b/tests/wgsl-errors.rs @@ -662,7 +662,6 @@ fn invalid_structs() { fn invalid_functions() { check_validation_error! { "fn unacceptable_unsized(arg: array) { }", - "fn unacceptable_unsized(arg: ptr>) { }", " struct Unsized { data: array; }; fn unacceptable_unsized(arg: Unsized) { } @@ -678,19 +677,39 @@ fn invalid_functions() { if function_name == "unacceptable_unsized" && argument_name == "arg" } + // Pointer's storage class cannot hold unsized data. check_validation_error! { + "fn unacceptable_unsized(arg: ptr>) { }", " struct Unsized { data: array; }; - fn acceptable_pointer_to_unsized(arg: ptr) { } + fn unacceptable_unsized(arg: ptr) { } ": - Ok(_) + Err(naga::valid::ValidationError::Type { + error: naga::valid::TypeError::InvalidPointerToUnsized { + base: _, + class: naga::StorageClass::WorkGroup { .. }, + }, + .. + }) + } + + // Pointers of these storage classes cannot be passed as arguments. + check_validation_error! { + "fn unacceptable_ptr_class(arg: ptr>) { }": + Err(naga::valid::ValidationError::Function { + name: function_name, + error: naga::valid::FunctionError::InvalidArgumentPointerClass { + index: 0, + name: argument_name, + class: naga::StorageClass::Storage { .. }, + }, + .. + }) + if function_name == "unacceptable_ptr_class" && argument_name == "arg" } check_validation_error! { - " - struct Unsized { data: array; }; - fn unacceptable_uniform_class(arg: ptr) { } - ": + "fn unacceptable_ptr_class(arg: ptr) { }": Err(naga::valid::ValidationError::Function { name: function_name, error: naga::valid::FunctionError::InvalidArgumentPointerClass { @@ -700,7 +719,7 @@ fn invalid_functions() { }, .. }) - if function_name == "unacceptable_uniform_class" && argument_name == "arg" + if function_name == "unacceptable_ptr_class" && argument_name == "arg" } }