diff --git a/src/front/glsl/context.rs b/src/front/glsl/context.rs index 6df233aa97..ce00a3b181 100644 --- a/src/front/glsl/context.rs +++ b/src/front/glsl/context.rs @@ -579,8 +579,112 @@ impl Context { let right_inner = self.typifier.get(right, &parser.module.types); match (left_inner, right_inner) { - (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) - | (&TypeInner::Matrix { .. }, &TypeInner::Matrix { .. }) => match op { + ( + &TypeInner::Matrix { + columns: left_columns, + rows: left_rows, + .. + }, + &TypeInner::Matrix { + columns: right_columns, + rows: right_rows, + .. + }, + ) => { + // Check that the two arguments have the same dimensions + if left_columns != right_columns || left_rows != right_rows { + parser.errors.push(Error { + kind: ErrorKind::SemanticError( + format!( + "Cannot apply operation to {:?} and {:?}", + left_inner, right_inner + ) + .into(), + ), + meta, + }) + } + + match op { + BinaryOperator::Equal | BinaryOperator::NotEqual => { + // Naga IR doesn't support matrix comparisons so we need to + // compare the columns individually and then fold them together + // + // The folding is done using a logical and for equality and + // a logical or for inequality + let equals = op == BinaryOperator::Equal; + + let (op, combine, fun) = match equals { + true => ( + BinaryOperator::Equal, + BinaryOperator::LogicalAnd, + RelationalFunction::All, + ), + false => ( + BinaryOperator::NotEqual, + BinaryOperator::LogicalOr, + RelationalFunction::Any, + ), + }; + + let mut root = None; + + for index in 0..left_columns as u32 { + // Get the column vectors + let left_vector = self.add_expression( + Expression::AccessIndex { base: left, index }, + meta, + body, + ); + let right_vector = self.add_expression( + Expression::AccessIndex { base: right, index }, + meta, + body, + ); + + let argument = self.expressions.append( + Expression::Binary { + op, + left: left_vector, + right: right_vector, + }, + meta, + ); + + // The result of comparing two vectors is a boolean vector + // so use a relational function like all to get a single + // boolean value + let compare = self.add_expression( + Expression::Relational { fun, argument }, + meta, + body, + ); + + // Fold the result + root = Some(match root { + Some(right) => self.add_expression( + Expression::Binary { + op: combine, + left: compare, + right, + }, + meta, + body, + ), + None => compare, + }); + } + + root.unwrap() + } + _ => self.add_expression( + Expression::Binary { left, op, right }, + meta, + body, + ), + } + } + (&TypeInner::Vector { .. }, &TypeInner::Vector { .. }) => match op { BinaryOperator::Equal | BinaryOperator::NotEqual => { let equals = op == BinaryOperator::Equal; diff --git a/tests/in/glsl/expressions.frag b/tests/in/glsl/expressions.frag index f1a959f47f..37e9fe78c2 100644 --- a/tests/in/glsl/expressions.frag +++ b/tests/in/glsl/expressions.frag @@ -62,6 +62,16 @@ void testBinOpUintUVec(uint a, uvec4 b) { v = a ^ b; } +void testBinOpMatMat(mat3 a, mat3 b) { + mat3 v; + bool c; + v = a * b; + v = a + b; + v = a - b; + c = a == b; + c = a != b; +} + void testStructConstructor() { struct BST { int data; diff --git a/tests/out/wgsl/expressions-frag.wgsl b/tests/out/wgsl/expressions-frag.wgsl index 0b3f7ad46e..e9c5fdb166 100644 --- a/tests/out/wgsl/expressions-frag.wgsl +++ b/tests/out/wgsl/expressions-frag.wgsl @@ -179,6 +179,32 @@ fn testBinOpUintUVec(a_10: u32, b_10: vec4) { return; } +fn testBinOpMatMat(a_12: mat3x3, b_12: mat3x3) { + var a_13: mat3x3; + var b_13: mat3x3; + var v_6: mat3x3; + var c: bool; + + a_13 = a_12; + b_13 = b_12; + let _e6 = a_13; + let _e7 = b_13; + v_6 = (_e6 * _e7); + let _e9 = a_13; + let _e10 = b_13; + v_6 = (_e9 + _e10); + let _e12 = a_13; + let _e13 = b_13; + v_6 = (_e12 - _e13); + let _e15 = a_13; + let _e16 = b_13; + c = (all((_e15[2] == _e16[2])) && (all((_e15[1] == _e16[1])) && all((_e15[0] == _e16[0])))); + let _e31 = a_13; + let _e32 = b_13; + c = (any((_e31[2] != _e32[2])) || (any((_e31[1] != _e32[1])) || any((_e31[0] != _e32[0])))); + return; +} + fn testStructConstructor() { var tree: BST = BST(1); @@ -194,7 +220,7 @@ fn testArrayConstructor() { } -fn privatePointer(a_12: ptr) { +fn privatePointer(a_14: ptr) { return; }