Add binary op support for the ConstantSolver

This commit is contained in:
João Capucho
2021-02-05 20:39:23 +00:00
committed by Dzmitry Malyshau
parent b1fa7471d2
commit 88eee833d7
2 changed files with 116 additions and 14 deletions

View File

@@ -401,7 +401,7 @@ pub struct Constant {
}
/// A literal scalar value, used in constants.
#[derive(Debug, PartialEq, Clone)]
#[derive(Debug, PartialEq, Clone, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ScalarValue {

View File

@@ -1,6 +1,7 @@
use crate::{
arena::{Arena, Handle},
Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, UnaryOperator,
BinaryOperator, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner,
UnaryOperator,
};
#[derive(Debug)]
@@ -46,6 +47,8 @@ pub enum ConstantSolvingError {
InvalidCastArg,
#[error("Cannot apply the unary op to the argument")]
InvalidUnaryOpArg,
#[error("Cannot apply the binary op to the arguments")]
InvalidBinaryOpArgs,
}
impl<'a> ConstantSolver<'a> {
@@ -78,7 +81,12 @@ impl<'a> ConstantSolver<'a> {
self.unary_op(op, tgt)
}
Expression::Binary { .. } => todo!(),
Expression::Binary { left, right, op } => {
let left = self.solve(left)?;
let right = self.solve(right)?;
self.binary_op(op, left, right)
}
Expression::Math { .. } => todo!(),
Expression::As {
convert,
@@ -97,11 +105,11 @@ impl<'a> ConstantSolver<'a> {
let array = self.solve(expr)?;
match self.constants[array].inner {
crate::ConstantInner::Scalar { .. } => {
ConstantInner::Scalar { .. } => {
Err(ConstantSolvingError::InvalidArrayLengthArg)
}
crate::ConstantInner::Composite { ty, .. } => match self.types[ty].inner {
crate::TypeInner::Array { size, .. } => match size {
ConstantInner::Composite { ty, .. } => match self.types[ty].inner {
TypeInner::Array { size, .. } => match size {
crate::ArraySize::Constant(constant) => Ok(constant),
crate::ArraySize::Dynamic => {
Err(ConstantSolvingError::ArrayLengthDynamic)
@@ -134,13 +142,13 @@ impl<'a> ConstantSolver<'a> {
let base = self.solve(base)?;
match self.constants[base].inner {
crate::ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase),
crate::ConstantInner::Composite { ty, ref components } => {
ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase),
ConstantInner::Composite { ty, ref components } => {
match self.types[ty].inner {
crate::TypeInner::Vector { .. }
| crate::TypeInner::Matrix { .. }
| crate::TypeInner::Array { .. }
| crate::TypeInner::Struct { .. } => (),
TypeInner::Vector { .. }
| TypeInner::Matrix { .. }
| TypeInner::Array { .. }
| TypeInner::Struct { .. } => (),
_ => return Err(ConstantSolvingError::InvalidAccessBase),
}
@@ -195,7 +203,7 @@ impl<'a> ConstantSolver<'a> {
ref mut components,
} => {
match self.types[ty].inner {
crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => (),
TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
_ => return Err(ConstantSolvingError::InvalidCastArg),
}
@@ -238,7 +246,7 @@ impl<'a> ConstantSolver<'a> {
ref mut components,
} => {
match self.types[ty].inner {
crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => (),
TypeInner::Vector { .. } | TypeInner::Matrix { .. } => (),
_ => return Err(ConstantSolvingError::InvalidCastArg),
}
@@ -254,6 +262,100 @@ impl<'a> ConstantSolver<'a> {
inner,
}))
}
fn binary_op(
&mut self,
op: BinaryOperator,
left: Handle<Constant>,
right: Handle<Constant>,
) -> Result<Handle<Constant>, ConstantSolvingError> {
let left = &self.constants[left].inner;
let right = &self.constants[right].inner;
let inner = match (left, right) {
(
ConstantInner::Scalar {
value: left_value,
width,
},
ConstantInner::Scalar {
value: right_value, ..
},
) => {
let value = match op {
BinaryOperator::Equal => ScalarValue::Bool(left_value == right_value),
BinaryOperator::NotEqual => ScalarValue::Bool(left_value != right_value),
BinaryOperator::Less => ScalarValue::Bool(left_value < right_value),
BinaryOperator::LessEqual => ScalarValue::Bool(left_value <= right_value),
BinaryOperator::Greater => ScalarValue::Bool(left_value > right_value),
BinaryOperator::GreaterEqual => ScalarValue::Bool(left_value >= right_value),
_ => match (left_value, right_value) {
(ScalarValue::Sint(a), ScalarValue::Sint(b)) => {
ScalarValue::Sint(match op {
BinaryOperator::Add => a + b,
BinaryOperator::Subtract => a - b,
BinaryOperator::Multiply => a * b,
BinaryOperator::Divide => a / b,
BinaryOperator::Modulo => a % b,
BinaryOperator::And => a & b,
BinaryOperator::ExclusiveOr => a ^ b,
BinaryOperator::InclusiveOr => a | b,
BinaryOperator::ShiftLeft => a << b,
BinaryOperator::ShiftRight => a >> b,
_ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
})
}
(ScalarValue::Uint(a), ScalarValue::Uint(b)) => {
ScalarValue::Uint(match op {
BinaryOperator::Add => a + b,
BinaryOperator::Subtract => a - b,
BinaryOperator::Multiply => a * b,
BinaryOperator::Divide => a / b,
BinaryOperator::Modulo => a % b,
BinaryOperator::And => a & b,
BinaryOperator::ExclusiveOr => a ^ b,
BinaryOperator::InclusiveOr => a | b,
BinaryOperator::ShiftLeft => a << b,
BinaryOperator::ShiftRight => a >> b,
_ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
})
}
(ScalarValue::Float(a), ScalarValue::Float(b)) => {
ScalarValue::Float(match op {
BinaryOperator::Add => a + b,
BinaryOperator::Subtract => a - b,
BinaryOperator::Multiply => a * b,
BinaryOperator::Divide => a / b,
BinaryOperator::Modulo => a % b,
_ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
})
}
(ScalarValue::Bool(a), ScalarValue::Bool(b)) => {
ScalarValue::Bool(match op {
BinaryOperator::LogicalAnd => *a && *b,
BinaryOperator::LogicalOr => *a || *b,
_ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
})
}
_ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
},
};
ConstantInner::Scalar {
value,
width: *width,
}
}
_ => return Err(ConstantSolvingError::InvalidBinaryOpArgs),
};
Ok(self.constants.fetch_or_append(Constant {
name: None,
specialization: None,
inner,
}))
}
}
#[cfg(test)]