diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 39c1ab7490..0b0d115c57 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -1161,7 +1161,7 @@ impl super::Validator { )); } } - Mf::Outer | Mf::Cross | Mf::Reflect => { + Mf::Outer | Mf::Reflect => { let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { (Some(ty1), None, None) => ty1, _ => return Err(ExpressionError::WrongArgumentCount(fun)), @@ -1172,8 +1172,31 @@ impl super::Validator { Sc { kind: Sk::Float, .. }, - size: vector_size, - } if fun != Mf::Cross || vector_size == crate::VectorSize::Tri => {} + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + } + Mf::Cross => { + let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), None, None) => ty1, + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Vector { + scalar: + Sc { + kind: Sk::Float, .. + }, + size: crate::VectorSize::Tri, + } => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } if arg1_ty != arg_ty {