diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index cfd53cf380..4762b79759 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -1456,7 +1456,9 @@ impl<'a, W: Write> Writer<'a, W> { write!(self.out, ")")? } - // `Binary` we just write `left op right` + // `Binary` we just write `left op right`, except when dealing with + // comparison operations on vectors as they are implemented with + // builtin functions. // Once again we wrap everything in parantheses to avoid precedence issues Expression::Binary { op, left, right } => { // Holds `Some(function_name)` if the binary operation is @@ -1470,6 +1472,8 @@ impl<'a, W: Write> Writer<'a, W> { BinaryOperator::LessEqual => Some("lessThanEqual"), BinaryOperator::Greater => Some("greaterThan"), BinaryOperator::GreaterEqual => Some("greaterThanEqual"), + BinaryOperator::Equal => Some("equal"), + BinaryOperator::NotEqual => Some("notEqual"), _ => None, } } else { diff --git a/src/front/glsl/ast.rs b/src/front/glsl/ast.rs index 7962f59b37..157e72c0b9 100644 --- a/src/front/glsl/ast.rs +++ b/src/front/glsl/ast.rs @@ -2,8 +2,8 @@ use super::{constants::ConstantSolver, error::ErrorKind}; use crate::{ proc::{ResolveContext, Typifier}, Arena, BinaryOperator, Binding, Constant, Expression, FastHashMap, Function, FunctionArgument, - GlobalVariable, Handle, Interpolation, LocalVariable, Module, ShaderStage, Statement, - StorageClass, Type, + GlobalVariable, Handle, Interpolation, LocalVariable, Module, RelationalFunction, ShaderStage, + Statement, StorageClass, Type, }; #[derive(Debug)] @@ -57,6 +57,49 @@ impl Program { })) } + /// Helper function to insert equality expressions, this handles the special + /// case of `vec1 == vec2` and `vec1 != vec2` since in the IR they are + /// represented as `all(equal(vec1, vec2))` and `any(notEqual(vec1, vec2))` + pub fn equality_expr( + &mut self, + equals: bool, + left: &ExpressionRule, + right: &ExpressionRule, + ) -> Result { + let left_is_vector = match self.resolve_type(left.expression)? { + crate::TypeInner::Vector { .. } => true, + _ => false, + }; + + let rigth_is_vector = match self.resolve_type(right.expression)? { + crate::TypeInner::Vector { .. } => true, + _ => false, + }; + + let (op, fun) = match equals { + true => (BinaryOperator::Equal, RelationalFunction::All), + false => (BinaryOperator::NotEqual, RelationalFunction::Any), + }; + + let expr = + ExpressionRule::from_expression(self.context.expressions.append(Expression::Binary { + op, + left: left.expression, + right: right.expression, + })); + + Ok(if left_is_vector && rigth_is_vector { + ExpressionRule::from_expression(self.context.expressions.append( + Expression::Relational { + fun, + argument: expr.expression, + }, + )) + } else { + expr + }) + } + pub fn resolve_type( &mut self, handle: Handle, diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index 3b6b880a96..7238cac128 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -146,7 +146,8 @@ impl Program { statements: fc.args.into_iter().flat_map(|a| a.statements).collect(), }) } - "lessThan" | "greaterThan" => { + "lessThan" | "greaterThan" | "lessThanEqual" | "greaterThanEqual" | "equal" + | "notEqual" => { if fc.args.len() != 2 { return Err(ErrorKind::WrongNumberArgs(name, 2, fc.args.len())); } @@ -155,6 +156,10 @@ impl Program { op: match name.as_str() { "lessThan" => BinaryOperator::Less, "greaterThan" => BinaryOperator::Greater, + "lessThanEqual" => BinaryOperator::LessEqual, + "greaterThanEqual" => BinaryOperator::GreaterEqual, + "equal" => BinaryOperator::Equal, + "notEqual" => BinaryOperator::NotEqual, _ => unreachable!(), }, left: fc.args[0].expression, diff --git a/src/front/glsl/parser.rs b/src/front/glsl/parser.rs index b9c889cd76..dc48cea8be 100644 --- a/src/front/glsl/parser.rs +++ b/src/front/glsl/parser.rs @@ -367,10 +367,10 @@ pomelo! { } equality_expression ::= relational_expression; equality_expression ::= equality_expression(left) EqOp relational_expression(right) { - extra.binary_expr(BinaryOperator::Equal, &left, &right) + extra.equality_expr(true, &left, &right)? } equality_expression ::= equality_expression(left) NeOp relational_expression(right) { - extra.binary_expr(BinaryOperator::NotEqual, &left, &right) + extra.equality_expr(false, &left, &right)? } and_expression ::= equality_expression; and_expression ::= and_expression(left) Ampersand equality_expression(right) {