From e461d30865150a338c41aa8b57b4b859db5ca7da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Capucho?= Date: Sat, 28 May 2022 00:08:02 +0100 Subject: [PATCH] glsl-in: Fix matrix multiplication check The previous check compared rows to rows and columns to columns but multiplication of matrices only needs the columns of the left matrix to be equal to the rows of the right matrix. --- src/front/glsl/context.rs | 11 +++++++---- tests/in/glsl/expressions.frag | 4 ++++ tests/out/wgsl/expressions-frag.wgsl | 13 +++++++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/front/glsl/context.rs b/src/front/glsl/context.rs index ed60759082..13eb5d180a 100644 --- a/src/front/glsl/context.rs +++ b/src/front/glsl/context.rs @@ -617,11 +617,14 @@ impl Context { width: right_width, }, ) => { + let dimensions_ok = if op == BinaryOperator::Multiply { + left_columns == right_rows + } else { + left_columns == right_columns && left_rows == right_rows + }; + // Check that the two arguments have the same dimensions - if left_columns != right_columns - || left_rows != right_rows - || left_width != right_width - { + if !dimensions_ok || left_width != right_width { parser.errors.push(Error { kind: ErrorKind::SemanticError( format!( diff --git a/tests/in/glsl/expressions.frag b/tests/in/glsl/expressions.frag index 8dd07c6525..acf0ea9213 100644 --- a/tests/in/glsl/expressions.frag +++ b/tests/in/glsl/expressions.frag @@ -128,6 +128,10 @@ void ternary(bool a) { uint nested = a ? (a ? (a ? 2u : 3) : 4u) : 5; } +void testMatrixMultiplication(mat4x3 a, mat4x4 b) { + mat4x3 c = a * b; +} + out vec4 o_color; void main() { privatePointer(global); diff --git a/tests/out/wgsl/expressions-frag.wgsl b/tests/out/wgsl/expressions-frag.wgsl index b4364e597c..7cf1eace0e 100644 --- a/tests/out/wgsl/expressions-frag.wgsl +++ b/tests/out/wgsl/expressions-frag.wgsl @@ -356,6 +356,19 @@ fn ternary(a_20: bool) { return; } +fn testMatrixMultiplication(a_22: mat4x3, b_18: mat4x4) { + var a_23: mat4x3; + var b_19: mat4x4; + var c_2: mat4x3; + + a_23 = a_22; + b_19 = b_18; + let _e5 = a_23; + let _e6 = b_19; + c_2 = (_e5 * _e6); + return; +} + fn main_1() { var local_5: f32;