Add support for vecN<i32> and vecN<u32> to dot() function (#1689)

* Allow vecN<i32> and vecN<u32> in `dot()`, first changes

* Added a test case

* Fix the test

* Changes to baking of expressions, incl args of integer dot product

* Implemented requested changes for glsl backend

* Added support for integer dot product on MSL backend

* Removed outdated code for hlsl and wgls writers

* Implement in spv backend

* Commit modified outputs from running the tests

* cargo fmt

* Applied requested changes for both MSL and GLSL back

* Changes to spv back

* Committed all test output changes

* Cargo fmt

* Added a comment w.r.t. VK_KHR_shader_integer_dot_product

* Implemented requested svp change

* Minor change to test case

This is because I wanted to highlight the fact that the correct
id is used in the last sum of the integer dot product expression

* Changed function signature

since it could not fail, changed it to simply return `void`
This commit is contained in:
francesco-cattoglio
2022-02-03 20:03:43 +01:00
committed by GitHub
parent 42bf3545c9
commit b235973d2e
11 changed files with 442 additions and 42 deletions

View File

@@ -417,6 +417,8 @@ pub struct Writer<'a, W> {
block_id: IdGenerator,
/// Set of expressions that have associated temporary variables.
named_expressions: crate::NamedExpressions,
/// Set of expressions that need to be baked to avoid unnecessary repetition in output
need_bake_expressions: crate::NeedBakeExpressions,
}
impl<'a, W: Write> Writer<'a, W> {
@@ -468,6 +470,7 @@ impl<'a, W: Write> Writer<'a, W> {
block_id: IdGenerator::default(),
named_expressions: crate::NamedExpressions::default(),
need_bake_expressions: crate::NeedBakeExpressions::default(),
};
// Find all features required to print this module
@@ -1000,6 +1003,45 @@ impl<'a, W: Write> Writer<'a, W> {
Ok(())
}
/// Helper method used to find which expressions of a given function require baking
///
/// # Notes
/// Clears `need_bake_expressions` set before adding to it
fn update_expressions_to_bake(&mut self, func: &crate::Function, info: &valid::FunctionInfo) {
use crate::Expression;
self.need_bake_expressions.clear();
for expr in func.expressions.iter() {
let expr_info = &info[expr.0];
let min_ref_count = func.expressions[expr.0].bake_ref_count();
if min_ref_count <= expr_info.ref_count {
self.need_bake_expressions.insert(expr.0);
}
// if the expression is a Dot product with integer arguments,
// then the args needs baking as well
if let (
fun_handle,
&Expression::Math {
fun: crate::MathFunction::Dot,
arg,
arg1,
..
},
) = expr
{
let inner = info[fun_handle].ty.inner_with(&self.module.types);
if let TypeInner::Scalar { kind, .. } = *inner {
match kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
}
_ => {}
}
}
}
}
}
/// Helper method used to get a name for a global
///
/// Globals have different naming schemes depending on their binding:
@@ -1151,6 +1193,7 @@ impl<'a, W: Write> Writer<'a, W> {
};
self.named_expressions.clear();
self.update_expressions_to_bake(func, info);
// Write the function header
//
@@ -1401,6 +1444,33 @@ impl<'a, W: Write> Writer<'a, W> {
Ok(())
}
/// Helper method used to output a dot product as an arithmetic expression
///
fn write_dot_product(
&mut self,
arg: Handle<crate::Expression>,
arg1: Handle<crate::Expression>,
size: usize,
) -> BackendResult {
write!(self.out, "(")?;
let arg0_name = &self.named_expressions[&arg];
let arg1_name = &self.named_expressions[&arg1];
// This will print an extra '+' at the beginning but that is fine in glsl
for index in 0..size {
let component = back::COMPONENTS[index];
write!(
self.out,
" + {}.{} * {}.{}",
arg0_name, component, arg1_name, component
)?;
}
write!(self.out, ")")?;
Ok(())
}
/// Helper method used to write structs
///
/// # Notes
@@ -1490,13 +1560,10 @@ impl<'a, W: Write> Writer<'a, W> {
// Otherwise, we could accidentally write variable name instead of full expression.
// Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
Some(self.namer.call(name))
} else if self.need_bake_expressions.contains(&handle) {
Some(format!("{}{}", super::BAKE_PREFIX, handle.index()))
} else {
let min_ref_count = ctx.expressions[handle].bake_ref_count();
if min_ref_count <= info.ref_count {
Some(format!("{}{}", super::BAKE_PREFIX, handle.index()))
} else {
None
}
None
};
if let Some(name) = expr_name {
@@ -2538,7 +2605,18 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Log2 => "log2",
Mf::Pow => "pow",
// geometry
Mf::Dot => "dot",
Mf::Dot => match *ctx.info[arg].ty.inner_with(&self.module.types) {
crate::TypeInner::Vector {
kind: crate::ScalarKind::Float,
..
} => "dot",
crate::TypeInner::Vector { size, .. } => {
return self.write_dot_product(arg, arg1.unwrap(), size as usize)
}
_ => unreachable!(
"Correct TypeInner for dot product should be already validated"
),
},
Mf::Outer => "outerProduct",
Mf::Cross => "cross",
Mf::Distance => "distance",

View File

@@ -309,6 +309,8 @@ pub struct Writer<W> {
out: W,
names: FastHashMap<NameKey, String>,
named_expressions: crate::NamedExpressions,
/// Set of expressions that need to be baked to avoid unnecessary repetition in output
need_bake_expressions: crate::NeedBakeExpressions,
namer: proc::Namer,
#[cfg(test)]
put_expression_stack_pointers: FastHashSet<*const ()>,
@@ -526,6 +528,7 @@ impl<W: Write> Writer<W> {
out,
names: FastHashMap::default(),
named_expressions: crate::NamedExpressions::default(),
need_bake_expressions: crate::NeedBakeExpressions::default(),
namer: proc::Namer::default(),
#[cfg(test)]
put_expression_stack_pointers: Default::default(),
@@ -827,6 +830,33 @@ impl<W: Write> Writer<W> {
Ok(())
}
/// Emit code for the arithmetic expression of the dot product.
///
fn put_dot_product(
&mut self,
arg: Handle<crate::Expression>,
arg1: Handle<crate::Expression>,
size: usize,
) -> BackendResult {
write!(self.out, "(")?;
let arg0_name = &self.named_expressions[&arg];
let arg1_name = &self.named_expressions[&arg1];
// This will print an extra '+' at the beginning but that is fine in msl
for index in 0..size {
let component = back::COMPONENTS[index];
write!(
self.out,
" + {}.{} * {}.{}",
arg0_name, component, arg1_name, component
)?;
}
write!(self.out, ")")?;
Ok(())
}
/// Emit code for the expression `expr_handle`.
///
/// The `is_scoped` argument is true if the surrounding operators have the
@@ -1216,7 +1246,18 @@ impl<W: Write> Writer<W> {
Mf::Log2 => "log2",
Mf::Pow => "pow",
// geometry
Mf::Dot => "dot",
Mf::Dot => match *context.resolve_type(arg) {
crate::TypeInner::Vector {
kind: crate::ScalarKind::Float,
..
} => "dot",
crate::TypeInner::Vector { size, .. } => {
return self.put_dot_product(arg, arg1.unwrap(), size as usize)
}
_ => unreachable!(
"Correct TypeInner for dot product should be already validated"
),
},
Mf::Outer => return Err(Error::UnsupportedCall(format!("{:?}", fun))),
Mf::Cross => "cross",
Mf::Distance => "distance",
@@ -1810,6 +1851,55 @@ impl<W: Write> Writer<W> {
Ok(())
}
/// Helper method used to find which expressions of a given function require baking
///
/// # Notes
/// This function overwrites the contents of `self.need_bake_expressions`
fn update_expressions_to_bake(
&mut self,
func: &crate::Function,
info: &valid::FunctionInfo,
context: &ExpressionContext,
) {
use crate::Expression;
self.need_bake_expressions.clear();
for expr in func.expressions.iter() {
// Expressions whose reference count is above the
// threshold should always be stored in temporaries.
let expr_info = &info[expr.0];
let min_ref_count = func.expressions[expr.0].bake_ref_count();
if min_ref_count <= expr_info.ref_count {
self.need_bake_expressions.insert(expr.0);
}
// if the expression is a Dot product with integer arguments,
// then the args needs baking as well
if let (
fun_handle,
&Expression::Math {
fun: crate::MathFunction::Dot,
arg,
arg1,
..
},
) = expr
{
use crate::TypeInner;
// check what kind of product this is depending
// on the resolve type of the Dot function itself
let inner = context.resolve_type(fun_handle);
if let TypeInner::Scalar { kind, .. } = *inner {
match kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
}
_ => {}
}
}
}
}
}
fn start_baking_expression(
&mut self,
handle: Handle<crate::Expression>,
@@ -1913,12 +2003,7 @@ impl<W: Write> Writer<W> {
if context.expression.guarded_indices.contains(handle.index()) {
true
} else {
// Expressions whose reference count is above the
// threshold should always be stored in temporaries.
let min_ref_count = context.expression.function.expressions
[handle]
.bake_ref_count();
min_ref_count <= info.ref_count
self.need_bake_expressions.contains(&handle)
};
if bake {
@@ -2763,6 +2848,7 @@ impl<W: Write> Writer<W> {
result_struct: None,
};
self.named_expressions.clear();
self.update_expressions_to_bake(fun, fun_info, &context.expression);
self.put_block(back::Level(1), &fun.body, &context)?;
writeln!(self.out, "}}")?;
}
@@ -3226,6 +3312,7 @@ impl<W: Write> Writer<W> {
result_struct: Some(&stage_out_name),
};
self.named_expressions.clear();
self.update_expressions_to_bake(fun, fun_info, &context.expression);
self.put_block(back::Level(1), &fun.body, &context)?;
writeln!(self.out, "}}")?;
if ep_index + 1 != module.entry_points.len() {

View File

@@ -554,13 +554,34 @@ impl<'w> BlockContext<'w> {
Mf::Frexp => MathOp::Ext(spirv::GLOp::Frexp),
Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp),
// geometry
Mf::Dot => MathOp::Custom(Instruction::binary(
spirv::Op::Dot,
result_type_id,
id,
arg0_id,
arg1_id,
)),
Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) {
crate::TypeInner::Vector {
kind: crate::ScalarKind::Float,
..
} => MathOp::Custom(Instruction::binary(
spirv::Op::Dot,
result_type_id,
id,
arg0_id,
arg1_id,
)),
// TODO: consider using integer dot product if VK_KHR_shader_integer_dot_product is available
crate::TypeInner::Vector { size, .. } => {
self.write_dot_product(
id,
result_type_id,
arg0_id,
arg1_id,
size as u32,
block,
);
self.cached[expr_handle] = id;
return Ok(());
}
_ => unreachable!(
"Correct TypeInner for dot product should be already validated"
),
},
Mf::Outer => MathOp::Custom(Instruction::binary(
spirv::Op::OuterProduct,
result_type_id,
@@ -1122,6 +1143,68 @@ impl<'w> BlockContext<'w> {
Ok(pointer)
}
/// Build the instructions for the arithmetic expression of a dot product
fn write_dot_product(
&mut self,
result_id: Word,
result_type_id: Word,
arg0_id: Word,
arg1_id: Word,
size: u32,
block: &mut Block,
) {
let const_null = self.gen_id();
block
.body
.push(Instruction::constant_null(result_type_id, const_null));
let mut partial_sum = const_null;
let last_component = size - 1;
for index in 0..=last_component {
// compute the product of the current components
let a_id = self.gen_id();
block.body.push(Instruction::composite_extract(
result_type_id,
a_id,
arg0_id,
&[index],
));
let b_id = self.gen_id();
block.body.push(Instruction::composite_extract(
result_type_id,
b_id,
arg1_id,
&[index],
));
let prod_id = self.gen_id();
block.body.push(Instruction::binary(
spirv::Op::IMul,
result_type_id,
prod_id,
a_id,
b_id,
));
// choose the id for the next sum, depending on current index
let id = if index == last_component {
result_id
} else {
self.gen_id()
};
// sum the computed product with the partial sum
block.body.push(Instruction::binary(
spirv::Op::IAdd,
result_type_id,
id,
partial_sum,
prod_id,
));
// set the id of the result as the previous partial sum
partial_sum = id;
}
}
pub(super) fn write_block(
&mut self,
label_id: Word,

View File

@@ -233,6 +233,7 @@ pub type FastHashSet<K> = rustc_hash::FxHashSet<K>;
/// Map of expressions that have associated variable names
pub(crate) type NamedExpressions = FastHashMap<Handle<Expression>, String>;
pub(crate) type NeedBakeExpressions = FastHashSet<Handle<Expression>>;
/// Early fragment tests.
///

View File

@@ -1059,7 +1059,28 @@ impl super::Validator {
));
}
}
Mf::Dot | Mf::Outer | Mf::Cross | Mf::Reflect => {
Mf::Dot => {
let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
(Some(ty1), None, None) => ty1,
_ => return Err(ExpressionError::WrongArgumentCount(fun)),
};
match *arg_ty {
Ti::Vector {
kind: Sk::Float, ..
}
| Ti::Vector { kind: Sk::Sint, .. }
| Ti::Vector { kind: Sk::Uint, .. } => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
if arg1_ty != arg_ty {
return Err(ExpressionError::InvalidArgumentType(
fun,
1,
arg1.unwrap(),
));
}
}
Mf::Outer | Mf::Cross | Mf::Reflect => {
let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
(Some(ty1), None, None) => ty1,
_ => return Err(ExpressionError::WrongArgumentCount(fun)),

View File

@@ -8,8 +8,22 @@ fn test_fma() -> vec2<f32> {
return fma(a, b, c);
}
fn test_integer_dot_product() -> i32 {
let a_2 = vec2<i32>(1);
let b_2 = vec2<i32>(1);
let c_2: i32 = dot(a_2, b_2);
let a_3 = vec3<u32>(1u);
let b_3 = vec3<u32>(1u);
let c_3: u32 = dot(a_3, b_3);
// test baking of arguments
let c_4: i32 = dot(vec4<i32>(4), vec4<i32>(2));
return c_4;
}
@stage(compute) @workgroup_size(1)
fn main() {
let a = test_fma();
let b = test_integer_dot_product();
}

View File

@@ -14,8 +14,22 @@ vec2 test_fma() {
return fma(a, b, c);
}
int test_integer_dot_product() {
ivec2 a_2_ = ivec2(1);
ivec2 b_2_ = ivec2(1);
int c_2_ = ( + a_2_.x * b_2_.x + a_2_.y * b_2_.y);
uvec3 a_3_ = uvec3(1u);
uvec3 b_3_ = uvec3(1u);
uint c_3_ = ( + a_3_.x * b_3_.x + a_3_.y * b_3_.y + a_3_.z * b_3_.z);
ivec4 _e11 = ivec4(4);
ivec4 _e13 = ivec4(2);
int c_4_ = ( + _e11.x * _e13.x + _e11.y * _e13.y + _e11.z * _e13.z + _e11.w * _e13.w);
return c_4_;
}
void main() {
vec2 _e0 = test_fma();
int _e1 = test_integer_dot_product();
return;
}

View File

@@ -7,9 +7,22 @@ float2 test_fma()
return mad(a, b, c);
}
int test_integer_dot_product()
{
int2 a_2_ = int2(1.xx);
int2 b_2_ = int2(1.xx);
int c_2_ = dot(a_2_, b_2_);
uint3 a_3_ = uint3(1u.xxx);
uint3 b_3_ = uint3(1u.xxx);
uint c_3_ = dot(a_3_, b_3_);
int c_4_ = dot(int4(4.xxxx), int4(2.xxxx));
return c_4_;
}
[numthreads(1, 1, 1)]
void main()
{
const float2 _e0 = test_fma();
const int _e1 = test_integer_dot_product();
return;
}

View File

@@ -11,8 +11,23 @@ metal::float2 test_fma(
return metal::fma(a, b, c);
}
int test_integer_dot_product(
) {
metal::int2 a_2_ = metal::int2(1);
metal::int2 b_2_ = metal::int2(1);
int c_2_ = ( + a_2_.x * b_2_.x + a_2_.y * b_2_.y);
metal::uint3 a_3_ = metal::uint3(1u);
metal::uint3 b_3_ = metal::uint3(1u);
metal::uint c_3_ = ( + a_3_.x * b_3_.x + a_3_.y * b_3_.y + a_3_.z * b_3_.z);
metal::int4 _e11 = metal::int4(4);
metal::int4 _e13 = metal::int4(2);
int c_4_ = ( + _e11.x * _e13.x + _e11.y * _e13.y + _e11.z * _e13.z + _e11.w * _e13.w);
return c_4_;
}
kernel void main_(
) {
metal::float2 _e0 = test_fma();
int _e1 = test_integer_dot_product();
return;
}

View File

@@ -1,33 +1,95 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 20
; Bound: 79
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %16 "main"
OpExecutionMode %16 LocalSize 1 1 1
OpEntryPoint GLCompute %74 "main"
OpExecutionMode %74 LocalSize 1 1 1
%2 = OpTypeVoid
%4 = OpTypeFloat 32
%3 = OpConstant %4 2.0
%5 = OpConstant %4 0.5
%6 = OpTypeVector %4 2
%9 = OpTypeFunction %6
%17 = OpTypeFunction %2
%8 = OpFunction %6 None %9
%7 = OpLabel
OpBranch %10
%10 = OpLabel
%11 = OpCompositeConstruct %6 %3 %3
%12 = OpCompositeConstruct %6 %5 %5
%13 = OpCompositeConstruct %6 %5 %5
%14 = OpExtInst %6 %1 Fma %11 %12 %13
OpReturnValue %14
%7 = OpTypeInt 32 1
%6 = OpConstant %7 1
%9 = OpTypeInt 32 0
%8 = OpConstant %9 1
%10 = OpConstant %7 4
%11 = OpConstant %7 2
%12 = OpTypeVector %4 2
%15 = OpTypeFunction %12
%23 = OpTypeFunction %7
%25 = OpTypeVector %7 2
%37 = OpTypeVector %9 3
%53 = OpTypeVector %7 4
%75 = OpTypeFunction %2
%29 = OpConstantNull %7
%41 = OpConstantNull %9
%57 = OpConstantNull %7
%14 = OpFunction %12 None %15
%13 = OpLabel
OpBranch %16
%16 = OpLabel
%17 = OpCompositeConstruct %12 %3 %3
%18 = OpCompositeConstruct %12 %5 %5
%19 = OpCompositeConstruct %12 %5 %5
%20 = OpExtInst %12 %1 Fma %17 %18 %19
OpReturnValue %20
OpFunctionEnd
%16 = OpFunction %2 None %17
%15 = OpLabel
OpBranch %18
%18 = OpLabel
%19 = OpFunctionCall %6 %8
%22 = OpFunction %7 None %23
%21 = OpLabel
OpBranch %24
%24 = OpLabel
%26 = OpCompositeConstruct %25 %6 %6
%27 = OpCompositeConstruct %25 %6 %6
%30 = OpCompositeExtract %7 %26 0
%31 = OpCompositeExtract %7 %27 0
%32 = OpIMul %7 %30 %31
%33 = OpIAdd %7 %29 %32
%34 = OpCompositeExtract %7 %26 1
%35 = OpCompositeExtract %7 %27 1
%36 = OpIMul %7 %34 %35
%28 = OpIAdd %7 %33 %36
%38 = OpCompositeConstruct %37 %8 %8 %8
%39 = OpCompositeConstruct %37 %8 %8 %8
%42 = OpCompositeExtract %9 %38 0
%43 = OpCompositeExtract %9 %39 0
%44 = OpIMul %9 %42 %43
%45 = OpIAdd %9 %41 %44
%46 = OpCompositeExtract %9 %38 1
%47 = OpCompositeExtract %9 %39 1
%48 = OpIMul %9 %46 %47
%49 = OpIAdd %9 %45 %48
%50 = OpCompositeExtract %9 %38 2
%51 = OpCompositeExtract %9 %39 2
%52 = OpIMul %9 %50 %51
%40 = OpIAdd %9 %49 %52
%54 = OpCompositeConstruct %53 %10 %10 %10 %10
%55 = OpCompositeConstruct %53 %11 %11 %11 %11
%58 = OpCompositeExtract %7 %54 0
%59 = OpCompositeExtract %7 %55 0
%60 = OpIMul %7 %58 %59
%61 = OpIAdd %7 %57 %60
%62 = OpCompositeExtract %7 %54 1
%63 = OpCompositeExtract %7 %55 1
%64 = OpIMul %7 %62 %63
%65 = OpIAdd %7 %61 %64
%66 = OpCompositeExtract %7 %54 2
%67 = OpCompositeExtract %7 %55 2
%68 = OpIMul %7 %66 %67
%69 = OpIAdd %7 %65 %68
%70 = OpCompositeExtract %7 %54 3
%71 = OpCompositeExtract %7 %55 3
%72 = OpIMul %7 %70 %71
%56 = OpIAdd %7 %69 %72
OpReturnValue %56
OpFunctionEnd
%74 = OpFunction %2 None %75
%73 = OpLabel
OpBranch %76
%76 = OpLabel
%77 = OpFunctionCall %12 %14
%78 = OpFunctionCall %7 %22
OpReturn
OpFunctionEnd

View File

@@ -5,8 +5,20 @@ fn test_fma() -> vec2<f32> {
return fma(a, b, c);
}
fn test_integer_dot_product() -> i32 {
let a_2_ = vec2<i32>(1);
let b_2_ = vec2<i32>(1);
let c_2_ = dot(a_2_, b_2_);
let a_3_ = vec3<u32>(1u);
let b_3_ = vec3<u32>(1u);
let c_3_ = dot(a_3_, b_3_);
let c_4_ = dot(vec4<i32>(4), vec4<i32>(2));
return c_4_;
}
@stage(compute) @workgroup_size(1, 1, 1)
fn main() {
let _e0 = test_fma();
let _e1 = test_integer_dot_product();
return;
}