Support num_workgroups builtin

This commit is contained in:
Dzmitry Malyshau
2021-08-17 01:07:15 -04:00
committed by Dzmitry Malyshau
parent 73be8c7454
commit 79d899fe4c
20 changed files with 206 additions and 111 deletions

View File

@@ -3,6 +3,7 @@
## TBD
- API:
- atomic types and functions
- `num_workgroups` built-in
- WGSL `select()` order of true/false is swapped
## v0.5 (2021-06-18)

View File

@@ -2643,6 +2643,7 @@ fn glsl_built_in(built_in: crate::BuiltIn, output: bool) -> &'static str {
Bi::LocalInvocationIndex => "gl_LocalInvocationIndex",
Bi::WorkGroupId => "gl_WorkGroupID",
Bi::WorkGroupSize => "gl_WorkGroupSize",
Bi::NumWorkGroups => "gl_NumWorkGroups",
}
}

View File

@@ -96,7 +96,13 @@ impl crate::BuiltIn {
Self::LocalInvocationId => "SV_GroupThreadID",
Self::LocalInvocationIndex => "SV_GroupIndex",
Self::WorkGroupId => "SV_GroupID",
_ => return Err(Error::Unimplemented(format!("builtin {:?}", self))),
// The specific semantic we use here doesn't matter, because references
// to this field will get replaced with references to `SPECIAL_CBUF_VAR`
// in `Writer::write_expr`.
Self::NumWorkGroups => "SV_GroupID",
Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
return Err(Error::Unimplemented(format!("builtin {:?}", self)))
}
})
}
}

View File

@@ -15,6 +15,7 @@ const SPECIAL_CBUF_TYPE: &str = "NagaConstants";
const SPECIAL_CBUF_VAR: &str = "_NagaConstants";
const SPECIAL_BASE_VERTEX: &str = "base_vertex";
const SPECIAL_BASE_INSTANCE: &str = "base_instance";
const SPECIAL_OTHER: &str = "other";
/// Structure contains information required for generating
/// wrapped structure of all entry points arguments
@@ -105,6 +106,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out, "struct {} {{", SPECIAL_CBUF_TYPE)?;
writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_BASE_VERTEX)?;
writeln!(self.out, "{}int {};", back::INDENT, SPECIAL_BASE_INSTANCE)?;
writeln!(self.out, "{}uint {};", back::INDENT, SPECIAL_OTHER)?;
writeln!(self.out, "}};")?;
write!(
self.out,
@@ -1234,10 +1236,26 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(
self.out,
"({}.{} + ",
SPECIAL_CBUF_VAR, SPECIAL_BASE_INSTANCE
SPECIAL_CBUF_VAR, SPECIAL_BASE_INSTANCE,
)?;
")"
}
Some(crate::BuiltIn::NumWorkGroups) => {
//Note: despite their names (`BASE_VERTEX` and `BASE_INSTANCE`),
// in compute shaders the special constants contain the number
// of workgroups, which we are using here.
write!(
self.out,
"uint3({}.{}, {}.{}, {}.{})",
SPECIAL_CBUF_VAR,
SPECIAL_BASE_VERTEX,
SPECIAL_CBUF_VAR,
SPECIAL_BASE_INSTANCE,
SPECIAL_CBUF_VAR,
SPECIAL_OTHER,
)?;
return Ok(());
}
_ => "",
};

View File

@@ -357,7 +357,8 @@ impl ResolvedBinding {
Bi::LocalInvocationIndex => "thread_index_in_threadgroup",
Bi::WorkGroupId => "threadgroup_position_in_grid",
Bi::WorkGroupSize => "dispatch_threads_per_threadgroup",
_ => return Err(Error::UnsupportedBuiltIn(built_in)),
Bi::NumWorkGroups => "threadgroups_per_grid",
Bi::CullDistance => return Err(Error::UnsupportedBuiltIn(built_in)),
};
write!(out, "{}", name)?;
}

View File

@@ -1043,6 +1043,7 @@ impl Writer {
Bi::LocalInvocationIndex => BuiltIn::LocalInvocationIndex,
Bi::WorkGroupId => BuiltIn::WorkgroupId,
Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
};
self.decorate(id, Decoration::BuiltIn, &[built_in as u32]);

View File

@@ -1596,6 +1596,7 @@ fn builtin_str(built_in: crate::BuiltIn) -> Option<&'static str> {
Bi::GlobalInvocationId => Some("global_invocation_id"),
Bi::WorkGroupId => Some("workgroup_id"),
Bi::WorkGroupSize => Some("workgroup_size"),
Bi::NumWorkGroups => Some("num_workgroups"),
Bi::SampleIndex => Some("sample_index"),
Bi::SampleMask => Some("sample_mask"),
Bi::PrimitiveIndex => Some("primitive_index"),

View File

@@ -156,6 +156,16 @@ impl Parser {
false,
StorageQualifier::Input,
),
"gl_NumWorkGroups" => add_builtin(
TypeInner::Vector {
size: VectorSize::Tri,
kind: ScalarKind::Uint,
width: 4,
},
BuiltIn::NumWorkGroups,
false,
StorageQualifier::Input,
),
"gl_FrontFacing" => add_builtin(
TypeInner::Scalar {
kind: ScalarKind::Bool,

View File

@@ -141,6 +141,7 @@ pub(super) fn map_builtin(word: spirv::Word) -> Result<crate::BuiltIn, Error> {
Some(Bi::LocalInvocationIndex) => crate::BuiltIn::LocalInvocationIndex,
Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId,
Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize,
Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups,
_ => return Err(Error::UnsupportedBuiltIn(word)),
})
}

View File

@@ -32,6 +32,7 @@ pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>>
"local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
"workgroup_id" => crate::BuiltIn::WorkGroupId,
"workgroup_size" => crate::BuiltIn::WorkGroupSize,
"num_workgroups" => crate::BuiltIn::NumWorkGroups,
_ => return Err(Error::UnknownBuiltin(span)),
})
}

View File

@@ -281,6 +281,7 @@ pub enum BuiltIn {
LocalInvocationIndex,
WorkGroupId,
WorkGroupSize,
NumWorkGroups,
}
/// Number of bytes per scalar.

View File

@@ -406,7 +406,8 @@ impl FunctionInfo {
crate::BuiltIn::FrontFacing
// per-work-group built-ins are uniform
| crate::BuiltIn::WorkGroupId
| crate::BuiltIn::WorkGroupSize => true,
| crate::BuiltIn::WorkGroupSize
| crate::BuiltIn::NumWorkGroups => true,
_ => false,
},
// only flat inputs are uniform

View File

@@ -216,7 +216,8 @@ impl VaryingContext<'_> {
Bi::GlobalInvocationId
| Bi::LocalInvocationId
| Bi::WorkGroupId
| Bi::WorkGroupSize => (
| Bi::WorkGroupSize
| Bi::NumWorkGroups => (
self.stage == St::Compute && !self.output,
*ty_inner
== Ti::Vector {

View File

@@ -2,4 +2,11 @@
spv_version: (1, 0),
spv_capabilities: [ Shader, SampleRateShading ],
spv_adjust_coordinate_space: false,
hlsl_custom: true,
hlsl: (
shader_model: V5_1,
binding_map: {},
fake_missing_bindings: false,
special_constants_binding: Some((space: 1, register: 0)),
),
)

View File

@@ -33,13 +33,15 @@ fn fragment(
return FragmentOutput(in.varying, mask, color);
}
var<workgroup> output: array<u32, 1>;
[[stage(compute), workgroup_size(1)]]
fn compute(
[[builtin(global_invocation_id)]] global_id: vec3<u32>,
[[builtin(local_invocation_id)]] local_id: vec3<u32>,
[[builtin(local_invocation_index)]] local_index: u32,
[[builtin(workgroup_id)]] wg_id: vec3<u32>,
//TODO: https://github.com/gpuweb/gpuweb/issues/1590
//[[builtin(workgroup_size)]] wg_size: vec3<u32>,
[[builtin(num_workgroups)]] num_wgs: vec3<u32>,
) {
output[0] = global_id.x + local_id.x + local_index + wg_id.x + num_wgs.x;
}

View File

@@ -1,3 +1,9 @@
struct NagaConstants {
int base_vertex;
int base_instance;
uint other;
};
ConstantBuffer<NagaConstants> _NagaConstants: register(b0, space1);
struct VertexOutput {
float4 position : SV_Position;
@@ -10,6 +16,8 @@ struct FragmentOutput {
float color : SV_Target0;
};
groupshared uint output[1];
struct VertexInput_vertex {
uint color1 : LOC10;
uint instance_index1 : SV_InstanceID;
@@ -28,11 +36,12 @@ struct ComputeInput_compute {
uint3 local_id1 : SV_GroupThreadID;
uint local_index1 : SV_GroupIndex;
uint3 wg_id1 : SV_GroupID;
uint3 num_wgs1 : SV_GroupID;
};
VertexOutput vertex(VertexInput_vertex vertexinput_vertex)
{
uint tmp = ((vertexinput_vertex.vertex_index1 + vertexinput_vertex.instance_index1) + vertexinput_vertex.color1);
uint tmp = (((_NagaConstants.base_vertex + vertexinput_vertex.vertex_index1) + (_NagaConstants.base_instance + vertexinput_vertex.instance_index1)) + vertexinput_vertex.color1);
const VertexOutput vertexoutput1 = { float4(1.0.xxxx), float(tmp) };
return vertexoutput1;
}
@@ -48,5 +57,6 @@ FragmentOutput fragment(FragmentInput_fragment fragmentinput_fragment)
[numthreads(1, 1, 1)]
void compute(ComputeInput_compute computeinput_compute)
{
output[0] = ((((computeinput_compute.global_id1.x + computeinput_compute.local_id1.x) + computeinput_compute.local_index1) + computeinput_compute.wg_id1.x) + uint3(_NagaConstants.base_vertex, _NagaConstants.base_instance, _NagaConstants.other).x);
return;
}

View File

@@ -1,6 +1,7 @@
struct NagaConstants {
int base_vertex;
int base_instance;
uint other;
};
ConstantBuffer<NagaConstants> _NagaConstants: register(b1);

View File

@@ -11,6 +11,9 @@ struct FragmentOutput {
metal::uint sample_mask;
float color;
};
struct type4 {
metal::uint inner[1];
};
struct vertex1Input {
metal::uint color [[attribute(10)]];
@@ -61,6 +64,9 @@ kernel void compute1(
, metal::uint3 local_id [[thread_position_in_threadgroup]]
, metal::uint local_index [[thread_index_in_threadgroup]]
, metal::uint3 wg_id [[threadgroup_position_in_grid]]
, metal::uint3 num_wgs [[threadgroups_per_grid]]
, threadgroup type4& output
) {
output.inner[0] = (((global_id.x + local_id.x) + local_index) + wg_id.x) + num_wgs.x;
return;
}

View File

@@ -1,124 +1,146 @@
; SPIR-V
; Version: 1.0
; Generator: rspirv
; Bound: 76
; Bound: 95
OpCapability Shader
OpCapability SampleRateShading
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Vertex %25 "vertex" %14 %17 %19 %21 %23
OpEntryPoint Fragment %54 "fragment" %37 %40 %43 %46 %48 %50 %51 %53
OpEntryPoint GLCompute %74 "compute" %65 %68 %70 %72
OpExecutionMode %54 OriginUpperLeft
OpExecutionMode %54 DepthReplacing
OpExecutionMode %74 LocalSize 1 1 1
OpMemberDecorate %9 0 Offset 0
OpMemberDecorate %9 1 Offset 16
OpMemberDecorate %10 0 Offset 0
OpMemberDecorate %10 1 Offset 4
OpMemberDecorate %10 2 Offset 8
OpDecorate %14 BuiltIn VertexIndex
OpDecorate %17 BuiltIn InstanceIndex
OpDecorate %19 Location 10
OpDecorate %21 BuiltIn Position
OpDecorate %23 Location 1
OpDecorate %37 BuiltIn FragCoord
OpDecorate %40 Location 1
OpDecorate %43 BuiltIn FrontFacing
OpDecorate %46 BuiltIn SampleId
OpDecorate %48 BuiltIn SampleMask
OpDecorate %50 BuiltIn FragDepth
OpDecorate %51 BuiltIn SampleMask
OpDecorate %53 Location 0
OpDecorate %65 BuiltIn GlobalInvocationId
OpDecorate %68 BuiltIn LocalInvocationId
OpDecorate %70 BuiltIn LocalInvocationIndex
OpDecorate %72 BuiltIn WorkgroupId
OpEntryPoint Vertex %31 "vertex" %20 %23 %25 %27 %29
OpEntryPoint Fragment %60 "fragment" %43 %46 %49 %52 %54 %56 %57 %59
OpEntryPoint GLCompute %82 "compute" %71 %74 %76 %78 %80
OpExecutionMode %60 OriginUpperLeft
OpExecutionMode %60 DepthReplacing
OpExecutionMode %82 LocalSize 1 1 1
OpMemberDecorate %12 0 Offset 0
OpMemberDecorate %12 1 Offset 16
OpMemberDecorate %13 0 Offset 0
OpMemberDecorate %13 1 Offset 4
OpMemberDecorate %13 2 Offset 8
OpDecorate %15 ArrayStride 4
OpDecorate %20 BuiltIn VertexIndex
OpDecorate %23 BuiltIn InstanceIndex
OpDecorate %25 Location 10
OpDecorate %27 BuiltIn Position
OpDecorate %29 Location 1
OpDecorate %43 BuiltIn FragCoord
OpDecorate %46 Location 1
OpDecorate %49 BuiltIn FrontFacing
OpDecorate %52 BuiltIn SampleId
OpDecorate %54 BuiltIn SampleMask
OpDecorate %56 BuiltIn FragDepth
OpDecorate %57 BuiltIn SampleMask
OpDecorate %59 Location 0
OpDecorate %71 BuiltIn GlobalInvocationId
OpDecorate %74 BuiltIn LocalInvocationId
OpDecorate %76 BuiltIn LocalInvocationIndex
OpDecorate %78 BuiltIn WorkgroupId
OpDecorate %80 BuiltIn NumWorkgroups
%2 = OpTypeVoid
%4 = OpTypeFloat 32
%3 = OpConstant %4 1.0
%6 = OpTypeInt 32 0
%5 = OpConstant %6 1
%7 = OpConstant %4 0.0
%8 = OpTypeVector %4 4
%9 = OpTypeStruct %8 %4
%10 = OpTypeStruct %4 %6 %4
%11 = OpTypeBool
%12 = OpTypeVector %6 3
%15 = OpTypePointer Input %6
%14 = OpVariable %15 Input
%17 = OpVariable %15 Input
%19 = OpVariable %15 Input
%22 = OpTypePointer Output %8
%21 = OpVariable %22 Output
%24 = OpTypePointer Output %4
%23 = OpVariable %24 Output
%26 = OpTypeFunction %2
%38 = OpTypePointer Input %8
%37 = OpVariable %38 Input
%41 = OpTypePointer Input %4
%40 = OpVariable %41 Input
%9 = OpTypeInt 32 1
%8 = OpConstant %9 1
%10 = OpConstant %9 0
%11 = OpTypeVector %4 4
%12 = OpTypeStruct %11 %4
%13 = OpTypeStruct %4 %6 %4
%14 = OpTypeBool
%15 = OpTypeArray %6 %8
%16 = OpTypeVector %6 3
%18 = OpTypePointer Workgroup %15
%17 = OpVariable %18 Workgroup
%21 = OpTypePointer Input %6
%20 = OpVariable %21 Input
%23 = OpVariable %21 Input
%25 = OpVariable %21 Input
%28 = OpTypePointer Output %11
%27 = OpVariable %28 Output
%30 = OpTypePointer Output %4
%29 = OpVariable %30 Output
%32 = OpTypeFunction %2
%44 = OpTypePointer Input %11
%43 = OpVariable %44 Input
%46 = OpVariable %15 Input
%48 = OpVariable %15 Input
%50 = OpVariable %24 Output
%52 = OpTypePointer Output %6
%51 = OpVariable %52 Output
%53 = OpVariable %24 Output
%66 = OpTypePointer Input %12
%65 = OpVariable %66 Input
%68 = OpVariable %66 Input
%70 = OpVariable %15 Input
%72 = OpVariable %66 Input
%25 = OpFunction %2 None %26
%13 = OpLabel
%16 = OpLoad %6 %14
%18 = OpLoad %6 %17
%20 = OpLoad %6 %19
OpBranch %27
%27 = OpLabel
%28 = OpIAdd %6 %16 %18
%29 = OpIAdd %6 %28 %20
%30 = OpCompositeConstruct %8 %3 %3 %3 %3
%31 = OpConvertUToF %4 %29
%32 = OpCompositeConstruct %9 %30 %31
%33 = OpCompositeExtract %8 %32 0
OpStore %21 %33
%34 = OpCompositeExtract %4 %32 1
OpStore %23 %34
%47 = OpTypePointer Input %4
%46 = OpVariable %47 Input
%50 = OpTypePointer Input %14
%49 = OpVariable %50 Input
%52 = OpVariable %21 Input
%54 = OpVariable %21 Input
%56 = OpVariable %30 Output
%58 = OpTypePointer Output %6
%57 = OpVariable %58 Output
%59 = OpVariable %30 Output
%72 = OpTypePointer Input %16
%71 = OpVariable %72 Input
%74 = OpVariable %72 Input
%76 = OpVariable %21 Input
%78 = OpVariable %72 Input
%80 = OpVariable %72 Input
%84 = OpTypePointer Workgroup %6
%93 = OpConstant %6 0
%31 = OpFunction %2 None %32
%19 = OpLabel
%22 = OpLoad %6 %20
%24 = OpLoad %6 %23
%26 = OpLoad %6 %25
OpBranch %33
%33 = OpLabel
%34 = OpIAdd %6 %22 %24
%35 = OpIAdd %6 %34 %26
%36 = OpCompositeConstruct %11 %3 %3 %3 %3
%37 = OpConvertUToF %4 %35
%38 = OpCompositeConstruct %12 %36 %37
%39 = OpCompositeExtract %11 %38 0
OpStore %27 %39
%40 = OpCompositeExtract %4 %38 1
OpStore %29 %40
OpReturn
OpFunctionEnd
%54 = OpFunction %2 None %26
%35 = OpLabel
%39 = OpLoad %8 %37
%42 = OpLoad %4 %40
%36 = OpCompositeConstruct %9 %39 %42
%60 = OpFunction %2 None %32
%41 = OpLabel
%45 = OpLoad %11 %43
%47 = OpLoad %6 %46
%49 = OpLoad %6 %48
OpBranch %55
%55 = OpLabel
%56 = OpShiftLeftLogical %6 %5 %47
%57 = OpBitwiseAnd %6 %49 %56
%58 = OpSelect %4 %45 %3 %7
%59 = OpCompositeExtract %4 %36 1
%60 = OpCompositeConstruct %10 %59 %57 %58
%61 = OpCompositeExtract %4 %60 0
OpStore %50 %61
%62 = OpCompositeExtract %6 %60 1
OpStore %51 %62
%63 = OpCompositeExtract %4 %60 2
OpStore %53 %63
%48 = OpLoad %4 %46
%42 = OpCompositeConstruct %12 %45 %48
%51 = OpLoad %14 %49
%53 = OpLoad %6 %52
%55 = OpLoad %6 %54
OpBranch %61
%61 = OpLabel
%62 = OpShiftLeftLogical %6 %5 %53
%63 = OpBitwiseAnd %6 %55 %62
%64 = OpSelect %4 %51 %3 %7
%65 = OpCompositeExtract %4 %42 1
%66 = OpCompositeConstruct %13 %65 %63 %64
%67 = OpCompositeExtract %4 %66 0
OpStore %56 %67
%68 = OpCompositeExtract %6 %66 1
OpStore %57 %68
%69 = OpCompositeExtract %4 %66 2
OpStore %59 %69
OpReturn
OpFunctionEnd
%74 = OpFunction %2 None %26
%64 = OpLabel
%67 = OpLoad %12 %65
%69 = OpLoad %12 %68
%71 = OpLoad %6 %70
%73 = OpLoad %12 %72
OpBranch %75
%75 = OpLabel
%82 = OpFunction %2 None %32
%70 = OpLabel
%73 = OpLoad %16 %71
%75 = OpLoad %16 %74
%77 = OpLoad %6 %76
%79 = OpLoad %16 %78
%81 = OpLoad %16 %80
OpBranch %83
%83 = OpLabel
%85 = OpCompositeExtract %6 %73 0
%86 = OpCompositeExtract %6 %75 0
%87 = OpIAdd %6 %85 %86
%88 = OpIAdd %6 %87 %77
%89 = OpCompositeExtract %6 %79 0
%90 = OpIAdd %6 %88 %89
%91 = OpCompositeExtract %6 %81 0
%92 = OpIAdd %6 %90 %91
%94 = OpAccessChain %84 %17 %93
OpStore %94 %92
OpReturn
OpFunctionEnd

View File

@@ -9,6 +9,8 @@ struct FragmentOutput {
[[location(0)]] color: f32;
};
var<workgroup> output: array<u32,1>;
[[stage(vertex)]]
fn vertex([[builtin(vertex_index)]] vertex_index: u32, [[builtin(instance_index)]] instance_index: u32, [[location(10)]] color: u32) -> VertexOutput {
let tmp: u32 = ((vertex_index + instance_index) + color);
@@ -23,6 +25,7 @@ fn fragment(in: VertexOutput, [[builtin(front_facing)]] front_facing: bool, [[bu
}
[[stage(compute), workgroup_size(1, 1, 1)]]
fn compute([[builtin(global_invocation_id)]] global_id: vec3<u32>, [[builtin(local_invocation_id)]] local_id: vec3<u32>, [[builtin(local_invocation_index)]] local_index: u32, [[builtin(workgroup_id)]] wg_id: vec3<u32>) {
fn compute([[builtin(global_invocation_id)]] global_id: vec3<u32>, [[builtin(local_invocation_id)]] local_id: vec3<u32>, [[builtin(local_invocation_index)]] local_index: u32, [[builtin(workgroup_id)]] wg_id: vec3<u32>, [[builtin(num_workgroups)]] num_wgs: vec3<u32>) {
output[0] = ((((global_id.x + local_id.x) + local_index) + wg_id.x) + num_wgs.x);
return;
}