Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions src/cnf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ impl fmt::Display for Literal {
#[derive(Debug, Clone)]
pub struct CNFClause {
pub literals: Array1<Literal>,
pub weight: Option<f64>, // Optional weight for the clause
}

impl CNFClause {
pub fn new(literals: Array1<Literal>) -> Self {
Self { literals }
pub fn new(literals: Array1<Literal>, weight: Option<f64>) -> Self {
Self { literals, weight }
}

// Function to return the literals in the clause.
Expand All @@ -54,12 +55,13 @@ impl fmt::Display for CNFClause {
pub struct CNFFormula {
pub clauses: Array1<CNFClause>,
pub varnum: usize,
pub must_clause_weight: Option<f64>, // Optional weight for "must have" clauses
}

impl CNFFormula {
pub fn new(clauses: Array1<CNFClause>, varnum: Option<usize>) -> Self {
pub fn new(clauses: Array1<CNFClause>, varnum: Option<usize>, must_clause_weight: Option<f64>) -> Self {
if let Some(varnum) = varnum {
Self { clauses, varnum }
Self { clauses, varnum, must_clause_weight }
} else {
let mut variables: HashSet<usize> = HashSet::new();

Expand All @@ -72,6 +74,7 @@ impl CNFFormula {
Self {
clauses,
varnum: variables.len(),
must_clause_weight,
}
}
}
Expand Down Expand Up @@ -138,6 +141,7 @@ impl fmt::Display for CNFFormula {
pub fn parse_dimacs_format(input: &str) -> CNFFormula {
let mut clauses: Vec<CNFClause> = Vec::new();
let mut varnum = None;
let mut must_clause_weight = None;

for line in input.lines() {
if line.starts_with('c') {
Expand All @@ -151,6 +155,12 @@ pub fn parse_dimacs_format(input: &str) -> CNFFormula {
varnum = Some(split.next().unwrap().parse().unwrap());
// The 'nbclauses' value is not needed for parsing
continue;
} else if line.starts_with("w") {
// Parse the weight line: w weight
let mut split = line.split_whitespace();
split.next(); // Skip the 'w' token
must_clause_weight = Some(split.next().unwrap().parse().unwrap());
continue;
} else {
// Parse clause lines
let literals: Vec<Literal> = line
Expand All @@ -164,11 +174,12 @@ pub fn parse_dimacs_format(input: &str) -> CNFFormula {
})
.collect();

clauses.push(CNFClause::new(literals.into()));
let weight = must_clause_weight; // Assign the weight to the clause
clauses.push(CNFClause::new(literals.into(), weight));
}
}

CNFFormula::new(clauses.into(), varnum)
CNFFormula::new(clauses.into(), varnum, must_clause_weight)
}

pub fn apply_variable_mapping(
Expand All @@ -191,11 +202,11 @@ pub fn apply_variable_mapping(
}
}

let mapped_clause = CNFClause::new(mapped_literals.into());
let mapped_clause = CNFClause::new(mapped_literals.into(), clause.weight);
mapped_clauses.push(mapped_clause);
}

CNFFormula::new(mapped_clauses.into(), Some(formula.varnum))
CNFFormula::new(mapped_clauses.into(), Some(formula.varnum), formula.must_clause_weight)
}

// Normalizes the variables of CNF function. That is, the smallest variable should be 0, and
Expand Down Expand Up @@ -227,6 +238,11 @@ pub fn cnf_to_dimacs_format(formula: &CNFFormula) -> String {
// Add problem line: p cnf nbvar nbclauses
dimacs_string.push_str(&format!("p cnf {varnum} {num_clauses}\n"));

// Add weight line if present
if let Some(weight) = formula.must_clause_weight {
dimacs_string.push_str(&format!("w {weight}\n"));
}

// Add clauses
for clause in &formula.clauses {
for literal in clause.get_literals() {
Expand Down Expand Up @@ -405,13 +421,14 @@ pub fn convert_to_cnf_formula(formula_set: &CNFFormulaSet) -> CNFFormula {
.cloned()
.collect::<Vec<_>>(),
);
let clause = CNFClause::new(literals);
let clause = CNFClause::new(literals, None);
clauses_array.push(clause);
}

CNFFormula {
clauses: Array1::from(clauses_array),
varnum: formula_set.varnum,
must_clause_weight: None,
}
}

Expand Down
96 changes: 88 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::collections::HashMap;
use std::fs;
use std::io::{self, Write};
use std::path::PathBuf;
use std::time::{Duration, Instant};

#[derive(Parser)]
#[command(author, version, about, long_about = None)]
Expand Down Expand Up @@ -62,6 +63,10 @@ pub struct SolveOpts {
/// Optional output file for the transformed CNF formula
#[arg(short = 'p', long)]
pub transformed_output: Option<PathBuf>,

/// Time limit in seconds
#[arg(short = 'T', long)]
pub time_limit: Option<u64>,
}

#[derive(Args)]
Expand Down Expand Up @@ -112,6 +117,10 @@ pub struct BatchOpts {
/// Learning rate
#[arg(short = 'l', long)]
pub learning_rate: Option<f64>,

/// Time limit in seconds
#[arg(short = 'T', long)]
pub time_limit: Option<u64>,
}

#[derive(Args)]
Expand Down Expand Up @@ -158,6 +167,7 @@ fn solve(solve_opts: SolveOpts) -> Result<(), Box<dyn std::error::Error>> {
7.0
};
let transformed_output_path = &solve_opts.transformed_output;
let time_limit = solve_opts.time_limit.map(Duration::from_secs);

println!("Reading CNF formula from file...");
let cnf_string = fs::read_to_string(input_path)?;
Expand All @@ -184,14 +194,50 @@ fn solve(solve_opts: SolveOpts) -> Result<(), Box<dyn std::error::Error>> {
xl: Array1::ones(normalized_formula.clauses.len()),
};

let result = simulate(
&mut state,
&normalized_formula,
tolerance,
step_size,
step_number,
learning_rate,
);
let start_time = Instant::now();
let mut best_solution = None;
let mut best_score = f64::INFINITY;

while time_limit.map_or(true, |limit| start_time.elapsed() < limit) {
let result = simulate(
&mut state,
&normalized_formula,
tolerance,
step_size,
step_number,
learning_rate,
);

let score = normalized_formula
.clauses
.iter()
.map(|clause| {
clause
.literals
.iter()
.map(|lit| {
let value = if lit.is_negated {
!result[lit.variable]
} else {
result[lit.variable]
};
if value { 0.0 } else { clause.weight.unwrap_or(1.0) }
})
.sum::<f64>()
})
.sum::<f64>();

if score < best_score {
best_score = score;
best_solution = Some(result.clone());
}

if score == 0.0 {
break;
}
}

let result = best_solution.unwrap_or_else(|| state.v.iter().map(|&value| value > 0.0).collect());

println!("Mapping values...");
let mut mapped_values = map_values_by_indices(&var_mapping, &result);
Expand Down Expand Up @@ -270,6 +316,7 @@ fn batch(batch_opts: BatchOpts) -> Result<(), Box<dyn std::error::Error>> {
let step_number = batch_opts.step_number;
let step_size = batch_opts.step_size;
let learning_rate = batch_opts.learning_rate;
let time_limit = batch_opts.time_limit.map(Duration::from_secs);

println!("Reading CNF formula from file...");
let cnf_string = fs::read_to_string(input_path)?;
Expand All @@ -285,8 +332,15 @@ fn batch(batch_opts: BatchOpts) -> Result<(), Box<dyn std::error::Error>> {
let mut rng = rand::thread_rng();
let mut is_satisfiable = false;
let mut mapped_values = HashMap::new();
let start_time = Instant::now();
let mut best_solution = None;
let mut best_score = f64::INFINITY;

for i in 0..batch_size {
if time_limit.map_or(false, |limit| start_time.elapsed() >= limit) {
break;
}

print!("\rRunning simulation {}.", i + 1);
io::stdout().flush().unwrap(); // Flush stdout to make sure it gets printed immediately

Expand All @@ -313,11 +367,37 @@ fn batch(batch_opts: BatchOpts) -> Result<(), Box<dyn std::error::Error>> {
mapped_values = map_values_by_indices(&var_mapping, &result);
is_satisfiable = evaluate_cnf(&mut mapped_values, formula.clone());

let score = normalized_formula
.clauses
.iter()
.map(|clause| {
clause
.literals
.iter()
.map(|lit| {
let value = if lit.is_negated {
!result[lit.variable]
} else {
result[lit.variable]
};
if value { 0.0 } else { clause.weight.unwrap_or(1.0) }
})
.sum::<f64>()
})
.sum::<f64>();

if score < best_score {
best_score = score;
best_solution = Some(result.clone());
}

if is_satisfiable {
break;
}
}

let result = best_solution.unwrap_or_else(|| state.v.iter().map(|&value| value > 0.0).collect());

println!("\nChecking if solution vector satisfies formula: {is_satisfiable}");

println!("Rendering variable assignments...");
Expand Down