diff --git a/.gitignore b/.gitignore index a246fcd8..b7f6f44d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,9 @@ target Cargo.lock *~ .z3-trace +.DS_Store + +# nix +.envrc +.direnv +result \ No newline at end of file diff --git a/z3-sys/src/lib.rs b/z3-sys/src/lib.rs index da7b9115..16b49ef1 100644 --- a/z3-sys/src/lib.rs +++ b/z3-sys/src/lib.rs @@ -1552,6 +1552,73 @@ pub enum ErrorCode { pub type Z3_error_handler = ::std::option::Option; +#[doc(hidden)] +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct _Z3_solver_callback { + _unused: [u8; 0], +} + +/// Type of callback functions for the User Propagator +pub type Z3_solver_callback = *mut _Z3_solver_callback; + +pub type Z3_push_eh = ::std::option::Option< + unsafe extern "C" fn(cyx: *mut ::std::ffi::c_void, cd: Z3_solver_callback), +>; +pub type Z3_pop_eh = ::std::option::Option< + unsafe extern "C" fn( + ctx: *mut ::std::ffi::c_void, + cb: Z3_solver_callback, + num_scopes: ::std::os::raw::c_uint, + ), +>; +pub type Z3_fresh_eh = ::std::option::Option< + unsafe extern "C" fn( + ctx: *mut ::std::ffi::c_void, + new_context: Z3_context, + ) -> *mut ::std::ffi::c_void, +>; +pub type Z3_fixed_eh = ::std::option::Option< + unsafe extern "C" fn( + ctx: *mut ::std::ffi::c_void, + cb: Z3_solver_callback, + t: Z3_ast, + value: Z3_ast, + ), +>; +pub type Z3_eq_eh = ::std::option::Option< + unsafe extern "C" fn( + ctx: *mut ::std::ffi::c_void, + cb: Z3_solver_callback, + s: Z3_ast, + t: Z3_ast, + ), +>; +pub type Z3_final_eh = ::std::option::Option< + unsafe extern "C" fn(cyx: *mut ::std::ffi::c_void, cb: Z3_solver_callback), +>; +pub type Z3_created_eh = ::std::option::Option< + unsafe extern "C" fn(cyx: *mut ::std::ffi::c_void, cb: Z3_solver_callback, t: Z3_ast), +>; +pub type Z3_decide_eh = ::std::option::Option< + unsafe extern "C" fn( + cyx: *mut ::std::ffi::c_void, + cd: Z3_solver_callback, + t: Z3_ast, + idx: ::std::os::raw::c_uint, + phase: bool, + ), +>; +pub type Z3_on_clause_eh = ::std::option::Option< + unsafe extern "C" fn( + ctx: *mut ::std::ffi::c_void, + proof_hint: Z3_ast, + n: ::std::os::raw::c_uint, + deps: *const ::std::os::raw::c_uint, + literals: Z3_ast_vector, + ), +>; + /// Precision of a given goal. Some goals can be transformed using over/under approximations. /// /// This corresponds to `Z3_goal_prec` in the C API. @@ -8034,6 +8101,157 @@ extern "C" { /// Best-effort quantifier elimination pub fn Z3_qe_lite(c: Z3_context, vars: Z3_ast_vector, body: Z3_ast) -> Z3_ast; + + /// Sets the next (registered) expression to split on. The function returns + /// false and ignores the given expression in case the expression is already + /// assigned internally (due to relevancy propagation, this assignments + /// might not have been reported yet by the fixed callback). In case the + /// function is called in the decide callback, it overrides the currently + /// selected variable and phase. + pub fn Z3_solver_next_split( + c: Z3_context, + cb: Z3_solver_callback, + t: Z3_ast, + idx: ::std::os::raw::c_uint, + phase: Z3_lbool, + ) -> bool; + + /// propagate a consequence based on fixed values and equalities. A client + /// may invoke it during the `propagate_fixed`, `propagate_eq`, `propagate_diseq`, + /// and `propagate_final` callbacks. The callback adds a propagation + /// consequence based on the fixed values passed ids and equalities eqs + /// based on parameters lhs, rhs. + /// + /// The solver might discard the propagation in case it is true in the + /// current state. The function returns false in this case; otw. the + /// function returns true. At least one propagation in the final callback + /// has to return true in order to prevent the solver from finishing. + /// + /// - `c`: context + /// - `solver_cb`: solver callback + /// - `num_ids`: number of fixed terms used as premise to propagation + /// - `ids`: array of length `num_ids` containing terms that are fixed in the current scope + /// - `num_eqs`: number of equalities used as premise to propagation + /// - `lhs`: left side of equalities + /// - `rhs`: right side of equalities + /// - `consequence`: consequence to propagate. It is typically an atomic formula, but + /// it can be an arbitrary formula. + /// + /// Assume the callback has the signature: + /// `propagate_consequence_eh(context, solver_cb, num_ids, ids, num_eqs, lhs, rhs, consequence)`. + pub fn Z3_solver_propagate_consequence( + c: Z3_context, + cb: Z3_solver_callback, + num_fixed: ::std::os::raw::c_uint, + fixed: *const Z3_ast, + num_eqs: ::std::os::raw::c_uint, + eq_lhs: *const Z3_ast, + eq_rhs: *const Z3_ast, + conseq: Z3_ast, + ) -> bool; + + /// register a callback when a new expression with a registered function is + /// used by the solver The registered function appears at the top level and + /// is created using [`Z3_solver_propagate_declare`]. + pub fn Z3_solver_propagate_created(c: Z3_context, s: Z3_solver, created_eh: Z3_created_eh); + + /// register a callback when the solver decides to split on a registered + /// expression. The callback may change the arguments by providing other + /// values by calling [`Z3_solver_next_split`]. + pub fn Z3_solver_propagate_decide(c: Z3_context, s: Z3_solver, decide_eh: Z3_decide_eh); + + /// Create uninterpreted function declaration for the user propagator. When + /// expressions using the function are created by the solver invoke a + /// callback to [`Z3_solver_propagate_created`] with arguments + /// + /// 1. context and callback solve + /// 2. `declared_expr`: expression using function that was used as the + /// top-level symbol + /// 3. `declared_id`: a unique identifier (unique within the current scope) to + /// track the expression. + pub fn Z3_solver_propagate_declare( + c: Z3_context, + name: Z3_symbol, + n: ::std::os::raw::c_uint, + domain: *const Z3_sort, + range: Z3_sort, + ) -> Z3_func_decl; + + /// register a callback on expression dis-equalities. + pub fn Z3_solver_propagate_diseq(c: Z3_context, s: Z3_solver, eq_eh: Z3_eq_eh); + + /// register a callback on expression equalities. + pub fn Z3_solver_propagate_eq(c: Z3_context, s: Z3_solver, eq_eh: Z3_eq_eh); + + /// register a callback on final check. This provides freedom to the + /// propagator to delay actions or implement a branch-and bound solver. The + /// final check is invoked when all decision variables have been assigned by + /// the solver. + /// + /// The `final_eh` callback takes as argument the origina`user_context`xt that + /// was used when calling [`Z3_solver_propagate_init`], and it takes a callback + /// context with the opaque type [`Z3_solver_callback`]. The callback context is + /// passed as argument to invoke the [`Z3_solver_propagate_consequence`] + /// function. The callback context can only be accessed (for propagation and + /// for dynamically registering expressions) within a callback. If the + /// callback context gets used for propagation or conflicts, those + /// propagations take effect and may trigger new decision variables to be + /// set. + pub fn Z3_solver_propagate_final(c: Z3_context, s: Z3_solver, final_eh: Z3_final_eh); + + /// register a callback for when an expression is bound to a fixed value. + /// The supported expression types are: + /// + /// - Booleans + /// - Bit-vectors + pub fn Z3_solver_propagate_fixed(c: Z3_context, s: Z3_solver, fixed_eh: Z3_fixed_eh); + + /// register a user-propagator with the solver. + /// + /// -`c`: context. + /// - `s`: solver object. + /// - `user_context`: a context used to maintain state for callbacks. + /// - `push_eh`: a callback invoked when scopes are pushed + /// - `pop_eh`: a callback invoked when scopes are popped + /// - `fresh_eh`: a solver may spawn new solvers internally. This callback + /// is used to produce a fresh `user_context` to be associated with fresh + /// solvers. + pub fn Z3_solver_propagate_init( + c: Z3_context, + s: Z3_solver, + user_context: *mut ::std::ffi::c_void, + push_eh: Z3_push_eh, + pop_eh: Z3_pop_eh, + fresh_eh: Z3_fresh_eh, + ); + + /// register an expression to propagate on with the solver. Only expressions + /// of type Bool and type Bit-Vector can be registered for propagation. + pub fn Z3_solver_propagate_register(c: Z3_context, s: Z3_solver, e: Z3_ast); + + /// register an expression to propagate on with the solver. Only expressions + /// of type Bool and type Bit-Vector can be registered for propagation. + /// Unlike [`Z3_solver_propagate_register`], this function takes a solver + /// callback context as argument. It can be invoked during a callback to + /// register new expressions. + pub fn Z3_solver_propagate_register_cb(c: Z3_context, cb: Z3_solver_callback, e: Z3_ast); + + /// register a callback to that retrieves assumed, inferred and deleted clauses during search. + /// + /// + /// - `c`: context. + /// - `s`: solver object. + /// - `user_context`: a context used to maintain state for callbacks. + /// - `on_clause_eh`: a callback that is invoked by when a clause is + /// * asserted to the CDCL engine (corresponding to an input clause after pre-processing) + /// * inferred by CDCL(T) using either a SAT or theory conflict/propagation + /// * deleted by the CDCL(T) engine + pub fn Z3_solver_register_on_clause( + c: Z3_context, + s: Z3_solver, + user_context: *mut ::std::ffi::c_void, + on_clause_eh: Z3_on_clause_eh, + ); } #[cfg(not(windows))] diff --git a/z3/src/func_decl.rs b/z3/src/func_decl.rs index 307e13eb..76ab037c 100644 --- a/z3/src/func_decl.rs +++ b/z3/src/func_decl.rs @@ -37,6 +37,37 @@ impl<'ctx> FuncDecl<'ctx> { } } + /// [`Self::new`] but register it for the [`UserPropagator`]s + /// + /// see [`user_propagator`] + /// + /// [user_propagator]: super::user_propagator + /// [UserPropagator]: super::user_propagator::UserPropagator + pub fn new_up>( + ctx: &'ctx Context, + name: S, + domain: &[&Sort<'ctx>], + range: &Sort<'ctx>, + ) -> Self { + assert!(domain.iter().all(|s| s.ctx.z3_ctx == ctx.z3_ctx)); + assert_eq!(ctx.z3_ctx, range.ctx.z3_ctx); + + let domain: Vec<_> = domain.iter().map(|s| s.z3_sort).collect(); + + unsafe { + Self::wrap( + ctx, + Z3_solver_propagate_declare( + ctx.z3_ctx, + name.into().as_z3_symbol(ctx), + domain.len().try_into().unwrap(), + domain.as_ptr(), + range.z3_sort, + ), + ) + } + } + /// Return the number of arguments of a function declaration. /// /// If the function declaration is a constant, then the arity is `0`. diff --git a/z3/src/lib.rs b/z3/src/lib.rs index e12d590d..941fa2fd 100644 --- a/z3/src/lib.rs +++ b/z3/src/lib.rs @@ -30,6 +30,7 @@ mod sort; mod statistics; mod symbol; mod tactic; +pub mod user_propagator; mod version; pub use crate::params::{get_global_param, reset_all_global_params, set_global_param}; diff --git a/z3/src/user_propagator.rs b/z3/src/user_propagator.rs new file mode 100644 index 00000000..ceea8d5c --- /dev/null +++ b/z3/src/user_propagator.rs @@ -0,0 +1,726 @@ +//! Z3's [User Propagator](https://microsoft.github.io/z3guide/programming/Example%20Programs/User%20Propagator/) API + +// I am following quite closly this: https://z3prover.github.io/api/html/classz3_1_1user__propagator__base.html + +use crate::{ + ast::Ast, + ast::{self, Dynamic}, + Context, Solver, +}; +use log::debug; +use std::{convert::TryInto, fmt::Debug, pin::Pin}; +use z3_sys::*; + +/// Interface to build a custom [User +/// Propagator](https://microsoft.github.io/z3guide/programming/Example%20Programs/User%20Propagator/) +/// +/// All function fowllow their C++ counterparts in the +/// [`user_propagator_base`](https://z3prover.github.io/api/html/classz3_1_1user__propagator__base.html). +/// Callbacks can be made though the `upw` paramter. Those callbacks my panic if +/// called from a wrong place. +/// +/// By default all function are implemented and do nothing +#[allow(unused_variables)] +pub trait UserPropagator<'ctx>: Debug { + fn get_context(&self) -> &'ctx Context; + + /// Called when z3 case splits + fn push(&mut self, cb: &CallBack<'ctx>) {} + + /// Called when z3 backtracks `num_scopes` times + fn pop(&mut self, cb: &CallBack<'ctx>, num_scopes: u32) {} + + /// Called when `id` is fixed to value `e` + fn fixed(&mut self, cb: &CallBack<'ctx>, id: &Dynamic<'ctx>, e: &Dynamic<'ctx>) {} + + /// Called when `x` and `y` are equated. + /// + /// This can be somewhat unreliable as z3 may call this less time than you'd + /// expect. + /// + /// See: + /// + fn eq(&mut self, cb: &CallBack<'ctx>, x: &Dynamic<'ctx>, y: &Dynamic<'ctx>) {} + + /// Same as [eq] be on negated equalities + fn neq(&mut self, cb: &CallBack<'ctx>, x: &Dynamic<'ctx>, y: &Dynamic<'ctx>) {} + + /// During the final check stage, all propagations have been processed. This + /// is an opportunity for the user-propagator to delay some analysis that + /// could be expensive to perform incrementally. It is also an opportunity + /// for the propagator to implement branch and bound optimization. + fn final_(&mut self, cb: &CallBack<'ctx>) {} + + /// `e` was created using one of the function declared with + /// [`declare_up_function`]. + /// + /// Remeber to register those expressions! (using [`CallBack::add`] or + /// [`UPSolver::add`] depending on the calling location) + /// + /// **NB**: there is no way to declare a function for specific + /// [`UserPropagator`]. + /// + /// [UPSolver::add]: super::UPSolver::add + fn created(&mut self, cb: &CallBack<'ctx>, e: &Dynamic<'ctx>) {} + + fn decide(&mut self, cb: &CallBack<'ctx>, val: &Dynamic<'ctx>, bit: u32, is_pos: bool) {} + + // TODO: figure out how to make fresh work + // fn fresh<'a>( + // upw: &'a UserPropagatorWrapper<'ctx>, + // ctx: &'a Context, + // ) -> Option<&'a dyn UserPropagator<'a>> { + // None + // } +} + +#[derive(Debug)] +pub struct CallBack<'ctx> { + cb: Z3_solver_callback, + ctx: &'ctx Context, +} + +impl<'ctx> CallBack<'ctx> { + pub fn new(cb: Z3_solver_callback, ctx: &'ctx Context) -> Self { + Self { cb, ctx } + } + + /// Sets the next (registered) expression to split on. The function returns + /// false and ignores the given expression in case the expression is already + /// assigned internally (due to relevancy propagation, this assignments + /// might not have been reported yet by the fixed callback). In case the + /// function is called in the decide callback, it overrides the currently + /// selected variable and phase. + /// + /// panics if not called from a callback + /// + /// see [`Z3_solver_next_split`] + pub fn next_split(&self, expr: &ast::Bool<'ctx>, idx: u32, phase: Option) -> bool { + debug_assert_eq!(self.get_ctx(), expr.get_ctx()); + let phase = match phase { + Some(true) => Z3_L_TRUE, + Some(false) => Z3_L_FALSE, + None => Z3_L_UNDEF, + }; + unsafe { Z3_solver_next_split(self.z3_ctx(), self.cb, expr.get_z3_ast(), idx, phase) } + } + + /// Tracks `expr` with ([`UserPropagator::fixed()`] or`UserPropagator::eq()`()]) + /// + /// If `expr` is a Boolean or Bit-vector, the [`UserPropagator::fixed()`] + /// callback gets invoked when `expr` is bound to a value.a Equalities + /// between registered expressions are reported thought + /// [`UserPropagator::eq()`]. A consumer can use the`Self::propagate`te] or + /// [`Self::conflict`] functions to invoke propagations or conflicts as a + /// consequence of these callbacks. These functions take a list of + /// identifiers for registered expressions that have been fixed. The list of + /// identifiers must correspond to already fixed values. Similarly, a list + /// of propagated equalities can be supplied. These must correspond to + /// equalities that have been registered during a callback. + /// + /// see [`Z3_solver_propagate_register_cb`] and [`Z3_solver_propagate_register`] + pub fn add(&self, expr: &impl Ast<'ctx>) { + debug_assert_eq!(self.get_ctx(), expr.get_ctx()); + unsafe { Z3_solver_propagate_register_cb(self.z3_ctx(), self.cb, expr.get_z3_ast()) }; + } + + /// Propagate a consequence based on fixed values and equalities. + /// + /// A client may invoke it during the `propagate_fixed`, `propagate_eq`, + /// `propagate_diseq`, and `propagate_final` callbacks. The callback adds a + /// propagation consequence based on the fixed values passed ids and + /// equalities eqs based on parameters lhs, rhs. + /// + /// The solver might discard the propagation in case it is true in the + /// current state. The function returns false in this case; otw. the + /// function returns true. At least one propagation in the final callback + /// has to return true in order to prevent the solver from finishing. + /// + /// - `fixed`: iterator containing terms that are fixed in the current scope + /// - `lhs`: left side of equalities + /// - `rhs`: right side of equalities + /// - `conseq`: consequence to propagate. It is typically an atomic formula, + /// but it can be an arbitrary formula. + /// + /// panics if not called from a callback or if `lhs` and `rhs` don't have the same + /// length. + /// + /// see [`Z3_solver_propagate_consequence`] + pub fn propagate<'b, I, J, A>( + &'b self, + fixed: I, + lhs: J, + rhs: J, + conseq: &'b ast::Bool<'ctx>, + ) -> bool + where + I: IntoIterator>, + J: IntoIterator, + A: Ast<'ctx> + 'b, + { + /* using generics because I need to map on the arguments anyway and it will turn + the other functions defined from `propagate` into the same things as the C++ + API this is based on */ + fn into_vec_and_check<'ctx, 'b, A: Ast<'ctx> + 'b>( + ctx: &'ctx Context, + iter: impl IntoIterator, + ) -> Vec { + iter.into_iter() + .inspect(|e| debug_assert_eq!(ctx, e.get_ctx())) + .map(|e| e.get_z3_ast()) + .collect() + } + debug_assert_eq!(self.get_ctx(), conseq.get_ctx()); + let fixed = into_vec_and_check(self.get_ctx(), fixed); + let lhs = into_vec_and_check(self.get_ctx(), lhs); + let rhs = into_vec_and_check(self.get_ctx(), rhs); + let conseq = conseq.get_z3_ast(); + assert_eq!(lhs.len(), rhs.len()); + + // not sure what the API does exactly, but it probably expects null pointers + // rather than dangling ones in case of empty vecs + let to_ptr = |v: Vec<_>| { + if v.is_empty() { + ::std::ptr::null() + } else { + v.as_ptr() + } + }; + + unsafe { + Z3_solver_propagate_consequence( + self.z3_ctx(), + self.cb, + fixed.len().try_into().unwrap(), + to_ptr(fixed), + lhs.len().try_into().unwrap(), + to_ptr(lhs), + to_ptr(rhs), + conseq, + ) + } + } + + /// triggers a confict on `fixed` + /// + /// Equivalent to [`self.propagate(fixed, [] , [], FALSE)`](Self::propagate) + /// + /// panics if not called from a callback. + pub fn conflict_on(&self, fixed: &[&ast::Bool<'ctx>]) -> bool { + self.propagate::<_, _, ast::Bool<'ctx>>( + fixed.iter().copied(), + [], + [], + &ast::Bool::from_bool(self.get_ctx(), false), + ) + } + + /// triggers a confict on `fixed` and `lhs == rhs` + /// + /// Equivalent to [`self.propagate(fixed, lhs, rhs, FALSE)`](Self::propagate) + /// + /// panics if not called from a callback or if `lhs` and `rhs` don't have the same + /// length. + pub fn conflict( + &self, + fixed: &[&ast::Bool<'ctx>], + lhs: &[&ast::Dynamic<'ctx>], + rhs: &[&ast::Dynamic<'ctx>], + ) -> bool { + self.propagate( + fixed.iter().copied(), + lhs.iter().copied(), + rhs.iter().copied(), + &ast::Bool::from_bool(self.get_ctx(), false), + ) + } + + /// Propagate `conseq` + /// + /// Equivalent to [`self.propagate(fixed, [], [], conseq)`](Self::propagate) + /// + /// panics if not called from a callback. + pub fn propagate_one(&self, fixed: &[&ast::Bool<'ctx>], conseq: &ast::Bool<'ctx>) -> bool { + self.propagate::<_, _, ast::Bool>(fixed.iter().copied(), [], [], conseq) + } + + pub fn get_ctx(&self) -> &'ctx Context { + self.ctx + } + + fn z3_ctx(&self) -> Z3_context { + self.get_ctx().z3_ctx + } +} + +/// The `on_clause` callback +/// +/// a callback that is invoked by when a clause is +/// - asserted to the CDCL engine (corresponding to an input clause after pre-processing) +/// - inferred by CDCL(T) using either a SAT or theory conflict/propagation +/// - deleted by the CDCL(T) engine +pub trait OnClause<'ctx>: Debug { + fn get_ctx(&self) -> &'ctx Context; + /// the callback + fn on_clause(&mut self, proof_hint: &Dynamic<'ctx>, deps: &[u32], literals: &[Dynamic<'ctx>]); +} + +/// A quick way to implement [`OnClause`] using a clausure +struct ClausureOnClause<'ctx, F> +where + F: FnMut(&Dynamic<'ctx>, &[u32], &[Dynamic<'ctx>]), +{ + ctx: &'ctx Context, + f: F, +} + +impl<'ctx, F> ClausureOnClause<'ctx, F> +where + F: FnMut(&Dynamic<'ctx>, &[u32], &[Dynamic<'ctx>]), +{ + pub fn new(ctx: &'ctx Context, f: F) -> Self { + Self { ctx, f } + } +} +impl<'ctx, F> Debug for ClausureOnClause<'ctx, F> +where + F: FnMut(&Dynamic<'ctx>, &[u32], &[Dynamic<'ctx>]), +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClausureOnClause") + .field("ctx", &self.ctx) + .field("f", &std::any::type_name::()) + .finish() + } +} + +impl<'ctx, F> OnClause<'ctx> for ClausureOnClause<'ctx, F> +where + F: FnMut(&Dynamic<'ctx>, &[u32], &[Dynamic<'ctx>]), +{ + fn get_ctx(&self) -> &'ctx Context { + self.ctx + } + + fn on_clause(&mut self, proof_hint: &Dynamic<'ctx>, deps: &[u32], literals: &[Dynamic<'ctx>]) { + (self.f)(proof_hint, deps, literals) + } +} + +/// Wrapper around a solver to ensure the [`UserPropagator`]s live long enough +/// +/// `'ctx` is the liftime of the [Context] and `'a` the commun liftime of the +/// [`UserPropagator`]s +/// +/// [Context]: crate::Context +#[derive(Debug)] +pub struct UPSolver<'ctx, 'a> { + solver: Solver<'ctx>, + user_propagators: Vec + 'a>>>, + on_clause_propagators: Vec + 'a>>>, +} + +impl<'ctx, 'a> UPSolver<'ctx, 'a> { + pub fn new(solver: Solver<'ctx>) -> Self { + Self { + solver, + user_propagators: Default::default(), + on_clause_propagators: Default::default(), + } + } + + pub fn solver(&self) -> &Solver<'ctx> { + &self.solver + } + + /// Add an expression to be tracked by all [`UserPropagator`]s registered to this solver + pub fn add(&self, expr: &impl Ast<'ctx>) { + let z3_ctx = self.solver().get_context().z3_ctx; + let z3_slv = self.solver().z3_slv; + unsafe { Z3_solver_propagate_register(z3_ctx, z3_slv, expr.get_z3_ast()) } + } + + /// Registers a [`UserPropagator`] to [`Self::solver`] + pub fn register_user_propagator + 'a>(&mut self, up: U) { + let pin = Box::pin(up); + unsafe { z3_user_propagator_init(pin.as_ref(), self.solver().z3_slv) } + self.user_propagators.push(pin); + } + + pub fn register_on_clause + 'a>(&mut self, up: U) { + let pin = Box::pin(up); + let s = self.z3_slv; + let c = self.ctx.z3_ctx; + let user_ctx = pin.as_ref().get_ref() as *const _ as *mut ::std::ffi::c_void; + unsafe { + Z3_solver_register_on_clause(c, s, user_ctx, Some(callbacks::clause_eh::)); + } + self.on_clause_propagators.push(pin); + } + + pub fn quick_register_on_clause(&mut self, f: F) + where + F: FnMut(&Dynamic<'ctx>, &[u32], &[Dynamic<'ctx>]) + 'a, + 'ctx: 'a, + { + let ctx: &'ctx Context = self.solver().get_context(); + // let up: ClausureOnClause<'ctx, F> = ; + self.register_on_clause(ClausureOnClause::new(ctx, f)); + } +} + +impl<'ctx> From> for UPSolver<'ctx, '_> { + fn from(solver: Solver<'ctx>) -> Self { + Self::new(solver) + } +} + +impl<'ctx> std::ops::Deref for UPSolver<'ctx, '_> { + type Target = Solver<'ctx>; + + fn deref(&self) -> &Self::Target { + &self.solver + } +} + +// ========================================================= +// =============== implementation details ================== +// ========================================================= + +/// Does all the z3 calls to register a new [`UserPropagator`]. +/// +/// At this point `z3_slv` effectively borrows `up`. I need +/// [`super::UPSolver`] to solver the liftetime problems, hence why to +/// function is `unsafe`. +pub(crate) unsafe fn z3_user_propagator_init<'ctx, U: UserPropagator<'ctx>>( + up: Pin<&U>, + z3_slv: Z3_solver, +) { + let z3_ctx = up.get_context().z3_ctx; + debug!("Z3_solver_propagate_init"); + Z3_solver_propagate_init( + z3_ctx, + z3_slv, + up.get_ref() as *const _ as *mut ::std::ffi::c_void, + Some(callbacks::push_eh::), + Some(callbacks::pop_eh::), + Some(callbacks::fresh_eh::), + ); + // we register all callbacks + // fixed + debug!("Z3_solver_propagate_fixed"); + Z3_solver_propagate_fixed(z3_ctx, z3_slv, Some(callbacks::fixed_eh::)); + // eq + debug!("Z3_solver_propagate_eq"); + Z3_solver_propagate_eq(z3_ctx, z3_slv, Some(callbacks::eq_eh::)); + // eq + debug!("Z3_solver_propagate_diseq"); + Z3_solver_propagate_diseq(z3_ctx, z3_slv, Some(callbacks::neq_eh::)); + // final + debug!("Z3_solver_propagate_final"); + Z3_solver_propagate_final(z3_ctx, z3_slv, Some(callbacks::final_eh::)); + // created + debug!("Z3_solver_propagate_created"); + Z3_solver_propagate_created(z3_ctx, z3_slv, Some(callbacks::created_eh::)); + // decide + debug!("Z3_solver_propagate_decide"); + Z3_solver_propagate_decide(z3_ctx, z3_slv, Some(callbacks::decide_eh::)); +} + +/// all the callbacks used in this file +mod callbacks { + use crate::{ + ast::{Ast, Dynamic}, + user_propagator::{CallBack, OnClause, UserPropagator}, + }; + use log::debug; + use std::convert::TryInto; + use z3_sys::*; + + /// Turns a `void*` into a `&mut Self`. + /// + /// This is highly unsafe! It panics if the pointer is `null`, no other checks are made! + unsafe fn mut_from_user_context<'ctx, 'b, U: UserPropagator<'ctx>>( + ptr: *mut ::std::ffi::c_void, + ) -> &'b mut U { + (ptr as *mut U).as_mut().unwrap() + } + + pub(crate) extern "C" fn push_eh<'ctx, U: UserPropagator<'ctx>>( + ctx: *mut ::std::ffi::c_void, + cb: Z3_solver_callback, + ) { + debug!("push_eh"); + let up = unsafe { mut_from_user_context::(ctx) }; + up.push(&CallBack::new(cb, up.get_context())); + } + + pub(crate) extern "C" fn pop_eh<'ctx, U: UserPropagator<'ctx>>( + ctx: *mut ::std::ffi::c_void, + cb: Z3_solver_callback, + num_scopes: ::std::os::raw::c_uint, + ) { + debug!("pop_eh"); + let up = unsafe { mut_from_user_context::(ctx) }; + up.pop(&CallBack::new(cb, up.get_context()), num_scopes); + } + + // TODO: figure out how to make this work + #[allow(unused_variables)] + #[allow(clippy::extra_unused_type_parameters)] + pub(crate) extern "C" fn fresh_eh<'ctx, U: UserPropagator<'ctx>>( + ctx: *mut ::std::ffi::c_void, + new_context: Z3_context, + ) -> *mut ::std::ffi::c_void { + ::std::ptr::null_mut() + } + + pub(crate) extern "C" fn fixed_eh<'ctx, U: UserPropagator<'ctx>>( + ctx: *mut ::std::ffi::c_void, + cb: Z3_solver_callback, + var: Z3_ast, + value: Z3_ast, + ) { + debug!("fixed_eh"); + let up = unsafe { mut_from_user_context::(ctx) }; + let var = unsafe { Dynamic::wrap(up.get_context(), var) }; + let value = unsafe { Dynamic::wrap(up.get_context(), value) }; + up.fixed(&CallBack::new(cb, up.get_context()), &var, &value); + } + + pub(crate) extern "C" fn eq_eh<'ctx, U: UserPropagator<'ctx>>( + ctx: *mut ::std::ffi::c_void, + cb: Z3_solver_callback, + x: Z3_ast, + y: Z3_ast, + ) { + debug!("eq_eh"); + let up = unsafe { mut_from_user_context::(ctx) }; + let x = unsafe { Dynamic::wrap(up.get_context(), x) }; + let y = unsafe { Dynamic::wrap(up.get_context(), y) }; + up.eq(&CallBack::new(cb, up.get_context()), &x, &y); + } + + pub(crate) extern "C" fn neq_eh<'ctx, U: UserPropagator<'ctx>>( + ctx: *mut ::std::ffi::c_void, + cb: Z3_solver_callback, + x: Z3_ast, + y: Z3_ast, + ) { + debug!("neq_eh"); + let up = unsafe { mut_from_user_context::(ctx) }; + let x = unsafe { Dynamic::wrap(up.get_context(), x) }; + let y = unsafe { Dynamic::wrap(up.get_context(), y) }; + up.neq(&CallBack::new(cb, up.get_context()), &x, &y); + } + + pub(crate) extern "C" fn final_eh<'ctx, U: UserPropagator<'ctx>>( + ctx: *mut ::std::ffi::c_void, + cb: Z3_solver_callback, + ) { + debug!("final_eh"); + let up = unsafe { mut_from_user_context::(ctx) }; + up.final_(&CallBack::new(cb, up.get_context())); + } + + pub(crate) extern "C" fn created_eh<'ctx, U: UserPropagator<'ctx>>( + ctx: *mut ::std::ffi::c_void, + cb: Z3_solver_callback, + e: Z3_ast, + ) { + debug!("created_eh"); + let up = unsafe { mut_from_user_context::(ctx) }; + let e = unsafe { Dynamic::wrap(up.get_context(), e) }; + up.created(&CallBack::new(cb, up.get_context()), &e); + } + + pub(crate) extern "C" fn decide_eh<'ctx, U: UserPropagator<'ctx>>( + ctx: *mut ::std::ffi::c_void, + cb: Z3_solver_callback, + val: Z3_ast, + bit: ::std::os::raw::c_uint, + is_pos: bool, + ) { + debug!("decide_eh"); + let up = unsafe { mut_from_user_context::(ctx) }; + let val = unsafe { Dynamic::wrap(up.get_context(), val) }; + up.decide(&CallBack::new(cb, up.get_context()), &val, bit, is_pos); + } + + pub(crate) unsafe extern "C" fn clause_eh<'ctx, U: OnClause<'ctx>>( + ctx: *mut ::std::ffi::c_void, + proof_hint: Z3_ast, + n: ::std::os::raw::c_uint, + deps: *const ::std::os::raw::c_uint, + literals: Z3_ast_vector, + ) { + debug!("clause_eh {n} {deps:?}"); + let oc = (ctx as *mut U).as_mut().unwrap(); + let n: usize = n.try_into().unwrap(); + let deps = if n == 0 { + &[] + } else { + std::slice::from_raw_parts(deps, n) + }; + let literals: Vec<_> = (0..Z3_ast_vector_size(oc.get_ctx().get_z3_context(), literals)) + .map(|i| { + Dynamic::wrap( + oc.get_ctx(), + Z3_ast_vector_get(oc.get_ctx().get_z3_context(), literals, i), + ) + }) + .collect(); + let proof_hint = Dynamic::wrap(oc.get_ctx(), proof_hint); + oc.on_clause(&proof_hint, deps, &literals); + } +} + +#[cfg(test)] +mod test { + use std::convert::TryInto; + + use crate::{ + ast::{self, Ast, Dynamic}, + user_propagator::{CallBack, UPSolver, UserPropagator}, + Config, Context, FuncDecl, Solver, Sort, + }; + + #[test] + fn on_clause() { + let _ = env_logger::try_init(); + let mut cfg = Config::default(); + cfg.set_model_generation(true); + cfg.set_proof_generation(true); + let ctx = Context::new(&cfg); + let s_sort = Sort::uninterpreted(&ctx, "S".into()); + let f = FuncDecl::new(&ctx, "f", &[&s_sort], &s_sort); + let g = FuncDecl::new(&ctx, "g", &[&s_sort], &Sort::bool(&ctx)); + let x = FuncDecl::new(&ctx, "x", &[], &s_sort).apply(&[]); + + let mut y: String = "I am a non-static lifetime check".to_owned(); + { + let mut s = UPSolver::new(Solver::new(&ctx)); + + s.quick_register_on_clause(|proof_hint, deps, literals| { + println!("on_clause:\n\tproof_hint: {proof_hint:}\n\tdeps: {deps:?}\n\tliteral: {literals:?}"); + y = "check successfull".to_owned(); + }); + + let gx = g.apply(&[&x]).as_bool().unwrap(); + let gxx = g.apply(&[&f.apply(&[&f.apply(&[&x])])]).as_bool().unwrap(); + + s.assert(&!(&gx & &gxx)); + s.assert(&(&gx | &gxx)); + s.assert(&f.apply(&[&x])._eq(&x)); + s.check(); + println!("result: {:?}", s.check()); + println!("{:?}", s.get_model()); + } + assert_eq!(&y, "check successfull") + } + + #[test] + fn user_propagator() { + /* proves f(f(f(f(x)))) = f(x) with a up that rewrites f(f(x)) into f(x) */ + let _ = env_logger::try_init(); + let mut cfg = Config::default(); + cfg.set_model_generation(true); + let ctx = Context::new(&cfg); + let s_sort = Sort::uninterpreted(&ctx, "S".into()); + let f = FuncDecl::new_up(&ctx, "f", &[&s_sort], &s_sort); + let x = FuncDecl::new(&ctx, "x", &[], &s_sort).apply(&[]); + let s = Solver::new(&ctx); + + #[derive(Debug)] + struct UP<'ctx> { + pub f: &'ctx FuncDecl<'ctx>, + pub ctx: &'ctx Context, + } + + impl<'ctx> UP<'ctx> { + fn generate_next_term(&self, e: &Dynamic<'ctx>) -> Option> { + let f1 = e.safe_decl().ok()?; + (f1.name() == self.f.name()).then_some(())?; // exits if f1 != f + let [arg1] = e.children().try_into().ok()?; + let f2 = e.safe_decl().ok()?; + (f2.name() == self.f.name()).then_some(())?; + let [arg2] = arg1.children().try_into().ok()?; + + let nt = self.f.apply(&[&arg2]); + println!("propagating: {e} = {nt}"); + Some(nt) + } + } + + impl<'ctx> UserPropagator<'ctx> for UP<'ctx> { + fn eq(&mut self, upw: &CallBack<'ctx>, x: &Dynamic<'ctx>, y: &Dynamic<'ctx>) { + println!("eq: {x} = {y}"); + for e in [x, y] { + let Some(nt) = self.generate_next_term(e) else { + continue; + }; + upw.propagate_one(&[], &e._eq(&nt)); + } + } + + fn neq(&mut self, upw: &CallBack<'ctx>, x: &Dynamic<'ctx>, y: &Dynamic<'ctx>) { + println!("neq: {x} != {y}"); + for e in [x, y] { + let Some(nt) = self.generate_next_term(e) else { + continue; + }; + upw.propagate_one(&[], &e._eq(&nt)); + } + } + + fn created(&mut self, _: &CallBack<'ctx>, e: &ast::Dynamic<'ctx>) { + println!("created: {e}") + } + + fn pop(&mut self, _: &CallBack<'ctx>, num_scopes: u32) { + println!("pop: {num_scopes:}") + } + + fn push(&mut self, _: &CallBack<'ctx>) { + println!("push") + } + + fn decide( + &mut self, + _: &CallBack<'ctx>, + val: &ast::Dynamic<'ctx>, + bit: u32, + is_pos: bool, + ) { + println!("decide: {val}, {bit:} {is_pos}") + } + + fn fixed( + &mut self, + _: &CallBack<'ctx>, + id: &ast::Dynamic<'ctx>, + e: &ast::Dynamic<'ctx>, + ) { + println!("fixed: {id} {e}") + } + + fn final_(&mut self, _: &CallBack<'ctx>) { + println!("final") + } + + fn get_context(&self) -> &'ctx Context { + self.ctx + } + } + + let mut s = UPSolver::new(s); + // let up = Box::pin(UP { f: &f, ctx: &ctx }); + s.register_user_propagator(UP { f: &f, ctx: &ctx }); + s.assert( + &f.apply(&[&f.apply(&[&f.apply(&[&f.apply(&[&x])])])]) + ._eq(&f.apply(&[&x])) + .not(), + ); + println!("result: {:?}", s.check()); + } +}