diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index a6e117772b..31533615ab 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -1941,6 +1941,7 @@ impl<'a, W: Write> Writer<'a, W> { Mf::Normalize => "normalize", Mf::FaceForward => "faceforward", Mf::Reflect => "reflect", + Mf::Refract => "refract", // computational Mf::Sign => "sign", Mf::Fma => "fma", diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 0f9ed3c2b1..eeb4afe0ca 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1023,6 +1023,7 @@ impl Writer { Mf::Normalize => "normalize", Mf::FaceForward => "faceforward", Mf::Reflect => "reflect", + Mf::Refract => "refract", // computational Mf::Sign => "sign", Mf::Fma => "fma", diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 42065ae029..d20fe48bb1 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1821,6 +1821,7 @@ impl Writer { Mf::Normalize => MathOp::Ext(spirv::GLOp::Normalize), Mf::FaceForward => MathOp::Ext(spirv::GLOp::FaceForward), Mf::Reflect => MathOp::Ext(spirv::GLOp::Reflect), + Mf::Refract => MathOp::Ext(spirv::GLOp::Refract), // exponent Mf::Exp => MathOp::Ext(spirv::GLOp::Exp), Mf::Exp2 => MathOp::Ext(spirv::GLOp::Exp2), diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index b33db13603..ccea724d5e 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -1737,6 +1737,7 @@ impl> Parser { Glo::Normalize => Mf::Normalize, Glo::FaceForward => Mf::FaceForward, Glo::Reflect => Mf::Reflect, + Glo::Refract => Mf::Refract, _ => return Err(Error::UnsupportedExtInst(inst_id)), }; diff --git a/src/lib.rs b/src/lib.rs index ce0897ae37..270fbf974c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -630,6 +630,7 @@ pub enum MathFunction { Normalize, FaceForward, Reflect, + Refract, // computational Sign, Fma, diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 27f5039f4e..14eb33df9e 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -150,6 +150,7 @@ impl super::MathFunction { Self::Normalize => 1, Self::FaceForward => 3, Self::Reflect => 2, + Self::Refract => 3, // computational Self::Sign => 1, Self::Fma => 3, diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 86a85e2b0e..4ba71949b2 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -561,7 +561,8 @@ impl<'a> ResolveContext<'a> { }, Mf::Normalize | Mf::FaceForward | - Mf::Reflect => res_arg.clone(), + Mf::Reflect | + Mf::Refract => res_arg.clone(), // computational Mf::Sign | Mf::Fma | diff --git a/src/valid/expression.rs b/src/valid/expression.rs index dcde1d087c..60bb169011 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -964,6 +964,47 @@ impl super::Validator { )); } } + Mf::Refract => { + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty) { + (Some(ty1), Some(ty2)) => (ty1, ty2), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + + match *arg_ty { + Ti::Vector { + kind: Sk::Float, .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + + match (arg_ty, arg2_ty) { + ( + &Ti::Vector { + width: vector_width, + .. + }, + &Ti::Scalar { + width: scalar_width, + kind: Sk::Float, + }, + ) if vector_width == scalar_width => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )) + } + } + } Mf::Normalize => { if arg1_ty.is_some() | arg2_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun));