Skip to content

Commit

Permalink
Refactor serialization classes to use BenchmarkSerialization and upda…
Browse files Browse the repository at this point in the history
…te persistence handling for GitHub CI
  • Loading branch information
fabianliebig committed Nov 27, 2024
1 parent 7fd83a5 commit 1056454
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 18 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/manual_benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,15 @@ jobs:
with:
app-id: ${{ vars.APP_ID }}
private-key: ${{ secrets.APP_PRIVATE_KEY }}

- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: ${{ secrets.AWS_ROLE_TO_ASSUME }}
role-session-name: Github_Add_Runner
aws-region: eu-central-1

- name: Login to Amazon ECR
id: login-ecr
uses: aws-actions/amazon-ecr-login@v2

- name: Execute Lambda function
run: |
aws lambda invoke --function-name jit_runner_register_and_create_runner_container --cli-binary-format raw-in-base64-out --payload '{"github_api_secret": "${{ steps.generate-token.outputs.token }}", "count_container": 1, "container_compute": "XL", "repository": "${{ github.repository }}" }' response.json
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/definition/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@

from baybe.utils.random import temporary_seed
from benchmarks.result import Result, ResultMetadata
from benchmarks.serialization import Serializable, converter
from benchmarks.serialization import BenchmarkSerialization, converter


@define(frozen=True)
class BenchmarkSettings(ABC, Serializable):
class BenchmarkSettings(ABC, BenchmarkSerialization):
"""Benchmark configuration for recommender analyses."""

random_seed: int = field(validator=instance_of(int), kw_only=True, default=1337)
Expand All @@ -42,7 +42,7 @@ class ConvergenceExperimentSettings(BenchmarkSettings):


@define(frozen=True)
class Benchmark(Generic[BenchmarkSettingsType], Serializable):
class Benchmark(Generic[BenchmarkSettingsType], BenchmarkSerialization):
"""The base class for a benchmark executable."""

settings: BenchmarkSettingsType = field()
Expand Down
17 changes: 10 additions & 7 deletions benchmarks/persistence/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
VARNAME_BENCHMARKING_PERSISTENCE_PATH = "BAYBE_BENCHMARKING_PERSISTENCE_PATH"
PERSIST_DATA_TO_S3_BUCKET = VARNAME_BENCHMARKING_PERSISTENCE_PATH in os.environ

VARNAME_GITHUB_CI = "CI"
RUNS_ON_GITHUB_CI = VARNAME_GITHUB_CI in os.environ


class PathStrategy(Enum):
"""The way a path extension is constructed."""
Expand Down Expand Up @@ -106,7 +109,7 @@ def get_path(self, strategy: PathStrategy) -> Path:
return Path(path)


class ObjectWriter(Protocol):
class ObjectStorage(Protocol):
"""Interface for interacting with storage."""

def write_json(self, object: dict, path_constructor: PathConstructor) -> None:
Expand All @@ -121,7 +124,7 @@ def write_json(self, object: dict, path_constructor: PathConstructor) -> None:


@define
class S3ObjectWriter(ObjectWriter):
class S3ObjectStorage(ObjectStorage):
"""Class for persisting objects in an S3 bucket."""

_bucket_name: str = field(validator=instance_of(str), init=False)
Expand Down Expand Up @@ -168,7 +171,7 @@ def write_json(self, object: dict, path_constructor: PathConstructor) -> None:


@define
class LocalFileSystemObjectWriter(ObjectWriter):
class LocalFileSystemObjectStorage(ObjectStorage):
"""Class for persisting JSON serializable dicts locally."""

folder_path_prefix: Path = field(converter=Path, default=Path("."))
Expand Down Expand Up @@ -202,15 +205,15 @@ def write_json(self, object: dict, path_constructor: PathConstructor) -> None:
json.dump(object, file)


def make_object_writer() -> ObjectWriter:
def make_object_writer() -> ObjectStorage:
"""Create a persistence handler based on the environment variables.
Returns:
The persistence handler.
"""
if PERSIST_DATA_TO_S3_BUCKET:
return S3ObjectWriter()
return LocalFileSystemObjectWriter()
if not RUNS_ON_GITHUB_CI:
return LocalFileSystemObjectStorage()
return S3ObjectStorage()


def make_path_constructor(benchmark: Benchmark, result: Result) -> PathConstructor:
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/result/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from attrs.validators import instance_of
from cattrs.gen import make_dict_unstructure_fn

from benchmarks.serialization import Serializable, converter
from benchmarks.serialization import BenchmarkSerialization, converter


@define(frozen=True)
class ResultMetadata(Serializable):
class ResultMetadata(BenchmarkSerialization):
"""The metadata of a benchmark result."""

start_datetime: datetime = field(validator=instance_of(datetime))
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/result/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from pandas import DataFrame

from benchmarks.result import ResultMetadata
from benchmarks.serialization import Serializable
from benchmarks.serialization import BenchmarkSerialization


@define(frozen=True)
class Result(Serializable):
class Result(BenchmarkSerialization):
"""A single result of the benchmarking."""

benchmark_identifier: str = field(validator=instance_of(str))
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)


class Serializable:
class BenchmarkSerialization:
"""Mixin class providing serialization methods."""

def to_dict(self) -> dict[str, Any]:
Expand Down

0 comments on commit 1056454

Please sign in to comment.