diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index c645ccf9ad..de0d83febf 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2808,7 +2808,7 @@ impl<'a, W: Write> Writer<'a, W> { // we might need to cast to unsigned integers since // GLSL's findLSB / findMSB always return signed integers let need_extra_paren = { - (fun == Mf::FindLsb || fun == Mf::FindMsb) + (fun == Mf::FindLsb || fun == Mf::FindMsb || fun == Mf::CountOneBits) && match *ctx.info[arg].ty.inner_with(&self.module.types) { crate::TypeInner::Scalar { kind: crate::ScalarKind::Uint, diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 1bce7d2719..c4280a7408 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -798,10 +798,18 @@ impl<'w> BlockContext<'w> { arg0_id, )), Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant), - Mf::ReverseBits | Mf::CountOneBits => { - log::error!("unimplemented math function {:?}", fun); - return Err(Error::FeatureNotImplemented("math function")); - } + Mf::ReverseBits => MathOp::Custom(Instruction::unary( + spirv::Op::BitReverse, + result_type_id, + id, + arg0_id, + )), + Mf::CountOneBits => MathOp::Custom(Instruction::unary( + spirv::Op::BitCount, + result_type_id, + id, + arg0_id, + )), Mf::ExtractBits => { let op = match arg_scalar_kind { Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract, diff --git a/tests/in/bits.wgsl b/tests/in/bits.wgsl index 1c78dae201..1e05d81f64 100644 --- a/tests/in/bits.wgsl +++ b/tests/in/bits.wgsl @@ -40,4 +40,20 @@ fn main() { u2 = firstTrailingBit(u2); i3 = firstLeadingBit(i3); u = firstLeadingBit(u); + i = countOneBits(i); + i2 = countOneBits(i2); + i3 = countOneBits(i3); + i4 = countOneBits(i4); + u = countOneBits(u); + u2 = countOneBits(u2); + u3 = countOneBits(u3); + u4 = countOneBits(u4); + i = reverseBits(i); + i2 = reverseBits(i2); + i3 = reverseBits(i3); + i4 = reverseBits(i4); + u = reverseBits(u); + u2 = reverseBits(u2); + u3 = reverseBits(u3); + u4 = reverseBits(u4); } diff --git a/tests/out/glsl/bits.main.Compute.glsl b/tests/out/glsl/bits.main.Compute.glsl index 3166d01363..0cdc8906a7 100644 --- a/tests/out/glsl/bits.main.Compute.glsl +++ b/tests/out/glsl/bits.main.Compute.glsl @@ -93,6 +93,38 @@ void main() { i3_ = findMSB(_e124); uint _e126 = u; u = uint(findMSB(_e126)); + int _e128 = i; + i = bitCount(_e128); + ivec2 _e130 = i2_; + i2_ = bitCount(_e130); + ivec3 _e132 = i3_; + i3_ = bitCount(_e132); + ivec4 _e134 = i4_; + i4_ = bitCount(_e134); + uint _e136 = u; + u = uint(bitCount(_e136)); + uvec2 _e138 = u2_; + u2_ = uvec2(bitCount(_e138)); + uvec3 _e140 = u3_; + u3_ = uvec3(bitCount(_e140)); + uvec4 _e142 = u4_; + u4_ = uvec4(bitCount(_e142)); + int _e144 = i; + i = bitfieldReverse(_e144); + ivec2 _e146 = i2_; + i2_ = bitfieldReverse(_e146); + ivec3 _e148 = i3_; + i3_ = bitfieldReverse(_e148); + ivec4 _e150 = i4_; + i4_ = bitfieldReverse(_e150); + uint _e152 = u; + u = bitfieldReverse(_e152); + uvec2 _e154 = u2_; + u2_ = bitfieldReverse(_e154); + uvec3 _e156 = u3_; + u3_ = bitfieldReverse(_e156); + uvec4 _e158 = u4_; + u4_ = bitfieldReverse(_e158); return; } diff --git a/tests/out/msl/bits.msl b/tests/out/msl/bits.msl index 88503320bb..aa35b8c17d 100644 --- a/tests/out/msl/bits.msl +++ b/tests/out/msl/bits.msl @@ -93,5 +93,37 @@ kernel void main_( i3_ = (((metal::clz(_e124) + 1) % 33) - 1); uint _e126 = u; u = (((metal::clz(_e126) + 1) % 33) - 1); + int _e128 = i; + i = metal::popcount(_e128); + metal::int2 _e130 = i2_; + i2_ = metal::popcount(_e130); + metal::int3 _e132 = i3_; + i3_ = metal::popcount(_e132); + metal::int4 _e134 = i4_; + i4_ = metal::popcount(_e134); + uint _e136 = u; + u = metal::popcount(_e136); + metal::uint2 _e138 = u2_; + u2_ = metal::popcount(_e138); + metal::uint3 _e140 = u3_; + u3_ = metal::popcount(_e140); + metal::uint4 _e142 = u4_; + u4_ = metal::popcount(_e142); + int _e144 = i; + i = metal::reverse_bits(_e144); + metal::int2 _e146 = i2_; + i2_ = metal::reverse_bits(_e146); + metal::int3 _e148 = i3_; + i3_ = metal::reverse_bits(_e148); + metal::int4 _e150 = i4_; + i4_ = metal::reverse_bits(_e150); + uint _e152 = u; + u = metal::reverse_bits(_e152); + metal::uint2 _e154 = u2_; + u2_ = metal::reverse_bits(_e154); + metal::uint3 _e156 = u3_; + u3_ = metal::reverse_bits(_e156); + metal::uint4 _e158 = u4_; + u4_ = metal::reverse_bits(_e158); return; } diff --git a/tests/out/spv/bits.spvasm b/tests/out/spv/bits.spvasm index 6d224eee20..ddd71d0b77 100644 --- a/tests/out/spv/bits.spvasm +++ b/tests/out/spv/bits.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 119 +; Bound: 151 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -163,5 +163,53 @@ OpStore %23 %116 %117 = OpLoad %6 %27 %118 = OpExtInst %6 %1 FindUMsb %117 OpStore %27 %118 +%119 = OpLoad %4 %19 +%120 = OpBitCount %4 %119 +OpStore %19 %120 +%121 = OpLoad %11 %21 +%122 = OpBitCount %11 %121 +OpStore %21 %122 +%123 = OpLoad %12 %23 +%124 = OpBitCount %12 %123 +OpStore %23 %124 +%125 = OpLoad %13 %25 +%126 = OpBitCount %13 %125 +OpStore %25 %126 +%127 = OpLoad %6 %27 +%128 = OpBitCount %6 %127 +OpStore %27 %128 +%129 = OpLoad %14 %29 +%130 = OpBitCount %14 %129 +OpStore %29 %130 +%131 = OpLoad %15 %31 +%132 = OpBitCount %15 %131 +OpStore %31 %132 +%133 = OpLoad %16 %33 +%134 = OpBitCount %16 %133 +OpStore %33 %134 +%135 = OpLoad %4 %19 +%136 = OpBitReverse %4 %135 +OpStore %19 %136 +%137 = OpLoad %11 %21 +%138 = OpBitReverse %11 %137 +OpStore %21 %138 +%139 = OpLoad %12 %23 +%140 = OpBitReverse %12 %139 +OpStore %23 %140 +%141 = OpLoad %13 %25 +%142 = OpBitReverse %13 %141 +OpStore %25 %142 +%143 = OpLoad %6 %27 +%144 = OpBitReverse %6 %143 +OpStore %27 %144 +%145 = OpLoad %14 %29 +%146 = OpBitReverse %14 %145 +OpStore %29 %146 +%147 = OpLoad %15 %31 +%148 = OpBitReverse %15 %147 +OpStore %31 %148 +%149 = OpLoad %16 %33 +%150 = OpBitReverse %16 %149 +OpStore %33 %150 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/bits.wgsl b/tests/out/wgsl/bits.wgsl index 2bdaf6e9ff..ccaebff41e 100644 --- a/tests/out/wgsl/bits.wgsl +++ b/tests/out/wgsl/bits.wgsl @@ -87,5 +87,37 @@ fn main() { i3_ = firstLeadingBit(_e124); let _e126 = u; u = firstLeadingBit(_e126); + let _e128 = i; + i = countOneBits(_e128); + let _e130 = i2_; + i2_ = countOneBits(_e130); + let _e132 = i3_; + i3_ = countOneBits(_e132); + let _e134 = i4_; + i4_ = countOneBits(_e134); + let _e136 = u; + u = countOneBits(_e136); + let _e138 = u2_; + u2_ = countOneBits(_e138); + let _e140 = u3_; + u3_ = countOneBits(_e140); + let _e142 = u4_; + u4_ = countOneBits(_e142); + let _e144 = i; + i = reverseBits(_e144); + let _e146 = i2_; + i2_ = reverseBits(_e146); + let _e148 = i3_; + i3_ = reverseBits(_e148); + let _e150 = i4_; + i4_ = reverseBits(_e150); + let _e152 = u; + u = reverseBits(_e152); + let _e154 = u2_; + u2_ = reverseBits(_e154); + let _e156 = u3_; + u3_ = reverseBits(_e156); + let _e158 = u4_; + u4_ = reverseBits(_e158); return; }