diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 0558c90963..fb3e1f47ab 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2242,6 +2242,9 @@ impl<'a, W: Write> Writer<'a, W> { Mf::Acos => "acos", Mf::Asin => "asin", Mf::Atan => "atan", + Mf::Asinh => "asinh", + Mf::Acosh => "acosh", + Mf::Atanh => "atanh", // glsl doesn't have atan2 function // use two-argument variation of the atan function Mf::Atan2 => "atan", diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 108d2b6c64..87dabd8637 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1635,76 +1635,105 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } => { use crate::MathFunction as Mf; - let fun_name = match fun { - // comparison - Mf::Abs => "abs", - Mf::Min => "min", - Mf::Max => "max", - Mf::Clamp => "clamp", - // trigonometry - Mf::Cos => "cos", - Mf::Cosh => "cosh", - Mf::Sin => "sin", - Mf::Sinh => "sinh", - Mf::Tan => "tan", - Mf::Tanh => "tanh", - Mf::Acos => "acos", - Mf::Asin => "asin", - Mf::Atan => "atan", - Mf::Atan2 => "atan2", - // decomposition - Mf::Ceil => "ceil", - Mf::Floor => "floor", - Mf::Round => "round", - Mf::Fract => "frac", - Mf::Trunc => "trunc", - Mf::Modf => "modf", - Mf::Frexp => "frexp", - Mf::Ldexp => "ldexp", - // exponent - Mf::Exp => "exp", - Mf::Exp2 => "exp2", - Mf::Log => "log", - Mf::Log2 => "log2", - Mf::Pow => "pow", - // geometry - Mf::Dot => "dot", - //Mf::Outer => , - Mf::Cross => "cross", - Mf::Distance => "distance", - Mf::Length => "length", - Mf::Normalize => "normalize", - Mf::FaceForward => "faceforward", - Mf::Reflect => "reflect", - Mf::Refract => "refract", - // computational - Mf::Sign => "sign", - Mf::Fma => "fma", - Mf::Mix => "lerp", - Mf::Step => "step", - Mf::SmoothStep => "smoothstep", - Mf::Sqrt => "sqrt", - Mf::InverseSqrt => "rsqrt", - //Mf::Inverse =>, - Mf::Transpose => "transpose", - Mf::Determinant => "determinant", - // bits - Mf::CountOneBits => "countbits", - Mf::ReverseBits => "reversebits", - _ => return Err(Error::Unimplemented(format!("write_expr_math {:?}", fun))), - }; + match fun { + Mf::Asinh | Mf::Acosh => { + write!(self.out, "log(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " + sqrt(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " * ")?; + self.write_expr(module, arg, func_ctx)?; + if fun == Mf::Asinh { + write!(self.out, " + 1.0))")? + } else { + write!(self.out, " - 1.0))")? + } + } + Mf::Atanh => { + write!(self.out, "0.5 * log((1.0 + ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ") / (1.0 - ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } + _ => { + let fun_name = match fun { + // comparison + Mf::Abs => "abs", + Mf::Min => "min", + Mf::Max => "max", + Mf::Clamp => "clamp", + // trigonometry + Mf::Cos => "cos", + Mf::Cosh => "cosh", + Mf::Sin => "sin", + Mf::Sinh => "sinh", + Mf::Tan => "tan", + Mf::Tanh => "tanh", + Mf::Acos => "acos", + Mf::Asin => "asin", + Mf::Atan => "atan", + Mf::Atan2 => "atan2", + // decomposition + Mf::Ceil => "ceil", + Mf::Floor => "floor", + Mf::Round => "round", + Mf::Fract => "frac", + Mf::Trunc => "trunc", + Mf::Modf => "modf", + Mf::Frexp => "frexp", + Mf::Ldexp => "ldexp", + // exponent + Mf::Exp => "exp", + Mf::Exp2 => "exp2", + Mf::Log => "log", + Mf::Log2 => "log2", + Mf::Pow => "pow", + // geometry + Mf::Dot => "dot", + //Mf::Outer => , + Mf::Cross => "cross", + Mf::Distance => "distance", + Mf::Length => "length", + Mf::Normalize => "normalize", + Mf::FaceForward => "faceforward", + Mf::Reflect => "reflect", + Mf::Refract => "refract", + // computational + Mf::Sign => "sign", + Mf::Fma => "fma", + Mf::Mix => "lerp", + Mf::Step => "step", + Mf::SmoothStep => "smoothstep", + Mf::Sqrt => "sqrt", + Mf::InverseSqrt => "rsqrt", + //Mf::Inverse =>, + Mf::Transpose => "transpose", + Mf::Determinant => "determinant", + // bits + Mf::CountOneBits => "countbits", + Mf::ReverseBits => "reversebits", + _ => { + return Err(Error::Unimplemented(format!( + "write_expr_math {:?}", + fun + ))) + } + }; - write!(self.out, "{}(", fun_name)?; - self.write_expr(module, arg, func_ctx)?; - if let Some(arg) = arg1 { - write!(self.out, ", ")?; - self.write_expr(module, arg, func_ctx)?; + write!(self.out, "{}(", fun_name)?; + self.write_expr(module, arg, func_ctx)?; + if let Some(arg) = arg1 { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } + if let Some(arg) = arg2 { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } + write!(self.out, ")")? + } } - if let Some(arg) = arg2 { - write!(self.out, ", ")?; - self.write_expr(module, arg, func_ctx)?; - } - write!(self.out, ")")? } Expression::Swizzle { size, diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index a15a4d2071..0d1fbd36bb 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1140,6 +1140,9 @@ impl Writer { Mf::Asin => "asin", Mf::Atan => "atan", Mf::Atan2 => "atan2", + Mf::Asinh => "asinh", + Mf::Acosh => "acosh", + Mf::Atanh => "atanh", // decomposition Mf::Ceil => "ceil", Mf::Floor => "floor", diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 41b37373e7..19450c02f7 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -558,6 +558,9 @@ impl<'w> BlockContext<'w> { Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh), Mf::Atan => MathOp::Ext(spirv::GLOp::Atan), Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2), + Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh), + Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh), + Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh), // decomposition Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil), Mf::Round => MathOp::Ext(spirv::GLOp::Round), diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index b101b745dc..2dfc72e808 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -1240,75 +1240,101 @@ impl Writer { } => { use crate::MathFunction as Mf; - let fun_name = match fun { - Mf::Abs => "abs", - Mf::Min => "min", - Mf::Max => "max", - Mf::Clamp => "clamp", - // trigonometry - Mf::Cos => "cos", - Mf::Cosh => "cosh", - Mf::Sin => "sin", - Mf::Sinh => "sinh", - Mf::Tan => "tan", - Mf::Tanh => "tanh", - Mf::Acos => "acos", - Mf::Asin => "asin", - Mf::Atan => "atan", - Mf::Atan2 => "atan2", - // decomposition - Mf::Ceil => "ceil", - Mf::Floor => "floor", - Mf::Round => "round", - Mf::Fract => "fract", - Mf::Trunc => "trunc", - Mf::Modf => "modf", - Mf::Frexp => "frexp", - Mf::Ldexp => "ldexp", - // exponent - Mf::Exp => "exp", - Mf::Exp2 => "exp2", - Mf::Log => "log", - Mf::Log2 => "log2", - Mf::Pow => "pow", - // geometry - Mf::Dot => "dot", - Mf::Outer => "outerProduct", - Mf::Cross => "cross", - Mf::Distance => "distance", - Mf::Length => "length", - Mf::Normalize => "normalize", - Mf::FaceForward => "faceForward", - Mf::Reflect => "reflect", - // computational - Mf::Sign => "sign", - Mf::Fma => "fma", - Mf::Mix => "mix", - Mf::Step => "step", - Mf::SmoothStep => "smoothStep", - Mf::Sqrt => "sqrt", - Mf::InverseSqrt => "inverseSqrt", - Mf::Transpose => "transpose", - Mf::Determinant => "determinant", - // bits - Mf::CountOneBits => "countOneBits", - Mf::ReverseBits => "reverseBits", - _ => { - return Err(Error::UnsupportedMathFunction(fun)); + // NOTE: If https://github.com/gpuweb/gpuweb/issues/1622 ever is + // accepted, replace this with the builtin functions + match fun { + Mf::Asinh | Mf::Acosh => { + write!(self.out, "log(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " + sqrt(")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, " * ")?; + self.write_expr(module, arg, func_ctx)?; + if fun == Mf::Asinh { + write!(self.out, " + 1.0))")? + } else { + write!(self.out, " - 1.0))")? + } } - }; + Mf::Atanh => { + write!(self.out, "0.5 * log((1.0 + ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, ") / (1.0 - ")?; + self.write_expr(module, arg, func_ctx)?; + write!(self.out, "))")?; + } + _ => { + let fun_name = match fun { + Mf::Abs => "abs", + Mf::Min => "min", + Mf::Max => "max", + Mf::Clamp => "clamp", + // trigonometry + Mf::Cos => "cos", + Mf::Cosh => "cosh", + Mf::Sin => "sin", + Mf::Sinh => "sinh", + Mf::Tan => "tan", + Mf::Tanh => "tanh", + Mf::Acos => "acos", + Mf::Asin => "asin", + Mf::Atan => "atan", + Mf::Atan2 => "atan2", + // decomposition + Mf::Ceil => "ceil", + Mf::Floor => "floor", + Mf::Round => "round", + Mf::Fract => "fract", + Mf::Trunc => "trunc", + Mf::Modf => "modf", + Mf::Frexp => "frexp", + Mf::Ldexp => "ldexp", + // exponent + Mf::Exp => "exp", + Mf::Exp2 => "exp2", + Mf::Log => "log", + Mf::Log2 => "log2", + Mf::Pow => "pow", + // geometry + Mf::Dot => "dot", + Mf::Outer => "outerProduct", + Mf::Cross => "cross", + Mf::Distance => "distance", + Mf::Length => "length", + Mf::Normalize => "normalize", + Mf::FaceForward => "faceForward", + Mf::Reflect => "reflect", + // computational + Mf::Sign => "sign", + Mf::Fma => "fma", + Mf::Mix => "mix", + Mf::Step => "step", + Mf::SmoothStep => "smoothStep", + Mf::Sqrt => "sqrt", + Mf::InverseSqrt => "inverseSqrt", + Mf::Transpose => "transpose", + Mf::Determinant => "determinant", + // bits + Mf::CountOneBits => "countOneBits", + Mf::ReverseBits => "reverseBits", + _ => { + return Err(Error::UnsupportedMathFunction(fun)); + } + }; - write!(self.out, "{}(", fun_name)?; - self.write_expr(module, arg, func_ctx)?; - if let Some(arg) = arg1 { - write!(self.out, ", ")?; - self.write_expr(module, arg, func_ctx)?; + write!(self.out, "{}(", fun_name)?; + self.write_expr(module, arg, func_ctx)?; + if let Some(arg) = arg1 { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } + if let Some(arg) = arg2 { + write!(self.out, ", ")?; + self.write_expr(module, arg, func_ctx)?; + } + write!(self.out, ")")? + } } - if let Some(arg) = arg2 { - write!(self.out, ", ")?; - self.write_expr(module, arg, func_ctx)?; - } - write!(self.out, ")")? } Expression::Swizzle { size, diff --git a/src/front/glsl/builtins.rs b/src/front/glsl/builtins.rs index e7e995674f..4af2c7de4b 100644 --- a/src/front/glsl/builtins.rs +++ b/src/front/glsl/builtins.rs @@ -451,7 +451,7 @@ pub fn inject_builtin(declaration: &mut FunctionDeclaration, module: &mut Module } } "sin" | "exp" | "exp2" | "sinh" | "cos" | "cosh" | "tan" | "tanh" | "acos" | "asin" - | "log" | "log2" | "radians" | "degrees" => { + | "log" | "log2" | "radians" | "degrees" | "asinh" | "acosh" | "atanh" => { // bits layout // bit 0 trough 1 - dims for bits in 0..(0b100) { @@ -481,6 +481,9 @@ pub fn inject_builtin(declaration: &mut FunctionDeclaration, module: &mut Module "asin" => MacroCall::MathFunction(MathFunction::Asin), "log" => MacroCall::MathFunction(MathFunction::Log), "log2" => MacroCall::MathFunction(MathFunction::Log2), + "asinh" => MacroCall::MathFunction(MathFunction::Asinh), + "acosh" => MacroCall::MathFunction(MathFunction::Acosh), + "atanh" => MacroCall::MathFunction(MathFunction::Atanh), "radians" => MacroCall::ConstMultiply(std::f64::consts::PI / 180.0), "degrees" => MacroCall::ConstMultiply(180.0 / std::f64::consts::PI), _ => unreachable!(), diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index 2dd40ea563..ce7080873d 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -1984,6 +1984,9 @@ impl> Parser { Glo::Cosh => Mf::Cosh, Glo::Tanh => Mf::Tanh, Glo::Atan2 => Mf::Atan2, + Glo::Asinh => Mf::Asinh, + Glo::Acosh => Mf::Acosh, + Glo::Atanh => Mf::Atanh, Glo::Pow => Mf::Pow, Glo::Exp => Mf::Exp, Glo::Log => Mf::Log, diff --git a/src/lib.rs b/src/lib.rs index f3904eb099..d5167f3729 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -802,6 +802,9 @@ pub enum MathFunction { Asin, Atan, Atan2, + Asinh, + Acosh, + Atanh, // decomposition Ceil, Floor, diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 7919c4267e..867782ffe7 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -151,6 +151,9 @@ impl super::MathFunction { Self::Asin => 1, Self::Atan => 1, Self::Atan2 => 2, + Self::Asinh => 1, + Self::Acosh => 1, + Self::Atanh => 1, // decomposition Self::Ceil => 1, Self::Floor => 1, diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index afd4d73bfd..e5635f9013 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -554,6 +554,9 @@ impl<'a> ResolveContext<'a> { Mf::Asin | Mf::Atan | Mf::Atan2 | + Mf::Asinh | + Mf::Acosh | + Mf::Atanh | // decomposition Mf::Ceil | Mf::Floor | diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 2324fb5184..e3e189453b 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -931,6 +931,9 @@ impl super::Validator { | Mf::Acos | Mf::Asin | Mf::Atan + | Mf::Asinh + | Mf::Acosh + | Mf::Atanh | Mf::Ceil | Mf::Floor | Mf::Round diff --git a/tests/in/spv/inv-hyperbolic-trig-functions.spv b/tests/in/spv/inv-hyperbolic-trig-functions.spv new file mode 100644 index 0000000000..a87d8deec5 Binary files /dev/null and b/tests/in/spv/inv-hyperbolic-trig-functions.spv differ diff --git a/tests/out/hlsl/inv-hyperbolic-trig-functions.hlsl b/tests/out/hlsl/inv-hyperbolic-trig-functions.hlsl new file mode 100644 index 0000000000..5857257ffa --- /dev/null +++ b/tests/out/hlsl/inv-hyperbolic-trig-functions.hlsl @@ -0,0 +1,22 @@ + +static float a = (float)0; + +void main1() +{ + float b = (float)0; + float c = (float)0; + float d = (float)0; + + float _expr8 = a; + b = log(_expr8 + sqrt(_expr8 * _expr8 + 1.0)); + float _expr10 = a; + c = log(_expr10 + sqrt(_expr10 * _expr10 - 1.0)); + float _expr12 = a; + d = 0.5 * log((1.0 + _expr12) / (1.0 - _expr12)); + return; +} + +void main() +{ + main1(); +} diff --git a/tests/out/hlsl/inv-hyperbolic-trig-functions.hlsl.config b/tests/out/hlsl/inv-hyperbolic-trig-functions.hlsl.config new file mode 100644 index 0000000000..f72fafd91f --- /dev/null +++ b/tests/out/hlsl/inv-hyperbolic-trig-functions.hlsl.config @@ -0,0 +1,3 @@ +vertex=(main:vs_5_1 ) +fragment=() +compute=() diff --git a/tests/out/ir/inv-hyperbolic-trig-functions.ron b/tests/out/ir/inv-hyperbolic-trig-functions.ron new file mode 100644 index 0000000000..bdf53ea71f --- /dev/null +++ b/tests/out/ir/inv-hyperbolic-trig-functions.ron @@ -0,0 +1,182 @@ +( + types: [ + ( + name: None, + inner: Scalar( + kind: Float, + width: 4, + ), + ), + ( + name: None, + inner: Pointer( + base: 1, + class: Function, + ), + ), + ( + name: None, + inner: Pointer( + base: 1, + class: Private, + ), + ), + ], + constants: [ + ( + name: None, + specialization: None, + inner: Scalar( + width: 4, + value: Sint(0), + ), + ), + ( + name: None, + specialization: None, + inner: Scalar( + width: 4, + value: Sint(1), + ), + ), + ( + name: None, + specialization: None, + inner: Scalar( + width: 4, + value: Sint(2), + ), + ), + ( + name: None, + specialization: None, + inner: Scalar( + width: 4, + value: Sint(3), + ), + ), + ], + global_variables: [ + ( + name: Some("a"), + class: Private, + binding: None, + ty: 1, + init: None, + ), + ], + functions: [ + ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [ + ( + name: Some("b"), + ty: 1, + init: None, + ), + ( + name: Some("c"), + ty: 1, + init: None, + ), + ( + name: Some("d"), + ty: 1, + init: None, + ), + ], + expressions: [ + GlobalVariable(1), + Constant(1), + Constant(2), + Constant(3), + Constant(4), + LocalVariable(1), + LocalVariable(2), + LocalVariable(3), + Load( + pointer: 1, + ), + Math( + fun: Asinh, + arg: 9, + arg1: None, + arg2: None, + ), + Load( + pointer: 1, + ), + Math( + fun: Acosh, + arg: 11, + arg1: None, + arg2: None, + ), + Load( + pointer: 1, + ), + Math( + fun: Atanh, + arg: 13, + arg1: None, + arg2: None, + ), + ], + named_expressions: {}, + body: [ + Emit(( + start: 8, + end: 10, + )), + Store( + pointer: 6, + value: 10, + ), + Emit(( + start: 10, + end: 12, + )), + Store( + pointer: 7, + value: 12, + ), + Emit(( + start: 12, + end: 14, + )), + Store( + pointer: 8, + value: 14, + ), + Return( + value: None, + ), + ], + ), + ], + entry_points: [ + ( + name: "main", + stage: Vertex, + early_depth_test: None, + workgroup_size: (0, 0, 0), + function: ( + name: Some("main_wrap"), + arguments: [], + result: None, + local_variables: [], + expressions: [], + named_expressions: {}, + body: [ + Call( + function: 1, + arguments: [], + result: None, + ), + ], + ), + ), + ], +) \ No newline at end of file diff --git a/tests/out/wgsl/inv-hyperbolic-trig-functions.wgsl b/tests/out/wgsl/inv-hyperbolic-trig-functions.wgsl new file mode 100644 index 0000000000..7378cb3001 --- /dev/null +++ b/tests/out/wgsl/inv-hyperbolic-trig-functions.wgsl @@ -0,0 +1,20 @@ +var a: f32; + +fn main1() { + var b: f32; + var c: f32; + var d: f32; + + let _e8: f32 = a; + b = log(_e8 + sqrt(_e8 * _e8 + 1.0)); + let _e10: f32 = a; + c = log(_e10 + sqrt(_e10 * _e10 - 1.0)); + let _e12: f32 = a; + d = 0.5 * log((1.0 + _e12) / (1.0 - _e12)); + return; +} + +[[stage(vertex)]] +fn main() { + main1(); +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index e5c9c3e768..e2a957e241 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -498,6 +498,16 @@ fn convert_spv_shadow() { convert_spv("shadow", true, Targets::IR | Targets::ANALYSIS); } +#[cfg(feature = "spv-in")] +#[test] +fn convert_spv_inverse_hyperbolic_trig_functions() { + convert_spv( + "inv-hyperbolic-trig-functions", + true, + Targets::HLSL | Targets::WGSL | Targets::IR, + ); +} + #[cfg(all(feature = "spv-in", feature = "spv-out"))] #[test] fn convert_spv_pointer_access() {