Add support for inverse hyperbolic trignometric functions

Hlsl and wgsl don't support them directly so a polyfill is used taken
from the msl spec.

`asinh` = `log(x + sqrt(x * x + 1.0))`
`acosh` = `log(x + sqrt(x * x - 1.0))`
`atanh` = `0.5 * log((1.0 + x) / (1.0 – x))`
This commit is contained in:
João Capucho
2021-08-20 17:11:55 +01:00
committed by Dzmitry Malyshau
parent 644fa684ba
commit bbf3e465f3
17 changed files with 454 additions and 135 deletions

View File

@@ -2242,6 +2242,9 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Acos => "acos",
Mf::Asin => "asin",
Mf::Atan => "atan",
Mf::Asinh => "asinh",
Mf::Acosh => "acosh",
Mf::Atanh => "atanh",
// glsl doesn't have atan2 function
// use two-argument variation of the atan function
Mf::Atan2 => "atan",

View File

@@ -1635,76 +1635,105 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
} => {
use crate::MathFunction as Mf;
let fun_name = match fun {
// comparison
Mf::Abs => "abs",
Mf::Min => "min",
Mf::Max => "max",
Mf::Clamp => "clamp",
// trigonometry
Mf::Cos => "cos",
Mf::Cosh => "cosh",
Mf::Sin => "sin",
Mf::Sinh => "sinh",
Mf::Tan => "tan",
Mf::Tanh => "tanh",
Mf::Acos => "acos",
Mf::Asin => "asin",
Mf::Atan => "atan",
Mf::Atan2 => "atan2",
// decomposition
Mf::Ceil => "ceil",
Mf::Floor => "floor",
Mf::Round => "round",
Mf::Fract => "frac",
Mf::Trunc => "trunc",
Mf::Modf => "modf",
Mf::Frexp => "frexp",
Mf::Ldexp => "ldexp",
// exponent
Mf::Exp => "exp",
Mf::Exp2 => "exp2",
Mf::Log => "log",
Mf::Log2 => "log2",
Mf::Pow => "pow",
// geometry
Mf::Dot => "dot",
//Mf::Outer => ,
Mf::Cross => "cross",
Mf::Distance => "distance",
Mf::Length => "length",
Mf::Normalize => "normalize",
Mf::FaceForward => "faceforward",
Mf::Reflect => "reflect",
Mf::Refract => "refract",
// computational
Mf::Sign => "sign",
Mf::Fma => "fma",
Mf::Mix => "lerp",
Mf::Step => "step",
Mf::SmoothStep => "smoothstep",
Mf::Sqrt => "sqrt",
Mf::InverseSqrt => "rsqrt",
//Mf::Inverse =>,
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
// bits
Mf::CountOneBits => "countbits",
Mf::ReverseBits => "reversebits",
_ => return Err(Error::Unimplemented(format!("write_expr_math {:?}", fun))),
};
match fun {
Mf::Asinh | Mf::Acosh => {
write!(self.out, "log(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " + sqrt(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " * ")?;
self.write_expr(module, arg, func_ctx)?;
if fun == Mf::Asinh {
write!(self.out, " + 1.0))")?
} else {
write!(self.out, " - 1.0))")?
}
}
Mf::Atanh => {
write!(self.out, "0.5 * log((1.0 + ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ") / (1.0 - ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
}
_ => {
let fun_name = match fun {
// comparison
Mf::Abs => "abs",
Mf::Min => "min",
Mf::Max => "max",
Mf::Clamp => "clamp",
// trigonometry
Mf::Cos => "cos",
Mf::Cosh => "cosh",
Mf::Sin => "sin",
Mf::Sinh => "sinh",
Mf::Tan => "tan",
Mf::Tanh => "tanh",
Mf::Acos => "acos",
Mf::Asin => "asin",
Mf::Atan => "atan",
Mf::Atan2 => "atan2",
// decomposition
Mf::Ceil => "ceil",
Mf::Floor => "floor",
Mf::Round => "round",
Mf::Fract => "frac",
Mf::Trunc => "trunc",
Mf::Modf => "modf",
Mf::Frexp => "frexp",
Mf::Ldexp => "ldexp",
// exponent
Mf::Exp => "exp",
Mf::Exp2 => "exp2",
Mf::Log => "log",
Mf::Log2 => "log2",
Mf::Pow => "pow",
// geometry
Mf::Dot => "dot",
//Mf::Outer => ,
Mf::Cross => "cross",
Mf::Distance => "distance",
Mf::Length => "length",
Mf::Normalize => "normalize",
Mf::FaceForward => "faceforward",
Mf::Reflect => "reflect",
Mf::Refract => "refract",
// computational
Mf::Sign => "sign",
Mf::Fma => "fma",
Mf::Mix => "lerp",
Mf::Step => "step",
Mf::SmoothStep => "smoothstep",
Mf::Sqrt => "sqrt",
Mf::InverseSqrt => "rsqrt",
//Mf::Inverse =>,
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
// bits
Mf::CountOneBits => "countbits",
Mf::ReverseBits => "reversebits",
_ => {
return Err(Error::Unimplemented(format!(
"write_expr_math {:?}",
fun
)))
}
};
write!(self.out, "{}(", fun_name)?;
self.write_expr(module, arg, func_ctx)?;
if let Some(arg) = arg1 {
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "{}(", fun_name)?;
self.write_expr(module, arg, func_ctx)?;
if let Some(arg) = arg1 {
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
}
if let Some(arg) = arg2 {
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
}
write!(self.out, ")")?
}
}
if let Some(arg) = arg2 {
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
}
write!(self.out, ")")?
}
Expression::Swizzle {
size,

View File

@@ -1140,6 +1140,9 @@ impl<W: Write> Writer<W> {
Mf::Asin => "asin",
Mf::Atan => "atan",
Mf::Atan2 => "atan2",
Mf::Asinh => "asinh",
Mf::Acosh => "acosh",
Mf::Atanh => "atanh",
// decomposition
Mf::Ceil => "ceil",
Mf::Floor => "floor",

View File

@@ -558,6 +558,9 @@ impl<'w> BlockContext<'w> {
Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh),
Mf::Atan => MathOp::Ext(spirv::GLOp::Atan),
Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2),
Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh),
Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh),
Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh),
// decomposition
Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil),
Mf::Round => MathOp::Ext(spirv::GLOp::Round),

View File

@@ -1240,75 +1240,101 @@ impl<W: Write> Writer<W> {
} => {
use crate::MathFunction as Mf;
let fun_name = match fun {
Mf::Abs => "abs",
Mf::Min => "min",
Mf::Max => "max",
Mf::Clamp => "clamp",
// trigonometry
Mf::Cos => "cos",
Mf::Cosh => "cosh",
Mf::Sin => "sin",
Mf::Sinh => "sinh",
Mf::Tan => "tan",
Mf::Tanh => "tanh",
Mf::Acos => "acos",
Mf::Asin => "asin",
Mf::Atan => "atan",
Mf::Atan2 => "atan2",
// decomposition
Mf::Ceil => "ceil",
Mf::Floor => "floor",
Mf::Round => "round",
Mf::Fract => "fract",
Mf::Trunc => "trunc",
Mf::Modf => "modf",
Mf::Frexp => "frexp",
Mf::Ldexp => "ldexp",
// exponent
Mf::Exp => "exp",
Mf::Exp2 => "exp2",
Mf::Log => "log",
Mf::Log2 => "log2",
Mf::Pow => "pow",
// geometry
Mf::Dot => "dot",
Mf::Outer => "outerProduct",
Mf::Cross => "cross",
Mf::Distance => "distance",
Mf::Length => "length",
Mf::Normalize => "normalize",
Mf::FaceForward => "faceForward",
Mf::Reflect => "reflect",
// computational
Mf::Sign => "sign",
Mf::Fma => "fma",
Mf::Mix => "mix",
Mf::Step => "step",
Mf::SmoothStep => "smoothStep",
Mf::Sqrt => "sqrt",
Mf::InverseSqrt => "inverseSqrt",
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
// bits
Mf::CountOneBits => "countOneBits",
Mf::ReverseBits => "reverseBits",
_ => {
return Err(Error::UnsupportedMathFunction(fun));
// NOTE: If https://github.com/gpuweb/gpuweb/issues/1622 ever is
// accepted, replace this with the builtin functions
match fun {
Mf::Asinh | Mf::Acosh => {
write!(self.out, "log(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " + sqrt(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " * ")?;
self.write_expr(module, arg, func_ctx)?;
if fun == Mf::Asinh {
write!(self.out, " + 1.0))")?
} else {
write!(self.out, " - 1.0))")?
}
}
};
Mf::Atanh => {
write!(self.out, "0.5 * log((1.0 + ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ") / (1.0 - ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
}
_ => {
let fun_name = match fun {
Mf::Abs => "abs",
Mf::Min => "min",
Mf::Max => "max",
Mf::Clamp => "clamp",
// trigonometry
Mf::Cos => "cos",
Mf::Cosh => "cosh",
Mf::Sin => "sin",
Mf::Sinh => "sinh",
Mf::Tan => "tan",
Mf::Tanh => "tanh",
Mf::Acos => "acos",
Mf::Asin => "asin",
Mf::Atan => "atan",
Mf::Atan2 => "atan2",
// decomposition
Mf::Ceil => "ceil",
Mf::Floor => "floor",
Mf::Round => "round",
Mf::Fract => "fract",
Mf::Trunc => "trunc",
Mf::Modf => "modf",
Mf::Frexp => "frexp",
Mf::Ldexp => "ldexp",
// exponent
Mf::Exp => "exp",
Mf::Exp2 => "exp2",
Mf::Log => "log",
Mf::Log2 => "log2",
Mf::Pow => "pow",
// geometry
Mf::Dot => "dot",
Mf::Outer => "outerProduct",
Mf::Cross => "cross",
Mf::Distance => "distance",
Mf::Length => "length",
Mf::Normalize => "normalize",
Mf::FaceForward => "faceForward",
Mf::Reflect => "reflect",
// computational
Mf::Sign => "sign",
Mf::Fma => "fma",
Mf::Mix => "mix",
Mf::Step => "step",
Mf::SmoothStep => "smoothStep",
Mf::Sqrt => "sqrt",
Mf::InverseSqrt => "inverseSqrt",
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
// bits
Mf::CountOneBits => "countOneBits",
Mf::ReverseBits => "reverseBits",
_ => {
return Err(Error::UnsupportedMathFunction(fun));
}
};
write!(self.out, "{}(", fun_name)?;
self.write_expr(module, arg, func_ctx)?;
if let Some(arg) = arg1 {
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "{}(", fun_name)?;
self.write_expr(module, arg, func_ctx)?;
if let Some(arg) = arg1 {
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
}
if let Some(arg) = arg2 {
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
}
write!(self.out, ")")?
}
}
if let Some(arg) = arg2 {
write!(self.out, ", ")?;
self.write_expr(module, arg, func_ctx)?;
}
write!(self.out, ")")?
}
Expression::Swizzle {
size,

View File

@@ -451,7 +451,7 @@ pub fn inject_builtin(declaration: &mut FunctionDeclaration, module: &mut Module
}
}
"sin" | "exp" | "exp2" | "sinh" | "cos" | "cosh" | "tan" | "tanh" | "acos" | "asin"
| "log" | "log2" | "radians" | "degrees" => {
| "log" | "log2" | "radians" | "degrees" | "asinh" | "acosh" | "atanh" => {
// bits layout
// bit 0 trough 1 - dims
for bits in 0..(0b100) {
@@ -481,6 +481,9 @@ pub fn inject_builtin(declaration: &mut FunctionDeclaration, module: &mut Module
"asin" => MacroCall::MathFunction(MathFunction::Asin),
"log" => MacroCall::MathFunction(MathFunction::Log),
"log2" => MacroCall::MathFunction(MathFunction::Log2),
"asinh" => MacroCall::MathFunction(MathFunction::Asinh),
"acosh" => MacroCall::MathFunction(MathFunction::Acosh),
"atanh" => MacroCall::MathFunction(MathFunction::Atanh),
"radians" => MacroCall::ConstMultiply(std::f64::consts::PI / 180.0),
"degrees" => MacroCall::ConstMultiply(180.0 / std::f64::consts::PI),
_ => unreachable!(),

View File

@@ -1984,6 +1984,9 @@ impl<I: Iterator<Item = u32>> Parser<I> {
Glo::Cosh => Mf::Cosh,
Glo::Tanh => Mf::Tanh,
Glo::Atan2 => Mf::Atan2,
Glo::Asinh => Mf::Asinh,
Glo::Acosh => Mf::Acosh,
Glo::Atanh => Mf::Atanh,
Glo::Pow => Mf::Pow,
Glo::Exp => Mf::Exp,
Glo::Log => Mf::Log,

View File

@@ -802,6 +802,9 @@ pub enum MathFunction {
Asin,
Atan,
Atan2,
Asinh,
Acosh,
Atanh,
// decomposition
Ceil,
Floor,

View File

@@ -151,6 +151,9 @@ impl super::MathFunction {
Self::Asin => 1,
Self::Atan => 1,
Self::Atan2 => 2,
Self::Asinh => 1,
Self::Acosh => 1,
Self::Atanh => 1,
// decomposition
Self::Ceil => 1,
Self::Floor => 1,

View File

@@ -554,6 +554,9 @@ impl<'a> ResolveContext<'a> {
Mf::Asin |
Mf::Atan |
Mf::Atan2 |
Mf::Asinh |
Mf::Acosh |
Mf::Atanh |
// decomposition
Mf::Ceil |
Mf::Floor |

View File

@@ -931,6 +931,9 @@ impl super::Validator {
| Mf::Acos
| Mf::Asin
| Mf::Atan
| Mf::Asinh
| Mf::Acosh
| Mf::Atanh
| Mf::Ceil
| Mf::Floor
| Mf::Round

Binary file not shown.

View File

@@ -0,0 +1,22 @@
static float a = (float)0;
void main1()
{
float b = (float)0;
float c = (float)0;
float d = (float)0;
float _expr8 = a;
b = log(_expr8 + sqrt(_expr8 * _expr8 + 1.0));
float _expr10 = a;
c = log(_expr10 + sqrt(_expr10 * _expr10 - 1.0));
float _expr12 = a;
d = 0.5 * log((1.0 + _expr12) / (1.0 - _expr12));
return;
}
void main()
{
main1();
}

View File

@@ -0,0 +1,3 @@
vertex=(main:vs_5_1 )
fragment=()
compute=()

View File

@@ -0,0 +1,182 @@
(
types: [
(
name: None,
inner: Scalar(
kind: Float,
width: 4,
),
),
(
name: None,
inner: Pointer(
base: 1,
class: Function,
),
),
(
name: None,
inner: Pointer(
base: 1,
class: Private,
),
),
],
constants: [
(
name: None,
specialization: None,
inner: Scalar(
width: 4,
value: Sint(0),
),
),
(
name: None,
specialization: None,
inner: Scalar(
width: 4,
value: Sint(1),
),
),
(
name: None,
specialization: None,
inner: Scalar(
width: 4,
value: Sint(2),
),
),
(
name: None,
specialization: None,
inner: Scalar(
width: 4,
value: Sint(3),
),
),
],
global_variables: [
(
name: Some("a"),
class: Private,
binding: None,
ty: 1,
init: None,
),
],
functions: [
(
name: Some("main"),
arguments: [],
result: None,
local_variables: [
(
name: Some("b"),
ty: 1,
init: None,
),
(
name: Some("c"),
ty: 1,
init: None,
),
(
name: Some("d"),
ty: 1,
init: None,
),
],
expressions: [
GlobalVariable(1),
Constant(1),
Constant(2),
Constant(3),
Constant(4),
LocalVariable(1),
LocalVariable(2),
LocalVariable(3),
Load(
pointer: 1,
),
Math(
fun: Asinh,
arg: 9,
arg1: None,
arg2: None,
),
Load(
pointer: 1,
),
Math(
fun: Acosh,
arg: 11,
arg1: None,
arg2: None,
),
Load(
pointer: 1,
),
Math(
fun: Atanh,
arg: 13,
arg1: None,
arg2: None,
),
],
named_expressions: {},
body: [
Emit((
start: 8,
end: 10,
)),
Store(
pointer: 6,
value: 10,
),
Emit((
start: 10,
end: 12,
)),
Store(
pointer: 7,
value: 12,
),
Emit((
start: 12,
end: 14,
)),
Store(
pointer: 8,
value: 14,
),
Return(
value: None,
),
],
),
],
entry_points: [
(
name: "main",
stage: Vertex,
early_depth_test: None,
workgroup_size: (0, 0, 0),
function: (
name: Some("main_wrap"),
arguments: [],
result: None,
local_variables: [],
expressions: [],
named_expressions: {},
body: [
Call(
function: 1,
arguments: [],
result: None,
),
],
),
),
],
)

View File

@@ -0,0 +1,20 @@
var<private> a: f32;
fn main1() {
var b: f32;
var c: f32;
var d: f32;
let _e8: f32 = a;
b = log(_e8 + sqrt(_e8 * _e8 + 1.0));
let _e10: f32 = a;
c = log(_e10 + sqrt(_e10 * _e10 - 1.0));
let _e12: f32 = a;
d = 0.5 * log((1.0 + _e12) / (1.0 - _e12));
return;
}
[[stage(vertex)]]
fn main() {
main1();
}

View File

@@ -498,6 +498,16 @@ fn convert_spv_shadow() {
convert_spv("shadow", true, Targets::IR | Targets::ANALYSIS);
}
#[cfg(feature = "spv-in")]
#[test]
fn convert_spv_inverse_hyperbolic_trig_functions() {
convert_spv(
"inv-hyperbolic-trig-functions",
true,
Targets::HLSL | Targets::WGSL | Targets::IR,
);
}
#[cfg(all(feature = "spv-in", feature = "spv-out"))]
#[test]
fn convert_spv_pointer_access() {