diff --git a/naga/src/back/hlsl/help.rs b/naga/src/back/hlsl/help.rs index 782f73175f..f39ec3cd60 100644 --- a/naga/src/back/hlsl/help.rs +++ b/naga/src/back/hlsl/help.rs @@ -1318,7 +1318,7 @@ impl super::Writer<'_, W> { crate::BinaryOperator::Modulo, Some( scalar @ crate::Scalar { - kind: ScalarKind::Sint | ScalarKind::Uint, + kind: ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float, .. }, ), @@ -1367,6 +1367,14 @@ impl super::Writer<'_, W> { ScalarKind::Uint => { writeln!(self.out, "{level}return lhs % (rhs == 0u ? 1u : rhs);")? } + // HLSL's fmod has the same definition as WGSL's % operator but due + // to its implementation in DXC it is not as accurate as the WGSL spec + // requires it to be. See: + // - https://shader-playground.timjones.io/0c8572816dbb6fc4435cc5d016a978a7 + // - https://github.com/llvm/llvm-project/blob/50f9b8acafdca48e87e6b8e393c1f116a2d193ee/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h#L78-L81 + ScalarKind::Float => { + writeln!(self.out, "{level}return lhs - rhs * trunc(lhs / rhs);")? + } _ => unreachable!(), } writeln!(self.out, "}}")?; diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 873a4a8a91..192238d1f8 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2924,7 +2924,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { right, } if matches!( func_ctx.resolve_type(expr, &module.types).scalar_kind(), - Some(ScalarKind::Sint | ScalarKind::Uint) + Some(ScalarKind::Sint | ScalarKind::Uint | ScalarKind::Float) ) => { write!(self.out, "{MOD_FUNCTION}(")?; @@ -2934,21 +2934,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, ")")?; } - // While HLSL supports float operands with the % operator it is only - // defined in cases where both sides are either positive or negative. - Expression::Binary { - op: crate::BinaryOperator::Modulo, - left, - right, - } if func_ctx.resolve_type(left, &module.types).scalar_kind() - == Some(ScalarKind::Float) => - { - write!(self.out, "fmod(")?; - self.write_expr(module, left, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, right, func_ctx)?; - write!(self.out, ")")?; - } Expression::Binary { op, left, right } => { write!(self.out, "(")?; self.write_expr(module, left, func_ctx)?; diff --git a/naga/tests/out/hlsl/wgsl-operators.hlsl b/naga/tests/out/hlsl/wgsl-operators.hlsl index bfc4cb6ec2..d76518b734 100644 --- a/naga/tests/out/hlsl/wgsl-operators.hlsl +++ b/naga/tests/out/hlsl/wgsl-operators.hlsl @@ -90,6 +90,10 @@ uint naga_mod(uint lhs, uint rhs) { return lhs % (rhs == 0u ? 1u : rhs); } +float naga_mod(float lhs, float rhs) { + return lhs - rhs * trunc(lhs / rhs); +} + int2 naga_mod(int2 lhs, int2 rhs) { int2 divisor = ((lhs == int(-2147483647 - 1) & rhs == -1) | (rhs == 0)) ? 1 : rhs; return lhs - (lhs / divisor) * divisor; @@ -99,6 +103,10 @@ uint3 naga_mod(uint3 lhs, uint3 rhs) { return lhs % (rhs == 0u ? 1u : rhs); } +float4 naga_mod(float4 lhs, float4 rhs) { + return lhs - rhs * trunc(lhs / rhs); +} + uint2 naga_div(uint2 lhs, uint2 rhs) { return lhs / (rhs == 0u ? 1u : rhs); } @@ -107,6 +115,10 @@ uint2 naga_mod(uint2 lhs, uint2 rhs) { return lhs % (rhs == 0u ? 1u : rhs); } +float2 naga_mod(float2 lhs, float2 rhs) { + return lhs - rhs * trunc(lhs / rhs); +} + float3x3 ZeroValuefloat3x3() { return (float3x3)0; } @@ -153,10 +165,10 @@ void arithmetic() float4 div5_ = ((2.0).xxxx / (1.0).xxxx); int rem0_ = naga_mod(int(2), int(1)); uint rem1_ = naga_mod(2u, 1u); - float rem2_ = fmod(2.0, 1.0); + float rem2_ = naga_mod(2.0, 1.0); int2 rem3_ = naga_mod((int(2)).xx, (int(1)).xx); uint3 rem4_ = naga_mod((2u).xxx, (1u).xxx); - float4 rem5_ = fmod((2.0).xxxx, (1.0).xxxx); + float4 rem5_ = naga_mod((2.0).xxxx, (1.0).xxxx); { int2 add0_1 = asint(asuint((int(2)).xx) + asuint((int(1)).xx)); int2 add1_1 = asint(asuint((int(2)).xx) + asuint((int(1)).xx)); @@ -186,8 +198,8 @@ void arithmetic() int2 rem1_1 = naga_mod((int(2)).xx, (int(1)).xx); uint2 rem2_1 = naga_mod((2u).xx, (1u).xx); uint2 rem3_1 = naga_mod((2u).xx, (1u).xx); - float2 rem4_1 = fmod((2.0).xx, (1.0).xx); - float2 rem5_1 = fmod((2.0).xx, (1.0).xx); + float2 rem4_1 = naga_mod((2.0).xx, (1.0).xx); + float2 rem5_1 = naga_mod((2.0).xx, (1.0).xx); } float3x3 add = (ZeroValuefloat3x3() + ZeroValuefloat3x3()); float3x3 sub = (ZeroValuefloat3x3() - ZeroValuefloat3x3());