From fcbf2aa4c48cb8874f15fade4df6320dbcbc8eaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Capucho?= Date: Mon, 1 Feb 2021 14:58:26 +0000 Subject: [PATCH] Add support for unary ops to the constant solver Fix wrong result when casting to bool Add tests --- src/proc/constants.rs | 218 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 215 insertions(+), 3 deletions(-) diff --git a/src/proc/constants.rs b/src/proc/constants.rs index 37560a06be..8097ed80c6 100644 --- a/src/proc/constants.rs +++ b/src/proc/constants.rs @@ -1,6 +1,6 @@ use crate::{ arena::{Arena, Handle}, - ArraySize, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, + ArraySize, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, UnaryOperator, }; #[derive(Debug)] @@ -44,6 +44,8 @@ pub enum ConstantSolvingError { Bitcast, #[error("Cannot cast type")] InvalidCastArg, + #[error("Cannot apply the unary op to the argument")] + InvalidUnaryOpArg, } impl<'a> ConstantSolver<'a> { @@ -71,7 +73,11 @@ impl<'a> ConstantSolver<'a> { inner: ConstantInner::Composite { ty, components }, })) } - Expression::Unary { .. } => todo!(), + Expression::Unary { expr, op } => { + let tgt = self.solve(expr)?; + + self.unary_op(op, tgt) + } Expression::Binary { .. } => todo!(), Expression::Math { .. } => todo!(), Expression::As { @@ -196,7 +202,7 @@ impl<'a> ConstantSolver<'a> { ScalarKind::Sint => *value = ScalarValue::Sint(inner_cast(intial)), ScalarKind::Uint => *value = ScalarValue::Uint(inner_cast(intial)), ScalarKind::Float => *value = ScalarValue::Float(inner_cast(intial)), - ScalarKind::Bool => *value = ScalarValue::Bool(inner_cast::(intial) == 0), + ScalarKind::Bool => *value = ScalarValue::Bool(inner_cast::(intial) != 0), } } ConstantInner::Composite { @@ -220,4 +226,210 @@ impl<'a> ConstantSolver<'a> { inner, })) } + + fn unary_op( + &mut self, + op: UnaryOperator, + constant: Handle, + ) -> Result, ConstantSolvingError> { + let mut inner = self.constants[constant].inner.clone(); + + match inner { + ConstantInner::Scalar { ref mut value, .. } => match op { + UnaryOperator::Negate => match value { + ScalarValue::Sint(v) => *v = -*v, + ScalarValue::Float(v) => *v = -*v, + _ => return Err(ConstantSolvingError::InvalidUnaryOpArg), + }, + UnaryOperator::Not => match value { + ScalarValue::Sint(v) => *v = !*v, + ScalarValue::Uint(v) => *v = !*v, + ScalarValue::Bool(v) => *v = !*v, + _ => return Err(ConstantSolvingError::InvalidUnaryOpArg), + }, + }, + ConstantInner::Composite { + ty, + ref mut components, + } => { + match self.types[ty].inner { + crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => (), + _ => return Err(ConstantSolvingError::InvalidCastArg), + } + + for component in components { + *component = self.unary_op(op, *component)? + } + } + } + + Ok(self.constants.fetch_or_append(Constant { + name: None, + specialization: None, + inner, + })) + } +} + +#[cfg(test)] +mod tests { + use std::vec; + + use crate::{ + Arena, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner, + UnaryOperator, VectorSize, + }; + + use super::ConstantSolver; + + #[test] + fn unary_op() { + let mut types = Arena::new(); + let mut expressions = Arena::new(); + let mut constants = Arena::new(); + + let vec_ty = types.append(Type { + name: None, + inner: TypeInner::Vector { + size: VectorSize::Bi, + kind: ScalarKind::Sint, + width: 4, + }, + }); + + let h = constants.append(Constant { + name: None, + specialization: None, + inner: ConstantInner::Scalar { + width: 4, + value: ScalarValue::Sint(4), + }, + }); + + let h1 = constants.append(Constant { + name: None, + specialization: None, + inner: ConstantInner::Scalar { + width: 4, + value: ScalarValue::Sint(8), + }, + }); + + let vec_h = constants.append(Constant { + name: None, + specialization: None, + inner: ConstantInner::Composite { + ty: vec_ty, + components: vec![h, h1], + }, + }); + + let expr = expressions.append(Expression::Constant(h)); + let expr1 = expressions.append(Expression::Constant(vec_h)); + + let root1 = expressions.append(Expression::Unary { + op: UnaryOperator::Negate, + expr, + }); + + let root2 = expressions.append(Expression::Unary { + op: UnaryOperator::Not, + expr, + }); + + let root3 = expressions.append(Expression::Unary { + op: UnaryOperator::Not, + expr: expr1, + }); + + let mut solver = ConstantSolver { + types: &types, + expressions: &expressions, + constants: &mut constants, + }; + + let res1 = solver.solve(root1).unwrap(); + let res2 = solver.solve(root2).unwrap(); + let res3 = solver.solve(root3).unwrap(); + + assert_eq!( + constants[res1].inner, + ConstantInner::Scalar { + width: 4, + value: ScalarValue::Sint(-4), + }, + ); + + assert_eq!( + constants[res2].inner, + ConstantInner::Scalar { + width: 4, + value: ScalarValue::Sint(!4), + }, + ); + + let res3_inner = &constants[res3].inner; + + match res3_inner { + ConstantInner::Composite { ty, components } => { + assert_eq!(*ty, vec_ty); + let mut components_iter = components.iter().copied(); + assert_eq!( + constants[components_iter.next().unwrap()].inner, + ConstantInner::Scalar { + width: 4, + value: ScalarValue::Sint(!4), + }, + ); + assert_eq!( + constants[components_iter.next().unwrap()].inner, + ConstantInner::Scalar { + width: 4, + value: ScalarValue::Sint(!8), + }, + ); + assert!(components_iter.next().is_none()); + } + _ => panic!("Expected vector"), + } + } + + #[test] + fn cast() { + let mut expressions = Arena::new(); + let mut constants = Arena::new(); + + let h = constants.append(Constant { + name: None, + specialization: None, + inner: ConstantInner::Scalar { + width: 4, + value: ScalarValue::Sint(4), + }, + }); + + let expr = expressions.append(Expression::Constant(h)); + + let root = expressions.append(Expression::As { + expr, + kind: ScalarKind::Bool, + convert: true, + }); + + let mut solver = ConstantSolver { + types: &Arena::new(), + expressions: &expressions, + constants: &mut constants, + }; + + let res = solver.solve(root).unwrap(); + + assert_eq!( + constants[res].inner, + ConstantInner::Scalar { + width: 4, + value: ScalarValue::Bool(true), + }, + ); + } }