diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 3433c2f1ca..ff00ded348 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2996,6 +2996,35 @@ impl<'a, W: Write> Writer<'a, W> { Mf::Transpose => "transpose", Mf::Determinant => "determinant", // bits + Mf::CountTrailingZeros => { + match *ctx.info[arg].ty.inner_with(&self.module.types) { + crate::TypeInner::Vector { size, kind, .. } => { + let s = back::vector_size_str(size); + if let crate::ScalarKind::Uint = kind { + write!(self.out, "min(uvec{s}(findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")), uvec{s}(32u))")?; + } else { + write!(self.out, "ivec{s}(min(uvec{s}(findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")), uvec{s}(32u)))")?; + } + } + crate::TypeInner::Scalar { kind, .. } => { + if let crate::ScalarKind::Uint = kind { + write!(self.out, "min(uint(findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")), 32u)")?; + } else { + write!(self.out, "int(min(uint(findLSB(")?; + self.write_expr(arg, ctx)?; + write!(self.out, ")), 32u))")?; + } + } + _ => unreachable!(), + }; + return Ok(()); + } Mf::CountLeadingZeros => { if self.options.version.supports_integer_functions() { match *ctx.info[arg].ty.inner_with(&self.module.types) { diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index c075fe3fe2..450dfe85ca 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2551,6 +2551,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Unpack2x16float, Regular(&'static str), MissingIntOverload(&'static str), + CountTrailingZeros, CountLeadingZeros, } @@ -2614,6 +2615,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Mf::Transpose => Function::Regular("transpose"), Mf::Determinant => Function::Regular("determinant"), // bits + Mf::CountTrailingZeros => Function::CountTrailingZeros, Mf::CountLeadingZeros => Function::CountLeadingZeros, Mf::CountOneBits => Function::MissingIntOverload("countbits"), Mf::ReverseBits => Function::MissingIntOverload("reversebits"), @@ -2682,6 +2684,41 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, ")")?; } } + Function::CountTrailingZeros => { + match *func_ctx.info[arg].ty.inner_with(&module.types) { + TypeInner::Vector { size, kind, .. } => { + let s = match size { + crate::VectorSize::Bi => ".xx", + crate::VectorSize::Tri => ".xxx", + crate::VectorSize::Quad => ".xxxx", + }; + + if let ScalarKind::Uint = kind { + write!(self.out, "min((32u){s}, firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "asint(min((32u){s}, asuint(firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))))")?; + } + } + TypeInner::Scalar { kind, .. } => { + if let ScalarKind::Uint = kind { + write!(self.out, "min(32u, firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } else { + write!(self.out, "asint(min(32u, asuint(firstbitlow(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))))")?; + } + } + _ => unreachable!(), + } + + return Ok(()); + } Function::CountLeadingZeros => { match *func_ctx.info[arg].ty.inner_with(&module.types) { TypeInner::Vector { size, kind, .. } => { diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 040541de41..c11765da96 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1686,6 +1686,7 @@ impl Writer { Mf::Transpose => "transpose", Mf::Determinant => "determinant", // bits + Mf::CountTrailingZeros => "ctz", Mf::CountLeadingZeros => "clz", Mf::CountOneBits => "popcount", Mf::ReverseBits => "reverse_bits", diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 4b0c52e8e3..11d5782633 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -874,6 +874,49 @@ impl<'w> BlockContext<'w> { id, arg0_id, )), + Mf::CountTrailingZeros => { + let uint = crate::ScalarValue::Uint(32); + let uint_id = match *arg_ty { + crate::TypeInner::Vector { size, width, .. } => { + let ty = LocalType::Value { + vector_size: Some(size), + kind: crate::ScalarKind::Uint, + width, + pointer_space: None, + } + .into(); + + self.temp_list.clear(); + self.temp_list.resize( + size as _, + self.writer.get_constant_scalar(uint, width), + ); + + self.writer.get_constant_composite(ty, &self.temp_list) + } + crate::TypeInner::Scalar { width, .. } => { + self.writer.get_constant_scalar(uint, width) + } + _ => unreachable!(), + }; + + let lsb_id = self.gen_id(); + block.body.push(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::FindILsb, + result_type_id, + lsb_id, + &[arg0_id], + )); + + MathOp::Custom(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::UMin, + result_type_id, + id, + &[uint_id, lsb_id], + )) + } Mf::CountLeadingZeros => { let int = crate::ScalarValue::Sint(31); diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 4f58cd9ee8..be4d9e4423 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -1578,6 +1578,7 @@ impl Writer { Mf::Transpose => Function::Regular("transpose"), Mf::Determinant => Function::Regular("determinant"), // bits + Mf::CountTrailingZeros => Function::Regular("countTrailingZeros"), Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"), Mf::CountOneBits => Function::Regular("countOneBits"), Mf::ReverseBits => Function::Regular("reverseBits"), diff --git a/src/front/wgsl/parse/conv.rs b/src/front/wgsl/parse/conv.rs index 4164332166..0e20774dc5 100644 --- a/src/front/wgsl/parse/conv.rs +++ b/src/front/wgsl/parse/conv.rs @@ -191,6 +191,7 @@ pub fn map_standard_fun(word: &str) -> Option { "transpose" => Mf::Transpose, "determinant" => Mf::Determinant, // bits + "countTrailingZeros" => Mf::CountTrailingZeros, "countLeadingZeros" => Mf::CountLeadingZeros, "countOneBits" => Mf::CountOneBits, "reverseBits" => Mf::ReverseBits, diff --git a/src/lib.rs b/src/lib.rs index 136205ca17..e60d8e0cc4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1066,6 +1066,7 @@ pub enum MathFunction { Transpose, Determinant, // bits + CountTrailingZeros, CountLeadingZeros, CountOneBits, ReverseBits, diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 0a8e4a961a..6a8bfa03c7 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -279,6 +279,7 @@ impl super::MathFunction { Self::Transpose => 1, Self::Determinant => 1, // bits + Self::CountTrailingZeros => 1, Self::CountLeadingZeros => 1, Self::CountOneBits => 1, Self::ReverseBits => 1, diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index cac32ebb78..3b9fa1d50b 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -793,6 +793,7 @@ impl<'a> ResolveContext<'a> { )), }, // bits + Mf::CountTrailingZeros | Mf::CountLeadingZeros | Mf::CountOneBits | Mf::ReverseBits | diff --git a/src/valid/expression.rs b/src/valid/expression.rs index d599e6b62f..01d6910eba 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1223,7 +1223,8 @@ impl super::Validator { _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } } - Mf::CountLeadingZeros + Mf::CountTrailingZeros + | Mf::CountLeadingZeros | Mf::CountOneBits | Mf::ReverseBits | Mf::FindLsb diff --git a/tests/in/math-functions.wgsl b/tests/in/math-functions.wgsl index 1c2a8e4579..db50880d14 100644 --- a/tests/in/math-functions.wgsl +++ b/tests/in/math-functions.wgsl @@ -10,6 +10,14 @@ fn main() { let g = refract(v, v, f); let const_dot = dot(vec2(), vec2()); let first_leading_bit_abs = firstLeadingBit(abs(0u)); + let ctz_a = countTrailingZeros(0u); + let ctz_b = countTrailingZeros(0); + let ctz_c = countTrailingZeros(0xFFFFFFFFu); + let ctz_d = countTrailingZeros(-1); + let ctz_e = countTrailingZeros(vec2(0u)); + let ctz_f = countTrailingZeros(vec2(0)); + let ctz_g = countTrailingZeros(vec2(1u)); + let ctz_h = countTrailingZeros(vec2(1)); let clz_a = countLeadingZeros(-1); let clz_b = countLeadingZeros(1u); let clz_c = countLeadingZeros(vec2(-1)); diff --git a/tests/out/glsl/math-functions.main.Vertex.glsl b/tests/out/glsl/math-functions.main.Vertex.glsl index 8cfa6a10b5..3c5c1dd345 100644 --- a/tests/out/glsl/math-functions.main.Vertex.glsl +++ b/tests/out/glsl/math-functions.main.Vertex.glsl @@ -14,10 +14,18 @@ void main() { vec4 g = refract(v, v, 1.0); int const_dot = ( + ivec2(0, 0).x * ivec2(0, 0).x + ivec2(0, 0).y * ivec2(0, 0).y); uint first_leading_bit_abs = uint(findMSB(uint(abs(int(0u))))); + uint ctz_a = min(uint(findLSB(0u)), 32u); + int ctz_b = int(min(uint(findLSB(0)), 32u)); + uint ctz_c = min(uint(findLSB(4294967295u)), 32u); + int ctz_d = int(min(uint(findLSB(-1)), 32u)); + uvec2 ctz_e = min(uvec2(findLSB(uvec2(0u))), uvec2(32u)); + ivec2 ctz_f = ivec2(min(uvec2(findLSB(ivec2(0))), uvec2(32u))); + uvec2 ctz_g = min(uvec2(findLSB(uvec2(1u))), uvec2(32u)); + ivec2 ctz_h = ivec2(min(uvec2(findLSB(ivec2(1))), uvec2(32u))); int clz_a = (-1 < 0 ? 0 : 31 - findMSB(-1)); uint clz_b = uint(31 - findMSB(1u)); - ivec2 _e20 = ivec2(-1); - ivec2 clz_c = mix(ivec2(31) - findMSB(_e20), ivec2(0), lessThan(_e20, ivec2(0))); + ivec2 _e40 = ivec2(-1); + ivec2 clz_c = mix(ivec2(31) - findMSB(_e40), ivec2(0), lessThan(_e40, ivec2(0))); uvec2 clz_d = uvec2(ivec2(31) - findMSB(uvec2(1u))); } diff --git a/tests/out/hlsl/math-functions.hlsl b/tests/out/hlsl/math-functions.hlsl index 2a95c849c9..958e77d80a 100644 --- a/tests/out/hlsl/math-functions.hlsl +++ b/tests/out/hlsl/math-functions.hlsl @@ -10,9 +10,17 @@ void main() float4 g = refract(v, v, 1.0); int const_dot = dot(int2(0, 0), int2(0, 0)); uint first_leading_bit_abs = firstbithigh(abs(0u)); + uint ctz_a = min(32u, firstbitlow(0u)); + int ctz_b = asint(min(32u, asuint(firstbitlow(0)))); + uint ctz_c = min(32u, firstbitlow(4294967295u)); + int ctz_d = asint(min(32u, asuint(firstbitlow(-1)))); + uint2 ctz_e = min((32u).xx, firstbitlow((0u).xx)); + int2 ctz_f = asint(min((32u).xx, asuint(firstbitlow((0).xx)))); + uint2 ctz_g = min((32u).xx, firstbitlow((1u).xx)); + int2 ctz_h = asint(min((32u).xx, asuint(firstbitlow((1).xx)))); int clz_a = (-1 < 0 ? 0 : 31 - firstbithigh(-1)); uint clz_b = asuint(31 - firstbithigh(1u)); - int2 _expr20 = (-1).xx; - int2 clz_c = (_expr20 < (0).xx ? (0).xx : (31).xx - firstbithigh(_expr20)); + int2 _expr40 = (-1).xx; + int2 clz_c = (_expr40 < (0).xx ? (0).xx : (31).xx - firstbithigh(_expr40)); uint2 clz_d = asuint((31).xx - firstbithigh((1u).xx)); } diff --git a/tests/out/msl/math-functions.msl b/tests/out/msl/math-functions.msl index 3fdb7b75a5..c2aac6ef98 100644 --- a/tests/out/msl/math-functions.msl +++ b/tests/out/msl/math-functions.msl @@ -18,6 +18,14 @@ vertex void main_( int const_dot = ( + const_type_1_.x * const_type_1_.x + const_type_1_.y * const_type_1_.y); uint _e13 = metal::abs(0u); uint first_leading_bit_abs = metal::select(31 - metal::clz(_e13), uint(-1), _e13 == 0 || _e13 == -1); + uint ctz_a = metal::ctz(0u); + int ctz_b = metal::ctz(0); + uint ctz_c = metal::ctz(4294967295u); + int ctz_d = metal::ctz(-1); + metal::uint2 ctz_e = metal::ctz(metal::uint2(0u)); + metal::int2 ctz_f = metal::ctz(metal::int2(0)); + metal::uint2 ctz_g = metal::ctz(metal::uint2(1u)); + metal::int2 ctz_h = metal::ctz(metal::int2(1)); int clz_a = metal::clz(-1); uint clz_b = metal::clz(1u); metal::int2 clz_c = metal::clz(metal::int2(-1)); diff --git a/tests/out/spv/math-functions.spvasm b/tests/out/spv/math-functions.spvasm index a65f535181..d1a98b4e43 100644 --- a/tests/out/spv/math-functions.spvasm +++ b/tests/out/spv/math-functions.spvasm @@ -1,11 +1,11 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 54 +; Bound: 78 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint Vertex %16 "main" +OpEntryPoint Vertex %18 "main" %2 = OpTypeVoid %4 = OpTypeFloat 32 %3 = OpConstant %4 1.0 @@ -14,50 +14,74 @@ OpEntryPoint Vertex %16 "main" %6 = OpConstant %7 0 %9 = OpTypeInt 32 0 %8 = OpConstant %9 0 -%10 = OpConstant %7 -1 -%11 = OpConstant %9 1 -%12 = OpTypeVector %4 4 -%13 = OpTypeVector %7 2 -%14 = OpConstantComposite %13 %6 %6 -%17 = OpTypeFunction %2 -%25 = OpConstantComposite %12 %5 %5 %5 %5 -%26 = OpConstantComposite %12 %3 %3 %3 %3 -%40 = OpConstant %7 31 -%47 = OpConstantComposite %13 %40 %40 -%49 = OpTypeVector %9 2 -%29 = OpConstantNull %7 -%16 = OpFunction %2 None %17 -%15 = OpLabel -OpBranch %18 -%18 = OpLabel -%19 = OpCompositeConstruct %12 %5 %5 %5 %5 -%20 = OpExtInst %4 %1 Degrees %3 -%21 = OpExtInst %4 %1 Radians %3 -%22 = OpExtInst %12 %1 Degrees %19 -%23 = OpExtInst %12 %1 Radians %19 -%24 = OpExtInst %12 %1 FClamp %19 %25 %26 -%27 = OpExtInst %12 %1 Refract %19 %19 %3 -%30 = OpCompositeExtract %7 %14 0 -%31 = OpCompositeExtract %7 %14 0 -%32 = OpIMul %7 %30 %31 -%33 = OpIAdd %7 %29 %32 -%34 = OpCompositeExtract %7 %14 1 -%35 = OpCompositeExtract %7 %14 1 -%36 = OpIMul %7 %34 %35 -%28 = OpIAdd %7 %33 %36 -%37 = OpCopyObject %9 %8 -%38 = OpExtInst %9 %1 FindUMsb %37 -%39 = OpExtInst %7 %1 FindUMsb %10 -%41 = OpISub %7 %40 %39 -%42 = OpExtInst %7 %1 FindUMsb %11 -%43 = OpISub %7 %40 %42 -%44 = OpBitcast %9 %43 -%45 = OpCompositeConstruct %13 %10 %10 -%46 = OpExtInst %13 %1 FindUMsb %45 -%48 = OpISub %13 %47 %46 -%50 = OpCompositeConstruct %49 %11 %11 -%51 = OpExtInst %13 %1 FindUMsb %50 -%52 = OpISub %13 %47 %51 -%53 = OpBitcast %49 %52 +%10 = OpConstant %9 4294967295 +%11 = OpConstant %7 -1 +%12 = OpConstant %9 1 +%13 = OpConstant %7 1 +%14 = OpTypeVector %4 4 +%15 = OpTypeVector %7 2 +%16 = OpConstantComposite %15 %6 %6 +%19 = OpTypeFunction %2 +%27 = OpConstantComposite %14 %5 %5 %5 %5 +%28 = OpConstantComposite %14 %3 %3 %3 %3 +%42 = OpConstant %9 32 +%50 = OpTypeVector %9 2 +%53 = OpConstantComposite %50 %42 %42 +%65 = OpConstant %7 31 +%72 = OpConstantComposite %15 %65 %65 +%31 = OpConstantNull %7 +%18 = OpFunction %2 None %19 +%17 = OpLabel +OpBranch %20 +%20 = OpLabel +%21 = OpCompositeConstruct %14 %5 %5 %5 %5 +%22 = OpExtInst %4 %1 Degrees %3 +%23 = OpExtInst %4 %1 Radians %3 +%24 = OpExtInst %14 %1 Degrees %21 +%25 = OpExtInst %14 %1 Radians %21 +%26 = OpExtInst %14 %1 FClamp %21 %27 %28 +%29 = OpExtInst %14 %1 Refract %21 %21 %3 +%32 = OpCompositeExtract %7 %16 0 +%33 = OpCompositeExtract %7 %16 0 +%34 = OpIMul %7 %32 %33 +%35 = OpIAdd %7 %31 %34 +%36 = OpCompositeExtract %7 %16 1 +%37 = OpCompositeExtract %7 %16 1 +%38 = OpIMul %7 %36 %37 +%30 = OpIAdd %7 %35 %38 +%39 = OpCopyObject %9 %8 +%40 = OpExtInst %9 %1 FindUMsb %39 +%43 = OpExtInst %9 %1 FindILsb %8 +%41 = OpExtInst %9 %1 UMin %42 %43 +%45 = OpExtInst %7 %1 FindILsb %6 +%44 = OpExtInst %7 %1 UMin %42 %45 +%47 = OpExtInst %9 %1 FindILsb %10 +%46 = OpExtInst %9 %1 UMin %42 %47 +%49 = OpExtInst %7 %1 FindILsb %11 +%48 = OpExtInst %7 %1 UMin %42 %49 +%51 = OpCompositeConstruct %50 %8 %8 +%54 = OpExtInst %50 %1 FindILsb %51 +%52 = OpExtInst %50 %1 UMin %53 %54 +%55 = OpCompositeConstruct %15 %6 %6 +%57 = OpExtInst %15 %1 FindILsb %55 +%56 = OpExtInst %15 %1 UMin %53 %57 +%58 = OpCompositeConstruct %50 %12 %12 +%60 = OpExtInst %50 %1 FindILsb %58 +%59 = OpExtInst %50 %1 UMin %53 %60 +%61 = OpCompositeConstruct %15 %13 %13 +%63 = OpExtInst %15 %1 FindILsb %61 +%62 = OpExtInst %15 %1 UMin %53 %63 +%64 = OpExtInst %7 %1 FindUMsb %11 +%66 = OpISub %7 %65 %64 +%67 = OpExtInst %7 %1 FindUMsb %12 +%68 = OpISub %7 %65 %67 +%69 = OpBitcast %9 %68 +%70 = OpCompositeConstruct %15 %11 %11 +%71 = OpExtInst %15 %1 FindUMsb %70 +%73 = OpISub %15 %72 %71 +%74 = OpCompositeConstruct %50 %12 %12 +%75 = OpExtInst %15 %1 FindUMsb %74 +%76 = OpISub %15 %72 %75 +%77 = OpBitcast %50 %76 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/math-functions.wgsl b/tests/out/wgsl/math-functions.wgsl index d91a26cff4..71ae0fd749 100644 --- a/tests/out/wgsl/math-functions.wgsl +++ b/tests/out/wgsl/math-functions.wgsl @@ -9,6 +9,14 @@ fn main() { let g = refract(v, v, 1.0); let const_dot = dot(vec2(0, 0), vec2(0, 0)); let first_leading_bit_abs = firstLeadingBit(abs(0u)); + let ctz_a = countTrailingZeros(0u); + let ctz_b = countTrailingZeros(0); + let ctz_c = countTrailingZeros(4294967295u); + let ctz_d = countTrailingZeros(-1); + let ctz_e = countTrailingZeros(vec2(0u)); + let ctz_f = countTrailingZeros(vec2(0)); + let ctz_g = countTrailingZeros(vec2(1u)); + let ctz_h = countTrailingZeros(vec2(1)); let clz_a = countLeadingZeros(-1); let clz_b = countLeadingZeros(1u); let clz_c = countLeadingZeros(vec2(-1));