diff --git a/src/builder.rs b/src/builder.rs index dc99381..2fb392b 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -767,6 +767,48 @@ impl CpModelBuilder { pub fn solve_with_parameters(&self, params: &proto::SatParameters) -> proto::CpSolverResponse { ffi::solve_with_parameters(self.proto(), params) } + + /// Solves the model with the given + /// [parameters][proto::SatParameters], + /// a solution handler that is called with feasible solutions [proto::CpSolverResponse], + /// and returns the final [proto::CpSolverResponse]. + /// + /// The given function will be called on each improving feasible solution found + /// during the search. For a non-optimization problem, if the option + /// [proto::SatParameters::enumerate_all_solutions] to find all + /// solutions was set, then this will be called on each new solution. + /// + /// Please note that it does not work in parallel + /// (i. e. parameter [proto::SatParameters::num_search_workers] > 1). + /// + /// ``` + /// # use std::cell::RefCell; + /// # use std::rc::Rc; + /// # use cp_sat::builder::CpModelBuilder; + /// # use cp_sat::proto::{SatParameters, CpSolverResponse}; + /// let mut model = CpModelBuilder::default(); + /// // linear constraint will only allow a = 2, a = 3 and a = 4 + /// let a = model.new_int_var([(2, 7)]); + /// model.add_linear_constraint([(3, a)], [(0, 13)]); + /// let mut params = SatParameters::default(); + /// params.enumerate_all_solutions = Some(true); + /// + /// let memory = Rc::new(RefCell::new(Vec::new())); + /// let memory2 = memory.clone(); + /// let handler = move |response: CpSolverResponse| { + /// memory2.borrow_mut().push(response); + /// }; + /// + /// let _response = model.solve_with_parameters_and_handler(¶ms, handler); + /// assert_eq!(3, memory.borrow().len()); + /// ``` + pub fn solve_with_parameters_and_handler( + &self, + params: &proto::SatParameters, + handler: impl FnMut(proto::CpSolverResponse) + 'static, + ) -> proto::CpSolverResponse { + ffi::solve_with_parameters_and_handler(self.proto(), params, Box::new(handler)) + } } /// Boolean variable identifier. diff --git a/src/cp_sat_wrapper.cpp b/src/cp_sat_wrapper.cpp index 1dc82db..c2a1126 100644 --- a/src/cp_sat_wrapper.cpp +++ b/src/cp_sat_wrapper.cpp @@ -51,6 +51,66 @@ cp_sat_wrapper_solve( return out_buf; } +/** + * Solution handler that is called on every encountered solution. + * + * Arguments: + * - serialized buffer of a CpSolverResponse + * - length of the buffer + * - additional data passed from the outside + */ +typedef void (*solution_handler)(unsigned char*, size_t, void*); + +/** + * Similar to cp_sat_wrapper_solve_with_parameters, but with a callback function + * for all encountered solutions. + * + * - handler: called on every solution + * - handler_data: additional data that is provided to the callback + */ +extern "C" unsigned char* +cp_sat_wrapper_solve_with_parameters_and_handler( + unsigned char* model_buf, + size_t model_size, + unsigned char* params_buf, + size_t params_size, + solution_handler handler, + void* handler_data, + size_t* out_size) +{ + sat::Model extra_model; + sat::CpModelProto model; + bool res = model.ParseFromArray(model_buf, model_size); + assert(res); + + sat::SatParameters params; + res = params.ParseFromArray(params_buf, params_size); + assert(res); + + extra_model.Add(sat::NewSatParameters(params)); + + // local function that serializes the CpSolverResponse for the provided solution handler + auto wrapped_handler = [&](const operations_research::sat::CpSolverResponse& curr_response) { + // serialize CpSolverResponse + size_t response_size = curr_response.ByteSizeLong(); + unsigned char* response_buf = (unsigned char*) malloc(response_size); + bool curr_res = curr_response.SerializeToArray(response_buf, response_size); + assert(curr_res); + + handler(response_buf, response_size, handler_data); + }; + extra_model.Add(sat::NewFeasibleSolutionObserver(wrapped_handler)); + + sat::CpSolverResponse response = sat::SolveCpModel(model, &extra_model); + + *out_size = response.ByteSizeLong(); + unsigned char* out_buf = (unsigned char*) malloc(*out_size); + res = response.SerializeToArray(out_buf, *out_size); + assert(res); + + return out_buf; +} + extern "C" char* cp_sat_wrapper_cp_model_stats(unsigned char* model_buf, size_t model_size) { sat::CpModelProto model; diff --git a/src/ffi.rs b/src/ffi.rs index 759eabe..7284c6c 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -2,6 +2,7 @@ use crate::proto; use libc::c_char; use prost::Message; use std::ffi::CStr; +use std::ffi::c_void; extern "C" { fn cp_sat_wrapper_solve( @@ -16,6 +17,15 @@ extern "C" { params_size: usize, out_size: &mut usize, ) -> *mut u8; + fn cp_sat_wrapper_solve_with_parameters_and_handler( + model_buf: *const u8, + model_size: usize, + params_buf: *const u8, + params_size: usize, + handler_caller: extern "C" fn(*const u8, usize, *mut c_void), + handler: *mut c_void, + out_size: &mut usize, + ) -> *mut u8; fn cp_sat_wrapper_cp_model_stats(model_buf: *const u8, model_size: usize) -> *mut c_char; fn cp_sat_wrapper_cp_solver_response_stats( response_buf: *const u8, @@ -72,6 +82,66 @@ pub fn solve_with_parameters( response } +/// User provided solution handler that is called with feasible solutions. +pub type SolutionHandler = Box; + +/// Solves the given [CpModelProto][crate::proto::CpModelProto] with +/// the given parameters, +/// and calls the [SolutionHandler] on each improving feasible solution found +/// during the search. For a non-optimization problem, if the option +/// [proto::SatParameters.enumerate_all_solutions] to find all +/// solutions was set, then this will be called on each new solution. +/// +/// Please note that it does not work in parallel +/// (i. e. parameter [proto::SatParameters::num_search_workers] > 1). +pub fn solve_with_parameters_and_handler( + model: &proto::CpModelProto, + params: &proto::SatParameters, + mut handler: SolutionHandler, +) -> proto::CpSolverResponse { + let mut model_buf = Vec::default(); + model.encode(&mut model_buf).unwrap(); + let mut params_buf = Vec::default(); + params.encode(&mut params_buf).unwrap(); + + let mut out_size = 0; + let res = unsafe { + cp_sat_wrapper_solve_with_parameters_and_handler( + model_buf.as_ptr(), + model_buf.len(), + params_buf.as_ptr(), + params_buf.len(), + solution_handler_caller, + &mut handler as *mut _ as *mut c_void, + &mut out_size, + ) + }; + let out_slice = unsafe { std::slice::from_raw_parts(res, out_size) }; + let response = proto::CpSolverResponse::decode(out_slice).unwrap(); + unsafe { libc::free(res as _) }; + response +} + +/// Callback that is called from cpp code and transforms a buffered response to a +/// [proto::CpSolverResponse] that can be used by a [SolutionHandler]. +/// +/// # Arguments +/// - `response_buf` and `response_size`: buffer and size of a [proto::CpSolverResponse] +/// - `handler`: a user provided solution handler [SolutionHandler] that accepts a +/// [proto::CpSolverResponse] +extern "C" fn solution_handler_caller(response_buf: *const u8, response_size: usize, handler: *mut c_void) { + let response_slice = unsafe { + std::slice::from_raw_parts(response_buf, response_size) + }; + let response = proto::CpSolverResponse::decode(response_slice).unwrap(); + unsafe { libc::free(response_buf as _) }; + + unsafe { + let tmp = handler as *mut SolutionHandler; + (*tmp)(response); + } +} + /// Returns a string with some statistics on the given /// [CpModelProto][crate::proto::CpModelProto]. pub fn cp_model_stats(model: &proto::CpModelProto) -> String { diff --git a/tests/solution_handler.rs b/tests/solution_handler.rs new file mode 100644 index 0000000..959541f --- /dev/null +++ b/tests/solution_handler.rs @@ -0,0 +1,63 @@ +use std::cell::RefCell; +use std::collections::HashSet; +use std::rc::Rc; +use cp_sat::builder::CpModelBuilder; +use cp_sat::proto::{SatParameters, CpSolverResponse}; + +/// In a non-optimization problem all feasible solutions should be found. +#[test] +fn enumeration_solution_handler() { + let mut model = CpModelBuilder::default(); + // linear constraint will only allow a = 2, a = 3 and a = 4 + let a = model.new_int_var([(2, 7)]); + model.add_linear_constraint([(3, a)], [(0, 13)]); + let mut params = SatParameters::default(); + params.enumerate_all_solutions = Some(true); + + let memory = Rc::new(RefCell::new(Vec::new())); + let memory2 = memory.clone(); + let handler = move |response: CpSolverResponse| { + memory2.borrow_mut().push(response); + }; + + let _response = model.solve_with_parameters_and_handler(¶ms, handler); + + assert_eq!(3, memory.borrow().len()); + + let expected = HashSet::from([2, 3, 4]); + let actual = memory + .borrow() + .iter() + .map(|response| a.solution_value(response)) + .collect::>(); + + assert_eq!(expected, actual); +} + +/// In an optimization problem at least one feasible solution should be found. +#[test] +fn optimization_solution_handler() { + let mut model = CpModelBuilder::default(); + // linear constraint will only allow a = 2, a = 3 and a = 4 + let a = model.new_int_var([(2, 7)]); + model.add_linear_constraint([(3, a)], [(0, 13)]); + model.minimize(a); + let mut params = SatParameters::default(); + params.enumerate_all_solutions = Some(true); + + let memory = Rc::new(RefCell::new(Vec::new())); + let memory2 = memory.clone(); + let handler = move |response: CpSolverResponse| { + memory2.borrow_mut().push(response); + }; + + let response = model.solve_with_parameters_and_handler(¶ms, handler); + + assert_eq!(2, a.solution_value(&response)); + + // At least one feasible solution is encountered. + // As we do not know how often the solution improves, or whether the first + // feasible solution is already the optimal one, we cannot expect more than one + // improvement. + assert!(memory.borrow().len() >= 1); +}