Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CI #51

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ kurobako_core = { path = "kurobako_core", version = "0.1" }
kurobako_problems = { path = "kurobako_problems", version = "0.1" }
kurobako_solvers = { path = "kurobako_solvers", version = "0.2" }
nasbench = "0.1"
num = "0.3"
num = "0.4"
num-integer = "0.1"
ordered-float = "2"
rand = "0.8"
Expand Down
14 changes: 7 additions & 7 deletions kurobako_core/src/domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ impl VariableBuilder {
Self {
name: name.to_owned(),
range: Range::Continuous {
low: std::f64::NEG_INFINITY,
high: std::f64::INFINITY,
low: f64::NEG_INFINITY,
high: f64::INFINITY,
},
distribution: Distribution::Uniform,
constraint: None,
Expand All @@ -75,7 +75,7 @@ impl VariableBuilder {

/// Sets the name of this variable.
pub fn name(mut self, name: &str) -> Self {
self.name = name.to_owned();
name.clone_into(&mut self.name);
self
}

Expand Down Expand Up @@ -121,7 +121,7 @@ impl VariableBuilder {
///
/// This is equivalent to `self.categorical(&["false", "true"])`.
pub fn boolean(self) -> Self {
self.categorical(&["false", "true"])
self.categorical(["false", "true"])
}

/// Sets the range of this variable.
Expand Down Expand Up @@ -242,11 +242,11 @@ fn is_not_finite(x: &f64) -> bool {
}

fn neg_infinity() -> f64 {
std::f64::NEG_INFINITY
f64::NEG_INFINITY
}

fn infinity() -> f64 {
std::f64::INFINITY
f64::INFINITY
}

/// Variable range.
Expand Down Expand Up @@ -396,7 +396,7 @@ mod tests {
let vars = vec![
var("a").continuous(-10.0, 10.0).finish()?,
var("b").discrete(0, 5).finish()?,
var("c").categorical(&["foo", "bar", "baz"]).finish()?,
var("c").categorical(["foo", "bar", "baz"]).finish()?,
];

let constraint = Constraint::new("(a + b) < 2");
Expand Down
2 changes: 1 addition & 1 deletion kurobako_core/src/epi/problem/external_program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use structopt::StructOpt;

thread_local! {
static FACTORY_CACHE : RefCell<Option<(Vec<u8>, ExternalProgramProblemFactory)>> =
RefCell::new(None);
const { RefCell::new(None) };
}

/// Recipe for the problem implemented by an external program.
Expand Down
8 changes: 6 additions & 2 deletions kurobako_core/src/trial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ pub struct EvaluatedTrial {
pub struct IdGen {
next: u64,
}
impl Default for IdGen {
fn default() -> Self {
Self::new()
}
}
impl IdGen {
/// Makes a new `IdGen` instance.
pub const fn new() -> Self {
Expand Down Expand Up @@ -201,15 +206,14 @@ impl Deref for Values {

mod nullable_f64_vec {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::f64::NAN;

pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<f64>, D::Error>
where
D: Deserializer<'de>,
{
let v: Vec<Option<f64>> = Deserialize::deserialize(deserializer)?;
Ok(v.into_iter()
.map(|v| if let Some(v) = v { v } else { NAN })
.map(|v| if let Some(v) = v { v } else { f64::NAN })
.collect())
}

Expand Down
6 changes: 3 additions & 3 deletions kurobako_problems/src/hpobench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ impl ProblemFactory for HpobenchProblemFactory {
arXiv preprint arXiv:1905.04970 (2019).",
)
.attr("github", "https://github.com/automl/nas_benchmarks")
.param(domain::var("activation_fn_1").categorical(&["tanh", "relu"]))
.param(domain::var("activation_fn_2").categorical(&["tanh", "relu"]))
.param(domain::var("activation_fn_1").categorical(["tanh", "relu"]))
.param(domain::var("activation_fn_2").categorical(["tanh", "relu"]))
.param(domain::var("batch_size").discrete(0, 4))
.param(domain::var("dropout_1").discrete(0, 3))
.param(domain::var("dropout_2").discrete(0, 3))
.param(domain::var("init_lr").discrete(0, 6))
.param(domain::var("lr_schedule").categorical(&["cosine", "const"]))
.param(domain::var("lr_schedule").categorical(["cosine", "const"]))
.param(domain::var("n_units_1").discrete(0, 6))
.param(domain::var("n_units_2").discrete(0, 6))
.value(domain::var("Validation MSE").continuous(0.0, f64::INFINITY))
Expand Down
11 changes: 4 additions & 7 deletions kurobako_problems/src/nasbench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ impl Evaluator for NasbenchEvaluator {
/// [nas_cifar10.py]: https://github.com/automl/nas_benchmarks/blob/c1bae6632bf15d45ba49c269c04dbbeb3f0379f0/tabular_benchmarks/nas_cifar10.py
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[allow(missing_docs)]
#[derive(Default)]
pub enum Encoding {
#[default]
A,
B,
C,
Expand All @@ -233,7 +235,7 @@ impl Encoding {
fn common_params() -> Vec<VariableBuilder> {
let mut params = Vec::new();
for i in 0..5 {
params.push(domain::var(&format!("op{}", i)).categorical(&[
params.push(domain::var(&format!("op{}", i)).categorical([
"conv1x1-bn-relu",
"conv3x3-bn-relu",
"maxpool3x3",
Expand Down Expand Up @@ -297,7 +299,7 @@ impl Encoding {
fn edges_a(params: &[f64]) -> HashSet<usize> {
let mut edges = HashSet::new();
for (i, p) in params.iter().enumerate() {
if (*p - 1.0).abs() < std::f64::EPSILON {
if (*p - 1.0).abs() < f64::EPSILON {
edges.insert(i);
}
}
Expand Down Expand Up @@ -342,11 +344,6 @@ impl FromStr for Encoding {
}
}
}
impl Default for Encoding {
fn default() -> Self {
Encoding::A
}
}

/// Evaluation metric.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
Expand Down
13 changes: 8 additions & 5 deletions kurobako_problems/src/sigopt/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
use super::bessel::bessel0;
use kurobako_core::{ErrorKind, Result};
use std::f64::consts::PI;
use std::f64::EPSILON;
use std::fmt;
use std::iter;

Expand Down Expand Up @@ -122,7 +121,7 @@ impl TestFunction for Csendes {

fn evaluate(&self, xs: &[f64]) -> f64 {
xs.iter()
.map(|&x| x.powi(6) * (2.0 + (1.0 / (x + EPSILON)).sin()))
.map(|&x| x.powi(6) * (2.0 + (1.0 / (x + f64::EPSILON)).sin()))
.sum()
}
}
Expand Down Expand Up @@ -394,7 +393,7 @@ impl McCourtBase {
e_mat: &'static [&'static [f64]],
) -> impl 'a + Iterator<Item = f64> {
e_mat.iter().zip(centers.iter()).map(move |(evec, center)| {
let mut max = std::f64::NEG_INFINITY;
let mut max = f64::NEG_INFINITY;
for x in xs
.iter()
.zip(center.iter())
Expand Down Expand Up @@ -2374,7 +2373,7 @@ mod tests {
fn shekel05_works() {
assert_eq!(
Shekel05.evaluate(&[4.0, 4.0, 4.0, 4.0]),
-10.152719932456289
-10.152_719_932_456_29
);
}

Expand Down Expand Up @@ -2402,7 +2401,11 @@ mod tests {
#[test]
fn styblinskitang_works() {
assert_eq!(
StyblinskiTang.evaluate(&[-2.903534018185960, -2.903534018185960, -2.903534018185960]),
StyblinskiTang.evaluate(&[
-2.903_534_018_185_96,
-2.903_534_018_185_96,
-2.903_534_018_185_96
]),
-117.49849711131424
);
}
Expand Down
15 changes: 7 additions & 8 deletions src/batch_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use kurobako_core::trial::{Params, Values};
use kurobako_core::{ErrorKind, Result};
use serde::Deserialize;
use serde::Serialize;
use std::io;
use structopt::StructOpt;
use serde_json::Error;
use std::io;
use std::io::Write;
use structopt::StructOpt;

/// Options of the `kurobako batch-evaluate` command.
#[derive(Debug, Clone, StructOpt)]
Expand All @@ -31,7 +31,7 @@ pub struct BatchEvaluateOpt {
#[derive(Debug, Clone, Deserialize)]
struct EvalCall {
params: Params,
step: Option<u64>
step: Option<u64>,
}

#[derive(Debug, Clone, Serialize)]
Expand All @@ -50,7 +50,7 @@ impl BatchEvaluateOpt {

let problem = track!(problem_factory.create_problem(rng))?;
let mut writer = io::stdout();
loop{
loop {
let mut line = String::new();
let n = io::stdin().read_line(&mut line)?;
if n == 0 {
Expand All @@ -64,15 +64,14 @@ impl BatchEvaluateOpt {
ErrorKind::InvalidInput
);


let evaluator_or_error = track!(problem.create_evaluator(params.clone()));

let values = match evaluator_or_error {
Ok(mut evaluator) => {
let s = step.unwrap_or_else(|| problem_spec.steps.last());
let (_, values) = track!(evaluator.evaluate(s))?;
values
},
}
Err(e) => {
if *e.kind() != ErrorKind::UnevaluableParams {
return Err(e);
Expand All @@ -82,8 +81,8 @@ impl BatchEvaluateOpt {
}
};

serde_json::to_writer(&mut writer, &EvalReply{values}).map_err(Error::from)?;
writer.write("\n".as_bytes())?;
serde_json::to_writer(&mut writer, &EvalReply { values }).map_err(Error::from)?;
writer.write_all("\n".as_bytes())?;
writer.flush()?;
}
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl NasbenchOpt {
);

let file = track!(
std::fs::File::open(&tfrecord_format_dataset_path).map_err(Error::from);
std::fs::File::open(tfrecord_format_dataset_path).map_err(Error::from);
tfrecord_format_dataset_path
)?;
let nasbench = track!(nasbench::NasBench::from_tfrecord_reader(
Expand Down
8 changes: 4 additions & 4 deletions src/dataset/surrogate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ impl SurrogateOpt {
let mut table = TableBuilder::new();
let column_types = trials[0]
.distributions
.iter()
.map(|(_, d)| {
.values()
.map(|d| {
if matches!(d, Distribution::CategoricalDistribution { .. }) {
ColumnType::Categorical
} else {
Expand Down Expand Up @@ -208,11 +208,11 @@ impl SurrogateOpt {
track!(std::fs::create_dir_all(&dir).map_err(Error::from))?;

let spec_path = dir.join("spec.json");
let spec_file = track!(std::fs::File::create(&spec_path).map_err(Error::from))?;
let spec_file = track!(std::fs::File::create(spec_path).map_err(Error::from))?;
serde_json::to_writer(spec_file, &spec)?;

let regressor_path = dir.join("model.bin");
let regressor_file = track!(std::fs::File::create(&regressor_path).map_err(Error::from))?;
let regressor_file = track!(std::fs::File::create(regressor_path).map_err(Error::from))?;
model.regressor.serialize(BufWriter::new(regressor_file))?;

eprintln!("Saved the surrogate model to the direcotry {:?}", dir);
Expand Down
2 changes: 1 addition & 1 deletion src/plot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl PlotOpt {

fn execute_gnuplot(script: &str) -> Result<()> {
let output = track!(Command::new("gnuplot")
.args(&["-e", script])
.args(["-e", script])
.output()
.map_err(Error::from))?;
if !output.status.success() {
Expand Down
Loading
Loading