diff --git a/src/cnf.rs b/src/cnf.rs index 4b7c645..54e648d 100644 --- a/src/cnf.rs +++ b/src/cnf.rs @@ -30,11 +30,12 @@ impl fmt::Display for Literal { #[derive(Debug, Clone)] pub struct CNFClause { pub literals: Array1, + pub weight: Option, // Optional weight for the clause } impl CNFClause { - pub fn new(literals: Array1) -> Self { - Self { literals } + pub fn new(literals: Array1, weight: Option) -> Self { + Self { literals, weight } } // Function to return the literals in the clause. @@ -54,12 +55,13 @@ impl fmt::Display for CNFClause { pub struct CNFFormula { pub clauses: Array1, pub varnum: usize, + pub must_clause_weight: Option, // Optional weight for "must have" clauses } impl CNFFormula { - pub fn new(clauses: Array1, varnum: Option) -> Self { + pub fn new(clauses: Array1, varnum: Option, must_clause_weight: Option) -> Self { if let Some(varnum) = varnum { - Self { clauses, varnum } + Self { clauses, varnum, must_clause_weight } } else { let mut variables: HashSet = HashSet::new(); @@ -72,6 +74,7 @@ impl CNFFormula { Self { clauses, varnum: variables.len(), + must_clause_weight, } } } @@ -138,6 +141,7 @@ impl fmt::Display for CNFFormula { pub fn parse_dimacs_format(input: &str) -> CNFFormula { let mut clauses: Vec = Vec::new(); let mut varnum = None; + let mut must_clause_weight = None; for line in input.lines() { if line.starts_with('c') { @@ -151,6 +155,12 @@ pub fn parse_dimacs_format(input: &str) -> CNFFormula { varnum = Some(split.next().unwrap().parse().unwrap()); // The 'nbclauses' value is not needed for parsing continue; + } else if line.starts_with("w") { + // Parse the weight line: w weight + let mut split = line.split_whitespace(); + split.next(); // Skip the 'w' token + must_clause_weight = Some(split.next().unwrap().parse().unwrap()); + continue; } else { // Parse clause lines let literals: Vec = line @@ -164,11 +174,12 @@ pub fn parse_dimacs_format(input: &str) -> CNFFormula { }) .collect(); - clauses.push(CNFClause::new(literals.into())); + let weight = must_clause_weight; // Assign the weight to the clause + clauses.push(CNFClause::new(literals.into(), weight)); } } - CNFFormula::new(clauses.into(), varnum) + CNFFormula::new(clauses.into(), varnum, must_clause_weight) } pub fn apply_variable_mapping( @@ -191,11 +202,11 @@ pub fn apply_variable_mapping( } } - let mapped_clause = CNFClause::new(mapped_literals.into()); + let mapped_clause = CNFClause::new(mapped_literals.into(), clause.weight); mapped_clauses.push(mapped_clause); } - CNFFormula::new(mapped_clauses.into(), Some(formula.varnum)) + CNFFormula::new(mapped_clauses.into(), Some(formula.varnum), formula.must_clause_weight) } // Normalizes the variables of CNF function. That is, the smallest variable should be 0, and @@ -227,6 +238,11 @@ pub fn cnf_to_dimacs_format(formula: &CNFFormula) -> String { // Add problem line: p cnf nbvar nbclauses dimacs_string.push_str(&format!("p cnf {varnum} {num_clauses}\n")); + // Add weight line if present + if let Some(weight) = formula.must_clause_weight { + dimacs_string.push_str(&format!("w {weight}\n")); + } + // Add clauses for clause in &formula.clauses { for literal in clause.get_literals() { @@ -405,13 +421,14 @@ pub fn convert_to_cnf_formula(formula_set: &CNFFormulaSet) -> CNFFormula { .cloned() .collect::>(), ); - let clause = CNFClause::new(literals); + let clause = CNFClause::new(literals, None); clauses_array.push(clause); } CNFFormula { clauses: Array1::from(clauses_array), varnum: formula_set.varnum, + must_clause_weight: None, } } diff --git a/src/main.rs b/src/main.rs index 514b8ea..70c9bf8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,6 +9,7 @@ use std::collections::HashMap; use std::fs; use std::io::{self, Write}; use std::path::PathBuf; +use std::time::{Duration, Instant}; #[derive(Parser)] #[command(author, version, about, long_about = None)] @@ -62,6 +63,10 @@ pub struct SolveOpts { /// Optional output file for the transformed CNF formula #[arg(short = 'p', long)] pub transformed_output: Option, + + /// Time limit in seconds + #[arg(short = 'T', long)] + pub time_limit: Option, } #[derive(Args)] @@ -112,6 +117,10 @@ pub struct BatchOpts { /// Learning rate #[arg(short = 'l', long)] pub learning_rate: Option, + + /// Time limit in seconds + #[arg(short = 'T', long)] + pub time_limit: Option, } #[derive(Args)] @@ -158,6 +167,7 @@ fn solve(solve_opts: SolveOpts) -> Result<(), Box> { 7.0 }; let transformed_output_path = &solve_opts.transformed_output; + let time_limit = solve_opts.time_limit.map(Duration::from_secs); println!("Reading CNF formula from file..."); let cnf_string = fs::read_to_string(input_path)?; @@ -184,14 +194,50 @@ fn solve(solve_opts: SolveOpts) -> Result<(), Box> { xl: Array1::ones(normalized_formula.clauses.len()), }; - let result = simulate( - &mut state, - &normalized_formula, - tolerance, - step_size, - step_number, - learning_rate, - ); + let start_time = Instant::now(); + let mut best_solution = None; + let mut best_score = f64::INFINITY; + + while time_limit.map_or(true, |limit| start_time.elapsed() < limit) { + let result = simulate( + &mut state, + &normalized_formula, + tolerance, + step_size, + step_number, + learning_rate, + ); + + let score = normalized_formula + .clauses + .iter() + .map(|clause| { + clause + .literals + .iter() + .map(|lit| { + let value = if lit.is_negated { + !result[lit.variable] + } else { + result[lit.variable] + }; + if value { 0.0 } else { clause.weight.unwrap_or(1.0) } + }) + .sum::() + }) + .sum::(); + + if score < best_score { + best_score = score; + best_solution = Some(result.clone()); + } + + if score == 0.0 { + break; + } + } + + let result = best_solution.unwrap_or_else(|| state.v.iter().map(|&value| value > 0.0).collect()); println!("Mapping values..."); let mut mapped_values = map_values_by_indices(&var_mapping, &result); @@ -270,6 +316,7 @@ fn batch(batch_opts: BatchOpts) -> Result<(), Box> { let step_number = batch_opts.step_number; let step_size = batch_opts.step_size; let learning_rate = batch_opts.learning_rate; + let time_limit = batch_opts.time_limit.map(Duration::from_secs); println!("Reading CNF formula from file..."); let cnf_string = fs::read_to_string(input_path)?; @@ -285,8 +332,15 @@ fn batch(batch_opts: BatchOpts) -> Result<(), Box> { let mut rng = rand::thread_rng(); let mut is_satisfiable = false; let mut mapped_values = HashMap::new(); + let start_time = Instant::now(); + let mut best_solution = None; + let mut best_score = f64::INFINITY; for i in 0..batch_size { + if time_limit.map_or(false, |limit| start_time.elapsed() >= limit) { + break; + } + print!("\rRunning simulation {}.", i + 1); io::stdout().flush().unwrap(); // Flush stdout to make sure it gets printed immediately @@ -313,11 +367,37 @@ fn batch(batch_opts: BatchOpts) -> Result<(), Box> { mapped_values = map_values_by_indices(&var_mapping, &result); is_satisfiable = evaluate_cnf(&mut mapped_values, formula.clone()); + let score = normalized_formula + .clauses + .iter() + .map(|clause| { + clause + .literals + .iter() + .map(|lit| { + let value = if lit.is_negated { + !result[lit.variable] + } else { + result[lit.variable] + }; + if value { 0.0 } else { clause.weight.unwrap_or(1.0) } + }) + .sum::() + }) + .sum::(); + + if score < best_score { + best_score = score; + best_solution = Some(result.clone()); + } + if is_satisfiable { break; } } + let result = best_solution.unwrap_or_else(|| state.v.iter().map(|&value| value > 0.0).collect()); + println!("\nChecking if solution vector satisfies formula: {is_satisfiable}"); println!("Rendering variable assignments...");