From 88eee833d77f2e9378c21ffa3b43e7e1ea80315a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Capucho?= Date: Fri, 5 Feb 2021 20:39:23 +0000 Subject: [PATCH] Add binary op support for the ConstantSolver --- src/lib.rs | 2 +- src/proc/constants.rs | 128 +++++++++++++++++++++++++++++++++++++----- 2 files changed, 116 insertions(+), 14 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 86f9cb5bb8..f2febc9ac3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -401,7 +401,7 @@ pub struct Constant { } /// A literal scalar value, used in constants. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Clone, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum ScalarValue { diff --git a/src/proc/constants.rs b/src/proc/constants.rs index f5bf3ea70f..14f748a0fa 100644 --- a/src/proc/constants.rs +++ b/src/proc/constants.rs @@ -1,6 +1,7 @@ use crate::{ arena::{Arena, Handle}, - Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, UnaryOperator, + BinaryOperator, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner, + UnaryOperator, }; #[derive(Debug)] @@ -46,6 +47,8 @@ pub enum ConstantSolvingError { InvalidCastArg, #[error("Cannot apply the unary op to the argument")] InvalidUnaryOpArg, + #[error("Cannot apply the binary op to the arguments")] + InvalidBinaryOpArgs, } impl<'a> ConstantSolver<'a> { @@ -78,7 +81,12 @@ impl<'a> ConstantSolver<'a> { self.unary_op(op, tgt) } - Expression::Binary { .. } => todo!(), + Expression::Binary { left, right, op } => { + let left = self.solve(left)?; + let right = self.solve(right)?; + + self.binary_op(op, left, right) + } Expression::Math { .. } => todo!(), Expression::As { convert, @@ -97,11 +105,11 @@ impl<'a> ConstantSolver<'a> { let array = self.solve(expr)?; match self.constants[array].inner { - crate::ConstantInner::Scalar { .. } => { + ConstantInner::Scalar { .. } => { Err(ConstantSolvingError::InvalidArrayLengthArg) } - crate::ConstantInner::Composite { ty, .. } => match self.types[ty].inner { - crate::TypeInner::Array { size, .. } => match size { + ConstantInner::Composite { ty, .. } => match self.types[ty].inner { + TypeInner::Array { size, .. } => match size { crate::ArraySize::Constant(constant) => Ok(constant), crate::ArraySize::Dynamic => { Err(ConstantSolvingError::ArrayLengthDynamic) @@ -134,13 +142,13 @@ impl<'a> ConstantSolver<'a> { let base = self.solve(base)?; match self.constants[base].inner { - crate::ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase), - crate::ConstantInner::Composite { ty, ref components } => { + ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase), + ConstantInner::Composite { ty, ref components } => { match self.types[ty].inner { - crate::TypeInner::Vector { .. } - | crate::TypeInner::Matrix { .. } - | crate::TypeInner::Array { .. } - | crate::TypeInner::Struct { .. } => (), + TypeInner::Vector { .. } + | TypeInner::Matrix { .. } + | TypeInner::Array { .. } + | TypeInner::Struct { .. } => (), _ => return Err(ConstantSolvingError::InvalidAccessBase), } @@ -195,7 +203,7 @@ impl<'a> ConstantSolver<'a> { ref mut components, } => { match self.types[ty].inner { - crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => (), + TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (), _ => return Err(ConstantSolvingError::InvalidCastArg), } @@ -238,7 +246,7 @@ impl<'a> ConstantSolver<'a> { ref mut components, } => { match self.types[ty].inner { - crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => (), + TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (), _ => return Err(ConstantSolvingError::InvalidCastArg), } @@ -254,6 +262,100 @@ impl<'a> ConstantSolver<'a> { inner, })) } + + fn binary_op( + &mut self, + op: BinaryOperator, + left: Handle, + right: Handle, + ) -> Result, ConstantSolvingError> { + let left = &self.constants[left].inner; + let right = &self.constants[right].inner; + + let inner = match (left, right) { + ( + ConstantInner::Scalar { + value: left_value, + width, + }, + ConstantInner::Scalar { + value: right_value, .. + }, + ) => { + let value = match op { + BinaryOperator::Equal => ScalarValue::Bool(left_value == right_value), + BinaryOperator::NotEqual => ScalarValue::Bool(left_value != right_value), + BinaryOperator::Less => ScalarValue::Bool(left_value < right_value), + BinaryOperator::LessEqual => ScalarValue::Bool(left_value <= right_value), + BinaryOperator::Greater => ScalarValue::Bool(left_value > right_value), + BinaryOperator::GreaterEqual => ScalarValue::Bool(left_value >= right_value), + + _ => match (left_value, right_value) { + (ScalarValue::Sint(a), ScalarValue::Sint(b)) => { + ScalarValue::Sint(match op { + BinaryOperator::Add => a + b, + BinaryOperator::Subtract => a - b, + BinaryOperator::Multiply => a * b, + BinaryOperator::Divide => a / b, + BinaryOperator::Modulo => a % b, + BinaryOperator::And => a & b, + BinaryOperator::ExclusiveOr => a ^ b, + BinaryOperator::InclusiveOr => a | b, + BinaryOperator::ShiftLeft => a << b, + BinaryOperator::ShiftRight => a >> b, + _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), + }) + } + (ScalarValue::Uint(a), ScalarValue::Uint(b)) => { + ScalarValue::Uint(match op { + BinaryOperator::Add => a + b, + BinaryOperator::Subtract => a - b, + BinaryOperator::Multiply => a * b, + BinaryOperator::Divide => a / b, + BinaryOperator::Modulo => a % b, + BinaryOperator::And => a & b, + BinaryOperator::ExclusiveOr => a ^ b, + BinaryOperator::InclusiveOr => a | b, + BinaryOperator::ShiftLeft => a << b, + BinaryOperator::ShiftRight => a >> b, + _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), + }) + } + (ScalarValue::Float(a), ScalarValue::Float(b)) => { + ScalarValue::Float(match op { + BinaryOperator::Add => a + b, + BinaryOperator::Subtract => a - b, + BinaryOperator::Multiply => a * b, + BinaryOperator::Divide => a / b, + BinaryOperator::Modulo => a % b, + _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), + }) + } + (ScalarValue::Bool(a), ScalarValue::Bool(b)) => { + ScalarValue::Bool(match op { + BinaryOperator::LogicalAnd => *a && *b, + BinaryOperator::LogicalOr => *a || *b, + _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), + }) + } + _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), + }, + }; + + ConstantInner::Scalar { + value, + width: *width, + } + } + _ => return Err(ConstantSolvingError::InvalidBinaryOpArgs), + }; + + Ok(self.constants.fetch_or_append(Constant { + name: None, + specialization: None, + inner, + })) + } } #[cfg(test)]