[msl-out] Update firstLeadingBit for signed integers (#2235)

The prior code only supported unsigned integers.

Also fixes #2236.

Co-authored-by: Jim Blandy <jimb@red-bean.com>
This commit is contained in:
Evan Mark Hopkins
2023-02-02 12:47:17 -05:00
committed by GitHub
parent fe851fb008
commit fb2d438dbd
7 changed files with 243 additions and 182 deletions

View File

@@ -1721,9 +1721,49 @@ impl<W: Write> Writer<W> {
self.put_expression(arg, context, true)?;
write!(self.out, ") + 1) % 33) - 1)")?;
} else if fun == Mf::FindMsb {
write!(self.out, "((({NAMESPACE}::clz(")?;
let inner = context.resolve_type(arg);
write!(self.out, "{NAMESPACE}::select(31 - {NAMESPACE}::clz(")?;
if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
write!(self.out, "{NAMESPACE}::select(")?;
self.put_expression(arg, context, true)?;
write!(self.out, ", ~")?;
self.put_expression(arg, context, true)?;
write!(self.out, ", ")?;
self.put_expression(arg, context, true)?;
write!(self.out, " < 0)")?;
} else {
self.put_expression(arg, context, true)?;
}
write!(self.out, "), ")?;
// or metal will complain that select is ambiguous
match *inner {
crate::TypeInner::Vector { size, kind, .. } => {
let size = back::vector_size_str(size);
if let crate::ScalarKind::Sint = kind {
write!(self.out, "int{size}")?;
} else {
write!(self.out, "uint{size}")?;
}
}
crate::TypeInner::Scalar { kind, .. } => {
if let crate::ScalarKind::Sint = kind {
write!(self.out, "int")?;
} else {
write!(self.out, "uint")?;
}
}
_ => (),
}
write!(self.out, "(-1), ")?;
self.put_expression(arg, context, true)?;
write!(self.out, ") + 1) % 33) - 1)")?
write!(self.out, " == 0 || ")?;
self.put_expression(arg, context, true)?;
write!(self.out, " == -1)")?;
} else if fun == Mf::Unpack2x16float {
write!(self.out, "float2(as_type<half2>(")?;
self.put_expression(arg, context, false)?;
@@ -2275,42 +2315,42 @@ impl<W: Write> Writer<W> {
) {
use crate::Expression;
self.need_bake_expressions.clear();
for expr in func.expressions.iter() {
for (expr_handle, expr) in func.expressions.iter() {
// Expressions whose reference count is above the
// threshold should always be stored in temporaries.
let expr_info = &info[expr.0];
let min_ref_count = func.expressions[expr.0].bake_ref_count();
let expr_info = &info[expr_handle];
let min_ref_count = func.expressions[expr_handle].bake_ref_count();
if min_ref_count <= expr_info.ref_count {
self.need_bake_expressions.insert(expr.0);
self.need_bake_expressions.insert(expr_handle);
}
// WGSL's `dot` function works on any `vecN` type, but Metal's only
// works on floating-point vectors, so we emit inline code for
// integer vector `dot` calls. But that code uses each argument `N`
// times, once for each component (see `put_dot_product`), so to
// avoid duplicated evaluation, we must bake integer operands.
if let (
fun_handle,
&Expression::Math {
fun: crate::MathFunction::Dot,
arg,
arg1,
..
},
) = expr
{
use crate::TypeInner;
// check what kind of product this is depending
// on the resolve type of the Dot function itself
let inner = context.resolve_type(fun_handle);
if let TypeInner::Scalar { kind, .. } = *inner {
match kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
if let Expression::Math { fun, arg, arg1, .. } = *expr {
match fun {
crate::MathFunction::Dot => {
// WGSL's `dot` function works on any `vecN` type, but Metal's only
// works on floating-point vectors, so we emit inline code for
// integer vector `dot` calls. But that code uses each argument `N`
// times, once for each component (see `put_dot_product`), so to
// avoid duplicated evaluation, we must bake integer operands.
use crate::TypeInner;
// check what kind of product this is depending
// on the resolve type of the Dot function itself
let inner = context.resolve_type(expr_handle);
if let TypeInner::Scalar { kind, .. } = *inner {
match kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
}
_ => {}
}
}
_ => {}
}
crate::MathFunction::FindMsb => {
self.need_bake_expressions.insert(arg);
}
_ => {}
}
}
}

View File

@@ -39,6 +39,8 @@ fn main() {
i = firstTrailingBit(i);
u2 = firstTrailingBit(u2);
i3 = firstLeadingBit(i3);
u3 = firstLeadingBit(u3);
i = firstLeadingBit(i);
u = firstLeadingBit(u);
i = countOneBits(i);
i2 = countOneBits(i2);

View File

@@ -93,40 +93,44 @@ void main() {
u2_ = uvec2(findLSB(_e122));
ivec3 _e124 = i3_;
i3_ = findMSB(_e124);
uint _e126 = u;
u = uint(findMSB(_e126));
uvec3 _e126 = u3_;
u3_ = uvec3(findMSB(_e126));
int _e128 = i;
i = bitCount(_e128);
ivec2 _e130 = i2_;
i2_ = bitCount(_e130);
ivec3 _e132 = i3_;
i3_ = bitCount(_e132);
ivec4 _e134 = i4_;
i4_ = bitCount(_e134);
uint _e136 = u;
u = uint(bitCount(_e136));
uvec2 _e138 = u2_;
u2_ = uvec2(bitCount(_e138));
uvec3 _e140 = u3_;
u3_ = uvec3(bitCount(_e140));
uvec4 _e142 = u4_;
u4_ = uvec4(bitCount(_e142));
int _e144 = i;
i = bitfieldReverse(_e144);
ivec2 _e146 = i2_;
i2_ = bitfieldReverse(_e146);
ivec3 _e148 = i3_;
i3_ = bitfieldReverse(_e148);
ivec4 _e150 = i4_;
i4_ = bitfieldReverse(_e150);
uint _e152 = u;
u = bitfieldReverse(_e152);
uvec2 _e154 = u2_;
u2_ = bitfieldReverse(_e154);
uvec3 _e156 = u3_;
u3_ = bitfieldReverse(_e156);
uvec4 _e158 = u4_;
u4_ = bitfieldReverse(_e158);
i = findMSB(_e128);
uint _e130 = u;
u = uint(findMSB(_e130));
int _e132 = i;
i = bitCount(_e132);
ivec2 _e134 = i2_;
i2_ = bitCount(_e134);
ivec3 _e136 = i3_;
i3_ = bitCount(_e136);
ivec4 _e138 = i4_;
i4_ = bitCount(_e138);
uint _e140 = u;
u = uint(bitCount(_e140));
uvec2 _e142 = u2_;
u2_ = uvec2(bitCount(_e142));
uvec3 _e144 = u3_;
u3_ = uvec3(bitCount(_e144));
uvec4 _e146 = u4_;
u4_ = uvec4(bitCount(_e146));
int _e148 = i;
i = bitfieldReverse(_e148);
ivec2 _e150 = i2_;
i2_ = bitfieldReverse(_e150);
ivec3 _e152 = i3_;
i3_ = bitfieldReverse(_e152);
ivec4 _e154 = i4_;
i4_ = bitfieldReverse(_e154);
uint _e156 = u;
u = bitfieldReverse(_e156);
uvec2 _e158 = u2_;
u2_ = bitfieldReverse(_e158);
uvec3 _e160 = u3_;
u3_ = bitfieldReverse(_e160);
uvec4 _e162 = u4_;
u4_ = bitfieldReverse(_e162);
return;
}

View File

@@ -92,40 +92,44 @@ kernel void main_(
metal::uint2 _e122 = u2_;
u2_ = (((metal::ctz(_e122) + 1) % 33) - 1);
metal::int3 _e124 = i3_;
i3_ = (((metal::clz(_e124) + 1) % 33) - 1);
uint _e126 = u;
u = (((metal::clz(_e126) + 1) % 33) - 1);
i3_ = metal::select(31 - metal::clz(metal::select(_e124, ~_e124, _e124 < 0)), int3(-1), _e124 == 0 || _e124 == -1);
metal::uint3 _e126 = u3_;
u3_ = metal::select(31 - metal::clz(_e126), uint3(-1), _e126 == 0 || _e126 == -1);
int _e128 = i;
i = metal::popcount(_e128);
metal::int2 _e130 = i2_;
i2_ = metal::popcount(_e130);
metal::int3 _e132 = i3_;
i3_ = metal::popcount(_e132);
metal::int4 _e134 = i4_;
i4_ = metal::popcount(_e134);
uint _e136 = u;
u = metal::popcount(_e136);
metal::uint2 _e138 = u2_;
u2_ = metal::popcount(_e138);
metal::uint3 _e140 = u3_;
u3_ = metal::popcount(_e140);
metal::uint4 _e142 = u4_;
u4_ = metal::popcount(_e142);
int _e144 = i;
i = metal::reverse_bits(_e144);
metal::int2 _e146 = i2_;
i2_ = metal::reverse_bits(_e146);
metal::int3 _e148 = i3_;
i3_ = metal::reverse_bits(_e148);
metal::int4 _e150 = i4_;
i4_ = metal::reverse_bits(_e150);
uint _e152 = u;
u = metal::reverse_bits(_e152);
metal::uint2 _e154 = u2_;
u2_ = metal::reverse_bits(_e154);
metal::uint3 _e156 = u3_;
u3_ = metal::reverse_bits(_e156);
metal::uint4 _e158 = u4_;
u4_ = metal::reverse_bits(_e158);
i = metal::select(31 - metal::clz(metal::select(_e128, ~_e128, _e128 < 0)), int(-1), _e128 == 0 || _e128 == -1);
uint _e130 = u;
u = metal::select(31 - metal::clz(_e130), uint(-1), _e130 == 0 || _e130 == -1);
int _e132 = i;
i = metal::popcount(_e132);
metal::int2 _e134 = i2_;
i2_ = metal::popcount(_e134);
metal::int3 _e136 = i3_;
i3_ = metal::popcount(_e136);
metal::int4 _e138 = i4_;
i4_ = metal::popcount(_e138);
uint _e140 = u;
u = metal::popcount(_e140);
metal::uint2 _e142 = u2_;
u2_ = metal::popcount(_e142);
metal::uint3 _e144 = u3_;
u3_ = metal::popcount(_e144);
metal::uint4 _e146 = u4_;
u4_ = metal::popcount(_e146);
int _e148 = i;
i = metal::reverse_bits(_e148);
metal::int2 _e150 = i2_;
i2_ = metal::reverse_bits(_e150);
metal::int3 _e152 = i3_;
i3_ = metal::reverse_bits(_e152);
metal::int4 _e154 = i4_;
i4_ = metal::reverse_bits(_e154);
uint _e156 = u;
u = metal::reverse_bits(_e156);
metal::uint2 _e158 = u2_;
u2_ = metal::reverse_bits(_e158);
metal::uint3 _e160 = u3_;
u3_ = metal::reverse_bits(_e160);
metal::uint4 _e162 = u4_;
u4_ = metal::reverse_bits(_e162);
return;
}

View File

@@ -16,7 +16,8 @@ vertex void main_(
metal::float4 e = metal::saturate(v);
metal::float4 g = metal::refract(v, v, 1.0);
int const_dot = ( + const_type_1_.x * const_type_1_.x + const_type_1_.y * const_type_1_.y);
uint first_leading_bit_abs = (((metal::clz(metal::abs(0u)) + 1) % 33) - 1);
uint _e13 = metal::abs(0u);
uint first_leading_bit_abs = metal::select(31 - metal::clz(_e13), uint(-1), _e13 == 0 || _e13 == -1);
int clz_a = metal::clz(-1);
uint clz_b = metal::clz(1u);
metal::int2 clz_c = metal::clz(metal::int2(-1));

View File

@@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 161
; Bound: 165
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
@@ -172,56 +172,62 @@ OpStore %34 %124
%125 = OpLoad %12 %25
%126 = OpExtInst %12 %1 FindSMsb %125
OpStore %25 %126
%127 = OpLoad %6 %31
%128 = OpExtInst %6 %1 FindUMsb %127
OpStore %31 %128
%127 = OpLoad %15 %37
%128 = OpExtInst %15 %1 FindUMsb %127
OpStore %37 %128
%129 = OpLoad %4 %19
%130 = OpBitCount %4 %129
%130 = OpExtInst %4 %1 FindSMsb %129
OpStore %19 %130
%131 = OpLoad %11 %22
%132 = OpBitCount %11 %131
OpStore %22 %132
%133 = OpLoad %12 %25
%134 = OpBitCount %12 %133
OpStore %25 %134
%135 = OpLoad %13 %28
%136 = OpBitCount %13 %135
OpStore %28 %136
%137 = OpLoad %6 %31
%138 = OpBitCount %6 %137
OpStore %31 %138
%139 = OpLoad %14 %34
%140 = OpBitCount %14 %139
OpStore %34 %140
%141 = OpLoad %15 %37
%142 = OpBitCount %15 %141
OpStore %37 %142
%143 = OpLoad %16 %40
%144 = OpBitCount %16 %143
OpStore %40 %144
%145 = OpLoad %4 %19
%146 = OpBitReverse %4 %145
OpStore %19 %146
%147 = OpLoad %11 %22
%148 = OpBitReverse %11 %147
OpStore %22 %148
%149 = OpLoad %12 %25
%150 = OpBitReverse %12 %149
OpStore %25 %150
%151 = OpLoad %13 %28
%152 = OpBitReverse %13 %151
OpStore %28 %152
%153 = OpLoad %6 %31
%154 = OpBitReverse %6 %153
OpStore %31 %154
%155 = OpLoad %14 %34
%156 = OpBitReverse %14 %155
OpStore %34 %156
%157 = OpLoad %15 %37
%158 = OpBitReverse %15 %157
OpStore %37 %158
%159 = OpLoad %16 %40
%160 = OpBitReverse %16 %159
OpStore %40 %160
%131 = OpLoad %6 %31
%132 = OpExtInst %6 %1 FindUMsb %131
OpStore %31 %132
%133 = OpLoad %4 %19
%134 = OpBitCount %4 %133
OpStore %19 %134
%135 = OpLoad %11 %22
%136 = OpBitCount %11 %135
OpStore %22 %136
%137 = OpLoad %12 %25
%138 = OpBitCount %12 %137
OpStore %25 %138
%139 = OpLoad %13 %28
%140 = OpBitCount %13 %139
OpStore %28 %140
%141 = OpLoad %6 %31
%142 = OpBitCount %6 %141
OpStore %31 %142
%143 = OpLoad %14 %34
%144 = OpBitCount %14 %143
OpStore %34 %144
%145 = OpLoad %15 %37
%146 = OpBitCount %15 %145
OpStore %37 %146
%147 = OpLoad %16 %40
%148 = OpBitCount %16 %147
OpStore %40 %148
%149 = OpLoad %4 %19
%150 = OpBitReverse %4 %149
OpStore %19 %150
%151 = OpLoad %11 %22
%152 = OpBitReverse %11 %151
OpStore %22 %152
%153 = OpLoad %12 %25
%154 = OpBitReverse %12 %153
OpStore %25 %154
%155 = OpLoad %13 %28
%156 = OpBitReverse %13 %155
OpStore %28 %156
%157 = OpLoad %6 %31
%158 = OpBitReverse %6 %157
OpStore %31 %158
%159 = OpLoad %14 %34
%160 = OpBitReverse %14 %159
OpStore %34 %160
%161 = OpLoad %15 %37
%162 = OpBitReverse %15 %161
OpStore %37 %162
%163 = OpLoad %16 %40
%164 = OpBitReverse %16 %163
OpStore %40 %164
OpReturn
OpFunctionEnd

View File

@@ -87,39 +87,43 @@ fn main() {
u2_ = firstTrailingBit(_e122);
let _e124 = i3_;
i3_ = firstLeadingBit(_e124);
let _e126 = u;
u = firstLeadingBit(_e126);
let _e126 = u3_;
u3_ = firstLeadingBit(_e126);
let _e128 = i;
i = countOneBits(_e128);
let _e130 = i2_;
i2_ = countOneBits(_e130);
let _e132 = i3_;
i3_ = countOneBits(_e132);
let _e134 = i4_;
i4_ = countOneBits(_e134);
let _e136 = u;
u = countOneBits(_e136);
let _e138 = u2_;
u2_ = countOneBits(_e138);
let _e140 = u3_;
u3_ = countOneBits(_e140);
let _e142 = u4_;
u4_ = countOneBits(_e142);
let _e144 = i;
i = reverseBits(_e144);
let _e146 = i2_;
i2_ = reverseBits(_e146);
let _e148 = i3_;
i3_ = reverseBits(_e148);
let _e150 = i4_;
i4_ = reverseBits(_e150);
let _e152 = u;
u = reverseBits(_e152);
let _e154 = u2_;
u2_ = reverseBits(_e154);
let _e156 = u3_;
u3_ = reverseBits(_e156);
let _e158 = u4_;
u4_ = reverseBits(_e158);
i = firstLeadingBit(_e128);
let _e130 = u;
u = firstLeadingBit(_e130);
let _e132 = i;
i = countOneBits(_e132);
let _e134 = i2_;
i2_ = countOneBits(_e134);
let _e136 = i3_;
i3_ = countOneBits(_e136);
let _e138 = i4_;
i4_ = countOneBits(_e138);
let _e140 = u;
u = countOneBits(_e140);
let _e142 = u2_;
u2_ = countOneBits(_e142);
let _e144 = u3_;
u3_ = countOneBits(_e144);
let _e146 = u4_;
u4_ = countOneBits(_e146);
let _e148 = i;
i = reverseBits(_e148);
let _e150 = i2_;
i2_ = reverseBits(_e150);
let _e152 = i3_;
i3_ = reverseBits(_e152);
let _e154 = i4_;
i4_ = reverseBits(_e154);
let _e156 = u;
u = reverseBits(_e156);
let _e158 = u2_;
u2_ = reverseBits(_e158);
let _e160 = u3_;
u3_ = reverseBits(_e160);
let _e162 = u4_;
u4_ = reverseBits(_e162);
return;
}