Add support for unary ops to the constant solver

Fix wrong result when casting to bool
Add tests
This commit is contained in:
João Capucho
2021-02-01 14:58:26 +00:00
committed by Dzmitry Malyshau
parent 53bd721895
commit fcbf2aa4c4

View File

@@ -1,6 +1,6 @@
use crate::{
arena::{Arena, Handle},
ArraySize, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type,
ArraySize, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, UnaryOperator,
};
#[derive(Debug)]
@@ -44,6 +44,8 @@ pub enum ConstantSolvingError {
Bitcast,
#[error("Cannot cast type")]
InvalidCastArg,
#[error("Cannot apply the unary op to the argument")]
InvalidUnaryOpArg,
}
impl<'a> ConstantSolver<'a> {
@@ -71,7 +73,11 @@ impl<'a> ConstantSolver<'a> {
inner: ConstantInner::Composite { ty, components },
}))
}
Expression::Unary { .. } => todo!(),
Expression::Unary { expr, op } => {
let tgt = self.solve(expr)?;
self.unary_op(op, tgt)
}
Expression::Binary { .. } => todo!(),
Expression::Math { .. } => todo!(),
Expression::As {
@@ -196,7 +202,7 @@ impl<'a> ConstantSolver<'a> {
ScalarKind::Sint => *value = ScalarValue::Sint(inner_cast(intial)),
ScalarKind::Uint => *value = ScalarValue::Uint(inner_cast(intial)),
ScalarKind::Float => *value = ScalarValue::Float(inner_cast(intial)),
ScalarKind::Bool => *value = ScalarValue::Bool(inner_cast::<u64>(intial) == 0),
ScalarKind::Bool => *value = ScalarValue::Bool(inner_cast::<u64>(intial) != 0),
}
}
ConstantInner::Composite {
@@ -220,4 +226,210 @@ impl<'a> ConstantSolver<'a> {
inner,
}))
}
fn unary_op(
&mut self,
op: UnaryOperator,
constant: Handle<Constant>,
) -> Result<Handle<Constant>, ConstantSolvingError> {
let mut inner = self.constants[constant].inner.clone();
match inner {
ConstantInner::Scalar { ref mut value, .. } => match op {
UnaryOperator::Negate => match value {
ScalarValue::Sint(v) => *v = -*v,
ScalarValue::Float(v) => *v = -*v,
_ => return Err(ConstantSolvingError::InvalidUnaryOpArg),
},
UnaryOperator::Not => match value {
ScalarValue::Sint(v) => *v = !*v,
ScalarValue::Uint(v) => *v = !*v,
ScalarValue::Bool(v) => *v = !*v,
_ => return Err(ConstantSolvingError::InvalidUnaryOpArg),
},
},
ConstantInner::Composite {
ty,
ref mut components,
} => {
match self.types[ty].inner {
crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => (),
_ => return Err(ConstantSolvingError::InvalidCastArg),
}
for component in components {
*component = self.unary_op(op, *component)?
}
}
}
Ok(self.constants.fetch_or_append(Constant {
name: None,
specialization: None,
inner,
}))
}
}
#[cfg(test)]
mod tests {
use std::vec;
use crate::{
Arena, Constant, ConstantInner, Expression, ScalarKind, ScalarValue, Type, TypeInner,
UnaryOperator, VectorSize,
};
use super::ConstantSolver;
#[test]
fn unary_op() {
let mut types = Arena::new();
let mut expressions = Arena::new();
let mut constants = Arena::new();
let vec_ty = types.append(Type {
name: None,
inner: TypeInner::Vector {
size: VectorSize::Bi,
kind: ScalarKind::Sint,
width: 4,
},
});
let h = constants.append(Constant {
name: None,
specialization: None,
inner: ConstantInner::Scalar {
width: 4,
value: ScalarValue::Sint(4),
},
});
let h1 = constants.append(Constant {
name: None,
specialization: None,
inner: ConstantInner::Scalar {
width: 4,
value: ScalarValue::Sint(8),
},
});
let vec_h = constants.append(Constant {
name: None,
specialization: None,
inner: ConstantInner::Composite {
ty: vec_ty,
components: vec![h, h1],
},
});
let expr = expressions.append(Expression::Constant(h));
let expr1 = expressions.append(Expression::Constant(vec_h));
let root1 = expressions.append(Expression::Unary {
op: UnaryOperator::Negate,
expr,
});
let root2 = expressions.append(Expression::Unary {
op: UnaryOperator::Not,
expr,
});
let root3 = expressions.append(Expression::Unary {
op: UnaryOperator::Not,
expr: expr1,
});
let mut solver = ConstantSolver {
types: &types,
expressions: &expressions,
constants: &mut constants,
};
let res1 = solver.solve(root1).unwrap();
let res2 = solver.solve(root2).unwrap();
let res3 = solver.solve(root3).unwrap();
assert_eq!(
constants[res1].inner,
ConstantInner::Scalar {
width: 4,
value: ScalarValue::Sint(-4),
},
);
assert_eq!(
constants[res2].inner,
ConstantInner::Scalar {
width: 4,
value: ScalarValue::Sint(!4),
},
);
let res3_inner = &constants[res3].inner;
match res3_inner {
ConstantInner::Composite { ty, components } => {
assert_eq!(*ty, vec_ty);
let mut components_iter = components.iter().copied();
assert_eq!(
constants[components_iter.next().unwrap()].inner,
ConstantInner::Scalar {
width: 4,
value: ScalarValue::Sint(!4),
},
);
assert_eq!(
constants[components_iter.next().unwrap()].inner,
ConstantInner::Scalar {
width: 4,
value: ScalarValue::Sint(!8),
},
);
assert!(components_iter.next().is_none());
}
_ => panic!("Expected vector"),
}
}
#[test]
fn cast() {
let mut expressions = Arena::new();
let mut constants = Arena::new();
let h = constants.append(Constant {
name: None,
specialization: None,
inner: ConstantInner::Scalar {
width: 4,
value: ScalarValue::Sint(4),
},
});
let expr = expressions.append(Expression::Constant(h));
let root = expressions.append(Expression::As {
expr,
kind: ScalarKind::Bool,
convert: true,
});
let mut solver = ConstantSolver {
types: &Arena::new(),
expressions: &expressions,
constants: &mut constants,
};
let res = solver.solve(root).unwrap();
assert_eq!(
constants[res].inner,
ConstantInner::Scalar {
width: 4,
value: ScalarValue::Bool(true),
},
);
}
}