Integers (#91)

We already had sorts and values. Now we have operators too. These are only really in the SMT backend right now.
This commit is contained in:
Alex Ozdemir
2022-07-26 15:12:14 -07:00
committed by GitHub
parent 2d0709342e
commit 0b3c936a40
5 changed files with 285 additions and 2 deletions

View File

@@ -81,6 +81,7 @@ struct FrontendOptions {
#[structopt(long)]
lint_prim_rec: bool,
#[cfg(feature = "zok")]
/// In Z#, "isolate" assertions. That is, assertions in if/then/else expressions only take
/// effect if that branch is active.
///

View File

@@ -122,6 +122,11 @@ pub enum Op {
/// Takes the modulus.
UbvToPf(FieldT),
/// Integer n-ary operator
IntNaryOp(IntNaryOp),
/// Integer comparison operator
IntBinPred(IntBinPred),
/// Binary operator, with arguments (array, index).
///
/// Gets the value at index in array.
@@ -213,6 +218,18 @@ pub const PF_RECIP: Op = Op::PfUnOp(PfUnOp::Recip);
pub const PF_ADD: Op = Op::PfNaryOp(PfNaryOp::Add);
/// prime-field multiplication
pub const PF_MUL: Op = Op::PfNaryOp(PfNaryOp::Mul);
/// integer addition
pub const INT_ADD: Op = Op::IntNaryOp(IntNaryOp::Add);
/// integer multiplication
pub const INT_MUL: Op = Op::IntNaryOp(IntNaryOp::Mul);
/// integer less than
pub const INT_LT: Op = Op::IntBinPred(IntBinPred::Lt);
/// integer less than or equal
pub const INT_LE: Op = Op::IntBinPred(IntBinPred::Le);
/// integer greater than
pub const INT_GT: Op = Op::IntBinPred(IntBinPred::Gt);
/// integer greater than or equal
pub const INT_GE: Op = Op::IntBinPred(IntBinPred::Ge);
impl Op {
/// Number of arguments for this operator. `None` if n-ary.
@@ -247,6 +264,8 @@ impl Op {
Op::FpToFp(_) => Some(1),
Op::PfUnOp(_) => Some(1),
Op::PfNaryOp(_) => None,
Op::IntNaryOp(_) => None,
Op::IntBinPred(_) => Some(2),
Op::UbvToPf(_) => Some(1),
Op::Select => Some(2),
Op::Store => Some(3),
@@ -291,6 +310,8 @@ impl Display for Op {
Op::FpToFp(a) => write!(f, "(fp2fp {})", a),
Op::PfUnOp(a) => write!(f, "{}", a),
Op::PfNaryOp(a) => write!(f, "{}", a),
Op::IntNaryOp(a) => write!(f, "{}", a),
Op::IntBinPred(a) => write!(f, "{}", a),
Op::UbvToPf(a) => write!(f, "(bv2pf {})", a.modulus()),
Op::Select => write!(f, "select"),
Op::Store => write!(f, "store"),
@@ -589,6 +610,48 @@ impl Display for PfUnOp {
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
/// Integer n-ary operator
pub enum IntNaryOp {
/// Finite field (+)
Add,
/// Finite field (*)
Mul,
}
impl Display for IntNaryOp {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
IntNaryOp::Add => write!(f, "intadd"),
IntNaryOp::Mul => write!(f, "intmul"),
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
/// Integer binary predicate. See [Op::Eq] for equality.
pub enum IntBinPred {
/// Integer (<)
Lt,
/// Integer (>)
Gt,
/// Integer (<=)
Le,
/// Integer (>=)
Ge,
}
impl Display for IntBinPred {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
IntBinPred::Lt => write!(f, "<"),
IntBinPred::Gt => write!(f, ">"),
IntBinPred::Le => write!(f, "<="),
IntBinPred::Ge => write!(f, ">="),
}
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
/// A term: an operator applied to arguements
pub struct TermData {
@@ -1279,6 +1342,15 @@ impl Value {
}
}
#[track_caller]
/// Get the underlying bit-vector constant, or panic!
pub fn as_int(&self) -> &Integer {
if let Value::Int(b) = self {
b
} else {
panic!("Not a bit-vec: {}", self)
}
}
#[track_caller]
/// Get the underlying prime field constant, if possible.
pub fn as_pf(&self) -> &FieldV {
if let Value::Field(b) = self {
@@ -1510,6 +1582,27 @@ fn eval_value(vs: &mut TermMap<Value>, h: &FxHashMap<String, Value>, c: Term) ->
},
)
}),
Op::IntBinPred(o) => Value::Bool({
let a = vs.get(&c.cs[0]).unwrap().as_int();
let b = vs.get(&c.cs[1]).unwrap().as_int();
match o {
IntBinPred::Ge => a >= b,
IntBinPred::Gt => a > b,
IntBinPred::Le => a <= b,
IntBinPred::Lt => a < b,
}
}),
Op::IntNaryOp(o) => Value::Int({
let mut xs = c.cs.iter().map(|c| vs.get(c).unwrap().as_int().clone());
let f = xs.next().unwrap();
xs.fold(
f,
match o {
IntNaryOp::Add => std::ops::Add::add,
IntNaryOp::Mul => std::ops::Mul::mul,
},
)
}),
Op::UbvToPf(fty) => Value::Field(fty.new_v(vs.get(&c.cs[0]).unwrap().as_bv().uint())),
// tuple
Op::Tuple => Value::Tuple(c.cs.iter().map(|c| vs.get(c).unwrap().clone()).collect()),

View File

@@ -258,6 +258,12 @@ impl<'src> IrInterp<'src> {
Leaf(Ident, b"*") => Ok(Op::PfNaryOp(PfNaryOp::Mul)),
Leaf(Ident, b"pfrecip") => Ok(Op::PfUnOp(PfUnOp::Recip)),
Leaf(Ident, b"-") => Ok(Op::PfUnOp(PfUnOp::Neg)),
Leaf(Ident, b"<") => Ok(INT_LT),
Leaf(Ident, b"<=") => Ok(INT_LE),
Leaf(Ident, b">") => Ok(INT_GT),
Leaf(Ident, b">=") => Ok(INT_GE),
Leaf(Ident, b"intadd") => Ok(INT_ADD),
Leaf(Ident, b"intmul") => Ok(INT_MUL),
Leaf(Ident, b"select") => Ok(Op::Select),
Leaf(Ident, b"store") => Ok(Op::Store),
Leaf(Ident, b"tuple") => Ok(Op::Tuple),

View File

@@ -55,6 +55,8 @@ fn check_dependencies(t: &Term) -> Vec<Term> {
Op::FpToFp(_) => Vec::new(),
Op::PfUnOp(_) => vec![t.cs[0].clone()],
Op::PfNaryOp(_) => vec![t.cs[0].clone()],
Op::IntNaryOp(_) => Vec::new(),
Op::IntBinPred(_) => Vec::new(),
Op::UbvToPf(_) => Vec::new(),
Op::Select => vec![t.cs[0].clone()],
Op::Store => vec![t.cs[0].clone()],
@@ -122,6 +124,8 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result<Sort, TypeErrorReason> {
Op::FpToFp(32) => Ok(Sort::F32),
Op::PfUnOp(_) => Ok(get_ty(&t.cs[0]).clone()),
Op::PfNaryOp(_) => Ok(get_ty(&t.cs[0]).clone()),
Op::IntNaryOp(_) => Ok(Sort::Int),
Op::IntBinPred(_) => Ok(Sort::Bool),
Op::UbvToPf(m) => Ok(Sort::Field(m.clone())),
Op::Select => array_or(get_ty(&t.cs[0]), "select").map(|(_, v)| v.clone()),
Op::Store => Ok(get_ty(&t.cs[0]).clone()),
@@ -329,6 +333,15 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result<Sort, TypeErrorRea
}
(Op::UbvToPf(m), &[a]) => bv_or(a, "ubv-to-pf").map(|_| Sort::Field(m.clone())),
(Op::PfUnOp(_), &[a]) => pf_or(a, "pf unary op").map(|a| a.clone()),
(Op::IntNaryOp(_), a) => {
let ctx = "int nary op";
all_eq_or(a.iter().cloned(), ctx)
.and_then(|t| int_or(t, ctx))
.map(|a| a.clone())
}
(Op::IntBinPred(_), &[a, b]) => int_or(a, "int bin pred")
.and_then(|_| int_or(b, "int bin pred"))
.map(|_| Sort::Bool),
(Op::Select, &[Sort::Array(k, v, _), a]) => eq_or(k, a, "select").map(|_| (**v).clone()),
(Op::Store, &[Sort::Array(k, v, n), a, b]) => eq_or(k, a, "store")
.and_then(|_| eq_or(v, b, "store"))
@@ -465,12 +478,14 @@ pub struct TypeError {
pub enum TypeErrorReason {
/// Two sorts should be equal
NotEqual(Sort, Sort, &'static str),
/// A sort should be a boolean
/// A sort should be boolean
ExpectedBool(Sort, &'static str),
/// A sort should be a floating-point
ExpectedFp(Sort, &'static str),
/// A sort should be a bit-vector
ExpectedBv(Sort, &'static str),
/// A sort should be integer
ExpectedInt(Sort, &'static str),
/// A sort should be a prime field
ExpectedPf(Sort, &'static str),
/// A sort should be an array
@@ -495,6 +510,14 @@ fn bv_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Sort, TypeErrorReason
}
}
fn int_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a Sort, TypeErrorReason> {
if let Sort::Int = a {
Ok(a)
} else {
Err(TypeErrorReason::ExpectedInt(a.clone(), ctx))
}
}
fn array_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<(&'a Sort, &'a Sort), TypeErrorReason> {
if let Sort::Array(k, v, _) = a {
Ok((&*k, &*v))

View File

@@ -45,7 +45,8 @@ impl Expr2Smt<()> for Value {
match self {
Value::Bool(b) => write!(w, "{}", b)?,
Value::Field(f) => write!(w, "#f{}m{}", f.i(), f.modulus())?,
Value::Int(i) => write!(w, "{}", i)?,
Value::Int(i) if i >= &Integer::new() => write!(w, "{}", i)?,
Value::Int(i) => write!(w, "(- 0 {})", *i.as_neg())?,
Value::BitVector(b) => write!(w, "{}", b)?,
Value::F32(f) => {
let (sign, exp, mant) = f.decompose_raw();
@@ -163,6 +164,18 @@ impl Expr2Smt<()> for TermData {
write!(w, "(ffneg")?;
true
}
Op::IntNaryOp(IntNaryOp::Mul) => {
write!(w, "(*")?;
true
}
Op::IntNaryOp(IntNaryOp::Add) => {
write!(w, "(+")?;
true
}
Op::IntBinPred(o) => {
write!(w, "({}", o)?;
true
}
o => panic!("Cannot give {} to SMT solver", o),
};
if s_expr_children {
@@ -230,6 +243,8 @@ impl<'a, R: std::io::BufRead> IdentParser<String, Sort, &'a mut SmtParser<R>> fo
fn parse_type(self, input: &'a mut SmtParser<R>) -> SmtRes<Sort> {
if input.try_tag("Bool")? {
Ok(Sort::Bool)
} else if input.try_tag("Int")? {
Ok(Sort::Int)
} else if input.try_tag("(_ BitVec")? {
let n = input
.try_int(|s, b| {
@@ -292,6 +307,10 @@ impl<'a, Br: ::std::io::BufRead> ModelParser<String, Sort, Value, &'a mut SmtPar
let int_literal = input.get_sexpr()?;
let i = Integer::from_str_radix(int_literal, 10).unwrap();
Value::Field(f.new_v(i))
} else if let Sort::Int = s {
let int_literal = input.get_sexpr()?;
let i = Integer::from_str_radix(int_literal, 10).unwrap();
Value::Int(i)
} else {
unimplemented!("Could not parse model suffix: {}", input.buff_rest())
};
@@ -585,4 +604,145 @@ mod test {
.unwrap();
solver.check_sat().unwrap()
}
#[test]
fn int_model() {
let t = text::parse_term(
b"
(declare ((a int) (b int))
(and
(or (= (intadd a b) 1)
(= (intadd a b) 0))
(< a 1)
(> 1 b)
(>= a 0)
(<= 0 b)
)
)
",
);
assert_eq!(
find_model(&t),
Some(
vec![
("a".to_owned(), Value::Int(0.into())),
("b".to_owned(), Value::Int(0.into())),
]
.into_iter()
.collect()
)
)
}
#[test]
fn int_no_model() {
let t = text::parse_term(
b"
(declare ((a int) (b int))
(and
(or (= (intadd a b) 1)
(= (intadd a b) 1))
(< a 1)
(> 1 b)
(>= a 0)
(<= 0 b)
)
)
",
);
assert_eq!(find_model(&t), None)
}
#[test]
fn int_model_nia() {
let t = text::parse_term(
b"
(declare ((a int) (b int))
(and
(= (intmul a a) b)
(= (intmul b b) a)
(not (= a 0))
)
)
",
);
assert_eq!(
find_model(&t),
Some(
vec![
("a".to_owned(), Value::Int(1.into())),
("b".to_owned(), Value::Int(1.into())),
]
.into_iter()
.collect()
)
)
}
#[test]
fn int_model_div() {
let t = text::parse_term(
b"
(declare ((a int) (q int) (r int))
(and
(= a (intadd (intmul q 5) r))
(>= r 0)
(< r 5)
(= (intadd a (intmul -1 r)) 10)
(>= a 14)
)
)
",
);
assert_eq!(
find_model(&t),
Some(
vec![
("a".to_owned(), Value::Int(14.into())),
("r".to_owned(), Value::Int(4.into())),
("q".to_owned(), Value::Int(2.into())),
]
.into_iter()
.collect()
)
)
}
#[test]
fn bv_model_div() {
let t = text::parse_term(
b"
(declare ((a (bv 8)) (q (bv 8)) (r (bv 8)))
(and
(= a (bvadd (bvmul q #x05) r))
(bvuge r #x00)
(bvult r #x05)
(= (bvsub a r) #x0a)
(bvuge a #x0e)
)
)
",
);
assert_eq!(
find_model(&t),
Some(
vec![
(
"a".to_owned(),
Value::BitVector(BitVector::new(Integer::from(14), 8))
),
(
"r".to_owned(),
Value::BitVector(BitVector::new(Integer::from(4), 8))
),
(
"q".to_owned(),
Value::BitVector(BitVector::new(Integer::from(2), 8))
),
]
.into_iter()
.collect()
)
)
}
}