mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
Add FindLsb / FindMsb (#1473)
* Add FindLsb / FindMsb * Fixes and tests for FindLsb/FindMsb * Add findLsb / findMsb as WGSL builtins * Fix tests * Fix incompatible type issue with MSL output * Requested changes * Test fewer cases of findLsb/findMsb
This commit is contained in:
@@ -2485,6 +2485,8 @@ impl<'a, W: Write> Writer<'a, W> {
|
||||
Mf::ReverseBits => "bitfieldReverse",
|
||||
Mf::ExtractBits => "bitfieldExtract",
|
||||
Mf::InsertBits => "bitfieldInsert",
|
||||
Mf::FindLsb => "findLSB",
|
||||
Mf::FindMsb => "findMSB",
|
||||
// data packing
|
||||
Mf::Pack4x8snorm => "packSnorm4x8",
|
||||
Mf::Pack4x8unorm => "packUnorm4x8",
|
||||
|
||||
@@ -1874,6 +1874,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
// bits
|
||||
Mf::CountOneBits => Function::Regular("countbits"),
|
||||
Mf::ReverseBits => Function::Regular("reversebits"),
|
||||
Mf::FindLsb => Function::Regular("firstbitlow"),
|
||||
Mf::FindMsb => Function::Regular("firstbithigh"),
|
||||
_ => return Err(Error::Unimplemented(format!("write_expr_math {:?}", fun))),
|
||||
};
|
||||
|
||||
|
||||
@@ -1099,6 +1099,21 @@ impl<W: Write> Writer<W> {
|
||||
crate::TypeInner::Scalar { .. } => true,
|
||||
_ => false,
|
||||
};
|
||||
let argument_size_suffix = match *context.resolve_type(arg) {
|
||||
crate::TypeInner::Vector {
|
||||
size: crate::VectorSize::Bi,
|
||||
..
|
||||
} => "2",
|
||||
crate::TypeInner::Vector {
|
||||
size: crate::VectorSize::Tri,
|
||||
..
|
||||
} => "3",
|
||||
crate::TypeInner::Vector {
|
||||
size: crate::VectorSize::Quad,
|
||||
..
|
||||
} => "4",
|
||||
_ => "",
|
||||
};
|
||||
|
||||
let fun_name = match fun {
|
||||
// comparison
|
||||
@@ -1162,6 +1177,8 @@ impl<W: Write> Writer<W> {
|
||||
Mf::ReverseBits => "reverse_bits",
|
||||
Mf::ExtractBits => "extract_bits",
|
||||
Mf::InsertBits => "insert_bits",
|
||||
Mf::FindLsb => "",
|
||||
Mf::FindMsb => "",
|
||||
// data packing
|
||||
Mf::Pack4x8snorm => "pack_float_to_unorm4x8",
|
||||
Mf::Pack4x8unorm => "pack_float_to_snorm4x8",
|
||||
@@ -1182,6 +1199,22 @@ impl<W: Write> Writer<W> {
|
||||
write!(self.out, " - ")?;
|
||||
self.put_expression(arg1.unwrap(), context, false)?;
|
||||
write!(self.out, ")")?;
|
||||
} else if fun == Mf::FindLsb {
|
||||
write!(
|
||||
self.out,
|
||||
"(((1 + int{}({}::ctz(",
|
||||
argument_size_suffix, NAMESPACE
|
||||
)?;
|
||||
self.put_expression(arg, context, true)?;
|
||||
write!(self.out, "))) % 33) - 1)")?;
|
||||
} else if fun == Mf::FindMsb {
|
||||
write!(
|
||||
self.out,
|
||||
"(((1 + int{}({}::clz(",
|
||||
argument_size_suffix, NAMESPACE
|
||||
)?;
|
||||
self.put_expression(arg, context, true)?;
|
||||
write!(self.out, "))) % 33) - 1)")?;
|
||||
} else if fun == Mf::Unpack2x16float {
|
||||
write!(self.out, "float2(as_type<half2>(")?;
|
||||
self.put_expression(arg, context, false)?;
|
||||
|
||||
@@ -662,6 +662,12 @@ impl<'w> BlockContext<'w> {
|
||||
arg2_id,
|
||||
arg3_id,
|
||||
)),
|
||||
Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb),
|
||||
Mf::FindMsb => MathOp::Ext(match arg_scalar_kind {
|
||||
Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
|
||||
Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
|
||||
other => unimplemented!("Unexpected findMSB({:?})", other),
|
||||
}),
|
||||
Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8),
|
||||
Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8),
|
||||
Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16),
|
||||
|
||||
@@ -1566,6 +1566,8 @@ impl<W: Write> Writer<W> {
|
||||
Mf::ReverseBits => Function::Regular("reverseBits"),
|
||||
Mf::ExtractBits => Function::Regular("extractBits"),
|
||||
Mf::InsertBits => Function::Regular("insertBits"),
|
||||
Mf::FindLsb => Function::Regular("findLsb"),
|
||||
Mf::FindMsb => Function::Regular("findMsb"),
|
||||
// data packing
|
||||
Mf::Pack4x8snorm => Function::Regular("pack4x8snorm"),
|
||||
Mf::Pack4x8unorm => Function::Regular("pack4x8unorm"),
|
||||
|
||||
@@ -630,12 +630,15 @@ pub fn inject_builtin(declaration: &mut FunctionDeclaration, module: &mut Module
|
||||
))
|
||||
}
|
||||
}
|
||||
"bitCount" | "bitfieldReverse" | "bitfieldExtract" | "bitfieldInsert" => {
|
||||
"bitCount" | "bitfieldReverse" | "bitfieldExtract" | "bitfieldInsert" | "findLSB"
|
||||
| "findMSB" => {
|
||||
let fun = match name {
|
||||
"bitCount" => MathFunction::CountOneBits,
|
||||
"bitfieldReverse" => MathFunction::ReverseBits,
|
||||
"bitfieldExtract" => MathFunction::ExtractBits,
|
||||
"bitfieldInsert" => MathFunction::InsertBits,
|
||||
"findLSB" => MathFunction::FindLsb,
|
||||
"findMSB" => MathFunction::FindMsb,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
|
||||
@@ -2637,6 +2637,8 @@ impl<I: Iterator<Item = u32>> Parser<I> {
|
||||
Glo::UnpackHalf2x16 => Mf::Unpack2x16float,
|
||||
Glo::UnpackUnorm2x16 => Mf::Unpack2x16unorm,
|
||||
Glo::UnpackSnorm2x16 => Mf::Unpack2x16snorm,
|
||||
Glo::FindILsb => Mf::FindLsb,
|
||||
Glo::FindUMsb | Glo::FindSMsb => Mf::FindMsb,
|
||||
_ => return Err(Error::UnsupportedExtInst(inst_id)),
|
||||
};
|
||||
|
||||
|
||||
@@ -199,6 +199,8 @@ pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
|
||||
"reverseBits" => Mf::ReverseBits,
|
||||
"extractBits" => Mf::ExtractBits,
|
||||
"insertBits" => Mf::InsertBits,
|
||||
"findLsb" => Mf::FindLsb,
|
||||
"findMsb" => Mf::FindMsb,
|
||||
// data packing
|
||||
"pack4x8snorm" => Mf::Pack4x8snorm,
|
||||
"pack4x8unorm" => Mf::Pack4x8unorm,
|
||||
|
||||
@@ -926,6 +926,8 @@ pub enum MathFunction {
|
||||
ReverseBits,
|
||||
ExtractBits,
|
||||
InsertBits,
|
||||
FindLsb,
|
||||
FindMsb,
|
||||
// data packing
|
||||
Pack4x8snorm,
|
||||
Pack4x8unorm,
|
||||
|
||||
@@ -262,6 +262,8 @@ impl super::MathFunction {
|
||||
Self::ReverseBits => 1,
|
||||
Self::ExtractBits => 3,
|
||||
Self::InsertBits => 4,
|
||||
Self::FindLsb => 1,
|
||||
Self::FindMsb => 1,
|
||||
// data packing
|
||||
Self::Pack4x8snorm => 1,
|
||||
Self::Pack4x8unorm => 1,
|
||||
|
||||
@@ -803,6 +803,16 @@ impl<'a> ResolveContext<'a> {
|
||||
Mf::ReverseBits |
|
||||
Mf::ExtractBits |
|
||||
Mf::InsertBits => res_arg.clone(),
|
||||
Mf::FindLsb |
|
||||
Mf::FindMsb => match *res_arg.inner_with(types) {
|
||||
Ti::Scalar { kind: _, width } =>
|
||||
TypeResolution::Value(Ti::Scalar { kind: crate::ScalarKind::Sint, width }),
|
||||
Ti::Vector { size, kind: _, width } =>
|
||||
TypeResolution::Value(Ti::Vector { size, kind: crate::ScalarKind::Sint, width }),
|
||||
ref other => return Err(ResolveError::IncompatibleOperands(
|
||||
format!("{:?}({:?})", fun, other)
|
||||
)),
|
||||
},
|
||||
// data packing
|
||||
Mf::Pack4x8snorm |
|
||||
Mf::Pack4x8unorm |
|
||||
|
||||
@@ -1231,7 +1231,7 @@ impl super::Validator {
|
||||
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
|
||||
}
|
||||
}
|
||||
Mf::CountOneBits | Mf::ReverseBits => {
|
||||
Mf::CountOneBits | Mf::ReverseBits | Mf::FindLsb | Mf::FindMsb => {
|
||||
if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
|
||||
return Err(ExpressionError::WrongArgumentCount(fun));
|
||||
}
|
||||
|
||||
@@ -36,4 +36,8 @@ fn main() {
|
||||
u2 = extractBits(u2, 5u, 10u);
|
||||
u3 = extractBits(u3, 5u, 10u);
|
||||
u4 = extractBits(u4, 5u, 10u);
|
||||
i = findLsb(i);
|
||||
i2 = findLsb(u2);
|
||||
i3 = findMsb(i3);
|
||||
i = findMsb(u);
|
||||
}
|
||||
|
||||
@@ -37,4 +37,20 @@ void main() {
|
||||
u2 = bitfieldExtract(u2, 5, 10);
|
||||
u3 = bitfieldExtract(u3, 5, 10);
|
||||
u4 = bitfieldExtract(u4, 5, 10);
|
||||
i = findLSB(i);
|
||||
i2 = findLSB(i2);
|
||||
i3 = findLSB(i3);
|
||||
i4 = findLSB(i4);
|
||||
i = findLSB(u);
|
||||
i2 = findLSB(u2);
|
||||
i3 = findLSB(u3);
|
||||
i4 = findLSB(u4);
|
||||
i = findMSB(i);
|
||||
i2 = findMSB(i2);
|
||||
i3 = findMSB(i3);
|
||||
i4 = findMSB(i4);
|
||||
i = findMSB(u);
|
||||
i2 = findMSB(u2);
|
||||
i3 = findMSB(u3);
|
||||
i4 = findMSB(u4);
|
||||
}
|
||||
@@ -85,6 +85,14 @@ void main() {
|
||||
u3_ = bitfieldExtract(_e112, int(5u), int(10u));
|
||||
uvec4 _e116 = u4_;
|
||||
u4_ = bitfieldExtract(_e116, int(5u), int(10u));
|
||||
int _e120 = i;
|
||||
i = findLSB(_e120);
|
||||
uvec2 _e122 = u2_;
|
||||
i2_ = findLSB(_e122);
|
||||
ivec3 _e124 = i3_;
|
||||
i3_ = findMSB(_e124);
|
||||
uint _e126 = u;
|
||||
i = findMSB(_e126);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -83,5 +83,13 @@ kernel void main_(
|
||||
u3_ = metal::extract_bits(_e112, 5u, 10u);
|
||||
metal::uint4 _e116 = u4_;
|
||||
u4_ = metal::extract_bits(_e116, 5u, 10u);
|
||||
int _e120 = i;
|
||||
i = (((1 + int(metal::ctz(_e120))) % 33) - 1);
|
||||
metal::uint2 _e122 = u2_;
|
||||
i2_ = (((1 + int2(metal::ctz(_e122))) % 33) - 1);
|
||||
metal::int3 _e124 = i3_;
|
||||
i3_ = (((1 + int3(metal::clz(_e124))) % 33) - 1);
|
||||
metal::uint _e126 = u;
|
||||
i = (((1 + int(metal::clz(_e126))) % 33) - 1);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
; SPIR-V
|
||||
; Version: 1.1
|
||||
; Generator: rspirv
|
||||
; Bound: 111
|
||||
; Bound: 119
|
||||
OpCapability Shader
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
@@ -151,5 +151,17 @@ OpStore %31 %108
|
||||
%109 = OpLoad %16 %33
|
||||
%110 = OpBitFieldUExtract %16 %109 %9 %10
|
||||
OpStore %33 %110
|
||||
%111 = OpLoad %4 %19
|
||||
%112 = OpExtInst %4 %1 FindILsb %111
|
||||
OpStore %19 %112
|
||||
%113 = OpLoad %14 %29
|
||||
%114 = OpExtInst %11 %1 FindILsb %113
|
||||
OpStore %21 %114
|
||||
%115 = OpLoad %12 %23
|
||||
%116 = OpExtInst %12 %1 FindSMsb %115
|
||||
OpStore %23 %116
|
||||
%117 = OpLoad %6 %27
|
||||
%118 = OpExtInst %4 %1 FindUMsb %117
|
||||
OpStore %19 %118
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
@@ -79,5 +79,13 @@ fn main() {
|
||||
u3_ = extractBits(_e112, 5u, 10u);
|
||||
let _e116 = u4_;
|
||||
u4_ = extractBits(_e116, 5u, 10u);
|
||||
let _e120 = i;
|
||||
i = findLsb(_e120);
|
||||
let _e122 = u2_;
|
||||
i2_ = findLsb(_e122);
|
||||
let _e124 = i3_;
|
||||
i3_ = findMsb(_e124);
|
||||
let _e126 = u;
|
||||
i = findMsb(_e126);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -70,6 +70,38 @@ fn main_1() {
|
||||
u3_ = extractBits(_e207, u32(5), u32(10));
|
||||
let _e216 = u4_;
|
||||
u4_ = extractBits(_e216, u32(5), u32(10));
|
||||
let _e223 = i;
|
||||
i = findLsb(_e223);
|
||||
let _e226 = i2_;
|
||||
i2_ = findLsb(_e226);
|
||||
let _e229 = i3_;
|
||||
i3_ = findLsb(_e229);
|
||||
let _e232 = i4_;
|
||||
i4_ = findLsb(_e232);
|
||||
let _e235 = u;
|
||||
i = findLsb(_e235);
|
||||
let _e238 = u2_;
|
||||
i2_ = findLsb(_e238);
|
||||
let _e241 = u3_;
|
||||
i3_ = findLsb(_e241);
|
||||
let _e244 = u4_;
|
||||
i4_ = findLsb(_e244);
|
||||
let _e247 = i;
|
||||
i = findMsb(_e247);
|
||||
let _e250 = i2_;
|
||||
i2_ = findMsb(_e250);
|
||||
let _e253 = i3_;
|
||||
i3_ = findMsb(_e253);
|
||||
let _e256 = i4_;
|
||||
i4_ = findMsb(_e256);
|
||||
let _e259 = u;
|
||||
i = findMsb(_e259);
|
||||
let _e262 = u2_;
|
||||
i2_ = findMsb(_e262);
|
||||
let _e265 = u3_;
|
||||
i3_ = findMsb(_e265);
|
||||
let _e268 = u4_;
|
||||
i4_ = findMsb(_e268);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user