Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor benchmark architecture #477

Merged
merged 22 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
fe99d1f
Rename ConvergenceExperimentSettings to ConvergenceBenchmarkSettings
AdrianSosic Feb 3, 2025
4f4abcd
Extract convergence-related attributes into subclass
AdrianSosic Feb 3, 2025
bba2c77
Move convergence benchmark functionality to separate module
AdrianSosic Feb 3, 2025
3618972
Instantiate settings type
AdrianSosic Feb 5, 2025
41a1163
Make settings classes kw-only
AdrianSosic Feb 5, 2025
7fc50cc
Rename config module to base
AdrianSosic Feb 5, 2025
f552a94
Reorder benchmark attributes
AdrianSosic Feb 5, 2025
d6c7ef0
Ensure that the benchmark function has a docstring
AdrianSosic Feb 5, 2025
64e1a17
Drop unused name attribute
AdrianSosic Feb 5, 2025
8c4cb6d
Switch to modern cattrs hook registration mechanism
AdrianSosic Feb 5, 2025
9afcb1b
Update lockfile
AdrianSosic Feb 5, 2025
689ec00
Refine some docstrings
AdrianSosic Feb 5, 2025
c8fd731
Update branch attribute to allow None and filter out None components …
fabianliebig Feb 6, 2025
828c804
Simplify validator using built-in optional decorator
AdrianSosic Feb 6, 2025
acb82b1
Fix branch attribute for detached head cases
AdrianSosic Feb 6, 2025
6c0eaff
Refactor ConvergenceBenchmark attributes
AdrianSosic Feb 6, 2025
29fd0d7
Correctly unstructure Benchmark subclasses
AdrianSosic Feb 6, 2025
cf88fe2
Bring back benchmark name as property
AdrianSosic Feb 6, 2025
9cc2dcd
Update branch attribute to handle detached head state by setting it t…
fabianliebig Feb 6, 2025
5b3f8c6
Update branch converter to use '-branchless-' for None values for rea…
fabianliebig Feb 7, 2025
2dfc817
Update ResultMetadata branch attribute to allow None values for detac…
fabianliebig Feb 7, 2025
2322b2a
Simplify converter expression
AdrianSosic Feb 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions .lockfiles/py310-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ anyio==4.4.0
# via
# httpx
# jupyter-server
appnope==0.1.4 ; platform_system == 'Darwin'
appnope==0.1.4 ; sys_platform == 'darwin'
# via ipykernel
argon2-cffi==23.1.0
# via jupyter-server
Expand Down Expand Up @@ -61,7 +61,7 @@ cachetools==5.4.0
# via
# streamlit
# tox
cattrs==23.2.3
cattrs==24.1.2
# via baybe (pyproject.toml)
certifi==2024.7.4
# via
Expand Down Expand Up @@ -240,7 +240,7 @@ importlib-metadata==7.1.0
# opentelemetry-api
iniconfig==2.0.0
# via pytest
intel-openmp==2021.4.0 ; platform_system == 'Windows'
intel-openmp==2021.4.0 ; sys_platform == 'win32'
# via mkl
interface-meta==1.3.0
# via formulaic
Expand Down Expand Up @@ -393,7 +393,7 @@ mdurl==0.1.2
# via markdown-it-py
mistune==3.0.2
# via nbconvert
mkl==2021.4.0 ; platform_system == 'Windows'
mkl==2021.4.0 ; sys_platform == 'win32'
# via torch
mmh3==5.0.1
# via e3fp
Expand Down Expand Up @@ -487,36 +487,36 @@ numpy==1.26.4
# types-seaborn
# xarray
# xyzpy
nvidia-cublas-cu12==12.1.3.1 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cublas-cu12==12.1.3.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-cupti-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cuda-runtime-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cuda-runtime-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cudnn-cu12==8.9.2.26 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cudnn-cu12==8.9.2.26 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cufft-cu12==11.0.2.54 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cufft-cu12==11.0.2.54 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-curand-cu12==10.3.2.106 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-curand-cu12==10.3.2.106 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cusolver-cu12==11.4.5.107 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cusolver-cu12==11.4.5.107 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-cusparse-cu12==12.1.0.106 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-cusparse-cu12==12.1.0.106 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via
# nvidia-cusolver-cu12
# torch
nvidia-nccl-cu12==2.20.5 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nccl-cu12==2.20.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
nvidia-nvjitlink-cu12==12.5.82 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nvjitlink-cu12==12.5.82 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105 ; platform_machine == 'x86_64' and platform_system == 'Linux'
nvidia-nvtx-cu12==12.1.105 ; platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
onnx==1.16.1
# via
Expand Down Expand Up @@ -922,7 +922,7 @@ sympy==1.13.1
# via
# onnxruntime
# torch
tbb==2021.13.0 ; platform_system == 'Windows'
tbb==2021.13.0 ; sys_platform == 'win32'
# via mkl
tenacity==8.5.0
# via
Expand Down Expand Up @@ -1007,7 +1007,7 @@ traitlets==5.14.3
# nbclient
# nbconvert
# nbformat
triton==2.3.1 ; python_full_version < '3.12' and platform_machine == 'x86_64' and platform_system == 'Linux'
triton==2.3.1 ; python_full_version < '3.12' and platform_machine == 'x86_64' and sys_platform == 'linux'
# via torch
typeguard==2.13.3
# via
Expand Down Expand Up @@ -1050,7 +1050,7 @@ virtualenv==20.26.3
# via
# pre-commit
# tox
watchdog==4.0.1 ; platform_system != 'Darwin'
watchdog==4.0.1 ; sys_platform != 'darwin'
# via streamlit
wcwidth==0.2.13
# via prompt-toolkit
Expand Down
12 changes: 10 additions & 2 deletions benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
"""Benchmarking module for performance tracking."""

from benchmarks.definition import Benchmark
from benchmarks.definition import (
Benchmark,
BenchmarkSettings,
ConvergenceBenchmark,
ConvergenceBenchmarkSettings,
)
from benchmarks.result import Result

__all__ = [
"Result",
"Benchmark",
"BenchmarkSettings",
"ConvergenceBenchmark",
"ConvergenceBenchmarkSettings",
"Result",
]
12 changes: 8 additions & 4 deletions benchmarks/definition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Benchmark task definitions."""
"""Benchmark definitions."""

from benchmarks.definition.config import (
from benchmarks.definition.base import (
Benchmark,
BenchmarkSettings,
ConvergenceExperimentSettings,
)
from benchmarks.definition.convergence import (
ConvergenceBenchmark,
ConvergenceBenchmarkSettings,
)

__all__ = [
"ConvergenceExperimentSettings",
"Benchmark",
"BenchmarkSettings",
"ConvergenceBenchmark",
"ConvergenceBenchmarkSettings",
]
86 changes: 86 additions & 0 deletions benchmarks/definition/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Basic benchmark configuration."""

import time
from abc import ABC
from collections.abc import Callable
from datetime import datetime, timedelta, timezone
from typing import Generic, TypeVar

from attrs import define, field
from attrs.validators import instance_of
from cattrs import override
from cattrs.gen import make_dict_unstructure_fn
from pandas import DataFrame

from baybe.utils.random import temporary_seed
from benchmarks.result import Result, ResultMetadata
from benchmarks.serialization import BenchmarkSerialization, converter


@define(frozen=True, kw_only=True)
class BenchmarkSettings(ABC, BenchmarkSerialization):
"""The basic benchmark configuration."""

random_seed: int = field(validator=instance_of(int), default=1337)
"""The used random seed."""


BenchmarkSettingsType = TypeVar("BenchmarkSettingsType", bound=BenchmarkSettings)


@define(frozen=True)
class Benchmark(Generic[BenchmarkSettingsType], BenchmarkSerialization):
"""The base class for all benchmark definitions."""

function: Callable[[BenchmarkSettingsType], DataFrame] = field()
"""The callable containing the benchmarking logic."""

settings: BenchmarkSettingsType = field()
"""The benchmark configuration."""

@function.validator
def _validate_function(self, _, function) -> None:
if function.__doc__ is None:
raise ValueError("The benchmark function must have a docstring.")

@property
def name(self) -> str:
"""The name of the benchmark function."""
return self.function.__name__

@property
def description(self) -> str:
"""The description of the benchmark function."""
assert self.function.__doc__ is not None
return self.function.__doc__

def __call__(self) -> Result:
"""Execute the benchmark and return the result."""
start_datetime = datetime.now(timezone.utc)

with temporary_seed(self.settings.random_seed):
start_sec = time.perf_counter()
result = self.function(self.settings)
stop_sec = time.perf_counter()

duration = timedelta(seconds=stop_sec - start_sec)

metadata = ResultMetadata(
start_datetime=start_datetime,
duration=duration,
)

return Result(self.name, result, metadata)


@converter.register_unstructure_hook
def unstructure_benchmark(benchmark: Benchmark) -> dict:
"""Unstructure a benchmark instance."""
fn = make_dict_unstructure_fn(
type(benchmark), converter, function=override(omit=True)
)
return {
"name": benchmark.name,
"description": benchmark.description,
**fn(benchmark),
}
103 changes: 0 additions & 103 deletions benchmarks/definition/config.py

This file was deleted.

39 changes: 39 additions & 0 deletions benchmarks/definition/convergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Convergence benchmark configuration."""

from typing import Any

from attrs import define, field
from attrs.validators import deep_mapping, instance_of, optional

from benchmarks.definition.base import Benchmark, BenchmarkSettings


@define(frozen=True, kw_only=True)
class ConvergenceBenchmarkSettings(BenchmarkSettings):
"""Benchmark configuration for recommender convergence analyses."""

batch_size: int = field(validator=instance_of(int))
"""The recommendation batch size."""

n_doe_iterations: int = field(validator=instance_of(int))
"""The number of Design of Experiment iterations."""

n_mc_iterations: int = field(validator=instance_of(int))
"""The number of Monte Carlo iterations."""


@define(frozen=True)
class ConvergenceBenchmark(Benchmark[ConvergenceBenchmarkSettings]):
"""A class for defining convergence benchmarks."""

optimal_target_values: dict[str, Any] | None = field(
default=None,
validator=optional(
deep_mapping(
key_validator=instance_of(str),
mapping_validator=instance_of(dict),
value_validator=lambda *_: None,
)
),
)
"""The optimal values that can be achieved for the targets **individually**."""
2 changes: 1 addition & 1 deletion benchmarks/domains/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Benchmark domains."""

from benchmarks.definition.config import Benchmark
from benchmarks.definition.base import Benchmark
from benchmarks.domains.synthetic_2C1D_1C import synthetic_2C1D_1C_benchmark

BENCHMARKS: list[Benchmark] = [
Expand Down
Loading
Loading