diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index d424c11b20..a4635ac2e7 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1721,9 +1721,49 @@ impl Writer { self.put_expression(arg, context, true)?; write!(self.out, ") + 1) % 33) - 1)")?; } else if fun == Mf::FindMsb { - write!(self.out, "((({NAMESPACE}::clz(")?; + let inner = context.resolve_type(arg); + + write!(self.out, "{NAMESPACE}::select(31 - {NAMESPACE}::clz(")?; + + if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() { + write!(self.out, "{NAMESPACE}::select(")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", ~")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " < 0)")?; + } else { + self.put_expression(arg, context, true)?; + } + + write!(self.out, "), ")?; + + // or metal will complain that select is ambiguous + match *inner { + crate::TypeInner::Vector { size, kind, .. } => { + let size = back::vector_size_str(size); + if let crate::ScalarKind::Sint = kind { + write!(self.out, "int{size}")?; + } else { + write!(self.out, "uint{size}")?; + } + } + crate::TypeInner::Scalar { kind, .. } => { + if let crate::ScalarKind::Sint = kind { + write!(self.out, "int")?; + } else { + write!(self.out, "uint")?; + } + } + _ => (), + } + + write!(self.out, "(-1), ")?; self.put_expression(arg, context, true)?; - write!(self.out, ") + 1) % 33) - 1)")? + write!(self.out, " == 0 || ")?; + self.put_expression(arg, context, true)?; + write!(self.out, " == -1)")?; } else if fun == Mf::Unpack2x16float { write!(self.out, "float2(as_type(")?; self.put_expression(arg, context, false)?; @@ -2275,42 +2315,42 @@ impl Writer { ) { use crate::Expression; self.need_bake_expressions.clear(); - for expr in func.expressions.iter() { + for (expr_handle, expr) in func.expressions.iter() { // Expressions whose reference count is above the // threshold should always be stored in temporaries. - let expr_info = &info[expr.0]; - let min_ref_count = func.expressions[expr.0].bake_ref_count(); + let expr_info = &info[expr_handle]; + let min_ref_count = func.expressions[expr_handle].bake_ref_count(); if min_ref_count <= expr_info.ref_count { - self.need_bake_expressions.insert(expr.0); + self.need_bake_expressions.insert(expr_handle); } - // WGSL's `dot` function works on any `vecN` type, but Metal's only - // works on floating-point vectors, so we emit inline code for - // integer vector `dot` calls. But that code uses each argument `N` - // times, once for each component (see `put_dot_product`), so to - // avoid duplicated evaluation, we must bake integer operands. - if let ( - fun_handle, - &Expression::Math { - fun: crate::MathFunction::Dot, - arg, - arg1, - .. - }, - ) = expr - { - use crate::TypeInner; - // check what kind of product this is depending - // on the resolve type of the Dot function itself - let inner = context.resolve_type(fun_handle); - if let TypeInner::Scalar { kind, .. } = *inner { - match kind { - crate::ScalarKind::Sint | crate::ScalarKind::Uint => { - self.need_bake_expressions.insert(arg); - self.need_bake_expressions.insert(arg1.unwrap()); + if let Expression::Math { fun, arg, arg1, .. } = *expr { + match fun { + crate::MathFunction::Dot => { + // WGSL's `dot` function works on any `vecN` type, but Metal's only + // works on floating-point vectors, so we emit inline code for + // integer vector `dot` calls. But that code uses each argument `N` + // times, once for each component (see `put_dot_product`), so to + // avoid duplicated evaluation, we must bake integer operands. + + use crate::TypeInner; + // check what kind of product this is depending + // on the resolve type of the Dot function itself + let inner = context.resolve_type(expr_handle); + if let TypeInner::Scalar { kind, .. } = *inner { + match kind { + crate::ScalarKind::Sint | crate::ScalarKind::Uint => { + self.need_bake_expressions.insert(arg); + self.need_bake_expressions.insert(arg1.unwrap()); + } + _ => {} + } } - _ => {} } + crate::MathFunction::FindMsb => { + self.need_bake_expressions.insert(arg); + } + _ => {} } } } diff --git a/tests/in/bits.wgsl b/tests/in/bits.wgsl index 1e05d81f64..549ff08ec7 100644 --- a/tests/in/bits.wgsl +++ b/tests/in/bits.wgsl @@ -39,6 +39,8 @@ fn main() { i = firstTrailingBit(i); u2 = firstTrailingBit(u2); i3 = firstLeadingBit(i3); + u3 = firstLeadingBit(u3); + i = firstLeadingBit(i); u = firstLeadingBit(u); i = countOneBits(i); i2 = countOneBits(i2); diff --git a/tests/out/glsl/bits.main.Compute.glsl b/tests/out/glsl/bits.main.Compute.glsl index 504fb7c94f..1c17638faf 100644 --- a/tests/out/glsl/bits.main.Compute.glsl +++ b/tests/out/glsl/bits.main.Compute.glsl @@ -93,40 +93,44 @@ void main() { u2_ = uvec2(findLSB(_e122)); ivec3 _e124 = i3_; i3_ = findMSB(_e124); - uint _e126 = u; - u = uint(findMSB(_e126)); + uvec3 _e126 = u3_; + u3_ = uvec3(findMSB(_e126)); int _e128 = i; - i = bitCount(_e128); - ivec2 _e130 = i2_; - i2_ = bitCount(_e130); - ivec3 _e132 = i3_; - i3_ = bitCount(_e132); - ivec4 _e134 = i4_; - i4_ = bitCount(_e134); - uint _e136 = u; - u = uint(bitCount(_e136)); - uvec2 _e138 = u2_; - u2_ = uvec2(bitCount(_e138)); - uvec3 _e140 = u3_; - u3_ = uvec3(bitCount(_e140)); - uvec4 _e142 = u4_; - u4_ = uvec4(bitCount(_e142)); - int _e144 = i; - i = bitfieldReverse(_e144); - ivec2 _e146 = i2_; - i2_ = bitfieldReverse(_e146); - ivec3 _e148 = i3_; - i3_ = bitfieldReverse(_e148); - ivec4 _e150 = i4_; - i4_ = bitfieldReverse(_e150); - uint _e152 = u; - u = bitfieldReverse(_e152); - uvec2 _e154 = u2_; - u2_ = bitfieldReverse(_e154); - uvec3 _e156 = u3_; - u3_ = bitfieldReverse(_e156); - uvec4 _e158 = u4_; - u4_ = bitfieldReverse(_e158); + i = findMSB(_e128); + uint _e130 = u; + u = uint(findMSB(_e130)); + int _e132 = i; + i = bitCount(_e132); + ivec2 _e134 = i2_; + i2_ = bitCount(_e134); + ivec3 _e136 = i3_; + i3_ = bitCount(_e136); + ivec4 _e138 = i4_; + i4_ = bitCount(_e138); + uint _e140 = u; + u = uint(bitCount(_e140)); + uvec2 _e142 = u2_; + u2_ = uvec2(bitCount(_e142)); + uvec3 _e144 = u3_; + u3_ = uvec3(bitCount(_e144)); + uvec4 _e146 = u4_; + u4_ = uvec4(bitCount(_e146)); + int _e148 = i; + i = bitfieldReverse(_e148); + ivec2 _e150 = i2_; + i2_ = bitfieldReverse(_e150); + ivec3 _e152 = i3_; + i3_ = bitfieldReverse(_e152); + ivec4 _e154 = i4_; + i4_ = bitfieldReverse(_e154); + uint _e156 = u; + u = bitfieldReverse(_e156); + uvec2 _e158 = u2_; + u2_ = bitfieldReverse(_e158); + uvec3 _e160 = u3_; + u3_ = bitfieldReverse(_e160); + uvec4 _e162 = u4_; + u4_ = bitfieldReverse(_e162); return; } diff --git a/tests/out/msl/bits.msl b/tests/out/msl/bits.msl index 2a95996c70..2a00b6b843 100644 --- a/tests/out/msl/bits.msl +++ b/tests/out/msl/bits.msl @@ -92,40 +92,44 @@ kernel void main_( metal::uint2 _e122 = u2_; u2_ = (((metal::ctz(_e122) + 1) % 33) - 1); metal::int3 _e124 = i3_; - i3_ = (((metal::clz(_e124) + 1) % 33) - 1); - uint _e126 = u; - u = (((metal::clz(_e126) + 1) % 33) - 1); + i3_ = metal::select(31 - metal::clz(metal::select(_e124, ~_e124, _e124 < 0)), int3(-1), _e124 == 0 || _e124 == -1); + metal::uint3 _e126 = u3_; + u3_ = metal::select(31 - metal::clz(_e126), uint3(-1), _e126 == 0 || _e126 == -1); int _e128 = i; - i = metal::popcount(_e128); - metal::int2 _e130 = i2_; - i2_ = metal::popcount(_e130); - metal::int3 _e132 = i3_; - i3_ = metal::popcount(_e132); - metal::int4 _e134 = i4_; - i4_ = metal::popcount(_e134); - uint _e136 = u; - u = metal::popcount(_e136); - metal::uint2 _e138 = u2_; - u2_ = metal::popcount(_e138); - metal::uint3 _e140 = u3_; - u3_ = metal::popcount(_e140); - metal::uint4 _e142 = u4_; - u4_ = metal::popcount(_e142); - int _e144 = i; - i = metal::reverse_bits(_e144); - metal::int2 _e146 = i2_; - i2_ = metal::reverse_bits(_e146); - metal::int3 _e148 = i3_; - i3_ = metal::reverse_bits(_e148); - metal::int4 _e150 = i4_; - i4_ = metal::reverse_bits(_e150); - uint _e152 = u; - u = metal::reverse_bits(_e152); - metal::uint2 _e154 = u2_; - u2_ = metal::reverse_bits(_e154); - metal::uint3 _e156 = u3_; - u3_ = metal::reverse_bits(_e156); - metal::uint4 _e158 = u4_; - u4_ = metal::reverse_bits(_e158); + i = metal::select(31 - metal::clz(metal::select(_e128, ~_e128, _e128 < 0)), int(-1), _e128 == 0 || _e128 == -1); + uint _e130 = u; + u = metal::select(31 - metal::clz(_e130), uint(-1), _e130 == 0 || _e130 == -1); + int _e132 = i; + i = metal::popcount(_e132); + metal::int2 _e134 = i2_; + i2_ = metal::popcount(_e134); + metal::int3 _e136 = i3_; + i3_ = metal::popcount(_e136); + metal::int4 _e138 = i4_; + i4_ = metal::popcount(_e138); + uint _e140 = u; + u = metal::popcount(_e140); + metal::uint2 _e142 = u2_; + u2_ = metal::popcount(_e142); + metal::uint3 _e144 = u3_; + u3_ = metal::popcount(_e144); + metal::uint4 _e146 = u4_; + u4_ = metal::popcount(_e146); + int _e148 = i; + i = metal::reverse_bits(_e148); + metal::int2 _e150 = i2_; + i2_ = metal::reverse_bits(_e150); + metal::int3 _e152 = i3_; + i3_ = metal::reverse_bits(_e152); + metal::int4 _e154 = i4_; + i4_ = metal::reverse_bits(_e154); + uint _e156 = u; + u = metal::reverse_bits(_e156); + metal::uint2 _e158 = u2_; + u2_ = metal::reverse_bits(_e158); + metal::uint3 _e160 = u3_; + u3_ = metal::reverse_bits(_e160); + metal::uint4 _e162 = u4_; + u4_ = metal::reverse_bits(_e162); return; } diff --git a/tests/out/msl/math-functions.msl b/tests/out/msl/math-functions.msl index 0838a338c5..3fdb7b75a5 100644 --- a/tests/out/msl/math-functions.msl +++ b/tests/out/msl/math-functions.msl @@ -16,7 +16,8 @@ vertex void main_( metal::float4 e = metal::saturate(v); metal::float4 g = metal::refract(v, v, 1.0); int const_dot = ( + const_type_1_.x * const_type_1_.x + const_type_1_.y * const_type_1_.y); - uint first_leading_bit_abs = (((metal::clz(metal::abs(0u)) + 1) % 33) - 1); + uint _e13 = metal::abs(0u); + uint first_leading_bit_abs = metal::select(31 - metal::clz(_e13), uint(-1), _e13 == 0 || _e13 == -1); int clz_a = metal::clz(-1); uint clz_b = metal::clz(1u); metal::int2 clz_c = metal::clz(metal::int2(-1)); diff --git a/tests/out/spv/bits.spvasm b/tests/out/spv/bits.spvasm index b8c17d8709..750a2545b7 100644 --- a/tests/out/spv/bits.spvasm +++ b/tests/out/spv/bits.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 161 +; Bound: 165 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -172,56 +172,62 @@ OpStore %34 %124 %125 = OpLoad %12 %25 %126 = OpExtInst %12 %1 FindSMsb %125 OpStore %25 %126 -%127 = OpLoad %6 %31 -%128 = OpExtInst %6 %1 FindUMsb %127 -OpStore %31 %128 +%127 = OpLoad %15 %37 +%128 = OpExtInst %15 %1 FindUMsb %127 +OpStore %37 %128 %129 = OpLoad %4 %19 -%130 = OpBitCount %4 %129 +%130 = OpExtInst %4 %1 FindSMsb %129 OpStore %19 %130 -%131 = OpLoad %11 %22 -%132 = OpBitCount %11 %131 -OpStore %22 %132 -%133 = OpLoad %12 %25 -%134 = OpBitCount %12 %133 -OpStore %25 %134 -%135 = OpLoad %13 %28 -%136 = OpBitCount %13 %135 -OpStore %28 %136 -%137 = OpLoad %6 %31 -%138 = OpBitCount %6 %137 -OpStore %31 %138 -%139 = OpLoad %14 %34 -%140 = OpBitCount %14 %139 -OpStore %34 %140 -%141 = OpLoad %15 %37 -%142 = OpBitCount %15 %141 -OpStore %37 %142 -%143 = OpLoad %16 %40 -%144 = OpBitCount %16 %143 -OpStore %40 %144 -%145 = OpLoad %4 %19 -%146 = OpBitReverse %4 %145 -OpStore %19 %146 -%147 = OpLoad %11 %22 -%148 = OpBitReverse %11 %147 -OpStore %22 %148 -%149 = OpLoad %12 %25 -%150 = OpBitReverse %12 %149 -OpStore %25 %150 -%151 = OpLoad %13 %28 -%152 = OpBitReverse %13 %151 -OpStore %28 %152 -%153 = OpLoad %6 %31 -%154 = OpBitReverse %6 %153 -OpStore %31 %154 -%155 = OpLoad %14 %34 -%156 = OpBitReverse %14 %155 -OpStore %34 %156 -%157 = OpLoad %15 %37 -%158 = OpBitReverse %15 %157 -OpStore %37 %158 -%159 = OpLoad %16 %40 -%160 = OpBitReverse %16 %159 -OpStore %40 %160 +%131 = OpLoad %6 %31 +%132 = OpExtInst %6 %1 FindUMsb %131 +OpStore %31 %132 +%133 = OpLoad %4 %19 +%134 = OpBitCount %4 %133 +OpStore %19 %134 +%135 = OpLoad %11 %22 +%136 = OpBitCount %11 %135 +OpStore %22 %136 +%137 = OpLoad %12 %25 +%138 = OpBitCount %12 %137 +OpStore %25 %138 +%139 = OpLoad %13 %28 +%140 = OpBitCount %13 %139 +OpStore %28 %140 +%141 = OpLoad %6 %31 +%142 = OpBitCount %6 %141 +OpStore %31 %142 +%143 = OpLoad %14 %34 +%144 = OpBitCount %14 %143 +OpStore %34 %144 +%145 = OpLoad %15 %37 +%146 = OpBitCount %15 %145 +OpStore %37 %146 +%147 = OpLoad %16 %40 +%148 = OpBitCount %16 %147 +OpStore %40 %148 +%149 = OpLoad %4 %19 +%150 = OpBitReverse %4 %149 +OpStore %19 %150 +%151 = OpLoad %11 %22 +%152 = OpBitReverse %11 %151 +OpStore %22 %152 +%153 = OpLoad %12 %25 +%154 = OpBitReverse %12 %153 +OpStore %25 %154 +%155 = OpLoad %13 %28 +%156 = OpBitReverse %13 %155 +OpStore %28 %156 +%157 = OpLoad %6 %31 +%158 = OpBitReverse %6 %157 +OpStore %31 %158 +%159 = OpLoad %14 %34 +%160 = OpBitReverse %14 %159 +OpStore %34 %160 +%161 = OpLoad %15 %37 +%162 = OpBitReverse %15 %161 +OpStore %37 %162 +%163 = OpLoad %16 %40 +%164 = OpBitReverse %16 %163 +OpStore %40 %164 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/bits.wgsl b/tests/out/wgsl/bits.wgsl index 88876ec05c..c203ac94d5 100644 --- a/tests/out/wgsl/bits.wgsl +++ b/tests/out/wgsl/bits.wgsl @@ -87,39 +87,43 @@ fn main() { u2_ = firstTrailingBit(_e122); let _e124 = i3_; i3_ = firstLeadingBit(_e124); - let _e126 = u; - u = firstLeadingBit(_e126); + let _e126 = u3_; + u3_ = firstLeadingBit(_e126); let _e128 = i; - i = countOneBits(_e128); - let _e130 = i2_; - i2_ = countOneBits(_e130); - let _e132 = i3_; - i3_ = countOneBits(_e132); - let _e134 = i4_; - i4_ = countOneBits(_e134); - let _e136 = u; - u = countOneBits(_e136); - let _e138 = u2_; - u2_ = countOneBits(_e138); - let _e140 = u3_; - u3_ = countOneBits(_e140); - let _e142 = u4_; - u4_ = countOneBits(_e142); - let _e144 = i; - i = reverseBits(_e144); - let _e146 = i2_; - i2_ = reverseBits(_e146); - let _e148 = i3_; - i3_ = reverseBits(_e148); - let _e150 = i4_; - i4_ = reverseBits(_e150); - let _e152 = u; - u = reverseBits(_e152); - let _e154 = u2_; - u2_ = reverseBits(_e154); - let _e156 = u3_; - u3_ = reverseBits(_e156); - let _e158 = u4_; - u4_ = reverseBits(_e158); + i = firstLeadingBit(_e128); + let _e130 = u; + u = firstLeadingBit(_e130); + let _e132 = i; + i = countOneBits(_e132); + let _e134 = i2_; + i2_ = countOneBits(_e134); + let _e136 = i3_; + i3_ = countOneBits(_e136); + let _e138 = i4_; + i4_ = countOneBits(_e138); + let _e140 = u; + u = countOneBits(_e140); + let _e142 = u2_; + u2_ = countOneBits(_e142); + let _e144 = u3_; + u3_ = countOneBits(_e144); + let _e146 = u4_; + u4_ = countOneBits(_e146); + let _e148 = i; + i = reverseBits(_e148); + let _e150 = i2_; + i2_ = reverseBits(_e150); + let _e152 = i3_; + i3_ = reverseBits(_e152); + let _e154 = i4_; + i4_ = reverseBits(_e154); + let _e156 = u; + u = reverseBits(_e156); + let _e158 = u2_; + u2_ = reverseBits(_e158); + let _e160 = u3_; + u3_ = reverseBits(_e160); + let _e162 = u4_; + u4_ = reverseBits(_e162); return; }