Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yaml config sets #2876

Merged
merged 24 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions src/_nebari/config_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import logging
import pathlib
from typing import Optional

from packaging.requirements import SpecifierSet
from pydantic import BaseModel, ConfigDict, field_validator

from _nebari._version import __version__
from _nebari.utils import yaml

logger = logging.getLogger(__name__)


class ConfigSetMetadata(BaseModel):
model_config: ConfigDict = ConfigDict(extra="allow", arbitrary_types_allowed=True)
name: str # for use with guided init
Adam-D-Lewis marked this conversation as resolved.
Show resolved Hide resolved
description: Optional[str] = None
nebari_version: str | SpecifierSet

@field_validator("nebari_version")
@classmethod
def validate_version_requirement(cls, version_req):
if isinstance(version_req, str):
version_req = SpecifierSet(version_req, prereleases=True)

return version_req

def check_version(self, version):
if not self.nebari_version.contains(version, prereleases=True):
raise ValueError(
f'Nebari version "{version}" is not compatible with '
f'version requirement {self.nebari_version} for "{self.name}" config set.'
)


class ConfigSet(BaseModel):
metadata: ConfigSetMetadata
config: dict


def read_config_set(config_set_filepath: str):
"""Read a config set from a config file."""

filename = pathlib.Path(config_set_filepath)

with filename.open() as f:
config_set_yaml = yaml.load(f)

config_set = ConfigSet(**config_set_yaml)

# validation
config_set.metadata.check_version(__version__)

return config_set
10 changes: 8 additions & 2 deletions src/_nebari/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pydantic
import requests

from _nebari import constants
from _nebari import constants, utils
from _nebari.config_set import read_config_set
from _nebari.provider import git
from _nebari.provider.cicd import github
from _nebari.provider.cloud import amazon_web_services, azure_cloud, google_cloud
Expand Down Expand Up @@ -47,6 +48,7 @@ def render_config(
region: str = None,
disable_prompt: bool = False,
ssl_cert_email: str = None,
config_set: str = None,
) -> Dict[str, Any]:
config = {
"provider": cloud_provider,
Expand Down Expand Up @@ -176,13 +178,17 @@ def render_config(
config["certificate"] = {"type": CertificateEnum.letsencrypt.value}
config["certificate"]["acme_email"] = ssl_cert_email

if config_set:
config_set = read_config_set(config_set)
config = utils.deep_merge(config, config_set.config)

# validate configuration and convert to model
from nebari.plugins import nebari_plugin_manager

try:
config_model = nebari_plugin_manager.config_schema.model_validate(config)
except pydantic.ValidationError as e:
print(str(e))
Adam-D-Lewis marked this conversation as resolved.
Show resolved Hide resolved
raise e

if repository_auto_provision:
match = re.search(github_url_regex, repository)
Expand Down
16 changes: 15 additions & 1 deletion src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re
import sys
import tempfile
import warnings
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union

from pydantic import ConfigDict, Field, field_validator, model_validator
Expand Down Expand Up @@ -613,11 +614,23 @@ def check_provider(cls, data: Any) -> Any:
data[provider] = provider_enum_model_map[provider]()
else:
# if the provider field is invalid, it won't be set when this validator is called
# so we need to check for it explicitly here, and set the `pre` to True
# so we need to check for it explicitly here, and set mode to "before"
# TODO: this is a workaround, check if there is a better way to do this in Pydantic v2
raise ValueError(
f"'{provider}' is not a valid enumeration member; permitted: local, existing, aws, gcp, azure"
)
set_providers = {
provider
for provider in provider_name_abbreviation_map.keys()
if provider in data and data[provider]
}
expected_provider_config = provider_enum_name_map[provider]
extra_provider_config = set_providers - {expected_provider_config}
if extra_provider_config:
warnings.warn(
f"Provider is set to {getattr(provider, 'value', provider)}, but configuration defined for other providers: {extra_provider_config}"
)

else:
set_providers = [
provider
Expand All @@ -631,6 +644,7 @@ def check_provider(cls, data: Any) -> Any:
data["provider"] = provider_name_abbreviation_map[set_providers[0]]
elif num_providers == 0:
data["provider"] = schema.ProviderEnum.local.value

return data


Expand Down
9 changes: 9 additions & 0 deletions src/_nebari/subcommands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class InitInputs(schema.Base):
region: Optional[str] = None
ssl_cert_email: Optional[schema.email_pydantic] = None
disable_prompt: bool = False
config_set: Optional[str] = None
output: pathlib.Path = pathlib.Path("nebari-config.yaml")
explicit: int = 0

Expand Down Expand Up @@ -134,6 +135,7 @@ def handle_init(inputs: InitInputs, config_schema: BaseModel):
terraform_state=inputs.terraform_state,
ssl_cert_email=inputs.ssl_cert_email,
disable_prompt=inputs.disable_prompt,
config_set=inputs.config_set,
)

try:
Expand Down Expand Up @@ -496,6 +498,12 @@ def init(
False,
is_eager=True,
),
config_set: str = typer.Option(
None,
"--config-set",
"-s",
help="Apply a pre-defined set of nebari configuration options.",
),
output: str = typer.Option(
pathlib.Path("nebari-config.yaml"),
"--output",
Expand Down Expand Up @@ -554,6 +562,7 @@ def init(
inputs.terraform_state = terraform_state
inputs.ssl_cert_email = ssl_cert_email
inputs.disable_prompt = disable_prompt
inputs.config_set = config_set
inputs.output = output
inputs.explicit = explicit

Expand Down
4 changes: 2 additions & 2 deletions src/_nebari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def modified_environ(*remove: List[str], **update: Dict[str, str]):


def deep_merge(*args):
"""Deep merge multiple dictionaries.
"""Deep merge multiple dictionaries. Preserves order in dicts and lists.

>>> value_1 = {
'a': [1, 2],
Expand Down Expand Up @@ -190,7 +190,7 @@ def deep_merge(*args):

if isinstance(d1, dict) and isinstance(d2, dict):
d3 = {}
for key in d1.keys() | d2.keys():
for key in tuple(d1.keys()) + tuple(d2.keys()):
if key in d1 and key in d2:
d3[key] = deep_merge(d1[key], d2[key])
elif key in d1:
Expand Down
73 changes: 73 additions & 0 deletions tests/tests_unit/test_config_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from unittest.mock import patch

import pytest
from packaging.requirements import SpecifierSet

from _nebari.config_set import ConfigSetMetadata, read_config_set

test_version = "2024.12.2"


@pytest.mark.parametrize(
"version_input,test_version,should_pass",
[
# Standard version tests
(">=2024.12.0,<2025.0.0", "2024.12.2", True),
(SpecifierSet(">=2024.12.0,<2025.0.0"), "2024.12.2", True),
# Pre-release version requirement tests
(">=2024.12.0rc1,<2025.0.0", "2024.12.0rc1", True),
(SpecifierSet(">=2024.12.0rc1"), "2024.12.0rc2", True),
# Pre-release test version against standard requirement
(">=2024.12.0,<2025.0.0", "2024.12.1rc1", True),
(SpecifierSet(">=2024.12.0,<2025.0.0"), "2024.12.1rc1", True),
# Failing cases
(">=2025.0.0", "2024.12.2rc1", False),
(SpecifierSet(">=2025.0.0rc1"), "2024.12.2", False),
],
)
def test_version_requirement(version_input, test_version, should_pass):
metadata = ConfigSetMetadata(name="test-config", nebari_version=version_input)

if should_pass:
metadata.check_version(test_version)
else:
with pytest.raises(ValueError) as exc_info:
metadata.check_version(test_version)
assert "Nebari version" in str(exc_info.value)


def test_read_config_set_valid(tmp_path):
config_set_yaml = """
metadata:
name: test-config
nebari_version: ">=2024.12.0"
config:
key: value
"""
config_set_filepath = tmp_path / "config_set.yaml"
config_set_filepath.write_text(config_set_yaml)
with patch("_nebari.config_set.__version__", "2024.12.2"):
config_set = read_config_set(str(config_set_filepath))
assert config_set.metadata.name == "test-config"
assert config_set.config["key"] == "value"


def test_read_config_set_invalid_version(tmp_path):
config_set_yaml = """
metadata:
name: test-config
nebari_version: ">=2025.0.0"
config:
key: value
"""
config_set_filepath = tmp_path / "config_set.yaml"
config_set_filepath.write_text(config_set_yaml)

with patch("_nebari.config_set.__version__", "2024.12.2"):
with pytest.raises(ValueError) as exc_info:
read_config_set(str(config_set_filepath))
assert "Nebari version" in str(exc_info.value)


if __name__ == "__main__":
pytest.main()
10 changes: 10 additions & 0 deletions tests/tests_unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,13 @@ def test_set_provider(config_schema, provider):
result_config_dict = config.model_dump()
assert provider in result_config_dict
assert result_config_dict[provider]["kube_context"] == "some_context"


def test_provider_config_mismatch_warning(config_schema):
config_dict = {
"project_name": "test",
"provider": "local",
"existing": {"kube_context": "some_context"}, # <-- Doesn't match the provider
}
with pytest.warns(UserWarning, match="configuration defined for other providers"):
config_schema(**config_dict)
1 change: 1 addition & 0 deletions tests/tests_unit/test_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_check_immutable_fields_immutable_change(
mock_model_fields, mock_get_state, terraform_state_stage, mock_config
):
old_config = mock_config.model_copy(deep=True)
old_config.local = None
old_config.provider = schema.ProviderEnum.gcp
mock_get_state.return_value = old_config.model_dump()

Expand Down
74 changes: 73 additions & 1 deletion tests/tests_unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from _nebari.utils import JsonDiff, JsonDiffEnum, byte_unit_conversion
from _nebari.utils import JsonDiff, JsonDiffEnum, byte_unit_conversion, deep_merge


@pytest.mark.parametrize(
Expand Down Expand Up @@ -64,3 +64,75 @@ def test_JsonDiff_modified():
diff = JsonDiff(obj1, obj2)
modifieds = diff.modified()
assert sorted(modifieds) == sorted([(["b", "!"], 2, 3), (["+"], 4, 5)])


def test_deep_merge_order_preservation_dict():
value_1 = {
"a": [1, 2],
"b": {"c": 1, "z": [5, 6]},
"e": {"f": {"g": {}}},
"m": 1,
}

value_2 = {
"a": [3, 4],
"b": {"d": 2, "z": [7]},
"e": {"f": {"h": 1}},
"m": [1],
}

expected_result = {
"a": [1, 2, 3, 4],
"b": {"c": 1, "z": [5, 6, 7], "d": 2},
"e": {"f": {"g": {}, "h": 1}},
"m": 1,
}

result = deep_merge(value_1, value_2)
assert result == expected_result
assert list(result.keys()) == list(expected_result.keys())
assert list(result["b"].keys()) == list(expected_result["b"].keys())
assert list(result["e"]["f"].keys()) == list(expected_result["e"]["f"].keys())


def test_deep_merge_order_preservation_list():
value_1 = {
"a": [1, 2],
"b": {"c": 1, "z": [5, 6]},
}

value_2 = {
"a": [3, 4],
"b": {"d": 2, "z": [7]},
}

expected_result = {
"a": [1, 2, 3, 4],
"b": {"c": 1, "z": [5, 6, 7], "d": 2},
}

result = deep_merge(value_1, value_2)
assert result == expected_result
assert result["a"] == expected_result["a"]
assert result["b"]["z"] == expected_result["b"]["z"]


def test_deep_merge_single_dict():
value_1 = {
"a": [1, 2],
"b": {"c": 1, "z": [5, 6]},
}

expected_result = value_1

result = deep_merge(value_1)
assert result == expected_result
assert list(result.keys()) == list(expected_result.keys())
assert list(result["b"].keys()) == list(expected_result["b"].keys())


def test_deep_merge_empty():
expected_result = {}

result = deep_merge()
assert result == expected_result
Loading