mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
ConstantEvaluator::swizzle: Handle vector concatenation and indexing (#2485)
* ConstantEvaluator::swizzle: Handle vector concatenation, indexing. * Handle vector Compose expressions nested two deep. * Move `flatten_compose` to `proc`, and make it a free function. * [spv-out] Ensure that we flatten Compose for OpConstantCompose.
This commit is contained in:
committed by
Teodor Tanasoaia
parent
f0ac838019
commit
1a4b526d9a
@@ -243,21 +243,24 @@ impl<'w> BlockContext<'w> {
|
||||
self.writer.constant_ids[init.index()]
|
||||
}
|
||||
crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
|
||||
crate::Expression::Compose {
|
||||
ty: _,
|
||||
ref components,
|
||||
} => {
|
||||
crate::Expression::Compose { ty, ref components } => {
|
||||
self.temp_list.clear();
|
||||
for &component in components {
|
||||
self.temp_list.push(self.cached[component]);
|
||||
}
|
||||
|
||||
if self.ir_function.expressions.is_const(expr_handle) {
|
||||
let ty = self
|
||||
.writer
|
||||
.get_expression_lookup_type(&self.fun_info[expr_handle].ty);
|
||||
self.writer.get_constant_composite(ty, &self.temp_list)
|
||||
self.temp_list.extend(
|
||||
crate::proc::flatten_compose(
|
||||
ty,
|
||||
components,
|
||||
&self.ir_function.expressions,
|
||||
&self.ir_module.types,
|
||||
)
|
||||
.map(|component| self.cached[component]),
|
||||
);
|
||||
self.writer
|
||||
.get_constant_composite(LookupType::Handle(ty), &self.temp_list)
|
||||
} else {
|
||||
self.temp_list
|
||||
.extend(components.iter().map(|&component| self.cached[component]));
|
||||
|
||||
let id = self.gen_id();
|
||||
block.body.push(Instruction::composite_construct(
|
||||
result_type_id,
|
||||
|
||||
@@ -1269,10 +1269,14 @@ impl Writer {
|
||||
self.get_constant_null(type_id)
|
||||
}
|
||||
crate::Expression::Compose { ty, ref components } => {
|
||||
let component_ids: Vec<_> = components
|
||||
.iter()
|
||||
.map(|component| self.constant_ids[component.index()])
|
||||
.collect();
|
||||
let component_ids: Vec<_> = crate::proc::flatten_compose(
|
||||
ty,
|
||||
components,
|
||||
&ir_module.const_expressions,
|
||||
&ir_module.types,
|
||||
)
|
||||
.map(|component| self.constant_ids[component.index()])
|
||||
.collect();
|
||||
self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice())
|
||||
}
|
||||
crate::Expression::Splat { size, value } => {
|
||||
|
||||
@@ -73,6 +73,8 @@ pub enum ConstantEvaluatorError {
|
||||
SplatScalarOnly,
|
||||
#[error("Can only swizzle vector constants")]
|
||||
SwizzleVectorOnly,
|
||||
#[error("swizzle component not present in source expression")]
|
||||
SwizzleOutOfBounds,
|
||||
#[error("Type is not constructible")]
|
||||
TypeNotConstructible,
|
||||
#[error("Subexpression(s) are not constant")]
|
||||
@@ -306,20 +308,31 @@ impl ConstantEvaluator<'_> {
|
||||
let expr = Expression::Splat { size, value };
|
||||
Ok(self.register_evaluated_expr(expr, span))
|
||||
}
|
||||
Expression::Compose {
|
||||
ty,
|
||||
components: ref src_components,
|
||||
} => {
|
||||
Expression::Compose { ty, ref components } => {
|
||||
let dst_ty = get_dst_ty(ty)?;
|
||||
|
||||
let components = pattern
|
||||
let mut flattened = [src_constant; 4]; // dummy value
|
||||
let len =
|
||||
crate::proc::flatten_compose(ty, components, self.expressions, self.types)
|
||||
.zip(flattened.iter_mut())
|
||||
.map(|(component, elt)| *elt = component)
|
||||
.count();
|
||||
let flattened = &flattened[..len];
|
||||
|
||||
let swizzled_components = pattern[..size as usize]
|
||||
.iter()
|
||||
.take(size as usize)
|
||||
.map(|&sc| src_components[sc as usize])
|
||||
.collect();
|
||||
.map(|&sc| {
|
||||
let sc = sc as usize;
|
||||
if let Some(elt) = flattened.get(sc) {
|
||||
Ok(*elt)
|
||||
} else {
|
||||
Err(ConstantEvaluatorError::SwizzleOutOfBounds)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<Handle<Expression>>, _>>()?;
|
||||
let expr = Expression::Compose {
|
||||
ty: dst_ty,
|
||||
components,
|
||||
components: swizzled_components,
|
||||
};
|
||||
Ok(self.register_evaluated_expr(expr, span))
|
||||
}
|
||||
@@ -455,9 +468,8 @@ impl ConstantEvaluator<'_> {
|
||||
.components()
|
||||
.ok_or(ConstantEvaluatorError::InvalidAccessBase)?;
|
||||
|
||||
components
|
||||
.get(index)
|
||||
.copied()
|
||||
crate::proc::flatten_compose(ty, components, self.expressions, self.types)
|
||||
.nth(index)
|
||||
.ok_or(ConstantEvaluatorError::InvalidAccessIndex)
|
||||
}
|
||||
_ => Err(ConstantEvaluatorError::InvalidAccessBase),
|
||||
|
||||
@@ -638,6 +638,61 @@ impl GlobalCtx<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Return an iterator over the individual components assembled by a
|
||||
/// `Compose` expression.
|
||||
///
|
||||
/// Given `ty` and `components` from an `Expression::Compose`, return an
|
||||
/// iterator over the components of the resulting value.
|
||||
///
|
||||
/// Normally, this would just be an iterator over `components`. However,
|
||||
/// `Compose` expressions can concatenate vectors, in which case the i'th
|
||||
/// value being composed is not generally the i'th element of `components`.
|
||||
/// This function consults `ty` to decide if this concatenation is occuring,
|
||||
/// and returns an iterator that produces the components of the result of
|
||||
/// the `Compose` expression in either case.
|
||||
pub fn flatten_compose<'arenas>(
|
||||
ty: crate::Handle<crate::Type>,
|
||||
components: &'arenas [crate::Handle<crate::Expression>],
|
||||
expressions: &'arenas crate::Arena<crate::Expression>,
|
||||
types: &'arenas crate::UniqueArena<crate::Type>,
|
||||
) -> impl Iterator<Item = crate::Handle<crate::Expression>> + 'arenas {
|
||||
// Returning `impl Iterator` is a bit tricky. We may or may not want to
|
||||
// flatten the components, but we have to settle on a single concrete
|
||||
// type to return. The below is a single iterator chain that handles
|
||||
// both the flattening and non-flattening cases.
|
||||
let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner {
|
||||
(size as usize, true)
|
||||
} else {
|
||||
(components.len(), false)
|
||||
};
|
||||
|
||||
fn flattener<'c>(
|
||||
component: &'c crate::Handle<crate::Expression>,
|
||||
is_vector: bool,
|
||||
expressions: &'c crate::Arena<crate::Expression>,
|
||||
) -> &'c [crate::Handle<crate::Expression>] {
|
||||
if is_vector {
|
||||
if let crate::Expression::Compose {
|
||||
ty: _,
|
||||
components: ref subcomponents,
|
||||
} = expressions[*component]
|
||||
{
|
||||
return subcomponents;
|
||||
}
|
||||
}
|
||||
std::slice::from_ref(component)
|
||||
}
|
||||
|
||||
// Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to flatten
|
||||
// two levels.
|
||||
components
|
||||
.iter()
|
||||
.flat_map(move |component| flattener(component, is_vector, expressions))
|
||||
.flat_map(move |component| flattener(component, is_vector, expressions))
|
||||
.take(size)
|
||||
.cloned()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matrix_size() {
|
||||
let module = crate::Module::default();
|
||||
|
||||
14
tests/in/const-exprs.wgsl
Normal file
14
tests/in/const-exprs.wgsl
Normal file
@@ -0,0 +1,14 @@
|
||||
@group(0) @binding(0) var<storage, read_write> out: vec4<i32>;
|
||||
@group(0) @binding(1) var<storage, read_write> out2: i32;
|
||||
@group(0) @binding(2) var<storage, read_write> out3: i32;
|
||||
|
||||
@compute @workgroup_size(1)
|
||||
fn main() {
|
||||
let a = vec2(1, 2);
|
||||
let b = vec2(3, 4);
|
||||
out = vec4(a, b).wzyx;
|
||||
|
||||
out2 = vec4(a, b)[1];
|
||||
|
||||
out3 = vec4(vec3(vec2(6, 7), 8), 9)[0];
|
||||
}
|
||||
23
tests/out/glsl/const-exprs.main.Compute.glsl
Normal file
23
tests/out/glsl/const-exprs.main.Compute.glsl
Normal file
@@ -0,0 +1,23 @@
|
||||
#version 310 es
|
||||
|
||||
precision highp float;
|
||||
precision highp int;
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(std430) buffer type_block_0Compute { ivec4 _group_0_binding_0_cs; };
|
||||
|
||||
layout(std430) buffer type_1_block_1Compute { int _group_0_binding_1_cs; };
|
||||
|
||||
layout(std430) buffer type_1_block_2Compute { int _group_0_binding_2_cs; };
|
||||
|
||||
|
||||
void main() {
|
||||
ivec2 a = ivec2(1, 2);
|
||||
ivec2 b = ivec2(3, 4);
|
||||
_group_0_binding_0_cs = ivec4(4, 3, 2, 1);
|
||||
_group_0_binding_1_cs = 2;
|
||||
_group_0_binding_2_cs = 6;
|
||||
return;
|
||||
}
|
||||
|
||||
14
tests/out/hlsl/const-exprs.hlsl
Normal file
14
tests/out/hlsl/const-exprs.hlsl
Normal file
@@ -0,0 +1,14 @@
|
||||
RWByteAddressBuffer out_ : register(u0);
|
||||
RWByteAddressBuffer out2_ : register(u1);
|
||||
RWByteAddressBuffer out3_ : register(u2);
|
||||
|
||||
[numthreads(1, 1, 1)]
|
||||
void main()
|
||||
{
|
||||
int2 a = int2(1, 2);
|
||||
int2 b = int2(3, 4);
|
||||
out_.Store4(0, asuint(int4(4, 3, 2, 1)));
|
||||
out2_.Store(0, asuint(2));
|
||||
out3_.Store(0, asuint(6));
|
||||
return;
|
||||
}
|
||||
12
tests/out/hlsl/const-exprs.ron
Normal file
12
tests/out/hlsl/const-exprs.ron
Normal file
@@ -0,0 +1,12 @@
|
||||
(
|
||||
vertex:[
|
||||
],
|
||||
fragment:[
|
||||
],
|
||||
compute:[
|
||||
(
|
||||
entry_point:"main",
|
||||
target_profile:"cs_5_1",
|
||||
),
|
||||
],
|
||||
)
|
||||
19
tests/out/msl/const-exprs.msl
Normal file
19
tests/out/msl/const-exprs.msl
Normal file
@@ -0,0 +1,19 @@
|
||||
// language: metal2.0
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using metal::uint;
|
||||
|
||||
|
||||
kernel void main_(
|
||||
device metal::int4& out [[user(fake0)]]
|
||||
, device int& out2_ [[user(fake0)]]
|
||||
, device int& out3_ [[user(fake0)]]
|
||||
) {
|
||||
metal::int2 a = metal::int2(1, 2);
|
||||
metal::int2 b = metal::int2(3, 4);
|
||||
out = metal::int4(4, 3, 2, 1);
|
||||
out2_ = 2;
|
||||
out3_ = 6;
|
||||
return;
|
||||
}
|
||||
60
tests/out/spv/const-exprs.spvasm
Normal file
60
tests/out/spv/const-exprs.spvasm
Normal file
@@ -0,0 +1,60 @@
|
||||
; SPIR-V
|
||||
; Version: 1.1
|
||||
; Generator: rspirv
|
||||
; Bound: 34
|
||||
OpCapability Shader
|
||||
OpExtension "SPV_KHR_storage_buffer_storage_class"
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %16 "main"
|
||||
OpExecutionMode %16 LocalSize 1 1 1
|
||||
OpDecorate %6 DescriptorSet 0
|
||||
OpDecorate %6 Binding 0
|
||||
OpDecorate %7 Block
|
||||
OpMemberDecorate %7 0 Offset 0
|
||||
OpDecorate %9 DescriptorSet 0
|
||||
OpDecorate %9 Binding 1
|
||||
OpDecorate %10 Block
|
||||
OpMemberDecorate %10 0 Offset 0
|
||||
OpDecorate %12 DescriptorSet 0
|
||||
OpDecorate %12 Binding 2
|
||||
OpDecorate %13 Block
|
||||
OpMemberDecorate %13 0 Offset 0
|
||||
%2 = OpTypeVoid
|
||||
%4 = OpTypeInt 32 1
|
||||
%3 = OpTypeVector %4 4
|
||||
%5 = OpTypeVector %4 2
|
||||
%7 = OpTypeStruct %3
|
||||
%8 = OpTypePointer StorageBuffer %7
|
||||
%6 = OpVariable %8 StorageBuffer
|
||||
%10 = OpTypeStruct %4
|
||||
%11 = OpTypePointer StorageBuffer %10
|
||||
%9 = OpVariable %11 StorageBuffer
|
||||
%13 = OpTypeStruct %4
|
||||
%14 = OpTypePointer StorageBuffer %13
|
||||
%12 = OpVariable %14 StorageBuffer
|
||||
%17 = OpTypeFunction %2
|
||||
%18 = OpTypePointer StorageBuffer %3
|
||||
%20 = OpTypeInt 32 0
|
||||
%19 = OpConstant %20 0
|
||||
%22 = OpTypePointer StorageBuffer %4
|
||||
%25 = OpConstant %4 1
|
||||
%26 = OpConstant %4 2
|
||||
%27 = OpConstantComposite %5 %25 %26
|
||||
%28 = OpConstant %4 3
|
||||
%29 = OpConstant %4 4
|
||||
%30 = OpConstantComposite %5 %28 %29
|
||||
%31 = OpConstantComposite %3 %29 %28 %26 %25
|
||||
%32 = OpConstant %4 6
|
||||
%16 = OpFunction %2 None %17
|
||||
%15 = OpLabel
|
||||
%21 = OpAccessChain %18 %6 %19
|
||||
%23 = OpAccessChain %22 %9 %19
|
||||
%24 = OpAccessChain %22 %12 %19
|
||||
OpBranch %33
|
||||
%33 = OpLabel
|
||||
OpStore %21 %31
|
||||
OpStore %23 %26
|
||||
OpStore %24 %32
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
16
tests/out/wgsl/const-exprs.wgsl
Normal file
16
tests/out/wgsl/const-exprs.wgsl
Normal file
@@ -0,0 +1,16 @@
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> out: vec4<i32>;
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> out2_: i32;
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> out3_: i32;
|
||||
|
||||
@compute @workgroup_size(1, 1, 1)
|
||||
fn main() {
|
||||
let a = vec2<i32>(1, 2);
|
||||
let b = vec2<i32>(3, 4);
|
||||
out = vec4<i32>(4, 3, 2, 1);
|
||||
out2_ = 2;
|
||||
out3_ = 6;
|
||||
return;
|
||||
}
|
||||
@@ -777,6 +777,10 @@ fn convert_wgsl() {
|
||||
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
|
||||
),
|
||||
("msl-varyings", Targets::METAL),
|
||||
(
|
||||
"const-exprs",
|
||||
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
|
||||
),
|
||||
];
|
||||
|
||||
for &(name, targets) in inputs.iter() {
|
||||
|
||||
Reference in New Issue
Block a user