From b235973d2e6e645fef346859b9a033aeba7360f5 Mon Sep 17 00:00:00 2001 From: francesco-cattoglio Date: Thu, 3 Feb 2022 20:03:43 +0100 Subject: [PATCH] Add support for `vecN` and `vecN` to `dot()` function (#1689) * Allow vecN and vecN in `dot()`, first changes * Added a test case * Fix the test * Changes to baking of expressions, incl args of integer dot product * Implemented requested changes for glsl backend * Added support for integer dot product on MSL backend * Removed outdated code for hlsl and wgls writers * Implement in spv backend * Commit modified outputs from running the tests * cargo fmt * Applied requested changes for both MSL and GLSL back * Changes to spv back * Committed all test output changes * Cargo fmt * Added a comment w.r.t. VK_KHR_shader_integer_dot_product * Implemented requested svp change * Minor change to test case This is because I wanted to highlight the fact that the correct id is used in the last sum of the integer dot product expression * Changed function signature since it could not fail, changed it to simply return `void` --- src/back/glsl/mod.rs | 92 +++++++++++++++++-- src/back/msl/writer.rs | 101 ++++++++++++++++++-- src/back/spv/block.rs | 97 ++++++++++++++++++-- src/lib.rs | 1 + src/valid/expression.rs | 23 ++++- tests/in/functions.wgsl | 14 +++ tests/out/glsl/functions.main.Compute.glsl | 14 +++ tests/out/hlsl/functions.hlsl | 13 +++ tests/out/msl/functions.msl | 15 +++ tests/out/spv/functions.spvasm | 102 +++++++++++++++++---- tests/out/wgsl/functions.wgsl | 12 +++ 11 files changed, 442 insertions(+), 42 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 7c4e1d7df1..df1d6de146 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -417,6 +417,8 @@ pub struct Writer<'a, W> { block_id: IdGenerator, /// Set of expressions that have associated temporary variables. named_expressions: crate::NamedExpressions, + /// Set of expressions that need to be baked to avoid unnecessary repetition in output + need_bake_expressions: crate::NeedBakeExpressions, } impl<'a, W: Write> Writer<'a, W> { @@ -468,6 +470,7 @@ impl<'a, W: Write> Writer<'a, W> { block_id: IdGenerator::default(), named_expressions: crate::NamedExpressions::default(), + need_bake_expressions: crate::NeedBakeExpressions::default(), }; // Find all features required to print this module @@ -1000,6 +1003,45 @@ impl<'a, W: Write> Writer<'a, W> { Ok(()) } + /// Helper method used to find which expressions of a given function require baking + /// + /// # Notes + /// Clears `need_bake_expressions` set before adding to it + fn update_expressions_to_bake(&mut self, func: &crate::Function, info: &valid::FunctionInfo) { + use crate::Expression; + self.need_bake_expressions.clear(); + for expr in func.expressions.iter() { + let expr_info = &info[expr.0]; + let min_ref_count = func.expressions[expr.0].bake_ref_count(); + if min_ref_count <= expr_info.ref_count { + self.need_bake_expressions.insert(expr.0); + } + // if the expression is a Dot product with integer arguments, + // then the args needs baking as well + if let ( + fun_handle, + &Expression::Math { + fun: crate::MathFunction::Dot, + arg, + arg1, + .. + }, + ) = expr + { + let inner = info[fun_handle].ty.inner_with(&self.module.types); + if let TypeInner::Scalar { kind, .. } = *inner { + match kind { + crate::ScalarKind::Sint | crate::ScalarKind::Uint => { + self.need_bake_expressions.insert(arg); + self.need_bake_expressions.insert(arg1.unwrap()); + } + _ => {} + } + } + } + } + } + /// Helper method used to get a name for a global /// /// Globals have different naming schemes depending on their binding: @@ -1151,6 +1193,7 @@ impl<'a, W: Write> Writer<'a, W> { }; self.named_expressions.clear(); + self.update_expressions_to_bake(func, info); // Write the function header // @@ -1401,6 +1444,33 @@ impl<'a, W: Write> Writer<'a, W> { Ok(()) } + /// Helper method used to output a dot product as an arithmetic expression + /// + fn write_dot_product( + &mut self, + arg: Handle, + arg1: Handle, + size: usize, + ) -> BackendResult { + write!(self.out, "(")?; + + let arg0_name = &self.named_expressions[&arg]; + let arg1_name = &self.named_expressions[&arg1]; + + // This will print an extra '+' at the beginning but that is fine in glsl + for index in 0..size { + let component = back::COMPONENTS[index]; + write!( + self.out, + " + {}.{} * {}.{}", + arg0_name, component, arg1_name, component + )?; + } + + write!(self.out, ")")?; + Ok(()) + } + /// Helper method used to write structs /// /// # Notes @@ -1490,13 +1560,10 @@ impl<'a, W: Write> Writer<'a, W> { // Otherwise, we could accidentally write variable name instead of full expression. // Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords. Some(self.namer.call(name)) + } else if self.need_bake_expressions.contains(&handle) { + Some(format!("{}{}", super::BAKE_PREFIX, handle.index())) } else { - let min_ref_count = ctx.expressions[handle].bake_ref_count(); - if min_ref_count <= info.ref_count { - Some(format!("{}{}", super::BAKE_PREFIX, handle.index())) - } else { - None - } + None }; if let Some(name) = expr_name { @@ -2538,7 +2605,18 @@ impl<'a, W: Write> Writer<'a, W> { Mf::Log2 => "log2", Mf::Pow => "pow", // geometry - Mf::Dot => "dot", + Mf::Dot => match *ctx.info[arg].ty.inner_with(&self.module.types) { + crate::TypeInner::Vector { + kind: crate::ScalarKind::Float, + .. + } => "dot", + crate::TypeInner::Vector { size, .. } => { + return self.write_dot_product(arg, arg1.unwrap(), size as usize) + } + _ => unreachable!( + "Correct TypeInner for dot product should be already validated" + ), + }, Mf::Outer => "outerProduct", Mf::Cross => "cross", Mf::Distance => "distance", diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 26afaae64a..d1931e85c8 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -309,6 +309,8 @@ pub struct Writer { out: W, names: FastHashMap, named_expressions: crate::NamedExpressions, + /// Set of expressions that need to be baked to avoid unnecessary repetition in output + need_bake_expressions: crate::NeedBakeExpressions, namer: proc::Namer, #[cfg(test)] put_expression_stack_pointers: FastHashSet<*const ()>, @@ -526,6 +528,7 @@ impl Writer { out, names: FastHashMap::default(), named_expressions: crate::NamedExpressions::default(), + need_bake_expressions: crate::NeedBakeExpressions::default(), namer: proc::Namer::default(), #[cfg(test)] put_expression_stack_pointers: Default::default(), @@ -827,6 +830,33 @@ impl Writer { Ok(()) } + /// Emit code for the arithmetic expression of the dot product. + /// + fn put_dot_product( + &mut self, + arg: Handle, + arg1: Handle, + size: usize, + ) -> BackendResult { + write!(self.out, "(")?; + + let arg0_name = &self.named_expressions[&arg]; + let arg1_name = &self.named_expressions[&arg1]; + + // This will print an extra '+' at the beginning but that is fine in msl + for index in 0..size { + let component = back::COMPONENTS[index]; + write!( + self.out, + " + {}.{} * {}.{}", + arg0_name, component, arg1_name, component + )?; + } + + write!(self.out, ")")?; + Ok(()) + } + /// Emit code for the expression `expr_handle`. /// /// The `is_scoped` argument is true if the surrounding operators have the @@ -1216,7 +1246,18 @@ impl Writer { Mf::Log2 => "log2", Mf::Pow => "pow", // geometry - Mf::Dot => "dot", + Mf::Dot => match *context.resolve_type(arg) { + crate::TypeInner::Vector { + kind: crate::ScalarKind::Float, + .. + } => "dot", + crate::TypeInner::Vector { size, .. } => { + return self.put_dot_product(arg, arg1.unwrap(), size as usize) + } + _ => unreachable!( + "Correct TypeInner for dot product should be already validated" + ), + }, Mf::Outer => return Err(Error::UnsupportedCall(format!("{:?}", fun))), Mf::Cross => "cross", Mf::Distance => "distance", @@ -1810,6 +1851,55 @@ impl Writer { Ok(()) } + /// Helper method used to find which expressions of a given function require baking + /// + /// # Notes + /// This function overwrites the contents of `self.need_bake_expressions` + fn update_expressions_to_bake( + &mut self, + func: &crate::Function, + info: &valid::FunctionInfo, + context: &ExpressionContext, + ) { + use crate::Expression; + self.need_bake_expressions.clear(); + for expr in func.expressions.iter() { + // Expressions whose reference count is above the + // threshold should always be stored in temporaries. + let expr_info = &info[expr.0]; + let min_ref_count = func.expressions[expr.0].bake_ref_count(); + if min_ref_count <= expr_info.ref_count { + self.need_bake_expressions.insert(expr.0); + } + // if the expression is a Dot product with integer arguments, + // then the args needs baking as well + if let ( + fun_handle, + &Expression::Math { + fun: crate::MathFunction::Dot, + arg, + arg1, + .. + }, + ) = expr + { + use crate::TypeInner; + // check what kind of product this is depending + // on the resolve type of the Dot function itself + let inner = context.resolve_type(fun_handle); + if let TypeInner::Scalar { kind, .. } = *inner { + match kind { + crate::ScalarKind::Sint | crate::ScalarKind::Uint => { + self.need_bake_expressions.insert(arg); + self.need_bake_expressions.insert(arg1.unwrap()); + } + _ => {} + } + } + } + } + } + fn start_baking_expression( &mut self, handle: Handle, @@ -1913,12 +2003,7 @@ impl Writer { if context.expression.guarded_indices.contains(handle.index()) { true } else { - // Expressions whose reference count is above the - // threshold should always be stored in temporaries. - let min_ref_count = context.expression.function.expressions - [handle] - .bake_ref_count(); - min_ref_count <= info.ref_count + self.need_bake_expressions.contains(&handle) }; if bake { @@ -2763,6 +2848,7 @@ impl Writer { result_struct: None, }; self.named_expressions.clear(); + self.update_expressions_to_bake(fun, fun_info, &context.expression); self.put_block(back::Level(1), &fun.body, &context)?; writeln!(self.out, "}}")?; } @@ -3226,6 +3312,7 @@ impl Writer { result_struct: Some(&stage_out_name), }; self.named_expressions.clear(); + self.update_expressions_to_bake(fun, fun_info, &context.expression); self.put_block(back::Level(1), &fun.body, &context)?; writeln!(self.out, "}}")?; if ep_index + 1 != module.entry_points.len() { diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index ab33acfc82..c4ad830679 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -554,13 +554,34 @@ impl<'w> BlockContext<'w> { Mf::Frexp => MathOp::Ext(spirv::GLOp::Frexp), Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp), // geometry - Mf::Dot => MathOp::Custom(Instruction::binary( - spirv::Op::Dot, - result_type_id, - id, - arg0_id, - arg1_id, - )), + Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) { + crate::TypeInner::Vector { + kind: crate::ScalarKind::Float, + .. + } => MathOp::Custom(Instruction::binary( + spirv::Op::Dot, + result_type_id, + id, + arg0_id, + arg1_id, + )), + // TODO: consider using integer dot product if VK_KHR_shader_integer_dot_product is available + crate::TypeInner::Vector { size, .. } => { + self.write_dot_product( + id, + result_type_id, + arg0_id, + arg1_id, + size as u32, + block, + ); + self.cached[expr_handle] = id; + return Ok(()); + } + _ => unreachable!( + "Correct TypeInner for dot product should be already validated" + ), + }, Mf::Outer => MathOp::Custom(Instruction::binary( spirv::Op::OuterProduct, result_type_id, @@ -1122,6 +1143,68 @@ impl<'w> BlockContext<'w> { Ok(pointer) } + /// Build the instructions for the arithmetic expression of a dot product + fn write_dot_product( + &mut self, + result_id: Word, + result_type_id: Word, + arg0_id: Word, + arg1_id: Word, + size: u32, + block: &mut Block, + ) { + let const_null = self.gen_id(); + block + .body + .push(Instruction::constant_null(result_type_id, const_null)); + + let mut partial_sum = const_null; + let last_component = size - 1; + for index in 0..=last_component { + // compute the product of the current components + let a_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + result_type_id, + a_id, + arg0_id, + &[index], + )); + let b_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + result_type_id, + b_id, + arg1_id, + &[index], + )); + let prod_id = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::IMul, + result_type_id, + prod_id, + a_id, + b_id, + )); + + // choose the id for the next sum, depending on current index + let id = if index == last_component { + result_id + } else { + self.gen_id() + }; + + // sum the computed product with the partial sum + block.body.push(Instruction::binary( + spirv::Op::IAdd, + result_type_id, + id, + partial_sum, + prod_id, + )); + // set the id of the result as the previous partial sum + partial_sum = id; + } + } + pub(super) fn write_block( &mut self, label_id: Word, diff --git a/src/lib.rs b/src/lib.rs index f384c6ea14..ca516b2bd4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -233,6 +233,7 @@ pub type FastHashSet = rustc_hash::FxHashSet; /// Map of expressions that have associated variable names pub(crate) type NamedExpressions = FastHashMap, String>; +pub(crate) type NeedBakeExpressions = FastHashSet>; /// Early fragment tests. /// diff --git a/src/valid/expression.rs b/src/valid/expression.rs index f7a03abc03..1cba9e9e4c 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1059,7 +1059,28 @@ impl super::Validator { )); } } - Mf::Dot | Mf::Outer | Mf::Cross | Mf::Reflect => { + Mf::Dot => { + let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), None, None) => ty1, + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Vector { + kind: Sk::Float, .. + } + | Ti::Vector { kind: Sk::Sint, .. } + | Ti::Vector { kind: Sk::Uint, .. } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + } + Mf::Outer | Mf::Cross | Mf::Reflect => { let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { (Some(ty1), None, None) => ty1, _ => return Err(ExpressionError::WrongArgumentCount(fun)), diff --git a/tests/in/functions.wgsl b/tests/in/functions.wgsl index 71f736a2b9..6a15da1e37 100644 --- a/tests/in/functions.wgsl +++ b/tests/in/functions.wgsl @@ -8,8 +8,22 @@ fn test_fma() -> vec2 { return fma(a, b, c); } +fn test_integer_dot_product() -> i32 { + let a_2 = vec2(1); + let b_2 = vec2(1); + let c_2: i32 = dot(a_2, b_2); + + let a_3 = vec3(1u); + let b_3 = vec3(1u); + let c_3: u32 = dot(a_3, b_3); + + // test baking of arguments + let c_4: i32 = dot(vec4(4), vec4(2)); + return c_4; +} @stage(compute) @workgroup_size(1) fn main() { let a = test_fma(); + let b = test_integer_dot_product(); } diff --git a/tests/out/glsl/functions.main.Compute.glsl b/tests/out/glsl/functions.main.Compute.glsl index d3e7f7d171..07e4e7cdea 100644 --- a/tests/out/glsl/functions.main.Compute.glsl +++ b/tests/out/glsl/functions.main.Compute.glsl @@ -14,8 +14,22 @@ vec2 test_fma() { return fma(a, b, c); } +int test_integer_dot_product() { + ivec2 a_2_ = ivec2(1); + ivec2 b_2_ = ivec2(1); + int c_2_ = ( + a_2_.x * b_2_.x + a_2_.y * b_2_.y); + uvec3 a_3_ = uvec3(1u); + uvec3 b_3_ = uvec3(1u); + uint c_3_ = ( + a_3_.x * b_3_.x + a_3_.y * b_3_.y + a_3_.z * b_3_.z); + ivec4 _e11 = ivec4(4); + ivec4 _e13 = ivec4(2); + int c_4_ = ( + _e11.x * _e13.x + _e11.y * _e13.y + _e11.z * _e13.z + _e11.w * _e13.w); + return c_4_; +} + void main() { vec2 _e0 = test_fma(); + int _e1 = test_integer_dot_product(); return; } diff --git a/tests/out/hlsl/functions.hlsl b/tests/out/hlsl/functions.hlsl index 37a4ecfafa..c1b420f22d 100644 --- a/tests/out/hlsl/functions.hlsl +++ b/tests/out/hlsl/functions.hlsl @@ -7,9 +7,22 @@ float2 test_fma() return mad(a, b, c); } +int test_integer_dot_product() +{ + int2 a_2_ = int2(1.xx); + int2 b_2_ = int2(1.xx); + int c_2_ = dot(a_2_, b_2_); + uint3 a_3_ = uint3(1u.xxx); + uint3 b_3_ = uint3(1u.xxx); + uint c_3_ = dot(a_3_, b_3_); + int c_4_ = dot(int4(4.xxxx), int4(2.xxxx)); + return c_4_; +} + [numthreads(1, 1, 1)] void main() { const float2 _e0 = test_fma(); + const int _e1 = test_integer_dot_product(); return; } diff --git a/tests/out/msl/functions.msl b/tests/out/msl/functions.msl index 11574c679e..73c5a9fa64 100644 --- a/tests/out/msl/functions.msl +++ b/tests/out/msl/functions.msl @@ -11,8 +11,23 @@ metal::float2 test_fma( return metal::fma(a, b, c); } +int test_integer_dot_product( +) { + metal::int2 a_2_ = metal::int2(1); + metal::int2 b_2_ = metal::int2(1); + int c_2_ = ( + a_2_.x * b_2_.x + a_2_.y * b_2_.y); + metal::uint3 a_3_ = metal::uint3(1u); + metal::uint3 b_3_ = metal::uint3(1u); + metal::uint c_3_ = ( + a_3_.x * b_3_.x + a_3_.y * b_3_.y + a_3_.z * b_3_.z); + metal::int4 _e11 = metal::int4(4); + metal::int4 _e13 = metal::int4(2); + int c_4_ = ( + _e11.x * _e13.x + _e11.y * _e13.y + _e11.z * _e13.z + _e11.w * _e13.w); + return c_4_; +} + kernel void main_( ) { metal::float2 _e0 = test_fma(); + int _e1 = test_integer_dot_product(); return; } diff --git a/tests/out/spv/functions.spvasm b/tests/out/spv/functions.spvasm index 9c1a62ae4c..da2d240aab 100644 --- a/tests/out/spv/functions.spvasm +++ b/tests/out/spv/functions.spvasm @@ -1,33 +1,95 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 20 +; Bound: 79 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %16 "main" -OpExecutionMode %16 LocalSize 1 1 1 +OpEntryPoint GLCompute %74 "main" +OpExecutionMode %74 LocalSize 1 1 1 %2 = OpTypeVoid %4 = OpTypeFloat 32 %3 = OpConstant %4 2.0 %5 = OpConstant %4 0.5 -%6 = OpTypeVector %4 2 -%9 = OpTypeFunction %6 -%17 = OpTypeFunction %2 -%8 = OpFunction %6 None %9 -%7 = OpLabel -OpBranch %10 -%10 = OpLabel -%11 = OpCompositeConstruct %6 %3 %3 -%12 = OpCompositeConstruct %6 %5 %5 -%13 = OpCompositeConstruct %6 %5 %5 -%14 = OpExtInst %6 %1 Fma %11 %12 %13 -OpReturnValue %14 +%7 = OpTypeInt 32 1 +%6 = OpConstant %7 1 +%9 = OpTypeInt 32 0 +%8 = OpConstant %9 1 +%10 = OpConstant %7 4 +%11 = OpConstant %7 2 +%12 = OpTypeVector %4 2 +%15 = OpTypeFunction %12 +%23 = OpTypeFunction %7 +%25 = OpTypeVector %7 2 +%37 = OpTypeVector %9 3 +%53 = OpTypeVector %7 4 +%75 = OpTypeFunction %2 +%29 = OpConstantNull %7 +%41 = OpConstantNull %9 +%57 = OpConstantNull %7 +%14 = OpFunction %12 None %15 +%13 = OpLabel +OpBranch %16 +%16 = OpLabel +%17 = OpCompositeConstruct %12 %3 %3 +%18 = OpCompositeConstruct %12 %5 %5 +%19 = OpCompositeConstruct %12 %5 %5 +%20 = OpExtInst %12 %1 Fma %17 %18 %19 +OpReturnValue %20 OpFunctionEnd -%16 = OpFunction %2 None %17 -%15 = OpLabel -OpBranch %18 -%18 = OpLabel -%19 = OpFunctionCall %6 %8 +%22 = OpFunction %7 None %23 +%21 = OpLabel +OpBranch %24 +%24 = OpLabel +%26 = OpCompositeConstruct %25 %6 %6 +%27 = OpCompositeConstruct %25 %6 %6 +%30 = OpCompositeExtract %7 %26 0 +%31 = OpCompositeExtract %7 %27 0 +%32 = OpIMul %7 %30 %31 +%33 = OpIAdd %7 %29 %32 +%34 = OpCompositeExtract %7 %26 1 +%35 = OpCompositeExtract %7 %27 1 +%36 = OpIMul %7 %34 %35 +%28 = OpIAdd %7 %33 %36 +%38 = OpCompositeConstruct %37 %8 %8 %8 +%39 = OpCompositeConstruct %37 %8 %8 %8 +%42 = OpCompositeExtract %9 %38 0 +%43 = OpCompositeExtract %9 %39 0 +%44 = OpIMul %9 %42 %43 +%45 = OpIAdd %9 %41 %44 +%46 = OpCompositeExtract %9 %38 1 +%47 = OpCompositeExtract %9 %39 1 +%48 = OpIMul %9 %46 %47 +%49 = OpIAdd %9 %45 %48 +%50 = OpCompositeExtract %9 %38 2 +%51 = OpCompositeExtract %9 %39 2 +%52 = OpIMul %9 %50 %51 +%40 = OpIAdd %9 %49 %52 +%54 = OpCompositeConstruct %53 %10 %10 %10 %10 +%55 = OpCompositeConstruct %53 %11 %11 %11 %11 +%58 = OpCompositeExtract %7 %54 0 +%59 = OpCompositeExtract %7 %55 0 +%60 = OpIMul %7 %58 %59 +%61 = OpIAdd %7 %57 %60 +%62 = OpCompositeExtract %7 %54 1 +%63 = OpCompositeExtract %7 %55 1 +%64 = OpIMul %7 %62 %63 +%65 = OpIAdd %7 %61 %64 +%66 = OpCompositeExtract %7 %54 2 +%67 = OpCompositeExtract %7 %55 2 +%68 = OpIMul %7 %66 %67 +%69 = OpIAdd %7 %65 %68 +%70 = OpCompositeExtract %7 %54 3 +%71 = OpCompositeExtract %7 %55 3 +%72 = OpIMul %7 %70 %71 +%56 = OpIAdd %7 %69 %72 +OpReturnValue %56 +OpFunctionEnd +%74 = OpFunction %2 None %75 +%73 = OpLabel +OpBranch %76 +%76 = OpLabel +%77 = OpFunctionCall %12 %14 +%78 = OpFunctionCall %7 %22 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/functions.wgsl b/tests/out/wgsl/functions.wgsl index 5a53e38c6e..159daeb5b6 100644 --- a/tests/out/wgsl/functions.wgsl +++ b/tests/out/wgsl/functions.wgsl @@ -5,8 +5,20 @@ fn test_fma() -> vec2 { return fma(a, b, c); } +fn test_integer_dot_product() -> i32 { + let a_2_ = vec2(1); + let b_2_ = vec2(1); + let c_2_ = dot(a_2_, b_2_); + let a_3_ = vec3(1u); + let b_3_ = vec3(1u); + let c_3_ = dot(a_3_, b_3_); + let c_4_ = dot(vec4(4), vec4(2)); + return c_4_; +} + @stage(compute) @workgroup_size(1, 1, 1) fn main() { let _e0 = test_fma(); + let _e1 = test_integer_dot_product(); return; }