diff --git a/metadata-ingestion/pyproject.toml b/metadata-ingestion/pyproject.toml index 6cf7eb1f6fb055..c181e76ce5ea25 100644 --- a/metadata-ingestion/pyproject.toml +++ b/metadata-ingestion/pyproject.toml @@ -56,5 +56,11 @@ max-complexity = 20 [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" +[tool.ruff.lint.flake8-tidy-imports.banned-api] +# pytandic v2 deprecations +"pydantic.validator" = { msg = "Use pydantic.field_validator instead of deprecated validator" } +"pydantic.root_validator" = { msg = "Use pydantic.model_validator instead of deprecated root_validator" } + [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] +"src/datahub/configuration/pydantic_migration_helpers.py" = ["TID251"] # Intentional V1 imports for backward compatibility diff --git a/metadata-ingestion/src/datahub/api/entities/common/serialized_value.py b/metadata-ingestion/src/datahub/api/entities/common/serialized_value.py index cb279a4652f129..6eb2025a824328 100644 --- a/metadata-ingestion/src/datahub/api/entities/common/serialized_value.py +++ b/metadata-ingestion/src/datahub/api/entities/common/serialized_value.py @@ -104,7 +104,7 @@ def as_pydantic_object( assert self.schema_ref assert self.schema_ref == model_type.__name__ object_dict = self.as_raw_json() - return model_type.parse_obj(object_dict) + return model_type.model_validate(object_dict) @classmethod def from_resource_value( @@ -131,7 +131,7 @@ def create( elif isinstance(object, BaseModel): return SerializedResourceValue( content_type=models.SerializedValueContentTypeClass.JSON, - blob=json.dumps(object.dict(), sort_keys=True).encode("utf-8"), + blob=json.dumps(object.model_dump(), sort_keys=True).encode("utf-8"), schema_type=models.SerializedValueSchemaTypeClass.JSON, schema_ref=object.__class__.__name__, ) diff --git a/metadata-ingestion/src/datahub/api/entities/corpgroup/corpgroup.py b/metadata-ingestion/src/datahub/api/entities/corpgroup/corpgroup.py index 238c3223699a27..0138878ac29953 100644 --- a/metadata-ingestion/src/datahub/api/entities/corpgroup/corpgroup.py +++ b/metadata-ingestion/src/datahub/api/entities/corpgroup/corpgroup.py @@ -2,10 +2,9 @@ import logging from dataclasses import dataclass -from typing import Callable, Iterable, List, Optional, Union +from typing import Any, Callable, Iterable, List, Optional, Union -import pydantic -from pydantic import BaseModel +from pydantic import BaseModel, field_validator import datahub.emitter.mce_builder as builder from datahub.api.entities.corpuser.corpuser import CorpUser, CorpUserGenerationConfig @@ -70,9 +69,15 @@ class CorpGroup(BaseModel): _rename_admins_to_owners = pydantic_renamed_field("admins", "owners") - @pydantic.validator("owners", "members", each_item=True) - def make_urn_if_needed(cls, v): - if isinstance(v, str): + @field_validator("owners", "members", mode="before") + @classmethod + def make_urn_if_needed(cls, v: Any) -> Any: + if isinstance(v, list): + return [ + builder.make_user_urn(item) if isinstance(item, str) else item + for item in v + ] + elif isinstance(v, str): return builder.make_user_urn(v) return v diff --git a/metadata-ingestion/src/datahub/api/entities/corpuser/corpuser.py b/metadata-ingestion/src/datahub/api/entities/corpuser/corpuser.py index 9fe1ebedafca7e..6d1fba5e2d8088 100644 --- a/metadata-ingestion/src/datahub/api/entities/corpuser/corpuser.py +++ b/metadata-ingestion/src/datahub/api/entities/corpuser/corpuser.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from typing import Callable, Iterable, List, Optional -import pydantic +from pydantic import model_validator import datahub.emitter.mce_builder as builder from datahub.configuration.common import ConfigModel @@ -65,16 +65,16 @@ class CorpUser(ConfigModel): picture_link: Optional[str] = None phone: Optional[str] = None - @pydantic.validator("full_name", always=True) - def full_name_can_be_built_from_first_name_last_name(v, values): - if not v: - if "first_name" in values or "last_name" in values: - first_name = values.get("first_name") or "" - last_name = values.get("last_name") or "" - full_name = f"{first_name} {last_name}" if last_name else first_name - return full_name - else: - return v + @model_validator(mode="after") + def full_name_can_be_built_from_first_name_last_name(self) -> "CorpUser": + if not self.full_name: + if self.first_name or self.last_name: + first_name = self.first_name or "" + last_name = self.last_name or "" + self.full_name = ( + f"{first_name} {last_name}" if last_name else first_name + ) + return self @property def urn(self): diff --git a/metadata-ingestion/src/datahub/api/entities/dataproduct/dataproduct.py b/metadata-ingestion/src/datahub/api/entities/dataproduct/dataproduct.py index 571e01a1f432fe..f029c3332b8848 100644 --- a/metadata-ingestion/src/datahub/api/entities/dataproduct/dataproduct.py +++ b/metadata-ingestion/src/datahub/api/entities/dataproduct/dataproduct.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union -import pydantic +from pydantic import field_validator, model_validator from ruamel.yaml import YAML from typing_extensions import assert_never @@ -71,7 +71,8 @@ class Ownership(ConfigModel): id: str type: str - @pydantic.validator("type") + @field_validator("type", mode="after") + @classmethod def ownership_type_must_be_mappable_or_custom(cls, v: str) -> str: _, _ = builder.validate_ownership_type(v) return v @@ -116,30 +117,49 @@ class DataProduct(ConfigModel): output_ports: Optional[List[str]] = None _original_yaml_dict: Optional[dict] = None - @pydantic.validator("assets", each_item=True) - def assets_must_be_urns(cls, v: str) -> str: - try: - Urn.from_string(v) - except Exception as e: - raise ValueError(f"asset {v} is not an urn: {e}") from e - - return v - - @pydantic.validator("output_ports", each_item=True) - def output_ports_must_be_urns(cls, v: str) -> str: - try: - Urn.create_from_string(v) - except Exception as e: - raise ValueError(f"Output port {v} is not an urn: {e}") from e + @field_validator("assets", mode="before") + @classmethod + def assets_must_be_urns(cls, v: Any) -> Any: + if isinstance(v, list): + for item in v: + try: + Urn.from_string(item) + except Exception as e: + raise ValueError(f"asset {item} is not an urn: {e}") from e + return v + else: + try: + Urn.from_string(v) + except Exception as e: + raise ValueError(f"asset {v} is not an urn: {e}") from e + return v + @field_validator("output_ports", mode="before") + @classmethod + def output_ports_must_be_urns(cls, v: Any) -> Any: + if v is not None: + if isinstance(v, list): + for item in v: + try: + Urn.create_from_string(item) + except Exception as e: + raise ValueError( + f"Output port {item} is not an urn: {e}" + ) from e + else: + try: + Urn.create_from_string(v) + except Exception as e: + raise ValueError(f"Output port {v} is not an urn: {e}") from e return v - @pydantic.validator("output_ports", each_item=True) - def output_ports_must_be_from_asset_list(cls, v: str, values: dict) -> str: - assets = values.get("assets", []) - if v not in assets: - raise ValueError(f"Output port {v} is not in asset list") - return v + @model_validator(mode="after") + def output_ports_must_be_from_asset_list(self) -> "DataProduct": + if self.output_ports and self.assets: + for port in self.output_ports: + if port not in self.assets: + raise ValueError(f"Output port {port} is not in asset list") + return self @property def urn(self) -> str: @@ -454,7 +474,7 @@ def _patch_ownership( patches_add.append(new_owner) else: patches_add.append( - Ownership(id=new_owner, type=new_owner_type).dict() + Ownership(id=new_owner, type=new_owner_type).model_dump() ) mutation_needed = bool(patches_replace or patches_drop or patches_add) @@ -485,8 +505,8 @@ def patch_yaml( raise Exception("Original Data Product was not loaded from yaml") orig_dictionary = original_dataproduct._original_yaml_dict - original_dataproduct_dict = original_dataproduct.dict() - this_dataproduct_dict = self.dict() + original_dataproduct_dict = original_dataproduct.model_dump() + this_dataproduct_dict = self.model_dump() for simple_field in ["display_name", "description", "external_url"]: if original_dataproduct_dict.get(simple_field) != this_dataproduct_dict.get( simple_field @@ -566,7 +586,7 @@ def to_yaml( yaml = YAML(typ="rt") # default, if not specfied, is 'rt' (round-trip) yaml.indent(mapping=2, sequence=4, offset=2) yaml.default_flow_style = False - yaml.dump(self.dict(), fp) + yaml.dump(self.model_dump(), fp) @staticmethod def get_patch_builder( diff --git a/metadata-ingestion/src/datahub/api/entities/dataset/dataset.py b/metadata-ingestion/src/datahub/api/entities/dataset/dataset.py index 7924f5d86e9e9a..c3f0b75ad8eae0 100644 --- a/metadata-ingestion/src/datahub/api/entities/dataset/dataset.py +++ b/metadata-ingestion/src/datahub/api/entities/dataset/dataset.py @@ -3,6 +3,7 @@ import time from pathlib import Path from typing import ( + Any, Dict, Iterable, List, @@ -19,8 +20,9 @@ BaseModel, Field, StrictStr, - root_validator, - validator, + ValidationInfo, + field_validator, + model_validator, ) from ruamel.yaml import YAML from typing_extensions import TypeAlias @@ -213,14 +215,15 @@ def from_schema_field( ), ) - @validator("urn", pre=True, always=True) - def either_id_or_urn_must_be_filled_out(cls, v, values): - if not v and not values.get("id"): + @model_validator(mode="after") + def either_id_or_urn_must_be_filled_out(self) -> "SchemaFieldSpecification": + if not self.urn and not self.id: raise ValueError("Either id or urn must be present") - return v + return self - @root_validator(pre=True) - def sync_doc_into_description(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def sync_doc_into_description(cls, values: Any) -> Any: """Synchronize doc into description field if doc is provided.""" description = values.get("description") doc = values.pop("doc", None) @@ -348,8 +351,9 @@ class SchemaSpecification(BaseModel): fields: Optional[List[SchemaFieldSpecification]] = None raw_schema: Optional[str] = None - @validator("file") - def file_must_be_avsc(cls, v): + @field_validator("file", mode="after") + @classmethod + def file_must_be_avsc(cls, v: Optional[str]) -> Optional[str]: if v and not v.endswith(".avsc"): raise ValueError("file must be a .avsc file") return v @@ -359,7 +363,8 @@ class Ownership(ConfigModel): id: str type: str - @validator("type") + @field_validator("type", mode="after") + @classmethod def ownership_type_must_be_mappable_or_custom(cls, v: str) -> str: _, _ = validate_ownership_type(v) return v @@ -397,30 +402,36 @@ def platform_urn(self) -> str: dataset_urn = DatasetUrn.from_string(self.urn) return str(dataset_urn.get_data_platform_urn()) - @validator("urn", pre=True, always=True) - def urn_must_be_present(cls, v, values): + @field_validator("urn", mode="before") + @classmethod + def urn_must_be_present(cls, v: Any, info: ValidationInfo) -> Any: if not v: + values = info.data assert "id" in values, "id must be present if urn is not" assert "platform" in values, "platform must be present if urn is not" assert "env" in values, "env must be present if urn is not" return make_dataset_urn(values["platform"], values["id"], values["env"]) return v - @validator("name", pre=True, always=True) - def name_filled_with_id_if_not_present(cls, v, values): + @field_validator("name", mode="before") + @classmethod + def name_filled_with_id_if_not_present(cls, v: Any, info: ValidationInfo) -> Any: if not v: + values = info.data assert "id" in values, "id must be present if name is not" return values["id"] return v - @validator("platform") - def platform_must_not_be_urn(cls, v): - if v.startswith("urn:li:dataPlatform:"): + @field_validator("platform", mode="after") + @classmethod + def platform_must_not_be_urn(cls, v: Optional[str]) -> Optional[str]: + if v and v.startswith("urn:li:dataPlatform:"): return v[len("urn:li:dataPlatform:") :] return v - @validator("structured_properties") - def simplify_structured_properties(cls, v): + @field_validator("structured_properties", mode="after") + @classmethod + def simplify_structured_properties(cls, v: Any) -> Any: return StructuredPropertiesHelper.simplify_structured_properties_list(v) def _mint_auditstamp(self, message: str) -> AuditStampClass: @@ -461,7 +472,7 @@ def from_yaml(cls, file: str) -> Iterable["Dataset"]: if isinstance(datasets, dict): datasets = [datasets] for dataset_raw in datasets: - dataset = Dataset.parse_obj(dataset_raw) + dataset = Dataset.model_validate(dataset_raw) # dataset = Dataset.model_validate(dataset_raw, strict=True) yield dataset diff --git a/metadata-ingestion/src/datahub/api/entities/external/lake_formation_external_entites.py b/metadata-ingestion/src/datahub/api/entities/external/lake_formation_external_entites.py index 3c3469b5d81850..24ee30b763cb23 100644 --- a/metadata-ingestion/src/datahub/api/entities/external/lake_formation_external_entites.py +++ b/metadata-ingestion/src/datahub/api/entities/external/lake_formation_external_entites.py @@ -12,7 +12,7 @@ # https://learn.microsoft.com/en-us/azure/databricks/database-objects/tags#constraint from typing import Any, Dict, Optional -from pydantic import validator +from pydantic import field_validator from typing_extensions import ClassVar from datahub.api.entities.external.external_tag import ExternalTag @@ -50,11 +50,10 @@ class LakeFormationTag(ExternalTag): value: Optional[LakeFormationTagValueText] = None catalog: Optional[str] = None - # Pydantic v1 validators - @validator("key", pre=True) + @field_validator("key", mode="before") @classmethod def _validate_key(cls, v: Any) -> LakeFormationTagKeyText: - """Validate and convert key field for Pydantic v1.""" + """Validate and convert key field.""" if isinstance(v, LakeFormationTagKeyText): return v @@ -64,10 +63,10 @@ def _validate_key(cls, v: Any) -> LakeFormationTagKeyText: return LakeFormationTagKeyText(raw_text=v) - @validator("value", pre=True) + @field_validator("value", mode="before") @classmethod def _validate_value(cls, v: Any) -> Optional[LakeFormationTagValueText]: - """Validate and convert value field for Pydantic v1.""" + """Validate and convert value field.""" if v is None: return None diff --git a/metadata-ingestion/src/datahub/api/entities/external/unity_catalog_external_entites.py b/metadata-ingestion/src/datahub/api/entities/external/unity_catalog_external_entites.py index 399c6e66b33985..03ec130715be0b 100644 --- a/metadata-ingestion/src/datahub/api/entities/external/unity_catalog_external_entites.py +++ b/metadata-ingestion/src/datahub/api/entities/external/unity_catalog_external_entites.py @@ -12,8 +12,7 @@ # https://learn.microsoft.com/en-us/azure/databricks/database-objects/tags#constraint from typing import Any, Dict, Optional, Set -# Import validator for Pydantic v1 (always needed since we removed conditional logic) -from pydantic import validator +from pydantic import field_validator from typing_extensions import ClassVar from datahub.api.entities.external.external_tag import ExternalTag @@ -62,11 +61,10 @@ class UnityCatalogTag(ExternalTag): key: UnityCatalogTagKeyText value: Optional[UnityCatalogTagValueText] = None - # Pydantic v1 validators - @validator("key", pre=True) + @field_validator("key", mode="before") @classmethod def _validate_key(cls, v: Any) -> UnityCatalogTagKeyText: - """Validate and convert key field for Pydantic v1.""" + """Validate and convert key field.""" if isinstance(v, UnityCatalogTagKeyText): return v @@ -76,10 +74,10 @@ def _validate_key(cls, v: Any) -> UnityCatalogTagKeyText: return UnityCatalogTagKeyText(raw_text=v) - @validator("value", pre=True) + @field_validator("value", mode="before") @classmethod def _validate_value(cls, v: Any) -> Optional[UnityCatalogTagValueText]: - """Validate and convert value field for Pydantic v1.""" + """Validate and convert value field.""" if v is None: return None diff --git a/metadata-ingestion/src/datahub/api/entities/forms/forms.py b/metadata-ingestion/src/datahub/api/entities/forms/forms.py index 4760ba7105b18b..3e42d8bf339233 100644 --- a/metadata-ingestion/src/datahub/api/entities/forms/forms.py +++ b/metadata-ingestion/src/datahub/api/entities/forms/forms.py @@ -5,7 +5,7 @@ from typing import List, Optional, Union import yaml -from pydantic import Field, validator +from pydantic import Field, model_validator from ruamel.yaml import YAML from typing_extensions import Literal @@ -70,11 +70,13 @@ class Prompt(ConfigModel): structured_property_urn: Optional[str] = Field(default=None, validate_default=True) required: Optional[bool] = None - @validator("structured_property_urn", pre=True, always=True) - def structured_property_urn_must_be_present(cls, v, values): - if not v and values.get("structured_property_id"): - return Urn.make_structured_property_urn(values["structured_property_id"]) - return v + @model_validator(mode="after") + def structured_property_urn_must_be_present(self) -> "Prompt": + if not self.structured_property_urn and self.structured_property_id: + self.structured_property_urn = Urn.make_structured_property_urn( + self.structured_property_id + ) + return self class FormType(Enum): @@ -122,13 +124,13 @@ class Forms(ConfigModel): group_owners: Optional[List[str]] = None # can be group IDs or urns actors: Optional[Actors] = None - @validator("urn", pre=True, always=True) - def urn_must_be_present(cls, v, values): - if not v: - if values.get("id") is None: + @model_validator(mode="after") + def urn_must_be_present(self) -> "Forms": + if not self.urn: + if self.id is None: raise ValueError("Form id must be present if urn is not") - return f"urn:li:form:{values['id']}" - return v + self.urn = f"urn:li:form:{self.id}" + return self @staticmethod def create(file: str) -> None: @@ -137,7 +139,7 @@ def create(file: str) -> None: with get_default_graph(ClientMode.CLI) as emitter, open(file) as fp: forms: List[dict] = yaml.safe_load(fp) for form_raw in forms: - form = Forms.parse_obj(form_raw) + form = Forms.model_validate(form_raw) try: if not FormType.has_value(form.type): @@ -445,4 +447,4 @@ def to_yaml( yaml = YAML(typ="rt") # default, if not specfied, is 'rt' (round-trip) yaml.indent(mapping=2, sequence=4, offset=2) yaml.default_flow_style = False - yaml.dump(self.dict(), fp) + yaml.dump(self.model_dump(), fp) diff --git a/metadata-ingestion/src/datahub/api/entities/structuredproperties/structuredproperties.py b/metadata-ingestion/src/datahub/api/entities/structuredproperties/structuredproperties.py index 9ee0c6cadf9138..ce4b4cddeb0612 100644 --- a/metadata-ingestion/src/datahub/api/entities/structuredproperties/structuredproperties.py +++ b/metadata-ingestion/src/datahub/api/entities/structuredproperties/structuredproperties.py @@ -4,7 +4,7 @@ from typing import Iterable, List, Optional, Type, Union import yaml -from pydantic import Field, StrictStr, validator +from pydantic import Field, StrictStr, field_validator, model_validator from ruamel.yaml import YAML from datahub.configuration.common import ConfigModel @@ -61,9 +61,12 @@ def _validate_entity_type_urn(cls: Type, v: str) -> str: class TypeQualifierAllowedTypes(ConfigModel): allowed_types: List[str] - _check_allowed_types = validator("allowed_types", each_item=True, allow_reuse=True)( - _validate_entity_type_urn - ) + @field_validator("allowed_types", mode="before") + @classmethod + def _check_allowed_types(cls, v: Union[str, List[str]]) -> Union[str, List[str]]: + if isinstance(v, list): + return [_validate_entity_type_urn(cls, item) for item in v] + return _validate_entity_type_urn(cls, v) class StructuredProperties(ConfigModel): @@ -80,11 +83,15 @@ class StructuredProperties(ConfigModel): type_qualifier: Optional[TypeQualifierAllowedTypes] = None immutable: Optional[bool] = False - _check_entity_types = validator("entity_types", each_item=True, allow_reuse=True)( - _validate_entity_type_urn - ) + @field_validator("entity_types", mode="before") + @classmethod + def _check_entity_types(cls, v: Union[str, List[str]]) -> Union[str, List[str]]: + if isinstance(v, list): + return [_validate_entity_type_urn(cls, item) for item in v] + return _validate_entity_type_urn(cls, v) - @validator("type") + @field_validator("type", mode="after") + @classmethod def validate_type(cls, v: str) -> str: # This logic is somewhat hacky, since we need to deal with # 1. fully qualified urns @@ -123,13 +130,13 @@ def fqn(self) -> str: ) return id - @validator("urn", pre=True, always=True) - def urn_must_be_present(cls, v, values): - if not v: - if "id" not in values: + @model_validator(mode="after") + def urn_must_be_present(self) -> "StructuredProperties": + if not self.urn: + if not hasattr(self, "id") or not self.id: raise ValueError("id must be present if urn is not") - return f"urn:li:structuredProperty:{values['id']}" - return v + self.urn = f"urn:li:structuredProperty:{self.id}" + return self @staticmethod def from_yaml(file: str) -> List["StructuredProperties"]: @@ -138,7 +145,7 @@ def from_yaml(file: str) -> List["StructuredProperties"]: result: List[StructuredProperties] = [] for structuredproperty_raw in structuredproperties: - result.append(StructuredProperties.parse_obj(structuredproperty_raw)) + result.append(StructuredProperties.model_validate(structuredproperty_raw)) return result def generate_mcps(self) -> List[MetadataChangeProposalWrapper]: @@ -225,7 +232,7 @@ def to_yaml( yaml = YAML(typ="rt") # default, if not specfied, is 'rt' (round-trip) yaml.indent(mapping=2, sequence=4, offset=2) yaml.default_flow_style = False - yaml.dump(self.dict(), fp) + yaml.dump(self.model_dump(), fp) @staticmethod def list_urns(graph: DataHubGraph) -> Iterable[str]: diff --git a/metadata-ingestion/src/datahub/cli/check_cli.py b/metadata-ingestion/src/datahub/cli/check_cli.py index e58f6ba5c17971..d964461b801f50 100644 --- a/metadata-ingestion/src/datahub/cli/check_cli.py +++ b/metadata-ingestion/src/datahub/cli/check_cli.py @@ -316,7 +316,7 @@ def test_allow_deny(config: str, input: str, pattern_key: str) -> None: click.secho(f"{pattern_key} is not defined in the config", fg="red") exit(1) - allow_deny_pattern = AllowDenyPattern.parse_obj(pattern_dict) + allow_deny_pattern = AllowDenyPattern.model_validate(pattern_dict) if allow_deny_pattern.allowed(input): click.secho(f"✅ {input} is allowed by {pattern_key}", fg="green") exit(0) @@ -372,7 +372,7 @@ def test_path_spec(config: str, input: str, path_spec_key: str) -> None: pattern_dicts = [pattern_dicts] for pattern_dict in pattern_dicts: - path_spec_pattern = PathSpec.parse_obj(pattern_dict) + path_spec_pattern = PathSpec.model_validate(pattern_dict) if path_spec_pattern.allowed(input): click.echo(f"{input} is allowed by {path_spec_pattern}") else: diff --git a/metadata-ingestion/src/datahub/cli/config_utils.py b/metadata-ingestion/src/datahub/cli/config_utils.py index 11b31f6c27a8b6..54164db795e35e 100644 --- a/metadata-ingestion/src/datahub/cli/config_utils.py +++ b/metadata-ingestion/src/datahub/cli/config_utils.py @@ -114,7 +114,7 @@ def load_client_config() -> DatahubClientConfig: try: _ensure_datahub_config() client_config_dict = get_raw_client_config() - datahub_config: DatahubClientConfig = DatahubConfig.parse_obj( + datahub_config: DatahubClientConfig = DatahubConfig.model_validate( client_config_dict ).gms return datahub_config @@ -146,7 +146,7 @@ def write_gms_config( logger.debug( f"Failed to retrieve config from file {DATAHUB_CONFIG_PATH}: {e}. This isn't fatal." ) - config_dict = {**previous_config, **config.dict()} + config_dict = {**previous_config, **config.model_dump()} else: - config_dict = config.dict() + config_dict = config.model_dump() persist_raw_datahub_config(config_dict) diff --git a/metadata-ingestion/src/datahub/cli/lite_cli.py b/metadata-ingestion/src/datahub/cli/lite_cli.py index 5feee9188ece87..a1f64b2d752e6c 100644 --- a/metadata-ingestion/src/datahub/cli/lite_cli.py +++ b/metadata-ingestion/src/datahub/cli/lite_cli.py @@ -40,13 +40,13 @@ class DuckDBLiteConfigWrapper(DuckDBLiteConfig): class LiteCliConfig(DatahubConfig): lite: LiteLocalConfig = LiteLocalConfig( - type="duckdb", config=DuckDBLiteConfigWrapper().dict() + type="duckdb", config=DuckDBLiteConfigWrapper().model_dump() ) def get_lite_config() -> LiteLocalConfig: client_config_dict = get_raw_client_config() - lite_config = LiteCliConfig.parse_obj(client_config_dict) + lite_config = LiteCliConfig.model_validate(client_config_dict) return lite_config.lite @@ -55,7 +55,9 @@ def _get_datahub_lite(read_only: bool = False) -> DataHubLiteLocal: if lite_config.type == "duckdb": lite_config.config["read_only"] = read_only - duckdb_lite = get_datahub_lite(config_dict=lite_config.dict(), read_only=read_only) + duckdb_lite = get_datahub_lite( + config_dict=lite_config.model_dump(), read_only=read_only + ) return duckdb_lite @@ -308,7 +310,7 @@ def search( ): result_str = searchable.id if details: - result_str = json.dumps(searchable.dict()) + result_str = json.dumps(searchable.model_dump()) # suppress id if we have already seen it in the non-detailed response if details or searchable.id not in result_ids: click.secho(result_str) @@ -321,7 +323,7 @@ def search( def write_lite_config(lite_config: LiteLocalConfig) -> None: cli_config = get_raw_client_config() assert isinstance(cli_config, dict) - cli_config["lite"] = lite_config.dict() + cli_config["lite"] = lite_config.model_dump() persist_raw_datahub_config(cli_config) @@ -332,12 +334,12 @@ def write_lite_config(lite_config: LiteLocalConfig) -> None: @telemetry.with_telemetry() def init(ctx: click.Context, type: Optional[str], file: Optional[str]) -> None: lite_config = get_lite_config() - new_lite_config_dict = lite_config.dict() + new_lite_config_dict = lite_config.model_dump() # Update the type and config sections only new_lite_config_dict["type"] = type if file: new_lite_config_dict["config"]["file"] = file - new_lite_config = LiteLocalConfig.parse_obj(new_lite_config_dict) + new_lite_config = LiteLocalConfig.model_validate(new_lite_config_dict) if lite_config != new_lite_config: if click.confirm( f"Will replace datahub lite config {lite_config} with {new_lite_config}" diff --git a/metadata-ingestion/src/datahub/cli/migrate.py b/metadata-ingestion/src/datahub/cli/migrate.py index fa2fa1c5a7232b..2e8b3197e63868 100644 --- a/metadata-ingestion/src/datahub/cli/migrate.py +++ b/metadata-ingestion/src/datahub/cli/migrate.py @@ -318,13 +318,13 @@ def migrate_containers( try: newKey: Union[SchemaKey, DatabaseKey, ProjectIdKey, BigQueryDatasetKey] if subType == "Schema": - newKey = SchemaKey.parse_obj(customProperties) + newKey = SchemaKey.model_validate(customProperties) elif subType == "Database": - newKey = DatabaseKey.parse_obj(customProperties) + newKey = DatabaseKey.model_validate(customProperties) elif subType == "Project": - newKey = ProjectIdKey.parse_obj(customProperties) + newKey = ProjectIdKey.model_validate(customProperties) elif subType == "Dataset": - newKey = BigQueryDatasetKey.parse_obj(customProperties) + newKey = BigQueryDatasetKey.model_validate(customProperties) else: log.warning(f"Invalid subtype {subType}. Skipping") continue diff --git a/metadata-ingestion/src/datahub/cli/quickstart_versioning.py b/metadata-ingestion/src/datahub/cli/quickstart_versioning.py index fc87e2e703472e..e16f3a45b7820b 100644 --- a/metadata-ingestion/src/datahub/cli/quickstart_versioning.py +++ b/metadata-ingestion/src/datahub/cli/quickstart_versioning.py @@ -80,7 +80,7 @@ def fetch_quickstart_config(cls) -> "QuickstartVersionMappingConfig": path = os.path.expanduser(LOCAL_QUICKSTART_MAPPING_FILE) with open(path) as f: config_raw = yaml.safe_load(f) - return cls.parse_obj(config_raw) + return cls.model_validate(config_raw) config_raw = None try: @@ -110,7 +110,7 @@ def fetch_quickstart_config(cls) -> "QuickstartVersionMappingConfig": } ) - config = cls.parse_obj(config_raw) + config = cls.model_validate(config_raw) # If stable is not defined in the config, we need to fetch the latest version from github. if config.quickstart_version_map.get("stable") is None: @@ -177,7 +177,7 @@ def save_quickstart_config( path = os.path.expanduser(path) os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w") as f: - yaml.dump(config.dict(), f) + yaml.dump(config.model_dump(), f) logger.info(f"Saved quickstart config to {path}.") diff --git a/metadata-ingestion/src/datahub/cli/specific/group_cli.py b/metadata-ingestion/src/datahub/cli/specific/group_cli.py index b76be9986456e9..0c5c052a6dfa30 100644 --- a/metadata-ingestion/src/datahub/cli/specific/group_cli.py +++ b/metadata-ingestion/src/datahub/cli/specific/group_cli.py @@ -42,7 +42,7 @@ def upsert(file: Path, override_editable: bool) -> None: with get_default_graph(ClientMode.CLI) as emitter: for group_config in group_configs: try: - datahub_group = CorpGroup.parse_obj(group_config) + datahub_group = CorpGroup.model_validate(group_config) for mcp in datahub_group.generate_mcp( generation_config=CorpGroupGenerationConfig( override_editable=override_editable, datahub_graph=emitter diff --git a/metadata-ingestion/src/datahub/cli/specific/structuredproperties_cli.py b/metadata-ingestion/src/datahub/cli/specific/structuredproperties_cli.py index 85a6c58c94ec77..4ee0b2b8163c37 100644 --- a/metadata-ingestion/src/datahub/cli/specific/structuredproperties_cli.py +++ b/metadata-ingestion/src/datahub/cli/specific/structuredproperties_cli.py @@ -85,7 +85,7 @@ def to_yaml_list( with open(file, "r") as fp: existing_objects = yaml.load(fp) # this is a list of dicts existing_objects = [ - StructuredProperties.parse_obj(obj) for obj in existing_objects + StructuredProperties.model_validate(obj) for obj in existing_objects ] objects = [obj for obj in objects] # do a positional update of the existing objects diff --git a/metadata-ingestion/src/datahub/cli/specific/user_cli.py b/metadata-ingestion/src/datahub/cli/specific/user_cli.py index 868f4855cf08f1..3259e29f984f3f 100644 --- a/metadata-ingestion/src/datahub/cli/specific/user_cli.py +++ b/metadata-ingestion/src/datahub/cli/specific/user_cli.py @@ -42,7 +42,7 @@ def upsert(file: Path, override_editable: bool) -> None: with get_default_graph(ClientMode.CLI) as emitter: for user_config in user_configs: try: - datahub_user: CorpUser = CorpUser.parse_obj(user_config) + datahub_user: CorpUser = CorpUser.model_validate(user_config) emitter.emit_all( datahub_user.generate_mcp( diff --git a/metadata-ingestion/src/datahub/configuration/common.py b/metadata-ingestion/src/datahub/configuration/common.py index af0f862e8a2189..ad15c327c2cd18 100644 --- a/metadata-ingestion/src/datahub/configuration/common.py +++ b/metadata-ingestion/src/datahub/configuration/common.py @@ -140,6 +140,18 @@ def _schema_extra(schema: Dict[str, Any], model: Type["ConfigModel"]) -> None: @classmethod def parse_obj_allow_extras(cls, obj: Any) -> Self: + """Parse an object while allowing extra fields. + + 'parse_obj' in Pydantic v1 is equivalent to 'model_validate' in Pydantic v2. + However, 'parse_obj_allow_extras' in v1 is not directly available in v2. + + `model_validate(..., strict=False)` does not work because it still raises errors on extra fields; + strict=False only affects type coercion and validation strictness, not extra field handling. + + This method temporarily modifies the model's configuration to allow extra fields + + TODO: Do we really need to support this behaviour? Consider removing this method in future. + """ if PYDANTIC_VERSION_2: try: with unittest.mock.patch.dict( @@ -148,12 +160,12 @@ def parse_obj_allow_extras(cls, obj: Any) -> Self: clear=False, ): cls.model_rebuild(force=True) # type: ignore - return cls.parse_obj(obj) + return cls.model_validate(obj) finally: cls.model_rebuild(force=True) # type: ignore else: with unittest.mock.patch.object(cls.Config, "extra", pydantic.Extra.allow): - return cls.parse_obj(obj) + return cls.model_validate(obj) class PermissiveConfigModel(ConfigModel): diff --git a/metadata-ingestion/src/datahub/configuration/connection_resolver.py b/metadata-ingestion/src/datahub/configuration/connection_resolver.py index 8d329753c3b37b..cba9d785d5c39c 100644 --- a/metadata-ingestion/src/datahub/configuration/connection_resolver.py +++ b/metadata-ingestion/src/datahub/configuration/connection_resolver.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Type -import pydantic +from pydantic import model_validator from datahub.ingestion.api.global_context import get_graph_context @@ -40,4 +40,4 @@ def _resolve_connection(cls: Type, values: dict) -> dict: # https://github.com/pydantic/pydantic/blob/v1.10.9/pydantic/main.py#L264 # This hack ensures that multiple validators do not overwrite each other. _resolve_connection.__name__ = f"{_resolve_connection.__name__}_{connection_field}" - return pydantic.root_validator(pre=True, allow_reuse=True)(_resolve_connection) + return model_validator(mode="before")(_resolve_connection) diff --git a/metadata-ingestion/src/datahub/configuration/git.py b/metadata-ingestion/src/datahub/configuration/git.py index df25949d782c14..e52316681c8b21 100644 --- a/metadata-ingestion/src/datahub/configuration/git.py +++ b/metadata-ingestion/src/datahub/configuration/git.py @@ -1,7 +1,14 @@ import pathlib +from copy import deepcopy from typing import Any, Dict, Optional, Union -from pydantic import Field, FilePath, SecretStr, validator +from pydantic import ( + Field, + FilePath, + SecretStr, + field_validator, + model_validator, +) from datahub.configuration.common import ConfigModel from datahub.configuration.validate_field_rename import pydantic_renamed_field @@ -41,7 +48,8 @@ class GitReference(ConfigModel): transform=lambda url: _GITHUB_URL_TEMPLATE, ) - @validator("repo", pre=True) + @field_validator("repo", mode="before") + @classmethod def simplify_repo_url(cls, repo: str) -> str: if repo.startswith("github.com/") or repo.startswith("gitlab.com"): repo = f"https://{repo}" @@ -53,21 +61,22 @@ def simplify_repo_url(cls, repo: str) -> str: return repo - @validator("url_template", always=True) - def infer_url_template(cls, url_template: Optional[str], values: dict) -> str: - if url_template is not None: - return url_template + @model_validator(mode="after") + def infer_url_template(self) -> "GitReference": + if self.url_template is not None: + return self - repo: str = values["repo"] - if repo.startswith(_GITHUB_PREFIX): - return _GITHUB_URL_TEMPLATE - elif repo.startswith(_GITLAB_PREFIX): - return _GITLAB_URL_TEMPLATE + if self.repo.startswith(_GITHUB_PREFIX): + self.url_template = _GITHUB_URL_TEMPLATE + elif self.repo.startswith(_GITLAB_PREFIX): + self.url_template = _GITLAB_URL_TEMPLATE else: raise ValueError( "Unable to infer URL template from repo. Please set url_template manually." ) + return self + def get_url_for_file_path(self, file_path: str) -> str: assert self.url_template if self.url_subdir: @@ -98,35 +107,43 @@ class GitInfo(GitReference): _fix_deploy_key_newlines = pydantic_multiline_string("deploy_key") - @validator("deploy_key", pre=True, always=True) + @model_validator(mode="before") + @classmethod def deploy_key_filled_from_deploy_key_file( - cls, v: Optional[SecretStr], values: Dict[str, Any] - ) -> Optional[SecretStr]: - if v is None: + cls, values: Dict[str, Any] + ) -> Dict[str, Any]: + # In-place update of the input dict would cause state contamination. + # So a deepcopy is performed first. + values = deepcopy(values) + + if values.get("deploy_key") is None: deploy_key_file = values.get("deploy_key_file") if deploy_key_file is not None: with open(deploy_key_file) as fp: deploy_key = SecretStr(fp.read()) - return deploy_key - return v - - @validator("repo_ssh_locator", always=True) - def infer_repo_ssh_locator( - cls, repo_ssh_locator: Optional[str], values: dict - ) -> str: - if repo_ssh_locator is not None: - return repo_ssh_locator - - repo: str = values["repo"] - if repo.startswith(_GITHUB_PREFIX): - return f"git@github.com:{repo[len(_GITHUB_PREFIX) :]}.git" - elif repo.startswith(_GITLAB_PREFIX): - return f"git@gitlab.com:{repo[len(_GITLAB_PREFIX) :]}.git" + values["deploy_key"] = deploy_key + return values + + @model_validator(mode="after") + def infer_repo_ssh_locator(self) -> "GitInfo": + if self.repo_ssh_locator is not None: + return self + + if self.repo.startswith(_GITHUB_PREFIX): + self.repo_ssh_locator = ( + f"git@github.com:{self.repo[len(_GITHUB_PREFIX) :]}.git" + ) + elif self.repo.startswith(_GITLAB_PREFIX): + self.repo_ssh_locator = ( + f"git@gitlab.com:{self.repo[len(_GITLAB_PREFIX) :]}.git" + ) else: raise ValueError( "Unable to infer repo_ssh_locator from repo. Please set repo_ssh_locator manually." ) + return self + @property def branch_for_clone(self) -> Optional[str]: # If branch was manually set, we should use it. Otherwise return None. diff --git a/metadata-ingestion/src/datahub/configuration/import_resolver.py b/metadata-ingestion/src/datahub/configuration/import_resolver.py index f9c9c82b4ae529..7fad3062fe3133 100644 --- a/metadata-ingestion/src/datahub/configuration/import_resolver.py +++ b/metadata-ingestion/src/datahub/configuration/import_resolver.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Type, TypeVar, Union -import pydantic +from pydantic import field_validator from datahub.ingestion.api.registry import import_path @@ -15,4 +15,4 @@ def _pydantic_resolver(cls: Type, v: Union[str, _T]) -> _T: def pydantic_resolve_key(field: str) -> "V1Validator": - return pydantic.validator(field, pre=True, allow_reuse=True)(_pydantic_resolver) + return field_validator(field, mode="before")(_pydantic_resolver) diff --git a/metadata-ingestion/src/datahub/configuration/kafka.py b/metadata-ingestion/src/datahub/configuration/kafka.py index 67c3dff5330f3d..ec5ce534ab7490 100644 --- a/metadata-ingestion/src/datahub/configuration/kafka.py +++ b/metadata-ingestion/src/datahub/configuration/kafka.py @@ -1,4 +1,4 @@ -from pydantic import Field, validator +from pydantic import Field, field_validator from datahub.configuration.common import ConfigModel, ConfigurationError from datahub.configuration.env_vars import ( @@ -42,7 +42,8 @@ class _KafkaConnectionConfig(ConfigModel): description="The request timeout used when interacting with the Kafka APIs.", ) - @validator("bootstrap") + @field_validator("bootstrap", mode="after") + @classmethod def bootstrap_host_colon_port_comma(cls, val: str) -> str: for entry in val.split(","): validate_host_port(entry) @@ -57,7 +58,7 @@ class KafkaConsumerConnectionConfig(_KafkaConnectionConfig): description="Extra consumer config serialized as JSON. These options will be passed into Kafka's DeserializingConsumer. See https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#deserializingconsumer and https://github.com/edenhill/librdkafka/blob/master/CONFIGURATION.md .", ) - @validator("consumer_config") + @field_validator("consumer_config", mode="after") @classmethod def resolve_callback(cls, value: dict) -> dict: if CallableConsumerConfig.is_callable_config(value): diff --git a/metadata-ingestion/src/datahub/configuration/time_window_config.py b/metadata-ingestion/src/datahub/configuration/time_window_config.py index 5fabcf904d3219..d54e7cb989d506 100644 --- a/metadata-ingestion/src/datahub/configuration/time_window_config.py +++ b/metadata-ingestion/src/datahub/configuration/time_window_config.py @@ -1,10 +1,9 @@ import enum from datetime import datetime, timedelta, timezone -from typing import Any, Dict, List +from typing import Any, List import humanfriendly -import pydantic -from pydantic.fields import Field +from pydantic import Field, ValidationInfo, field_validator, model_validator from datahub.configuration.common import ConfigModel from datahub.configuration.datetimes import parse_absolute_time, parse_relative_timespan @@ -52,45 +51,46 @@ class BaseTimeWindowConfig(ConfigModel): description="Earliest date of lineage/usage to consider. Default: Last full day in UTC (or hour, depending on `bucket_duration`). You can also specify relative time with respect to end_time such as '-7 days' Or '-7d'.", ) # type: ignore - @pydantic.validator("start_time", pre=True, always=True) - def default_start_time( - cls, v: Any, values: Dict[str, Any], **kwargs: Any - ) -> datetime: - if v is None: - return get_time_bucket( - values["end_time"] - - get_bucket_duration_delta(values["bucket_duration"]), - values["bucket_duration"], - ) - elif isinstance(v, str): + @field_validator("start_time", mode="before") + @classmethod + def parse_start_time(cls, v: Any, info: ValidationInfo) -> Any: + if isinstance(v, str): # This is where start_time str is resolved to datetime try: delta = parse_relative_timespan(v) assert delta < timedelta(0), ( "Relative start time should start with minus sign (-) e.g. '-2 days'." ) - assert abs(delta) >= get_bucket_duration_delta( - values["bucket_duration"] - ), ( + bucket_duration = info.data.get("bucket_duration", BucketDuration.DAY) + assert abs(delta) >= get_bucket_duration_delta(bucket_duration), ( "Relative start time should be in terms of configured bucket duration. e.g '-2 days' or '-2 hours'." ) - # The end_time's default value is not yet populated, in which case - # we can just manually generate it here. - if "end_time" not in values: - values["end_time"] = datetime.now(tz=timezone.utc) + # We need end_time, but it might not be set yet + # In that case, we'll use the default + end_time = info.data.get("end_time") + if end_time is None: + end_time = datetime.now(tz=timezone.utc) - return get_time_bucket( - values["end_time"] + delta, values["bucket_duration"] - ) + return get_time_bucket(end_time + delta, bucket_duration) except humanfriendly.InvalidTimespan: # We do not floor start_time to the bucket start time if absolute start time is specified. # If user has specified absolute start time in recipe, it's most likely that he means it. return parse_absolute_time(v) - return v - @pydantic.validator("start_time", "end_time") + @model_validator(mode="after") + def default_start_time(self) -> "BaseTimeWindowConfig": + # Only calculate start_time if it was None (not provided by user) + if self.start_time is None: + self.start_time = get_time_bucket( + self.end_time - get_bucket_duration_delta(self.bucket_duration), + self.bucket_duration, + ) + return self + + @field_validator("start_time", "end_time", mode="after") + @classmethod def ensure_timestamps_in_utc(cls, v: datetime) -> datetime: if v.tzinfo is None: raise ValueError( diff --git a/metadata-ingestion/src/datahub/configuration/validate_field_deprecation.py b/metadata-ingestion/src/datahub/configuration/validate_field_deprecation.py index 5143034bfd734f..79e676caee699a 100644 --- a/metadata-ingestion/src/datahub/configuration/validate_field_deprecation.py +++ b/metadata-ingestion/src/datahub/configuration/validate_field_deprecation.py @@ -1,7 +1,7 @@ import warnings from typing import TYPE_CHECKING, Any, Optional, Type -import pydantic +from pydantic import model_validator from datahub.configuration.common import ConfigurationWarning from datahub.utilities.global_warning_util import add_global_warning @@ -34,4 +34,4 @@ def _validate_deprecated(cls: Type, values: dict) -> dict: # https://github.com/pydantic/pydantic/blob/v1.10.9/pydantic/main.py#L264 # This hack ensures that multiple field deprecated do not overwrite each other. _validate_deprecated.__name__ = f"{_validate_deprecated.__name__}_{field}" - return pydantic.root_validator(pre=True, allow_reuse=True)(_validate_deprecated) + return model_validator(mode="before")(_validate_deprecated) diff --git a/metadata-ingestion/src/datahub/configuration/validate_field_removal.py b/metadata-ingestion/src/datahub/configuration/validate_field_removal.py index 0433730dc799cd..13ee062af300d5 100644 --- a/metadata-ingestion/src/datahub/configuration/validate_field_removal.py +++ b/metadata-ingestion/src/datahub/configuration/validate_field_removal.py @@ -1,7 +1,7 @@ import warnings from typing import TYPE_CHECKING, Type -import pydantic +from pydantic import model_validator from datahub.configuration.common import ConfigurationWarning @@ -31,4 +31,4 @@ def _validate_field_removal(cls: Type, values: dict) -> dict: # https://github.com/pydantic/pydantic/blob/v1.10.9/pydantic/main.py#L264 # This hack ensures that multiple field removals do not overwrite each other. _validate_field_removal.__name__ = f"{_validate_field_removal.__name__}_{field}" - return pydantic.root_validator(pre=True, allow_reuse=True)(_validate_field_removal) + return model_validator(mode="before")(_validate_field_removal) diff --git a/metadata-ingestion/src/datahub/configuration/validate_field_rename.py b/metadata-ingestion/src/datahub/configuration/validate_field_rename.py index 2551986a7d4cc9..4be2516a731f07 100644 --- a/metadata-ingestion/src/datahub/configuration/validate_field_rename.py +++ b/metadata-ingestion/src/datahub/configuration/validate_field_rename.py @@ -1,7 +1,7 @@ import warnings from typing import TYPE_CHECKING, Callable, Type, TypeVar -import pydantic +from pydantic import model_validator from datahub.configuration.common import ConfigurationWarning from datahub.utilities.global_warning_util import add_global_warning @@ -52,4 +52,4 @@ def _validate_field_rename(cls: Type, values: dict) -> dict: # validator with pre=True gets all the values that were passed in. # Given that a renamed field doesn't show up in the fields list, we can't use # the field-level validator, even with a different field name. - return pydantic.root_validator(pre=True, allow_reuse=True)(_validate_field_rename) + return model_validator(mode="before")(_validate_field_rename) diff --git a/metadata-ingestion/src/datahub/configuration/validate_multiline_string.py b/metadata-ingestion/src/datahub/configuration/validate_multiline_string.py index 54094052e6ecb1..b7589a8ffa7b31 100644 --- a/metadata-ingestion/src/datahub/configuration/validate_multiline_string.py +++ b/metadata-ingestion/src/datahub/configuration/validate_multiline_string.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Optional, Type, Union import pydantic +from pydantic import field_validator if TYPE_CHECKING: from pydantic.deprecated.class_validators import V1Validator @@ -31,4 +32,4 @@ def _validate_field( # https://github.com/pydantic/pydantic/blob/v1.10.9/pydantic/main.py#L264 # This hack ensures that multiple field deprecated do not overwrite each other. _validate_field.__name__ = f"{_validate_field.__name__}_{field}" - return pydantic.validator(field, pre=True, allow_reuse=True)(_validate_field) + return field_validator(field, mode="before")(_validate_field) diff --git a/metadata-ingestion/src/datahub/emitter/kafka_emitter.py b/metadata-ingestion/src/datahub/emitter/kafka_emitter.py index 781930011b78fb..d1502b16d7ef47 100644 --- a/metadata-ingestion/src/datahub/emitter/kafka_emitter.py +++ b/metadata-ingestion/src/datahub/emitter/kafka_emitter.py @@ -6,6 +6,7 @@ from confluent_kafka.schema_registry import SchemaRegistryClient from confluent_kafka.schema_registry.avro import AvroSerializer from confluent_kafka.serialization import SerializationContext, StringSerializer +from pydantic import field_validator from datahub.configuration.common import ConfigModel from datahub.configuration.kafka import KafkaProducerConnectionConfig @@ -49,7 +50,8 @@ class KafkaEmitterConfig(ConfigModel): }, ) - @pydantic.validator("topic_routes") + @field_validator("topic_routes", mode="after") + @classmethod def validate_topic_routes(cls, v: Dict[str, str]) -> Dict[str, str]: assert MCE_KEY in v, f"topic_routes must contain a route for {MCE_KEY}" assert MCP_KEY in v, f"topic_routes must contain a route for {MCP_KEY}" diff --git a/metadata-ingestion/src/datahub/emitter/rest_emitter.py b/metadata-ingestion/src/datahub/emitter/rest_emitter.py index a94878e6092210..ebea50ff972a37 100644 --- a/metadata-ingestion/src/datahub/emitter/rest_emitter.py +++ b/metadata-ingestion/src/datahub/emitter/rest_emitter.py @@ -145,8 +145,7 @@ class EmitMode(ConfigEnum): ASYNC_WAIT = auto() -_DEFAULT_EMIT_MODE = pydantic.parse_obj_as( - EmitMode, +_DEFAULT_EMIT_MODE = pydantic.TypeAdapter(EmitMode).validate_python( get_emit_mode() or EmitMode.SYNC_PRIMARY, ) @@ -156,8 +155,7 @@ class RestSinkEndpoint(ConfigEnum): OPENAPI = auto() -DEFAULT_REST_EMITTER_ENDPOINT = pydantic.parse_obj_as( - RestSinkEndpoint, +DEFAULT_REST_EMITTER_ENDPOINT = pydantic.TypeAdapter(RestSinkEndpoint).validate_python( get_rest_emitter_default_endpoint() or RestSinkEndpoint.RESTLI, ) diff --git a/metadata-ingestion/src/datahub/ingestion/api/decorators.py b/metadata-ingestion/src/datahub/ingestion/api/decorators.py index 4172ca47254f87..024c7d6035fe49 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/decorators.py +++ b/metadata-ingestion/src/datahub/ingestion/api/decorators.py @@ -17,7 +17,7 @@ def config_class(config_cls: Type) -> Callable[[Type], Type]: """Adds a get_config_class method to the decorated class""" def default_create(cls: Type, config_dict: Dict, ctx: PipelineContext) -> Type: - config = config_cls.parse_obj(config_dict) + config = config_cls.model_validate(config_dict) return cls(config=config, ctx=ctx) def wrapper(cls: Type) -> Type: diff --git a/metadata-ingestion/src/datahub/ingestion/api/report.py b/metadata-ingestion/src/datahub/ingestion/api/report.py index 5147dd8f4d8e6b..4ff7a05e14cd59 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/report.py +++ b/metadata-ingestion/src/datahub/ingestion/api/report.py @@ -65,7 +65,7 @@ def to_pure_python_obj(some_val: Any) -> Any: if isinstance(some_val, SupportsAsObj): return some_val.as_obj() elif isinstance(some_val, pydantic.BaseModel): - return Report.to_pure_python_obj(some_val.dict()) + return Report.to_pure_python_obj(some_val.model_dump()) elif dataclasses.is_dataclass(some_val) and not isinstance(some_val, type): # The `is_dataclass` function returns `True` for both instances and classes. # We need an extra check to ensure an instance was passed in. diff --git a/metadata-ingestion/src/datahub/ingestion/api/sink.py b/metadata-ingestion/src/datahub/ingestion/api/sink.py index 7b2668b4d0a50f..989dcafeb2e006 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/sink.py +++ b/metadata-ingestion/src/datahub/ingestion/api/sink.py @@ -123,7 +123,7 @@ def __post_init__(self) -> None: @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "Self": - return cls(ctx, cls.get_config_class().parse_obj(config_dict)) + return cls(ctx, cls.get_config_class().model_validate(config_dict)) def handle_work_unit_start(self, workunit: WorkUnit) -> None: """Called at the start of each new workunit. diff --git a/metadata-ingestion/src/datahub/ingestion/api/source.py b/metadata-ingestion/src/datahub/ingestion/api/source.py index 70367dea8ea813..447cf70bc0bb91 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/source.py +++ b/metadata-ingestion/src/datahub/ingestion/api/source.py @@ -480,7 +480,7 @@ def __init__(self, config_dict: dict, ctx: PipelineContext) -> None: config_class = self.get_config_class() self.ctx = ctx - self.config = config_class.parse_obj(config_dict) + self.config = config_class.model_validate(config_dict) @abstractmethod def get_records(self, workunit: WorkUnitType) -> Iterable[RecordEnvelope]: diff --git a/metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py b/metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py index ba03083854e785..dcda990cbe9936 100644 --- a/metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py +++ b/metadata-ingestion/src/datahub/ingestion/glossary/datahub_classifier.py @@ -3,7 +3,7 @@ from datahub_classify.helper_classes import ColumnInfo from datahub_classify.infotype_predictor import predict_infotypes from datahub_classify.reference_input import input1 as default_config -from pydantic import validator +from pydantic import field_validator from pydantic.fields import Field from datahub.configuration.common import ConfigModel @@ -90,7 +90,7 @@ class Config: DEFAULT_CLASSIFIER_CONFIG = { - k: InfoTypeConfig.parse_obj(v) for k, v in default_config.items() + k: InfoTypeConfig.model_validate(v) for k, v in default_config.items() } @@ -114,8 +114,11 @@ class DataHubClassifierConfig(ConfigModel): description="Minimum number of non-null column values required to process `values` prediction factor.", ) - @validator("info_types_config") - def input_config_selectively_overrides_default_config(cls, info_types_config): + @field_validator("info_types_config", mode="after") + @classmethod + def input_config_selectively_overrides_default_config( + cls, info_types_config: Dict[str, Any] + ) -> Dict[str, Any]: for infotype, infotype_config in DEFAULT_CLASSIFIER_CONFIG.items(): if infotype not in info_types_config: # if config for some info type is not provided by user, use default config for that info type. @@ -125,7 +128,7 @@ def input_config_selectively_overrides_default_config(cls, info_types_config): # use default config for that prediction factor. for factor, weight in ( info_types_config[infotype] - .Prediction_Factors_and_Weights.dict() + .Prediction_Factors_and_Weights.model_dump() .items() ): if ( @@ -146,7 +149,7 @@ def input_config_selectively_overrides_default_config(cls, info_types_config): for ( factor, weight, - ) in custom_infotype_config.Prediction_Factors_and_Weights.dict().items(): + ) in custom_infotype_config.Prediction_Factors_and_Weights.model_dump().items(): if weight > 0: assert getattr(custom_infotype_config, factor) is not None, ( f"Missing Configuration for Prediction Factor {factor} for Custom Info Type {custom_infotype}" @@ -173,7 +176,7 @@ def __init__(self, config: DataHubClassifierConfig): def create(cls, config_dict: Optional[Dict[str, Any]]) -> "DataHubClassifier": # This could be replaced by parsing to particular class, if required if config_dict is not None: - config = DataHubClassifierConfig.parse_obj(config_dict) + config = DataHubClassifierConfig.model_validate(config_dict) else: config = DataHubClassifierConfig() return cls(config) @@ -183,7 +186,7 @@ def classify(self, columns: List[ColumnInfo]) -> List[ColumnInfo]: column_infos=columns, confidence_level_threshold=self.config.confidence_level_threshold, global_config={ - k: v.dict() for k, v in self.config.info_types_config.items() + k: v.model_dump() for k, v in self.config.info_types_config.items() }, infotypes=self.config.info_types, minimum_values_threshold=self.config.minimum_values_threshold, diff --git a/metadata-ingestion/src/datahub/ingestion/reporting/datahub_ingestion_run_summary_provider.py b/metadata-ingestion/src/datahub/ingestion/reporting/datahub_ingestion_run_summary_provider.py index 071c6e38adbcd6..8621ec7e3b058b 100644 --- a/metadata-ingestion/src/datahub/ingestion/reporting/datahub_ingestion_run_summary_provider.py +++ b/metadata-ingestion/src/datahub/ingestion/reporting/datahub_ingestion_run_summary_provider.py @@ -82,7 +82,7 @@ def create( ctx: PipelineContext, sink: Sink, ) -> PipelineRunListener: - reporter_config = DatahubIngestionRunSummaryProviderConfig.parse_obj( + reporter_config = DatahubIngestionRunSummaryProviderConfig.model_validate( config_dict or {} ) if reporter_config.sink: diff --git a/metadata-ingestion/src/datahub/ingestion/reporting/file_reporter.py b/metadata-ingestion/src/datahub/ingestion/reporting/file_reporter.py index 40a95c01bdfc41..68a9c715c6f133 100644 --- a/metadata-ingestion/src/datahub/ingestion/reporting/file_reporter.py +++ b/metadata-ingestion/src/datahub/ingestion/reporting/file_reporter.py @@ -2,7 +2,7 @@ import logging from typing import Any, Dict -from pydantic import validator +from pydantic import field_validator from datahub.configuration.common import ConfigModel from datahub.ingestion.api.common import PipelineContext @@ -16,8 +16,9 @@ class FileReporterConfig(ConfigModel): filename: str format: str = "json" - @validator("format") - def only_json_supported(cls, v): + @field_validator("format", mode="after") + @classmethod + def only_json_supported(cls, v: str) -> str: if v and v.lower() != "json": raise ValueError( f"Format {v} is not yet supported. Only json is supported at this time" @@ -33,7 +34,7 @@ def create( ctx: PipelineContext, sink: Sink, ) -> PipelineRunListener: - reporter_config = FileReporterConfig.parse_obj(config_dict) + reporter_config = FileReporterConfig.model_validate(config_dict) return cls(reporter_config) def __init__(self, reporter_config: FileReporterConfig) -> None: diff --git a/metadata-ingestion/src/datahub/ingestion/run/pipeline.py b/metadata-ingestion/src/datahub/ingestion/run/pipeline.py index 46092053a4e672..5f908a95fde553 100644 --- a/metadata-ingestion/src/datahub/ingestion/run/pipeline.py +++ b/metadata-ingestion/src/datahub/ingestion/run/pipeline.py @@ -215,7 +215,7 @@ def __init__( sink_class = sink_registry.get(self.sink_type) with _add_init_error_context(f"configure the sink ({self.sink_type})"): - sink_config = self.config.sink.dict().get("config") or {} + sink_config = self.config.sink.model_dump().get("config") or {} self.sink = exit_stack.enter_context( sink_class.create(sink_config, self.ctx) ) @@ -245,7 +245,7 @@ def __init__( ): self.source = inner_exit_stack.enter_context( source_class.create( - self.config.source.dict().get("config", {}), self.ctx + self.config.source.model_dump().get("config", {}), self.ctx ) ) logger.debug( @@ -288,7 +288,7 @@ def _configure_transforms(self) -> None: for transformer in self.config.transformers: transformer_type = transformer.type transformer_class = transform_registry.get(transformer_type) - transformer_config = transformer.dict().get("config", {}) + transformer_config = transformer.model_dump().get("config", {}) self.transformers.append( transformer_class.create(transformer_config, self.ctx) ) @@ -310,12 +310,12 @@ def _configure_reporting(self, report_to: Optional[str]) -> None: reporter.type for reporter in self.config.reporting ]: self.config.reporting.append( - ReporterConfig.parse_obj({"type": "datahub"}) + ReporterConfig.model_validate({"type": "datahub"}) ) elif report_to: # we assume this is a file name, and add the file reporter self.config.reporting.append( - ReporterConfig.parse_obj( + ReporterConfig.model_validate( {"type": "file", "config": {"filename": report_to}} ) ) @@ -323,7 +323,7 @@ def _configure_reporting(self, report_to: Optional[str]) -> None: for reporter in self.config.reporting: reporter_type = reporter.type reporter_class = reporting_provider_registry.get(reporter_type) - reporter_config_dict = reporter.dict().get("config", {}) + reporter_config_dict = reporter.model_dump().get("config", {}) try: self.reporters.append( reporter_class.create( diff --git a/metadata-ingestion/src/datahub/ingestion/run/pipeline_config.py b/metadata-ingestion/src/datahub/ingestion/run/pipeline_config.py index 682ae1932beee3..034d4d30a85b44 100644 --- a/metadata-ingestion/src/datahub/ingestion/run/pipeline_config.py +++ b/metadata-ingestion/src/datahub/ingestion/run/pipeline_config.py @@ -2,9 +2,9 @@ import logging import random import string -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional -from pydantic import Field, validator +from pydantic import Field, model_validator from datahub.configuration.common import ConfigModel, DynamicTypedConfig, HiddenFromDocs from datahub.ingestion.graph.config import DatahubClientConfig @@ -96,30 +96,28 @@ class PipelineConfig(ConfigModel): None # the raw dict that was parsed to construct this config ) - @validator("run_id", pre=True, always=True) - def run_id_should_be_semantic( - cls, v: Optional[str], values: Dict[str, Any], **kwargs: Any - ) -> str: - if v == DEFAULT_RUN_ID: + @model_validator(mode="after") + def run_id_should_be_semantic(self) -> "PipelineConfig": + if self.run_id == DEFAULT_RUN_ID: source_type = None - if "source" in values and hasattr(values["source"], "type"): - source_type = values["source"].type + if hasattr(self.source, "type"): + source_type = self.source.type - return _generate_run_id(source_type) + self.run_id = _generate_run_id(source_type) else: - assert v is not None - return v + assert self.run_id is not None + return self @classmethod def from_dict( cls, resolved_dict: dict, raw_dict: Optional[dict] = None ) -> "PipelineConfig": - config = cls.parse_obj(resolved_dict) + config = cls.model_validate(resolved_dict) config._raw_dict = raw_dict return config def get_raw_dict(self) -> Dict: result = self._raw_dict if result is None: - result = self.dict() + result = self.model_dump() return result diff --git a/metadata-ingestion/src/datahub/ingestion/run/sink_callback.py b/metadata-ingestion/src/datahub/ingestion/run/sink_callback.py index 3d02c64b7a77c1..8a81cfb5a3ac55 100644 --- a/metadata-ingestion/src/datahub/ingestion/run/sink_callback.py +++ b/metadata-ingestion/src/datahub/ingestion/run/sink_callback.py @@ -39,7 +39,7 @@ def on_failure( class DeadLetterQueueCallback(WriteCallback, Closeable): def __init__(self, ctx: PipelineContext, config: Optional[FileSinkConfig]) -> None: if not config: - config = FileSinkConfig.parse_obj({"filename": "failed_events.json"}) + config = FileSinkConfig.model_validate({"filename": "failed_events.json"}) self.file_sink: FileSink = FileSink(ctx, config) self.file_sink_lock = threading.Lock() self.logging_callback = LoggingCallback(name="failure-queue") diff --git a/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py b/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py index 7e2ec94fa4106b..a0a49529929dcc 100644 --- a/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py +++ b/metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py @@ -9,6 +9,7 @@ from typing import List, Optional, Tuple, Union import pydantic +from pydantic import field_validator from datahub.configuration.common import ( ConfigEnum, @@ -63,8 +64,8 @@ class RestSinkMode(ConfigEnum): ASYNC_BATCH = auto() -_DEFAULT_REST_SINK_MODE = pydantic.parse_obj_as( - RestSinkMode, get_rest_sink_default_mode() or RestSinkMode.ASYNC_BATCH +_DEFAULT_REST_SINK_MODE = pydantic.TypeAdapter(RestSinkMode).validate_python( + get_rest_sink_default_mode() or RestSinkMode.ASYNC_BATCH ) @@ -80,8 +81,9 @@ class DatahubRestSinkConfig(DatahubClientConfig): # Only applies in async batch mode. max_per_batch: pydantic.PositiveInt = 100 - @pydantic.validator("max_per_batch", always=True) - def validate_max_per_batch(cls, v): + @field_validator("max_per_batch", mode="before") + @classmethod + def validate_max_per_batch(cls, v: int) -> int: if v > BATCH_INGEST_MAX_PAYLOAD_LENGTH: raise ValueError( f"max_per_batch must be less than or equal to {BATCH_INGEST_MAX_PAYLOAD_LENGTH}" diff --git a/metadata-ingestion/src/datahub/ingestion/source/abs/config.py b/metadata-ingestion/src/datahub/ingestion/source/abs/config.py index 0df1644ddcffa2..41cef8e1879813 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/abs/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/abs/config.py @@ -1,7 +1,7 @@ import logging from typing import Any, Dict, List, Optional, Union -import pydantic +from pydantic import ValidationInfo, field_validator, model_validator from pydantic.fields import Field from datahub.configuration.common import AllowDenyPattern @@ -105,9 +105,10 @@ def is_profiling_enabled(self) -> bool: self.profiling.operation_config ) - @pydantic.validator("path_specs", always=True) + @field_validator("path_specs", mode="before") + @classmethod def check_path_specs_and_infer_platform( - cls, path_specs: List[PathSpec], values: Dict + cls, path_specs: List[PathSpec], info: ValidationInfo ) -> List[PathSpec]: if len(path_specs) == 0: raise ValueError("path_specs must not be empty") @@ -124,38 +125,37 @@ def check_path_specs_and_infer_platform( # Ensure abs configs aren't used for file sources. if guessed_platform != "abs" and ( - values.get("use_abs_container_properties") - or values.get("use_abs_blob_tags") - or values.get("use_abs_blob_properties") + info.data.get("use_abs_container_properties") + or info.data.get("use_abs_blob_tags") + or info.data.get("use_abs_blob_properties") ): raise ValueError( "Cannot grab abs blob/container tags when platform is not abs. Remove the flag or use abs." ) # Infer platform if not specified. - if values.get("platform") and values["platform"] != guessed_platform: + if info.data.get("platform") and info.data["platform"] != guessed_platform: raise ValueError( - f"All path_specs belong to {guessed_platform} platform, but platform is set to {values['platform']}" + f"All path_specs belong to {guessed_platform} platform, but platform is set to {info.data['platform']}" ) else: logger.debug(f'Setting config "platform": {guessed_platform}') - values["platform"] = guessed_platform + info.data["platform"] = guessed_platform return path_specs - @pydantic.validator("platform", always=True) - def platform_not_empty(cls, platform: Any, values: dict) -> str: - inferred_platform = values.get("platform") # we may have inferred it above + @field_validator("platform", mode="before") + @classmethod + def platform_not_empty(cls, platform: Any, info: ValidationInfo) -> str: + inferred_platform = info.data.get("platform") # we may have inferred it above platform = platform or inferred_platform if not platform: raise ValueError("platform must not be empty") return platform - @pydantic.root_validator(skip_on_failure=True) - def ensure_profiling_pattern_is_passed_to_profiling( - cls, values: Dict[str, Any] - ) -> Dict[str, Any]: - profiling: Optional[DataLakeProfilerConfig] = values.get("profiling") + @model_validator(mode="after") + def ensure_profiling_pattern_is_passed_to_profiling(self) -> "DataLakeSourceConfig": + profiling = self.profiling if profiling is not None and profiling.enabled: - profiling._allow_deny_patterns = values["profile_patterns"] - return values + profiling._allow_deny_patterns = self.profile_patterns + return self diff --git a/metadata-ingestion/src/datahub/ingestion/source/abs/datalake_profiler_config.py b/metadata-ingestion/src/datahub/ingestion/source/abs/datalake_profiler_config.py index 58e930eb6e809c..b1f050b51d25c1 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/abs/datalake_profiler_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/abs/datalake_profiler_config.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, Optional +from typing import Optional import pydantic +from pydantic import model_validator from pydantic.fields import Field from datahub.configuration import ConfigModel @@ -72,21 +73,18 @@ class DataLakeProfilerConfig(ConfigModel): description="Whether to profile for the sample values for all columns.", ) - @pydantic.root_validator(skip_on_failure=True) - def ensure_field_level_settings_are_normalized( - cls: "DataLakeProfilerConfig", values: Dict[str, Any] - ) -> Dict[str, Any]: - max_num_fields_to_profile_key = "max_number_of_fields_to_profile" - max_num_fields_to_profile = values.get(max_num_fields_to_profile_key) + @model_validator(mode="after") + def ensure_field_level_settings_are_normalized(self) -> "DataLakeProfilerConfig": + max_num_fields_to_profile = self.max_number_of_fields_to_profile # Disable all field-level metrics. - if values.get("profile_table_level_only"): - for field_level_metric in cls.__fields__: - if field_level_metric.startswith("include_field_"): - values.setdefault(field_level_metric, False) + if self.profile_table_level_only: + for field_name in self.__fields__: + if field_name.startswith("include_field_"): + setattr(self, field_name, False) assert max_num_fields_to_profile is None, ( - f"{max_num_fields_to_profile_key} should be set to None" + "max_number_of_fields_to_profile should be set to None" ) - return values + return self diff --git a/metadata-ingestion/src/datahub/ingestion/source/abs/source.py b/metadata-ingestion/src/datahub/ingestion/source/abs/source.py index b78ae5282bb313..04b6d4ba9129bb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/abs/source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/abs/source.py @@ -149,7 +149,7 @@ def __init__(self, config: DataLakeSourceConfig, ctx: PipelineContext): self.report = DataLakeSourceReport() self.profiling_times_taken = [] config_report = { - config_option: config.dict().get(config_option) + config_option: config.model_dump().get(config_option) for config_option in config_options_to_report } config_report = { @@ -164,7 +164,7 @@ def __init__(self, config: DataLakeSourceConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict, ctx): - config = DataLakeSourceConfig.parse_obj(config_dict) + config = DataLakeSourceConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py index 00ef03f4d0e725..c57fa7980de45f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py @@ -246,7 +246,7 @@ def assume_role( **dict( RoleSessionName="DatahubIngestionSource", ), - **{k: v for k, v in role.dict().items() if v is not None}, + **{k: v for k, v in role.model_dump().items() if v is not None}, } assumed_role_object = sts_client.assume_role( diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/glue.py b/metadata-ingestion/src/datahub/ingestion/source/aws/glue.py index 697f2c5c721ff3..5ec5797b48da54 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/glue.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/glue.py @@ -21,7 +21,7 @@ import botocore.exceptions import yaml -from pydantic import validator +from pydantic import field_validator from pydantic.fields import Field from datahub.api.entities.dataset.dataset import Dataset @@ -221,7 +221,8 @@ def s3_client(self): def lakeformation_client(self): return self.get_lakeformation_client() - @validator("glue_s3_lineage_direction") + @field_validator("glue_s3_lineage_direction", mode="after") + @classmethod def check_direction(cls, v: str) -> str: if v.lower() not in ["upstream", "downstream"]: raise ValueError( @@ -229,7 +230,8 @@ def check_direction(cls, v: str) -> str: ) return v.lower() - @validator("platform") + @field_validator("platform", mode="after") + @classmethod def platform_validator(cls, v: str) -> str: if not v or v in VALID_PLATFORMS: return v @@ -473,7 +475,7 @@ def get_glue_arn( @classmethod def create(cls, config_dict, ctx): - config = GlueSourceConfig.parse_obj(config_dict) + config = GlueSourceConfig.model_validate(config_dict) return cls(config, ctx) @property diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py index 55b8f4d889072d..1c7885b165a214 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py @@ -66,7 +66,7 @@ def __init__(self, config: SagemakerSourceConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict, ctx): - config = SagemakerSourceConfig.parse_obj(config_dict) + config = SagemakerSourceConfig.model_validate(config_dict) return cls(config, ctx) def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/azure/azure_common.py b/metadata-ingestion/src/datahub/ingestion/source/azure/azure_common.py index 24d4e6526e1afc..8a1d75cfb438b3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/azure/azure_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/azure/azure_common.py @@ -1,9 +1,9 @@ -from typing import Dict, Optional, Union +from typing import Optional, Union from azure.identity import ClientSecretCredential from azure.storage.blob import BlobServiceClient from azure.storage.filedatalake import DataLakeServiceClient, FileSystemClient -from pydantic import Field, root_validator +from pydantic import Field, model_validator from datahub.configuration import ConfigModel from datahub.configuration.common import ConfigurationError @@ -81,18 +81,14 @@ def get_credentials( ) return self.sas_token if self.sas_token is not None else self.account_key - @root_validator(skip_on_failure=True) - def _check_credential_values(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def _check_credential_values(self) -> "AzureConnectionConfig": if ( - values.get("account_key") - or values.get("sas_token") - or ( - values.get("client_id") - and values.get("client_secret") - and values.get("tenant_id") - ) + self.account_key + or self.sas_token + or (self.client_id and self.client_secret and self.tenant_id) ): - return values + return self raise ConfigurationError( "credentials missing, requires one combination of account_key or sas_token or (client_id and client_secret and tenant_id)" ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py index d3b94d3808240f..5fd251b7087260 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py @@ -211,7 +211,7 @@ def __init__(self, ctx: PipelineContext, config: BigQueryV2Config): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "BigqueryV2Source": - config = BigQueryV2Config.parse_obj(config_dict) + config = BigQueryV2Config.model_validate(config_dict) return cls(ctx, config) @staticmethod diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_config.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_config.py index 314076b8ad6207..e13cfa0f44f3b6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_config.py @@ -2,9 +2,16 @@ import re from copy import deepcopy from datetime import timedelta -from typing import Dict, List, Optional, Union - -from pydantic import Field, PositiveInt, PrivateAttr, root_validator, validator +from typing import Any, Dict, List, Optional, Union + +from pydantic import ( + Field, + PositiveInt, + PrivateAttr, + ValidationInfo, + field_validator, + model_validator, +) from datahub.configuration.common import AllowDenyPattern, ConfigModel, HiddenFromDocs from datahub.configuration.env_vars import get_bigquery_schema_parallelism @@ -63,8 +70,9 @@ class BigQueryBaseConfig(ConfigModel): description="The regex pattern to match sharded tables and group as one table. This is a very low level config parameter, only change if you know what you are doing, ", ) - @validator("sharded_table_pattern") - def sharded_table_pattern_is_a_valid_regexp(cls, v): + @field_validator("sharded_table_pattern", mode="after") + @classmethod + def sharded_table_pattern_is_a_valid_regexp(cls, v: str) -> str: try: re.compile(v) except Exception as e: @@ -73,7 +81,8 @@ def sharded_table_pattern_is_a_valid_regexp(cls, v): ) from e return v - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def project_id_backward_compatibility_configs_set(cls, values: Dict) -> Dict: # Create a copy to avoid modifying the input dictionary, preventing state contamination in tests values = deepcopy(values) @@ -188,12 +197,11 @@ class BigQueryFilterConfig(SQLFilterConfig): default=AllowDenyPattern.allow_all(), ) - @root_validator(pre=False, skip_on_failure=True) - def backward_compatibility_configs_set(cls, values: Dict) -> Dict: - # Create a copy to avoid modifying the input dictionary, preventing state contamination in tests - values = deepcopy(values) - dataset_pattern: Optional[AllowDenyPattern] = values.get("dataset_pattern") - schema_pattern = values.get("schema_pattern") + @model_validator(mode="after") + def backward_compatibility_configs_set(self) -> Any: + dataset_pattern = self.dataset_pattern + schema_pattern = self.schema_pattern + if ( dataset_pattern == AllowDenyPattern.allow_all() and schema_pattern != AllowDenyPattern.allow_all() @@ -202,7 +210,7 @@ def backward_compatibility_configs_set(cls, values: Dict) -> Dict: "dataset_pattern is not set but schema_pattern is set, using schema_pattern as dataset_pattern. " "schema_pattern will be deprecated, please use dataset_pattern instead." ) - values["dataset_pattern"] = schema_pattern + self.dataset_pattern = schema_pattern dataset_pattern = schema_pattern elif ( dataset_pattern != AllowDenyPattern.allow_all() @@ -213,7 +221,7 @@ def backward_compatibility_configs_set(cls, values: Dict) -> Dict: " please use dataset_pattern only." ) - match_fully_qualified_names = values.get("match_fully_qualified_names") + match_fully_qualified_names = self.match_fully_qualified_names if ( dataset_pattern is not None @@ -243,7 +251,7 @@ def backward_compatibility_configs_set(cls, values: Dict) -> Dict: " of the form `.`." ) - return values + return self class BigQueryIdentifierConfig( @@ -478,7 +486,8 @@ def have_table_data_read_permission(self) -> bool: _include_view_column_lineage = pydantic_removed_field("include_view_column_lineage") _lineage_parse_view_ddl = pydantic_removed_field("lineage_parse_view_ddl") - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def set_include_schema_metadata(cls, values: Dict) -> Dict: # Create a copy to avoid modifying the input dictionary, preventing state contamination in tests values = deepcopy(values) @@ -498,30 +507,33 @@ def set_include_schema_metadata(cls, values: Dict) -> Dict: return values - @root_validator(skip_on_failure=True) + @model_validator(mode="before") + @classmethod def profile_default_settings(cls, values: Dict) -> Dict: # Create a copy to avoid modifying the input dictionary, preventing state contamination in tests values = deepcopy(values) # Extra default SQLAlchemy option for better connection pooling and threading. # https://docs.sqlalchemy.org/en/14/core/pooling.html#sqlalchemy.pool.QueuePool.params.max_overflow - values["options"].setdefault("max_overflow", -1) + values.setdefault("options", {}).setdefault("max_overflow", -1) return values - @validator("bigquery_audit_metadata_datasets") + @field_validator("bigquery_audit_metadata_datasets", mode="after") + @classmethod def validate_bigquery_audit_metadata_datasets( - cls, v: Optional[List[str]], values: Dict + cls, v: Optional[List[str]], info: ValidationInfo ) -> Optional[List[str]]: - if values.get("use_exported_bigquery_audit_metadata"): + if info.data.get("use_exported_bigquery_audit_metadata"): assert v and len(v) > 0, ( "`bigquery_audit_metadata_datasets` should be set if using `use_exported_bigquery_audit_metadata: True`." ) return v - @validator("upstream_lineage_in_report") - def validate_upstream_lineage_in_report(cls, v: bool, values: Dict) -> bool: - if v and values.get("use_queries_v2", True): + @field_validator("upstream_lineage_in_report", mode="after") + @classmethod + def validate_upstream_lineage_in_report(cls, v: bool, info: ValidationInfo) -> bool: + if v and info.data.get("use_queries_v2", True): logging.warning( "`upstream_lineage_in_report` is enabled but will be ignored because `use_queries_v2` is enabled." "This debugging feature only works with the legacy lineage approach (`use_queries_v2: false`)." @@ -529,11 +541,12 @@ def validate_upstream_lineage_in_report(cls, v: bool, values: Dict) -> bool: return v - @root_validator(pre=False, skip_on_failure=True) - def validate_queries_v2_stateful_ingestion(cls, values: Dict) -> Dict: - if values.get("use_queries_v2"): - if values.get("enable_stateful_lineage_ingestion") or values.get( - "enable_stateful_usage_ingestion" + @model_validator(mode="after") + def validate_queries_v2_stateful_ingestion(self) -> "BigQueryV2Config": + if self.use_queries_v2: + if ( + self.enable_stateful_lineage_ingestion + or self.enable_stateful_usage_ingestion ): logger.warning( "enable_stateful_lineage_ingestion and enable_stateful_usage_ingestion are deprecated " @@ -541,7 +554,7 @@ def validate_queries_v2_stateful_ingestion(cls, values: Dict) -> Dict: "For queries v2, use enable_stateful_time_window instead to enable stateful ingestion " "for the unified time window extraction (lineage + usage + operations + queries)." ) - return values + return self def get_table_pattern(self, pattern: List[str]) -> str: return "|".join(pattern) if pattern else "" diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_queries.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_queries.py index 09a58a40680db8..616c4a69ea3bb0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_queries.py @@ -80,7 +80,7 @@ def __init__(self, ctx: PipelineContext, config: BigQueryQueriesSourceConfig): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> Self: - config = BigQueryQueriesSourceConfig.parse_obj(config_dict) + config = BigQueryQueriesSourceConfig.model_validate(config_dict) return cls(ctx, config) def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py index d79cd83e13c121..d6adcd312044e4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py +++ b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py @@ -109,7 +109,7 @@ def __init__(self, ctx: PipelineContext, config: CassandraSourceConfig): @classmethod def create(cls, config_dict, ctx): - config = CassandraSourceConfig.parse_obj(config_dict) + config = CassandraSourceConfig.model_validate(config_dict) return cls(ctx, config) def get_platform(self) -> str: diff --git a/metadata-ingestion/src/datahub/ingestion/source/common/gcp_credentials_config.py b/metadata-ingestion/src/datahub/ingestion/source/common/gcp_credentials_config.py index 59d71ba276322e..f35508d6ea8064 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/common/gcp_credentials_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/common/gcp_credentials_config.py @@ -1,8 +1,8 @@ import json import tempfile -from typing import Any, Dict, Optional +from typing import Dict, Optional -from pydantic import Field, root_validator +from pydantic import Field, model_validator from datahub.configuration import ConfigModel from datahub.configuration.validate_multiline_string import pydantic_multiline_string @@ -37,16 +37,16 @@ class GCPCredential(ConfigModel): _fix_private_key_newlines = pydantic_multiline_string("private_key") - @root_validator(skip_on_failure=True) - def validate_config(cls, values: Dict[str, Any]) -> Dict[str, Any]: - if values.get("client_x509_cert_url") is None: - values["client_x509_cert_url"] = ( - f"https://www.googleapis.com/robot/v1/metadata/x509/{values['client_email']}" + @model_validator(mode="after") + def validate_config(self) -> "GCPCredential": + if self.client_x509_cert_url is None: + self.client_x509_cert_url = ( + f"https://www.googleapis.com/robot/v1/metadata/x509/{self.client_email}" ) - return values + return self def create_credential_temp_file(self, project_id: Optional[str] = None) -> str: - configs = self.dict() + configs = self.model_dump() if project_id: configs["project_id"] = project_id with tempfile.NamedTemporaryFile(delete=False) as fp: @@ -55,7 +55,7 @@ def create_credential_temp_file(self, project_id: Optional[str] = None) -> str: return fp.name def to_dict(self, project_id: Optional[str] = None) -> Dict[str, str]: - configs = self.dict() + configs = self.model_dump() if project_id: configs["project_id"] = project_id return configs diff --git a/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/path_spec.py b/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/path_spec.py index 9fe5c8e4a3df21..5809887e15041f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/path_spec.py +++ b/metadata-ingestion/src/datahub/ingestion/source/data_lake_common/path_spec.py @@ -3,11 +3,11 @@ import os import re from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import parse -import pydantic from cached_property import cached_property +from pydantic import field_validator, model_validator from pydantic.fields import Field from wcmatch import pathlib @@ -65,7 +65,8 @@ class SortKey(ConfigModel): description="The date format to use when sorting. This is used to parse the date from the key. The format should follow the java [SimpleDateFormat](https://docs.oracle.com/javase/8/docs/api/java/text/SimpleDateFormat.html) format.", ) - @pydantic.validator("date_format", always=True) + @field_validator("date_format", mode="before") + @classmethod def convert_date_format_to_python_format(cls, v: Optional[str]) -> Optional[str]: if v is None: return None @@ -86,7 +87,7 @@ class Config: arbitrary_types_allowed = True include: str = Field( - description="Path to table. Name variable `{table}` is used to mark the folder with dataset. In absence of `{table}`, file level dataset will be created. Check below examples for more details." + description="Path to table. Name variable `{table}` is used to mark the folder with dataset. In absence of `{table}`, file level dataset will be created. Check below examples for more details.", ) exclude: Optional[List[str]] = Field( [], @@ -260,20 +261,80 @@ def get_folder_named_vars( ) -> Union[None, parse.Result, parse.Match]: return self.compiled_folder_include.parse(path) - @pydantic.root_validator(skip_on_failure=True) - def validate_no_double_stars(cls, values: Dict) -> Dict: - if "include" not in values: - return values + @model_validator(mode="after") + def validate_path_spec_comprehensive(self): + """ + Comprehensive model validator that handles multiple interdependent validations. + + Consolidates related validation logic to avoid order dependencies between multiple + model validators and ensures reliable cross-field validation. This approach is + preferred over multiple separate validators when: + + 1. Validations depend on multiple fields (e.g., sample_files depends on include) + 2. One validation modifies a field that another validation checks + 3. Field validators can't reliably access other field values or defaults + 4. Order of execution between validators is important but undefined + + By combining related validations, we ensure they execute in the correct sequence + and have access to all field values after Pydantic has processed defaults. + """ + # Handle autodetect_partitions logic first + if self.autodetect_partitions: + include = self.include + if include.endswith("/"): + include = include[:-1] + if include.endswith("{table}"): + self.include = include + "/**" + # Allow double stars when we add them for autodetect_partitions + self.allow_double_stars = True + + # Handle table_name logic + if self.table_name is None and "{table}" in self.include: + self.table_name = "{table}" + elif self.table_name is not None: + parsable_include = PathSpec.get_parsable_include(self.include) + compiled_include = parse.compile(parsable_include) + if not all( + x in compiled_include.named_fields + for x in parse.compile(self.table_name).named_fields + ): + raise ValueError( + f"Not all named variables used in path_spec.table_name {self.table_name} are specified in path_spec.include {self.include}" + ) + + # Handle sample_files logic - turn off sampling for non-cloud URIs + is_s3 = is_s3_uri(self.include) + is_gcs = is_gcs_uri(self.include) + is_abs = is_abs_uri(self.include) + if not is_s3 and not is_gcs and not is_abs: + # Sampling only makes sense on s3 and gcs currently + self.sample_files = False + + # Validate double stars + if "**" in self.include and not self.allow_double_stars: + raise ValueError("path_spec.include cannot contain '**'") + # Validate file extension + include_ext = os.path.splitext(self.include)[1].strip(".") + if not include_ext: + include_ext = ( + "*" # if no extension is provided, we assume all files are allowed + ) if ( - values.get("include") - and "**" in values["include"] - and not values.get("allow_double_stars") + include_ext not in self.file_types + and include_ext not in ["*", ""] + and not self.default_extension + and include_ext not in SUPPORTED_COMPRESSIONS ): - raise ValueError("path_spec.include cannot contain '**'") - return values + raise ValueError( + f"file type specified ({include_ext}) in path_spec.include is not in specified file " + f'types. Please select one from {self.file_types} or specify ".*" to allow all types' + ) - @pydantic.validator("file_types", always=True) + return self + + @field_validator("file_types", mode="before") + @classmethod def validate_file_types(cls, v: Optional[List[str]]) -> List[str]: if v is None: return SUPPORTED_FILE_TYPES @@ -285,50 +346,24 @@ def validate_file_types(cls, v: Optional[List[str]]) -> List[str]: ) return v - @pydantic.validator("default_extension") - def validate_default_extension(cls, v): + @field_validator("default_extension", mode="after") + @classmethod + def validate_default_extension(cls, v: Optional[str]) -> Optional[str]: if v is not None and v not in SUPPORTED_FILE_TYPES: raise ValueError( f"default extension {v} not in supported default file extension. Please specify one from {SUPPORTED_FILE_TYPES}" ) return v - @pydantic.validator("sample_files", always=True) - def turn_off_sampling_for_non_s3(cls, v, values): - is_s3 = is_s3_uri(values.get("include") or "") - is_gcs = is_gcs_uri(values.get("include") or "") - is_abs = is_abs_uri(values.get("include") or "") - if not is_s3 and not is_gcs and not is_abs: - # Sampling only makes sense on s3 and gcs currently - v = False - return v - - @pydantic.validator("exclude", each_item=True) - def no_named_fields_in_exclude(cls, v: str) -> str: - if len(parse.compile(v).named_fields) != 0: - raise ValueError( - f"path_spec.exclude {v} should not contain any named variables" - ) - return v - - @pydantic.validator("table_name", always=True) - def table_name_in_include(cls, v, values): - if "include" not in values: - return v - - parsable_include = PathSpec.get_parsable_include(values["include"]) - compiled_include = parse.compile(parsable_include) - + @field_validator("exclude", mode="after") + @classmethod + def no_named_fields_in_exclude(cls, v: Optional[List[str]]) -> Optional[List[str]]: if v is None: - if "{table}" in values["include"]: - v = "{table}" - else: - if not all( - x in compiled_include.named_fields - for x in parse.compile(v).named_fields - ): + return v + for item in v: + if len(parse.compile(item).named_fields) != 0: raise ValueError( - f"Not all named variables used in path_spec.table_name {v} are specified in path_spec.include {values['include']}" + f"path_spec.exclude {item} should not contain any named variables" ) return v @@ -479,45 +514,6 @@ def glob_include(self): logger.debug(f"Setting _glob_include: {glob_include}") return glob_include - @pydantic.root_validator(skip_on_failure=True) - @staticmethod - def validate_path_spec(values: Dict) -> Dict[str, Any]: - # validate that main fields are populated - required_fields = ["include", "file_types", "default_extension"] - for f in required_fields: - if f not in values: - logger.debug( - f"Failed to validate because {f} wasn't populated correctly" - ) - return values - - if values["include"] and values["autodetect_partitions"]: - include = values["include"] - if include.endswith("/"): - include = include[:-1] - - if include.endswith("{table}"): - values["include"] = include + "/**" - - include_ext = os.path.splitext(values["include"])[1].strip(".") - if not include_ext: - include_ext = ( - "*" # if no extension is provided, we assume all files are allowed - ) - - if ( - include_ext not in values["file_types"] - and include_ext not in ["*", ""] - and not values["default_extension"] - and include_ext not in SUPPORTED_COMPRESSIONS - ): - raise ValueError( - f"file type specified ({include_ext}) in path_spec.include is not in specified file " - f'types. Please select one from {values.get("file_types")} or specify ".*" to allow all types' - ) - - return values - def _extract_table_name(self, named_vars: dict) -> str: if self.table_name is None: raise ValueError("path_spec.table_name is not set") diff --git a/metadata-ingestion/src/datahub/ingestion/source/datahub/config.py b/metadata-ingestion/src/datahub/ingestion/source/datahub/config.py index e32e57f2fa959e..670086a03456dd 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/datahub/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/datahub/config.py @@ -2,7 +2,7 @@ from typing import Optional, Set import pydantic -from pydantic import Field, root_validator +from pydantic import Field, model_validator from datahub.configuration.common import AllowDenyPattern, HiddenFromDocs from datahub.configuration.kafka import KafkaConsumerConnectionConfig @@ -132,20 +132,20 @@ class DataHubSourceConfig(StatefulIngestionConfigBase): default=True, description="Copy system metadata from the source system" ) - @root_validator(skip_on_failure=True) - def check_ingesting_data(cls, values): + @model_validator(mode="after") + def check_ingesting_data(self): if ( - not values.get("database_connection") - and not values.get("kafka_connection") - and not values.get("pull_from_datahub_api") + not self.database_connection + and not self.kafka_connection + and not self.pull_from_datahub_api ): raise ValueError( "Your current config will not ingest any data." " Please specify at least one of `database_connection` or `kafka_connection`, ideally both." ) - return values + return self - @pydantic.validator("database_connection") + @pydantic.field_validator("database_connection") def validate_mysql_scheme( cls, v: SQLAlchemyConnectionConfig ) -> SQLAlchemyConnectionConfig: diff --git a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_source.py b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_source.py index 1325341d67c0d3..a6200125b68af9 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/datahub/datahub_source.py @@ -62,7 +62,7 @@ def __init__(self, config: DataHubSourceConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: Dict, ctx: PipelineContext) -> "DataHubSource": - config: DataHubSourceConfig = DataHubSourceConfig.parse_obj(config_dict) + config: DataHubSourceConfig = DataHubSourceConfig.model_validate(config_dict) return cls(config, ctx) def get_report(self) -> SourceReport: diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py index 00484b8079859c..6055d1c76e66bc 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_cloud.py @@ -1,4 +1,5 @@ import logging +from copy import deepcopy from datetime import datetime from json import JSONDecodeError from typing import Dict, List, Literal, Optional, Tuple @@ -6,7 +7,7 @@ import dateutil.parser import requests -from pydantic import Field, root_validator +from pydantic import Field, model_validator from datahub.ingestion.api.decorators import ( SourceCapability, @@ -68,8 +69,13 @@ class DBTCloudConfig(DBTCommonConfig): description='Where should the "View in dbt" link point to - either the "Explore" UI or the dbt Cloud IDE', ) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def set_metadata_endpoint(cls, values: dict) -> dict: + # In-place update of the input dict would cause state contamination. + # So a deepcopy is performed first. + values = deepcopy(values) + if values.get("access_url") and not values.get("metadata_endpoint"): metadata_endpoint = infer_metadata_endpoint(values["access_url"]) if metadata_endpoint is None: @@ -271,7 +277,7 @@ class DBTCloudSource(DBTSourceBase, TestableSource): @classmethod def create(cls, config_dict, ctx): - config = DBTCloudConfig.parse_obj(config_dict) + config = DBTCloudConfig.model_validate(config_dict) return cls(config, ctx) @staticmethod diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py index 069836bac15f5b..62be321b8d4f57 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py @@ -1,6 +1,7 @@ import logging import re from abc import abstractmethod +from copy import deepcopy from dataclasses import dataclass, field from datetime import datetime from enum import auto @@ -8,7 +9,7 @@ import more_itertools import pydantic -from pydantic import root_validator, validator +from pydantic import field_validator, model_validator from pydantic.fields import Field from datahub.api.entities.dataprocess.dataprocess_instance import ( @@ -194,22 +195,26 @@ class DBTEntitiesEnabled(ConfigModel): "Only supported with dbt core.", ) - @root_validator(skip_on_failure=True) - def process_only_directive(cls, values): + @model_validator(mode="after") + def process_only_directive(self) -> "DBTEntitiesEnabled": # Checks that at most one is set to ONLY, and then sets the others to NO. - - only_values = [k for k in values if values.get(k) == EmitDirective.ONLY] + only_values = [ + k for k, v in self.model_dump().items() if v == EmitDirective.ONLY + ] if len(only_values) > 1: raise ValueError( f"Cannot have more than 1 type of entity emission set to ONLY. Found {only_values}" ) if len(only_values) == 1: - for k in values: - values[k] = EmitDirective.NO - values[only_values[0]] = EmitDirective.YES + # Set all fields to NO first + for field_name in self.model_dump(): + setattr(self, field_name, EmitDirective.NO) - return values + # Set the ONLY one to YES + setattr(self, only_values[0], EmitDirective.YES) + + return self def _node_type_allow_map(self): # Node type comes from dbt's node types. @@ -412,7 +417,8 @@ class DBTCommonConfig( "This ensures that lineage is generated reliably, but will lose any documentation associated only with the source.", ) - @validator("target_platform") + @field_validator("target_platform", mode="after") + @classmethod def validate_target_platform_value(cls, target_platform: str) -> str: if target_platform.lower() == DBT_PLATFORM: raise ValueError( @@ -421,15 +427,21 @@ def validate_target_platform_value(cls, target_platform: str) -> str: ) return target_platform - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def set_convert_column_urns_to_lowercase_default_for_snowflake( cls, values: dict ) -> dict: + # In-place update of the input dict would cause state contamination. + # So a deepcopy is performed first. + values = deepcopy(values) + if values.get("target_platform", "").lower() == "snowflake": values.setdefault("convert_column_urns_to_lowercase", True) return values - @validator("write_semantics") + @field_validator("write_semantics", mode="after") + @classmethod def validate_write_semantics(cls, write_semantics: str) -> str: if write_semantics.lower() not in {"patch", "override"}: raise ValueError( @@ -439,10 +451,9 @@ def validate_write_semantics(cls, write_semantics: str) -> str: ) return write_semantics - @validator("meta_mapping") - def meta_mapping_validator( - cls, meta_mapping: Dict[str, Any], values: Dict, **kwargs: Any - ) -> Dict[str, Any]: + @field_validator("meta_mapping", mode="after") + @classmethod + def meta_mapping_validator(cls, meta_mapping: Dict[str, Any]) -> Dict[str, Any]: for k, v in meta_mapping.items(): if "match" not in v: raise ValueError( @@ -458,44 +469,35 @@ def meta_mapping_validator( mce_builder.validate_ownership_type(owner_category) return meta_mapping - @validator("include_column_lineage") - def validate_include_column_lineage( - cls, include_column_lineage: bool, values: Dict - ) -> bool: - if include_column_lineage and not values.get("infer_dbt_schemas"): + @model_validator(mode="after") + def validate_include_column_lineage(self) -> "DBTCommonConfig": + if self.include_column_lineage and not self.infer_dbt_schemas: raise ValueError( "`infer_dbt_schemas` must be enabled to use `include_column_lineage`" ) - return include_column_lineage - - @validator("skip_sources_in_lineage", always=True) - def validate_skip_sources_in_lineage( - cls, skip_sources_in_lineage: bool, values: Dict - ) -> bool: - entities_enabled: Optional[DBTEntitiesEnabled] = values.get("entities_enabled") - prefer_sql_parser_lineage: Optional[bool] = values.get( - "prefer_sql_parser_lineage" - ) + return self - if prefer_sql_parser_lineage and not skip_sources_in_lineage: + @model_validator(mode="after") + def validate_skip_sources_in_lineage(self) -> "DBTCommonConfig": + if self.prefer_sql_parser_lineage and not self.skip_sources_in_lineage: raise ValueError( "`prefer_sql_parser_lineage` requires that `skip_sources_in_lineage` is enabled." ) if ( - skip_sources_in_lineage - and entities_enabled - and entities_enabled.sources == EmitDirective.YES + self.skip_sources_in_lineage + and self.entities_enabled + and self.entities_enabled.sources == EmitDirective.YES # When `prefer_sql_parser_lineage` is enabled, it's ok to have `skip_sources_in_lineage` enabled # without also disabling sources. - and not prefer_sql_parser_lineage + and not self.prefer_sql_parser_lineage ): raise ValueError( "When `skip_sources_in_lineage` is enabled, `entities_enabled.sources` must be set to NO." ) - return skip_sources_in_lineage + return self @dataclass diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py index a71061807b72a7..9d7d72640b183a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_core.py @@ -9,7 +9,7 @@ import dateutil.parser import requests from packaging import version -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, model_validator from datahub.configuration.git import GitReference from datahub.configuration.validate_field_rename import pydantic_renamed_field @@ -99,26 +99,24 @@ class DBTCoreConfig(DBTCommonConfig): _github_info_deprecated = pydantic_renamed_field("github_info", "git_info") - @validator("aws_connection", always=True) - def aws_connection_needed_if_s3_uris_present( - cls, aws_connection: Optional[AwsConnectionConfig], values: Dict, **kwargs: Any - ) -> Optional[AwsConnectionConfig]: + @model_validator(mode="after") + def aws_connection_needed_if_s3_uris_present(self) -> "DBTCoreConfig": # first check if there are fields that contain s3 uris uris = [ - values.get(f) + getattr(self, f, None) for f in [ "manifest_path", "catalog_path", "sources_path", ] - ] + values.get("run_results_paths", []) + ] + (self.run_results_paths or []) s3_uris = [uri for uri in uris if is_s3_uri(uri or "")] - if s3_uris and aws_connection is None: + if s3_uris and self.aws_connection is None: raise ValueError( f"Please provide aws_connection configuration, since s3 uris have been provided {s3_uris}" ) - return aws_connection + return self def get_columns( @@ -426,13 +424,13 @@ def load_run_results( ) return all_nodes - dbt_metadata = DBTRunMetadata.parse_obj(test_results_json.get("metadata", {})) + dbt_metadata = DBTRunMetadata.model_validate(test_results_json.get("metadata", {})) all_nodes_map: Dict[str, DBTNode] = {x.dbt_name: x for x in all_nodes} results = test_results_json.get("results", []) for result in results: - run_result = DBTRunResult.parse_obj(result) + run_result = DBTRunResult.model_validate(result) id = run_result.unique_id if id.startswith("test."): @@ -477,7 +475,7 @@ def __init__(self, config: DBTCommonConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict, ctx): - config = DBTCoreConfig.parse_obj(config_dict) + config = DBTCoreConfig.model_validate(config_dict) return cls(config, ctx) @staticmethod diff --git a/metadata-ingestion/src/datahub/ingestion/source/debug/datahub_debug.py b/metadata-ingestion/src/datahub/ingestion/source/debug/datahub_debug.py index d564afcc56218c..0a2d5e876096f2 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/debug/datahub_debug.py +++ b/metadata-ingestion/src/datahub/ingestion/source/debug/datahub_debug.py @@ -46,7 +46,7 @@ def __init__(self, ctx: PipelineContext, config: DataHubDebugSourceConfig): @classmethod def create(cls, config_dict, ctx): - config = DataHubDebugSourceConfig.parse_obj(config_dict) + config = DataHubDebugSourceConfig.model_validate(config_dict) return cls(ctx, config) def perform_dns_probe(self, url: str) -> None: diff --git a/metadata-ingestion/src/datahub/ingestion/source/delta_lake/config.py b/metadata-ingestion/src/datahub/ingestion/source/delta_lake/config.py index c9d8ef7833f975..3b86b8c47a434c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/delta_lake/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/delta_lake/config.py @@ -1,9 +1,8 @@ import logging from typing import Optional -import pydantic from cached_property import cached_property -from pydantic import Field +from pydantic import Field, field_validator from typing_extensions import Literal from datahub.configuration.common import AllowDenyPattern, ConfigModel @@ -98,8 +97,11 @@ def complete_path(self): return complete_path - @pydantic.validator("version_history_lookback") - def negative_version_history_implies_no_limit(cls, v): + @field_validator("version_history_lookback", mode="after") + @classmethod + def negative_version_history_implies_no_limit( + cls, v: Optional[int] + ) -> Optional[int]: if v and v < 0: return None return v diff --git a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_config.py b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_config.py index 50c8d8f122ca43..bd353ff21bb931 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dremio/dremio_config.py @@ -2,7 +2,7 @@ from typing import List, Literal, Optional import certifi -from pydantic import Field, validator +from pydantic import Field, ValidationInfo, field_validator from datahub.configuration.common import AllowDenyPattern, ConfigModel, HiddenFromDocs from datahub.configuration.source_common import ( @@ -78,8 +78,9 @@ class DremioConnectionConfig(ConfigModel): description="ID of Dremio Cloud Project. Found in Project Settings in the Dremio Cloud UI", ) - @validator("authentication_method") - def validate_auth_method(cls, value): + @field_validator("authentication_method", mode="after") + @classmethod + def validate_auth_method(cls, value: str) -> str: allowed_methods = ["password", "PAT"] if value not in allowed_methods: raise ValueError( @@ -87,9 +88,12 @@ def validate_auth_method(cls, value): ) return value - @validator("password") - def validate_password(cls, value, values): - if values.get("authentication_method") == "PAT" and not value: + @field_validator("password", mode="after") + @classmethod + def validate_password( + cls, value: Optional[str], info: ValidationInfo + ) -> Optional[str]: + if info.data.get("authentication_method") == "PAT" and not value: raise ValueError( "Password (Personal Access Token) is required when using PAT authentication", ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/dynamodb/dynamodb.py b/metadata-ingestion/src/datahub/ingestion/source/dynamodb/dynamodb.py index e6a82e802f6e34..03953ee81ba634 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dynamodb/dynamodb.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dynamodb/dynamodb.py @@ -200,7 +200,7 @@ def __init__(self, ctx: PipelineContext, config: DynamoDBConfig, platform: str): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "DynamoDBSource": - config = DynamoDBConfig.parse_obj(config_dict) + config = DynamoDBConfig.model_validate(config_dict) return cls(ctx, config, "dynamodb") def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/elastic_search.py b/metadata-ingestion/src/datahub/ingestion/source/elastic_search.py index 785047a9a43745..1ea4533b467d86 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/elastic_search.py +++ b/metadata-ingestion/src/datahub/ingestion/source/elastic_search.py @@ -8,7 +8,7 @@ from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type, Union from elasticsearch import Elasticsearch -from pydantic import validator +from pydantic import field_validator from pydantic.fields import Field from datahub.configuration.common import AllowDenyPattern, ConfigModel @@ -330,7 +330,8 @@ def is_profiling_enabled(self) -> bool: self.profiling.operation_config ) - @validator("host") + @field_validator("host", mode="after") + @classmethod def host_colon_port_comma(cls, host_val: str) -> str: for entry in host_val.split(","): entry = remove_protocol(entry) @@ -382,7 +383,7 @@ def __init__(self, config: ElasticsearchSourceConfig, ctx: PipelineContext): def create( cls, config_dict: Dict[str, Any], ctx: PipelineContext ) -> "ElasticsearchSource": - config = ElasticsearchSourceConfig.parse_obj(config_dict) + config = ElasticsearchSourceConfig.model_validate(config_dict) return cls(config, ctx) def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/excel/source.py b/metadata-ingestion/src/datahub/ingestion/source/excel/source.py index 62568231dab562..8fd634012b80c2 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/excel/source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/excel/source.py @@ -156,7 +156,7 @@ def __init__(self, ctx: PipelineContext, config: ExcelSourceConfig): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "ExcelSource": - config = ExcelSourceConfig.parse_obj(config_dict) + config = ExcelSourceConfig.model_validate(config_dict) return cls(ctx, config) def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/feast.py b/metadata-ingestion/src/datahub/ingestion/source/feast.py index feb3cd2d14c30f..ea52df319d3b7b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/feast.py +++ b/metadata-ingestion/src/datahub/ingestion/source/feast.py @@ -462,7 +462,7 @@ def _create_owner_association(self, owner: str) -> Optional[OwnerClass]: @classmethod def create(cls, config_dict, ctx): - config = FeastRepositorySourceConfig.parse_obj(config_dict) + config = FeastRepositorySourceConfig.model_validate(config_dict) return cls(config, ctx) def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/file.py b/metadata-ingestion/src/datahub/ingestion/source/file.py index 5074a3ed5bb263..b484b1ef1da4f9 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/file.py +++ b/metadata-ingestion/src/datahub/ingestion/source/file.py @@ -9,7 +9,7 @@ from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union import ijson -from pydantic import validator +from pydantic import field_validator from pydantic.fields import Field from datahub.configuration.common import ConfigEnum @@ -103,7 +103,8 @@ class FileSourceConfig(StatefulIngestionConfigBase): stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = None - @validator("file_extension", always=True) + @field_validator("file_extension", mode="after") + @classmethod def add_leading_dot_to_extension(cls, v: str) -> str: if v: if v.startswith("."): @@ -205,7 +206,7 @@ def __init__(self, ctx: PipelineContext, config: FileSourceConfig): @classmethod def create(cls, config_dict, ctx): - config = FileSourceConfig.parse_obj(config_dict) + config = FileSourceConfig.model_validate(config_dict) return cls(ctx, config) def get_filenames(self) -> Iterable[FileInfo]: @@ -358,7 +359,7 @@ def iterate_generic_file( @staticmethod def test_connection(config_dict: dict) -> TestConnectionReport: - config = FileSourceConfig.parse_obj(config_dict) + config = FileSourceConfig.model_validate(config_dict) exists = os.path.exists(config.path) if not exists: return TestConnectionReport( diff --git a/metadata-ingestion/src/datahub/ingestion/source/fivetran/config.py b/metadata-ingestion/src/datahub/ingestion/source/fivetran/config.py index 7113fac447f5a3..0f54a84473f915 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/fivetran/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/fivetran/config.py @@ -1,10 +1,10 @@ import dataclasses import logging import warnings -from typing import Dict, Optional +from typing import Any, Dict, Optional import pydantic -from pydantic import Field, root_validator +from pydantic import Field, field_validator, model_validator from typing_extensions import Literal from datahub.configuration.common import ( @@ -98,7 +98,8 @@ class DatabricksDestinationConfig(UnityCatalogConnectionConfig): catalog: str = Field(description="The fivetran connector log catalog.") log_schema: str = Field(description="The fivetran connector log schema.") - @pydantic.validator("warehouse_id") + @field_validator("warehouse_id", mode="after") + @classmethod def warehouse_id_should_not_be_empty(cls, warehouse_id: Optional[str]) -> str: if warehouse_id is None or (warehouse_id and warehouse_id.strip() == ""): raise ValueError("Fivetran requires warehouse_id to be set") @@ -141,29 +142,28 @@ class FivetranLogConfig(ConfigModel): "destination_config", "snowflake_destination_config" ) - @root_validator(skip_on_failure=True) - def validate_destination_platfrom_and_config(cls, values: Dict) -> Dict: - destination_platform = values["destination_platform"] - if destination_platform == "snowflake": - if "snowflake_destination_config" not in values: + @model_validator(mode="after") + def validate_destination_platform_and_config(self) -> "FivetranLogConfig": + if self.destination_platform == "snowflake": + if self.snowflake_destination_config is None: raise ValueError( "If destination platform is 'snowflake', user must provide snowflake destination configuration in the recipe." ) - elif destination_platform == "bigquery": - if "bigquery_destination_config" not in values: + elif self.destination_platform == "bigquery": + if self.bigquery_destination_config is None: raise ValueError( "If destination platform is 'bigquery', user must provide bigquery destination configuration in the recipe." ) - elif destination_platform == "databricks": - if "databricks_destination_config" not in values: + elif self.destination_platform == "databricks": + if self.databricks_destination_config is None: raise ValueError( "If destination platform is 'databricks', user must provide databricks destination configuration in the recipe." ) else: raise ValueError( - f"Destination platform '{destination_platform}' is not yet supported." + f"Destination platform '{self.destination_platform}' is not yet supported." ) - return values + return self @dataclasses.dataclass @@ -267,8 +267,9 @@ class FivetranSourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin description="Fivetran REST API configuration, used to provide wider support for connections.", ) - @pydantic.root_validator(pre=True) - def compat_sources_to_database(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def compat_sources_to_database(cls, values: Any) -> Any: if "sources_to_database" in values: warnings.warn( "The sources_to_database field is deprecated, please use sources_to_platform_instance instead.", diff --git a/metadata-ingestion/src/datahub/ingestion/source/fivetran/fivetran.py b/metadata-ingestion/src/datahub/ingestion/source/fivetran/fivetran.py index b585d2e5600400..736327f4afc26c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/fivetran/fivetran.py +++ b/metadata-ingestion/src/datahub/ingestion/source/fivetran/fivetran.py @@ -234,12 +234,12 @@ def _extend_lineage(self, connector: Connector, datajob: DataJob) -> Dict[str, s return dict( **{ f"source.{k}": str(v) - for k, v in source_details.dict().items() + for k, v in source_details.model_dump().items() if v is not None and not isinstance(v, bool) }, **{ f"destination.{k}": str(v) - for k, v in destination_details.dict().items() + for k, v in destination_details.model_dump().items() if v is not None and not isinstance(v, bool) }, ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py b/metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py index 91b4e58ce3c3a0..e8b423db276701 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py +++ b/metadata-ingestion/src/datahub/ingestion/source/gc/datahub_gc.py @@ -127,7 +127,7 @@ def __init__(self, ctx: PipelineContext, config: DataHubGcSourceConfig): @classmethod def create(cls, config_dict, ctx): - config = DataHubGcSourceConfig.parse_obj(config_dict) + config = DataHubGcSourceConfig.model_validate(config_dict) return cls(ctx, config) # auto_work_unit_report is overriden to disable a couple of automation like auto status aspect, etc. which is not needed her. diff --git a/metadata-ingestion/src/datahub/ingestion/source/gcs/gcs_source.py b/metadata-ingestion/src/datahub/ingestion/source/gcs/gcs_source.py index 51e6589d7a89e2..7e42f84685e04b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/gcs/gcs_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/gcs/gcs_source.py @@ -1,7 +1,7 @@ import logging -from typing import Dict, Iterable, List, Optional +from typing import Iterable, List, Optional -from pydantic import Field, SecretStr, validator +from pydantic import Field, SecretStr, model_validator from datahub.configuration.common import ConfigModel from datahub.configuration.source_common import DatasetSourceConfigMixin @@ -64,18 +64,16 @@ class GCSSourceConfig( stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = None - @validator("path_specs", always=True) - def check_path_specs_and_infer_platform( - cls, path_specs: List[PathSpec], values: Dict - ) -> List[PathSpec]: - if len(path_specs) == 0: + @model_validator(mode="after") + def check_path_specs_and_infer_platform(self) -> "GCSSourceConfig": + if len(self.path_specs) == 0: raise ValueError("path_specs must not be empty") # Check that all path specs have the gs:// prefix. - if any([not is_gcs_uri(path_spec.include) for path_spec in path_specs]): + if any([not is_gcs_uri(path_spec.include) for path_spec in self.path_specs]): raise ValueError("All path_spec.include should start with gs://") - return path_specs + return self class GCSSourceReport(DataLakeSourceReport): @@ -105,7 +103,7 @@ def __init__(self, config: GCSSourceConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict, ctx): - config = GCSSourceConfig.parse_obj(config_dict) + config = GCSSourceConfig.model_validate(config_dict) return cls(config, ctx) def create_equivalent_s3_config(self): diff --git a/metadata-ingestion/src/datahub/ingestion/source/ge_profiling_config.py b/metadata-ingestion/src/datahub/ingestion/source/ge_profiling_config.py index 420da1201906cf..c593f36ffbc5d5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/ge_profiling_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/ge_profiling_config.py @@ -4,6 +4,7 @@ from typing import Annotated, Any, Dict, List, Optional import pydantic +from pydantic import model_validator from pydantic.fields import Field from datahub.configuration.common import AllowDenyPattern, ConfigModel, SupportedSources @@ -212,7 +213,8 @@ class GEProfilingConfig(GEProfilingBaseConfig): description="Whether to profile complex types like structs, arrays and maps. ", ) - @pydantic.root_validator(pre=True) + @model_validator(mode="before") + @classmethod def deprecate_bigquery_temp_table_schema(cls, values): # TODO: Update docs to remove mention of this field. if "bigquery_temp_table_schema" in values: @@ -222,16 +224,17 @@ def deprecate_bigquery_temp_table_schema(cls, values): del values["bigquery_temp_table_schema"] return values - @pydantic.root_validator(pre=True) + @model_validator(mode="before") + @classmethod def ensure_field_level_settings_are_normalized( - cls: "GEProfilingConfig", values: Dict[str, Any] + cls, values: Dict[str, Any] ) -> Dict[str, Any]: max_num_fields_to_profile_key = "max_number_of_fields_to_profile" max_num_fields_to_profile = values.get(max_num_fields_to_profile_key) # Disable all field-level metrics. if values.get("profile_table_level_only"): - for field_level_metric in cls.__fields__: + for field_level_metric in cls.model_fields: if field_level_metric.startswith("include_field_"): if values.get(field_level_metric): raise ValueError( @@ -267,7 +270,7 @@ def any_field_level_metrics_enabled(self) -> bool: ) def config_for_telemetry(self) -> Dict[str, Any]: - config_dict = self.dict() + config_dict = self.model_dump() return { flag: config_dict[flag] diff --git a/metadata-ingestion/src/datahub/ingestion/source/grafana/grafana_api.py b/metadata-ingestion/src/datahub/ingestion/source/grafana/grafana_api.py index 380e2b96967e4d..95bd857eca7a48 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/grafana/grafana_api.py +++ b/metadata-ingestion/src/datahub/ingestion/source/grafana/grafana_api.py @@ -69,7 +69,7 @@ def get_folders(self) -> List[Folder]: if not batch: break - folders.extend(Folder.parse_obj(folder) for folder in batch) + folders.extend(Folder.model_validate(folder) for folder in batch) page += 1 except requests.exceptions.RequestException as e: self.report.report_failure( @@ -88,7 +88,7 @@ def get_dashboard(self, uid: str) -> Optional[Dashboard]: try: response = self.session.get(f"{self.base_url}/api/dashboards/uid/{uid}") response.raise_for_status() - return Dashboard.parse_obj(response.json()) + return Dashboard.model_validate(response.json()) except requests.exceptions.RequestException as e: self.report.warning( title="Dashboard Fetch Error", diff --git a/metadata-ingestion/src/datahub/ingestion/source/grafana/grafana_config.py b/metadata-ingestion/src/datahub/ingestion/source/grafana/grafana_config.py index cd1d6b4f6d2bfa..e84cdab6da9f9a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/grafana/grafana_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/grafana/grafana_config.py @@ -1,6 +1,6 @@ from typing import Dict, Optional -from pydantic import Field, SecretStr, validator +from pydantic import Field, SecretStr, field_validator from datahub.configuration.common import AllowDenyPattern, HiddenFromDocs from datahub.configuration.source_common import ( @@ -99,6 +99,7 @@ class GrafanaSourceConfig( description="Map of Grafana datasource types/UIDs to platform connection configs for lineage extraction", ) - @validator("url", allow_reuse=True) - def remove_trailing_slash(cls, v): + @field_validator("url", mode="after") + @classmethod + def remove_trailing_slash(cls, v: str) -> str: return config_clean.remove_trailing_slashes(v) diff --git a/metadata-ingestion/src/datahub/ingestion/source/grafana/grafana_source.py b/metadata-ingestion/src/datahub/ingestion/source/grafana/grafana_source.py index b5e258c7da18ae..dc26e7c4a0f631 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/grafana/grafana_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/grafana/grafana_source.py @@ -171,7 +171,7 @@ def __init__(self, config: GrafanaSourceConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "GrafanaSource": - config = GrafanaSourceConfig.parse_obj(config_dict) + config = GrafanaSourceConfig.model_validate(config_dict) return cls(config, ctx) def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/grafana/models.py b/metadata-ingestion/src/datahub/ingestion/source/grafana/models.py index 90780f7f847d13..c035c91037f31b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/grafana/models.py +++ b/metadata-ingestion/src/datahub/ingestion/source/grafana/models.py @@ -79,18 +79,29 @@ def extract_panels(panels_data: List[Dict[str, Any]]) -> List[Panel]: for panel_data in panels_data: if panel_data.get("type") == "row" and "panels" in panel_data: panels.extend( - Panel.parse_obj(p) + Panel.model_validate(p) for p in panel_data["panels"] if p.get("type") != "row" ) elif panel_data.get("type") != "row": - panels.append(Panel.parse_obj(panel_data)) + panels.append(Panel.model_validate(panel_data)) return panels @classmethod - def parse_obj(cls, data: Dict[str, Any]) -> "Dashboard": + def model_validate( + cls, + obj: Any, + *, + strict: Optional[bool] = None, + from_attributes: Optional[bool] = None, + context: Optional[Any] = None, + by_alias: Optional[bool] = None, + by_name: Optional[bool] = None, + ) -> "Dashboard": """Custom parsing to handle nested panel extraction.""" - dashboard_data = data.get("dashboard", {}) + # Handle both direct dashboard data and nested structure with 'dashboard' key + dashboard_data = obj.get("dashboard", obj) + _panel_data = dashboard_data.get("panels", []) panels = [] try: @@ -113,7 +124,14 @@ def parse_obj(cls, data: Dict[str, Any]) -> "Dashboard": if "refresh" in dashboard_dict and isinstance(dashboard_dict["refresh"], bool): dashboard_dict["refresh"] = str(dashboard_dict["refresh"]) - return super().parse_obj(dashboard_dict) + return super().model_validate( + dashboard_dict, + strict=strict, + from_attributes=from_attributes, + context=context, + by_alias=by_alias, + by_name=by_name, + ) class Folder(_GrafanaBaseModel): diff --git a/metadata-ingestion/src/datahub/ingestion/source/hex/api.py b/metadata-ingestion/src/datahub/ingestion/source/hex/api.py index 4b043a58b49282..58ba64b75252a7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/hex/api.py +++ b/metadata-ingestion/src/datahub/ingestion/source/hex/api.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Generator, List, Optional, Union import requests -from pydantic import BaseModel, Field, ValidationError, validator +from pydantic import BaseModel, Field, ValidationError, field_validator from requests.adapters import HTTPAdapter from typing_extensions import assert_never from urllib3.util.retry import Retry @@ -50,7 +50,8 @@ class HexApiProjectAnalytics(BaseModel): default=None, alias="publishedResultsUpdatedAt" ) - @validator("last_viewed_at", "published_results_updated_at", pre=True) + @field_validator("last_viewed_at", "published_results_updated_at", mode="before") + @classmethod def parse_datetime(cls, value): if value is None: return None @@ -167,14 +168,15 @@ class HexApiProjectApiResource(BaseModel): class Config: extra = "ignore" # Allow extra fields in the JSON - @validator( + @field_validator( "created_at", "last_edited_at", "last_published_at", "archived_at", "trashed_at", - pre=True, + mode="before", ) + @classmethod def parse_datetime(cls, value): if value is None: return None @@ -292,7 +294,7 @@ def _fetch_projects_page( ) response.raise_for_status() - api_response = HexApiProjectsListResponse.parse_obj(response.json()) + api_response = HexApiProjectsListResponse.model_validate(response.json()) logger.info(f"Fetched {len(api_response.values)} items") params["after"] = ( api_response.pagination.after if api_response.pagination else None diff --git a/metadata-ingestion/src/datahub/ingestion/source/hex/hex.py b/metadata-ingestion/src/datahub/ingestion/source/hex/hex.py index 8b2509630bb42f..4a6ef3af4af32e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/hex/hex.py +++ b/metadata-ingestion/src/datahub/ingestion/source/hex/hex.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta, timezone from typing import Any, Dict, Iterable, List, Optional -from pydantic import Field, SecretStr, root_validator +from pydantic import Field, SecretStr, model_validator from typing_extensions import assert_never from datahub.configuration.common import AllowDenyPattern @@ -120,7 +120,8 @@ class HexSourceConfig( description="Number of items to fetch per DataHub API call.", ) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def validate_lineage_times(cls, data: Dict[str, Any]) -> Dict[str, Any]: # In-place update of the input dict would cause state contamination. This was discovered through test failures # in test_hex.py where the same dict is reused. @@ -238,7 +239,7 @@ def __init__(self, config: HexSourceConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: Dict[str, Any], ctx: PipelineContext) -> "HexSource": - config = HexSourceConfig.parse_obj(config_dict) + config = HexSourceConfig.model_validate(config_dict) return cls(config, ctx) def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg.py b/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg.py index 887419b31f4ecd..3d9cdffb0a9c70 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg.py +++ b/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg.py @@ -161,7 +161,7 @@ def __init__(self, config: IcebergSourceConfig, ctx: PipelineContext) -> None: @classmethod def create(cls, config_dict: Dict, ctx: PipelineContext) -> "IcebergSource": - config = IcebergSourceConfig.parse_obj(config_dict) + config = IcebergSourceConfig.model_validate(config_dict) return cls(config, ctx) def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_common.py b/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_common.py index 7bc2e377e94b82..593b9af1327ad6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg_common.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Optional from humanfriendly import format_timespan -from pydantic import Field, validator +from pydantic import Field, field_validator from pyiceberg.catalog import Catalog, load_catalog from pyiceberg.catalog.rest import RestCatalog from requests.adapters import HTTPAdapter @@ -108,7 +108,8 @@ class IcebergSourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin) default=1, description="How many threads will be processing tables" ) - @validator("catalog", pre=True, always=True) + @field_validator("catalog", mode="before") + @classmethod def handle_deprecated_catalog_format(cls, value): # Once support for deprecated format is dropped, we can remove this validator. if ( @@ -131,7 +132,8 @@ def handle_deprecated_catalog_format(cls, value): # In case the input is already the new format or is invalid return value - @validator("catalog") + @field_validator("catalog", mode="after") + @classmethod def validate_catalog_size(cls, value): if len(value) != 1: raise ValueError("The catalog must contain exactly one entry.") diff --git a/metadata-ingestion/src/datahub/ingestion/source/identity/azure_ad.py b/metadata-ingestion/src/datahub/ingestion/source/identity/azure_ad.py index 76322141a7ba07..2fd921893ad3dd 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/identity/azure_ad.py +++ b/metadata-ingestion/src/datahub/ingestion/source/identity/azure_ad.py @@ -254,7 +254,7 @@ class AzureADSource(StatefulIngestionSourceBase): @classmethod def create(cls, config_dict, ctx): - config = AzureADConfig.parse_obj(config_dict) + config = AzureADConfig.model_validate(config_dict) return cls(config, ctx) def __init__(self, config: AzureADConfig, ctx: PipelineContext): diff --git a/metadata-ingestion/src/datahub/ingestion/source/identity/okta.py b/metadata-ingestion/src/datahub/ingestion/source/identity/okta.py index 1435b696140023..894642afb03fa2 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/identity/okta.py +++ b/metadata-ingestion/src/datahub/ingestion/source/identity/okta.py @@ -11,7 +11,7 @@ from okta.client import Client as OktaClient from okta.exceptions import OktaAPIException from okta.models import Group, GroupProfile, User, UserProfile, UserStatus -from pydantic import validator +from pydantic import model_validator from pydantic.fields import Field from datahub.emitter.mcp import MetadataChangeProposalWrapper @@ -157,21 +157,21 @@ class OktaConfig(StatefulIngestionConfigBase): mask_group_id: bool = True mask_user_id: bool = True - @validator("okta_users_search") - def okta_users_one_of_filter_or_search(cls, v, values): - if v and values["okta_users_filter"]: + @model_validator(mode="after") + def okta_users_one_of_filter_or_search(self) -> "OktaConfig": + if self.okta_users_search and self.okta_users_filter: raise ValueError( "Only one of okta_users_filter or okta_users_search can be set" ) - return v + return self - @validator("okta_groups_search") - def okta_groups_one_of_filter_or_search(cls, v, values): - if v and values["okta_groups_filter"]: + @model_validator(mode="after") + def okta_groups_one_of_filter_or_search(self) -> "OktaConfig": + if self.okta_groups_search and self.okta_groups_filter: raise ValueError( "Only one of okta_groups_filter or okta_groups_search can be set" ) - return v + return self @dataclass @@ -288,7 +288,7 @@ class OktaSource(StatefulIngestionSourceBase): @classmethod def create(cls, config_dict, ctx): - config = OktaConfig.parse_obj(config_dict) + config = OktaConfig.model_validate(config_dict) return cls(config, ctx) def __init__(self, config: OktaConfig, ctx: PipelineContext): diff --git a/metadata-ingestion/src/datahub/ingestion/source/kafka/kafka.py b/metadata-ingestion/src/datahub/ingestion/source/kafka/kafka.py index 51b1050e5e8816..2b2d26203b7c95 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/kafka/kafka.py +++ b/metadata-ingestion/src/datahub/ingestion/source/kafka/kafka.py @@ -267,7 +267,7 @@ def test_connection(config_dict: dict) -> TestConnectionReport: @classmethod def create(cls, config_dict: Dict, ctx: PipelineContext) -> "KafkaSource": - config: KafkaSourceConfig = KafkaSourceConfig.parse_obj(config_dict) + config: KafkaSourceConfig = KafkaSourceConfig.model_validate(config_dict) return cls(config, ctx) def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/ldap.py b/metadata-ingestion/src/datahub/ingestion/source/ldap.py index 7fc10795193cf1..ff31508101c7d0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/ldap.py +++ b/metadata-ingestion/src/datahub/ingestion/source/ldap.py @@ -242,7 +242,7 @@ def __init__(self, ctx: PipelineContext, config: LDAPSourceConfig): @classmethod def create(cls, config_dict: Dict[str, Any], ctx: PipelineContext) -> "LDAPSource": """Factory method.""" - config = LDAPSourceConfig.parse_obj(config_dict) + config = LDAPSourceConfig.model_validate(config_dict) return cls(ctx, config) def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_common.py b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_common.py index b30097adb7b828..681594184d0d7c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_common.py @@ -28,7 +28,7 @@ User, WriteQuery, ) -from pydantic import validator +from pydantic import field_validator import datahub.emitter.mce_builder as builder from datahub.api.entities.platformresource.platform_resource import ( @@ -202,8 +202,9 @@ def get_mapping(self, config: LookerCommonConfig) -> ViewNamingPatternMapping: folder_path=os.path.dirname(self.file_path), ) - @validator("view_name") - def remove_quotes(cls, v): + @field_validator("view_name", mode="after") + @classmethod + def remove_quotes(cls, v: str) -> str: # Sanitize the name. v = v.replace('"', "").replace("`", "") return v @@ -931,8 +932,9 @@ class LookerExplore: source_file: Optional[str] = None tags: List[str] = dataclasses_field(default_factory=list) - @validator("name") - def remove_quotes(cls, v): + @field_validator("name", mode="after") + @classmethod + def remove_quotes(cls, v: str) -> str: # Sanitize the name. v = v.replace('"', "").replace("`", "") return v diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_config.py b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_config.py index 59627fd735bad1..2342cc1a91d04d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/looker_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/looker_config.py @@ -1,11 +1,11 @@ import dataclasses import os import re -from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union, cast +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union import pydantic from looker_sdk.sdk.api40.models import DBConnection -from pydantic import Field, model_validator, validator +from pydantic import Field, field_validator, model_validator from datahub.configuration import ConfigModel from datahub.configuration.common import ( @@ -198,17 +198,20 @@ class LookerConnectionDefinition(ConfigModel): "the top level Looker configuration", ) - @validator("platform_env") + @field_validator("platform_env", mode="after") + @classmethod def platform_env_must_be_one_of(cls, v: Optional[str]) -> Optional[str]: if v is not None: return EnvConfigMixin.env_must_be_one_of(v) return v - @validator("platform", "default_db", "default_schema") - def lower_everything(cls, v): + @field_validator("platform", "default_db", "default_schema", mode="after") + @classmethod + def lower_everything(cls, v: Optional[str]) -> Optional[str]: """We lower case all strings passed in to avoid casing issues later""" if v is not None: return v.lower() + return v @classmethod def from_looker_connection( @@ -326,22 +329,20 @@ class LookerDashboardSourceConfig( "Dashboards will only be ingested if they're allowed by both this config and dashboard_pattern.", ) - @validator("external_base_url", pre=True, always=True) + @model_validator(mode="before") + @classmethod def external_url_defaults_to_api_config_base_url( - cls, v: Optional[str], *, values: Dict[str, Any], **kwargs: Dict[str, Any] - ) -> Optional[str]: - return v or values.get("base_url") - - @validator("extract_independent_looks", always=True) - def stateful_ingestion_should_be_enabled( - cls, v: Optional[bool], *, values: Dict[str, Any], **kwargs: Dict[str, Any] - ) -> Optional[bool]: - stateful_ingestion: StatefulStaleMetadataRemovalConfig = cast( - StatefulStaleMetadataRemovalConfig, values.get("stateful_ingestion") - ) - if v is True and ( - stateful_ingestion is None or stateful_ingestion.enabled is False + cls, values: Dict[str, Any] + ) -> Dict[str, Any]: + if "external_base_url" not in values or values["external_base_url"] is None: + values["external_base_url"] = values.get("base_url") + return values + + @model_validator(mode="after") + def stateful_ingestion_should_be_enabled(self): + if self.extract_independent_looks is True and ( + self.stateful_ingestion is None or self.stateful_ingestion.enabled is False ): raise ValueError("stateful_ingestion.enabled should be set to true") - return v + return self diff --git a/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_config.py b/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_config.py index c22aa9c011e2f8..f7076e4980ec2e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/looker/lookml_config.py @@ -1,10 +1,11 @@ import logging +from copy import deepcopy from dataclasses import dataclass, field as dataclass_field from datetime import timedelta from typing import Any, Dict, Literal, Optional, Union import pydantic -from pydantic import root_validator, validator +from pydantic import model_validator from pydantic.fields import Field from datahub.configuration.common import AllowDenyPattern @@ -210,75 +211,74 @@ class LookMLSourceConfig( "All if comments are evaluated to true for configured looker_environment value", ) - @validator("connection_to_platform_map", pre=True) - def convert_string_to_connection_def(cls, conn_map): - # Previous version of config supported strings in connection map. This upconverts strings to ConnectionMap - for key in conn_map: - if isinstance(conn_map[key], str): - platform = conn_map[key] - if "." in platform: - platform_db_split = conn_map[key].split(".") - connection = LookerConnectionDefinition( - platform=platform_db_split[0], - default_db=platform_db_split[1], - default_schema="", - ) - conn_map[key] = connection - else: - logger.warning( - f"Connection map for {key} provides platform {platform} but does not provide a default " - f"database name. This might result in failed resolution" - ) - conn_map[key] = LookerConnectionDefinition( - platform=platform, default_db="", default_schema="" - ) - return conn_map + @model_validator(mode="before") + @classmethod + def convert_string_to_connection_def(cls, values: Dict[str, Any]) -> Dict[str, Any]: + values = deepcopy(values) + conn_map = values.get("connection_to_platform_map") + if conn_map: + # Previous version of config supported strings in connection map. This upconverts strings to ConnectionMap + for key in conn_map: + if isinstance(conn_map[key], str): + platform = conn_map[key] + if "." in platform: + platform_db_split = conn_map[key].split(".") + connection = LookerConnectionDefinition( + platform=platform_db_split[0], + default_db=platform_db_split[1], + default_schema="", + ) + conn_map[key] = connection + else: + logger.warning( + f"Connection map for {key} provides platform {platform} but does not provide a default " + f"database name. This might result in failed resolution" + ) + conn_map[key] = LookerConnectionDefinition( + platform=platform, default_db="", default_schema="" + ) + return values - @root_validator(skip_on_failure=True) - def check_either_connection_map_or_connection_provided(cls, values): + @model_validator(mode="after") + def check_either_connection_map_or_connection_provided(self): """Validate that we must either have a connection map or an api credential""" - if not values.get("connection_to_platform_map", {}) and not values.get( - "api", {} - ): + if not (self.connection_to_platform_map or {}) and not (self.api): raise ValueError( "Neither api not connection_to_platform_map config was found. LookML source requires either api " "credentials for Looker or a map of connection names to platform identifiers to work correctly" ) - return values + return self - @root_validator(skip_on_failure=True) - def check_either_project_name_or_api_provided(cls, values): + @model_validator(mode="after") + def check_either_project_name_or_api_provided(self): """Validate that we must either have a project name or an api credential to fetch project names""" - if not values.get("project_name") and not values.get("api"): + if not self.project_name and not self.api: raise ValueError( "Neither project_name not an API credential was found. LookML source requires either api credentials " "for Looker or a project_name to accurately name views and models." ) - return values + return self - @root_validator(skip_on_failure=True) - def check_api_provided_for_view_lineage(cls, values): + @model_validator(mode="after") + def check_api_provided_for_view_lineage(self): """Validate that we must have an api credential to use Looker API for view's column lineage""" - if not values.get("api") and values.get("use_api_for_view_lineage"): + if not self.api and self.use_api_for_view_lineage: raise ValueError( "API credential was not found. LookML source requires api credentials " "for Looker to use Looker APIs for view's column lineage extraction." "Set `use_api_for_view_lineage` to False to skip using Looker APIs." ) - return values + return self - @validator("base_folder", always=True) - def check_base_folder_if_not_provided( - cls, v: Optional[pydantic.DirectoryPath], values: Dict[str, Any] - ) -> Optional[pydantic.DirectoryPath]: - if v is None: - git_info: Optional[GitInfo] = values.get("git_info") - if git_info: - if not git_info.deploy_key: + @model_validator(mode="after") + def check_base_folder_if_not_provided(self): + if self.base_folder is None: + if self.git_info: + if not self.git_info.deploy_key: logger.warning( "git_info is provided, but no SSH key is present. If the repo is not public, we'll fail to " "clone it." ) else: raise ValueError("Neither base_folder nor git_info has been provided.") - return v + return self diff --git a/metadata-ingestion/src/datahub/ingestion/source/metabase.py b/metadata-ingestion/src/datahub/ingestion/source/metabase.py index f87ecf63c07fc4..d9a0c0b249ab32 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/metabase.py +++ b/metadata-ingestion/src/datahub/ingestion/source/metabase.py @@ -9,7 +9,7 @@ import dateutil.parser as dp import pydantic import requests -from pydantic import Field, root_validator, validator +from pydantic import Field, field_validator, model_validator from requests.models import HTTPError import datahub.emitter.mce_builder as builder @@ -115,16 +115,16 @@ class MetabaseConfig( ) stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = None - @validator("connect_uri", "display_uri") + @field_validator("connect_uri", "display_uri", mode="after") + @classmethod def remove_trailing_slash(cls, v): return config_clean.remove_trailing_slashes(v) - @root_validator(skip_on_failure=True) - def default_display_uri_to_connect_uri(cls, values): - base = values.get("display_uri") - if base is None: - values["display_uri"] = values.get("connect_uri") - return values + @model_validator(mode="after") + def default_display_uri_to_connect_uri(self) -> "MetabaseConfig": + if self.display_uri is None: + self.display_uri = self.connect_uri + return self @dataclass diff --git a/metadata-ingestion/src/datahub/ingestion/source/metadata/business_glossary.py b/metadata-ingestion/src/datahub/ingestion/source/metadata/business_glossary.py index 945973fcbebae7..b617338238a6cf 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/metadata/business_glossary.py +++ b/metadata-ingestion/src/datahub/ingestion/source/metadata/business_glossary.py @@ -563,7 +563,7 @@ class BusinessGlossaryFileSource(Source): @classmethod def create(cls, config_dict, ctx): - config = BusinessGlossarySourceConfig.parse_obj(config_dict) + config = BusinessGlossarySourceConfig.model_validate(config_dict) return cls(ctx, config) @classmethod @@ -571,7 +571,7 @@ def load_glossary_config( cls, file_name: Union[str, pathlib.Path] ) -> BusinessGlossaryConfig: config = load_config_file(file_name, resolve_env_vars=True) - glossary_cfg = BusinessGlossaryConfig.parse_obj(config) + glossary_cfg = BusinessGlossaryConfig.model_validate(config) return glossary_cfg def get_workunits_internal( diff --git a/metadata-ingestion/src/datahub/ingestion/source/metadata/lineage.py b/metadata-ingestion/src/datahub/ingestion/source/metadata/lineage.py index b7b3f52b59f754..b9cac84942067b 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/metadata/lineage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/metadata/lineage.py @@ -3,7 +3,7 @@ from functools import partial from typing import Any, Dict, Iterable, List, Optional -from pydantic import validator +from pydantic import field_validator from pydantic.fields import Field import datahub.metadata.schema_classes as models @@ -51,7 +51,8 @@ class EntityConfig(EnvConfigMixin): platform: str platform_instance: Optional[str] = None - @validator("type") + @field_validator("type", mode="after") + @classmethod def type_must_be_supported(cls, v: str) -> str: allowed_types = ["dataset"] if v not in allowed_types: @@ -60,7 +61,8 @@ def type_must_be_supported(cls, v: str) -> str: ) return v - @validator("name") + @field_validator("name", mode="after") + @classmethod def validate_name(cls, v: str) -> str: if v.startswith("urn:li:"): raise ValueError( @@ -77,7 +79,8 @@ class FineGrainedLineageConfig(ConfigModel): transformOperation: Optional[str] confidenceScore: Optional[float] = 1.0 - @validator("upstreamType") + @field_validator("upstreamType", mode="after") + @classmethod def upstream_type_must_be_supported(cls, v: str) -> str: allowed_types = [ FineGrainedLineageUpstreamTypeClass.FIELD_SET, @@ -90,7 +93,8 @@ def upstream_type_must_be_supported(cls, v: str) -> str: ) return v - @validator("downstreamType") + @field_validator("downstreamType", mode="after") + @classmethod def downstream_type_must_be_supported(cls, v: str) -> str: allowed_types = [ FineGrainedLineageDownstreamTypeClass.FIELD_SET, @@ -124,7 +128,8 @@ class LineageFileSourceConfig(ConfigModel): class LineageConfig(VersionedConfig): lineage: List[EntityNodeConfig] - @validator("version") + @field_validator("version", mode="after") + @classmethod def version_must_be_1(cls, v): if v != "1": raise ValueError("Only version 1 is supported") @@ -148,13 +153,13 @@ class LineageFileSource(Source): def create( cls, config_dict: Dict[str, Any], ctx: PipelineContext ) -> "LineageFileSource": - config = LineageFileSourceConfig.parse_obj(config_dict) + config = LineageFileSourceConfig.model_validate(config_dict) return cls(ctx, config) @staticmethod def load_lineage_config(file_name: str) -> LineageConfig: config = load_config_file(file_name, resolve_env_vars=True) - lineage_config = LineageConfig.parse_obj(config) + lineage_config = LineageConfig.model_validate(config) return lineage_config def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py index e49bb34c75e8c5..5b495e84675817 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/mlflow.py +++ b/metadata-ingestion/src/datahub/ingestion/source/mlflow.py @@ -892,5 +892,5 @@ def _is_valid_platform(self, platform: Optional[str]) -> bool: @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "MLflowSource": - config = MLflowConfig.parse_obj(config_dict) + config = MLflowConfig.model_validate(config_dict) return cls(ctx, config) diff --git a/metadata-ingestion/src/datahub/ingestion/source/mode.py b/metadata-ingestion/src/datahub/ingestion/source/mode.py index c79e612056a984..b71492fb494823 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/mode.py +++ b/metadata-ingestion/src/datahub/ingestion/source/mode.py @@ -26,7 +26,7 @@ import tenacity import yaml from liquid import Template, Undefined -from pydantic import Field, validator +from pydantic import Field, field_validator from requests.adapters import HTTPAdapter, Retry from requests.exceptions import ConnectionError from requests.models import HTTPBasicAuth, HTTPError @@ -218,11 +218,13 @@ class ModeConfig( default=False, description="Exclude archived reports" ) - @validator("connect_uri") + @field_validator("connect_uri", mode="after") + @classmethod def remove_trailing_slash(cls, v): return config_clean.remove_trailing_slashes(v) - @validator("items_per_page") + @field_validator("items_per_page", mode="after") + @classmethod def validate_items_per_page(cls, v): if 1 <= v <= DEFAULT_API_ITEMS_PER_PAGE: return v @@ -1824,7 +1826,7 @@ def emit_dataset_mces(self): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "ModeSource": - config: ModeConfig = ModeConfig.parse_obj(config_dict) + config: ModeConfig = ModeConfig.model_validate(config_dict) return cls(ctx, config) def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/mongodb.py b/metadata-ingestion/src/datahub/ingestion/source/mongodb.py index b98d9d26ca47b8..50d6360a46444a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/mongodb.py +++ b/metadata-ingestion/src/datahub/ingestion/source/mongodb.py @@ -6,7 +6,7 @@ import bson.timestamp import pymongo.collection from packaging import version -from pydantic import PositiveInt, validator +from pydantic import PositiveInt, field_validator from pydantic.fields import Field from pymongo.mongo_client import MongoClient @@ -138,7 +138,8 @@ class MongoDBConfig( # Custom Stateful Ingestion settings stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = None - @validator("maxDocumentSize") + @field_validator("maxDocumentSize", mode="after") + @classmethod def check_max_doc_size_filter_is_valid(cls, doc_size_filter_value): if doc_size_filter_value > 16793600: raise ValueError("maxDocumentSize must be a positive value <= 16793600.") @@ -311,7 +312,7 @@ def __init__(self, ctx: PipelineContext, config: MongoDBConfig): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "MongoDBSource": - config = MongoDBConfig.parse_obj(config_dict) + config = MongoDBConfig.model_validate(config_dict) return cls(ctx, config) def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/neo4j/neo4j_source.py b/metadata-ingestion/src/datahub/ingestion/source/neo4j/neo4j_source.py index 0c95ae56c30906..990f9048009176 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/neo4j/neo4j_source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/neo4j/neo4j_source.py @@ -78,7 +78,7 @@ def __init__(self, config: Neo4jConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: Dict, ctx: PipelineContext) -> "Neo4jSource": - config = Neo4jConfig.parse_obj(config_dict) + config = Neo4jConfig.model_validate(config_dict) return cls(config, ctx) def create_schema_field_tuple( diff --git a/metadata-ingestion/src/datahub/ingestion/source/nifi.py b/metadata-ingestion/src/datahub/ingestion/source/nifi.py index 446253a818e330..5cf0fa1a3e3212 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/nifi.py +++ b/metadata-ingestion/src/datahub/ingestion/source/nifi.py @@ -13,7 +13,7 @@ from cached_property import cached_property from dateutil import parser from packaging import version -from pydantic import root_validator, validator +from pydantic import field_validator, model_validator from pydantic.fields import Field from requests import Response from requests.adapters import HTTPAdapter @@ -165,39 +165,33 @@ class NifiSourceConfig(StatefulIngestionConfigBase, EnvConfigMixin): " When disabled, re-states lineage on each run.", ) - @root_validator(skip_on_failure=True) - def validate_auth_params(cls, values): - if values.get("auth") is NifiAuthType.CLIENT_CERT and not values.get( - "client_cert_file" - ): + @model_validator(mode="after") + def validate_auth_params(self) -> "NifiSourceConfig": + if self.auth is NifiAuthType.CLIENT_CERT and not self.client_cert_file: raise ValueError( "Config `client_cert_file` is required for CLIENT_CERT auth" ) - elif values.get("auth") in ( + elif self.auth in ( NifiAuthType.SINGLE_USER, NifiAuthType.BASIC_AUTH, - ) and (not values.get("username") or not values.get("password")): + ) and (not self.username or not self.password): raise ValueError( - f"Config `username` and `password` is required for {values.get('auth').value} auth" + f"Config `username` and `password` is required for {self.auth.value} auth" ) - return values - - @root_validator(skip_on_failure=True) - def validator_site_url_to_site_name(cls, values): - site_url_to_site_name = values.get("site_url_to_site_name") - site_url = values.get("site_url") - site_name = values.get("site_name") + return self - if site_url_to_site_name is None: - site_url_to_site_name = {} - values["site_url_to_site_name"] = site_url_to_site_name + @model_validator(mode="after") + def validator_site_url_to_site_name(self) -> "NifiSourceConfig": + if self.site_url_to_site_name is None: + self.site_url_to_site_name = {} - if site_url not in site_url_to_site_name: - site_url_to_site_name[site_url] = site_name + if self.site_url not in self.site_url_to_site_name: + self.site_url_to_site_name[self.site_url] = self.site_name - return values + return self - @validator("site_url") + @field_validator("site_url", mode="after") + @classmethod def validator_site_url(cls, site_url: str) -> str: assert site_url.startswith(("http://", "https://")), ( "site_url must start with http:// or https://" diff --git a/metadata-ingestion/src/datahub/ingestion/source/openapi.py b/metadata-ingestion/src/datahub/ingestion/source/openapi.py index 8b1ee66162671b..89c06d28e167b3 100755 --- a/metadata-ingestion/src/datahub/ingestion/source/openapi.py +++ b/metadata-ingestion/src/datahub/ingestion/source/openapi.py @@ -4,7 +4,7 @@ from abc import ABC from typing import Dict, Iterable, List, Optional, Tuple -from pydantic import validator +from pydantic import model_validator from pydantic.fields import Field from datahub.configuration.common import ConfigModel @@ -86,13 +86,11 @@ class OpenApiConfig(ConfigModel): default=True, description="Enable SSL certificate verification" ) - @validator("bearer_token", always=True) - def ensure_only_one_token( - cls, bearer_token: Optional[str], values: Dict - ) -> Optional[str]: - if bearer_token is not None and values.get("token") is not None: + @model_validator(mode="after") + def ensure_only_one_token(self) -> "OpenApiConfig": + if self.bearer_token is not None and self.token is not None: raise ValueError("Unable to use 'token' and 'bearer_token' together.") - return bearer_token + return self def get_swagger(self) -> Dict: if self.get_token or self.token or self.bearer_token is not None: @@ -463,5 +461,5 @@ def __init__(self, config: OpenApiConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict, ctx): - config = OpenApiConfig.parse_obj(config_dict) + config = OpenApiConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/config.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/config.py index 9eb07b930bfa70..e59842b63f286a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/config.py @@ -4,7 +4,7 @@ from typing import Dict, List, Literal, Optional, Union import pydantic -from pydantic import root_validator, validator +from pydantic import field_validator, model_validator import datahub.emitter.mce_builder as builder from datahub.configuration.common import AllowDenyPattern, ConfigModel, HiddenFromDocs @@ -540,8 +540,8 @@ class PowerBiDashboardSourceConfig( description="timeout in seconds for Metadata Rest Api.", ) - @root_validator(skip_on_failure=True) - def validate_extract_column_level_lineage(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_extract_column_level_lineage(self) -> "PowerBiDashboardSourceConfig": flags = [ "native_query_parsing", "enable_advance_lineage_sql_construct", @@ -549,26 +549,23 @@ def validate_extract_column_level_lineage(cls, values: Dict) -> Dict: "extract_dataset_schema", ] - if ( - "extract_column_level_lineage" in values - and values["extract_column_level_lineage"] is False - ): + if self.extract_column_level_lineage is False: # Flag is not set. skip validation - return values + return self logger.debug(f"Validating additional flags: {flags}") is_flag_enabled: bool = True for flag in flags: - if flag not in values or values[flag] is False: + if not getattr(self, flag, True): is_flag_enabled = False if not is_flag_enabled: raise ValueError(f"Enable all these flags in recipe: {flags} ") - return values + return self - @validator("dataset_type_mapping") + @field_validator("dataset_type_mapping", mode="after") @classmethod def map_data_platform(cls, value): # For backward compatibility convert input PostgreSql to PostgreSQL @@ -580,28 +577,32 @@ def map_data_platform(cls, value): return value - @root_validator(skip_on_failure=True) - def workspace_id_backward_compatibility(cls, values: Dict) -> Dict: - workspace_id = values.get("workspace_id") - workspace_id_pattern = values.get("workspace_id_pattern") - - if workspace_id_pattern == AllowDenyPattern.allow_all() and workspace_id: + @model_validator(mode="after") + def workspace_id_backward_compatibility(self) -> "PowerBiDashboardSourceConfig": + if ( + self.workspace_id_pattern == AllowDenyPattern.allow_all() + and self.workspace_id + ): logger.warning( "workspace_id_pattern is not set but workspace_id is set, setting workspace_id as " "workspace_id_pattern. workspace_id will be deprecated, please use workspace_id_pattern instead." ) - values["workspace_id_pattern"] = AllowDenyPattern( - allow=[f"^{workspace_id}$"] + self.workspace_id_pattern = AllowDenyPattern( + allow=[f"^{self.workspace_id}$"] ) - elif workspace_id_pattern != AllowDenyPattern.allow_all() and workspace_id: + elif ( + self.workspace_id_pattern != AllowDenyPattern.allow_all() + and self.workspace_id + ): logger.warning( "workspace_id will be ignored in favour of workspace_id_pattern. workspace_id will be deprecated, " "please use workspace_id_pattern only." ) - values.pop("workspace_id") - return values + self.workspace_id = None + return self - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def raise_error_for_dataset_type_mapping(cls, values: Dict) -> Dict: if ( values.get("dataset_type_mapping") is not None @@ -613,18 +614,18 @@ def raise_error_for_dataset_type_mapping(cls, values: Dict) -> Dict: return values - @root_validator(skip_on_failure=True) - def validate_extract_dataset_schema(cls, values: Dict) -> Dict: - if values.get("extract_dataset_schema") is False: + @model_validator(mode="after") + def validate_extract_dataset_schema(self) -> "PowerBiDashboardSourceConfig": + if self.extract_dataset_schema is False: add_global_warning( "Please use `extract_dataset_schema: true`, otherwise dataset schema extraction will be skipped." ) - return values + return self - @root_validator(skip_on_failure=True) - def validate_dsn_to_database_schema(cls, values: Dict) -> Dict: - if values.get("dsn_to_database_schema") is not None: - dsn_mapping = values.get("dsn_to_database_schema") + @model_validator(mode="after") + def validate_dsn_to_database_schema(self) -> "PowerBiDashboardSourceConfig": + if self.dsn_to_database_schema is not None: + dsn_mapping = self.dsn_to_database_schema if not isinstance(dsn_mapping, dict): raise ValueError("dsn_to_database_schema must contain key-value pairs") @@ -639,4 +640,4 @@ def validate_dsn_to_database_schema(cls, values: Dict) -> Dict: f"dsn_to_database_schema invalid mapping value: {value}" ) - return values + return self diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/dataplatform_instance_resolver.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/dataplatform_instance_resolver.py index 6d51e853a2fb06..0e735806504e3a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/dataplatform_instance_resolver.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/dataplatform_instance_resolver.py @@ -41,7 +41,7 @@ def get_platform_instance( if isinstance(platform, PlatformDetail): return platform - return PlatformDetail.parse_obj({}) + return PlatformDetail.model_validate({}) class ResolvePlatformInstanceFromServerToPlatformInstance( @@ -56,7 +56,7 @@ def get_platform_instance( ] if data_platform_detail.data_platform_server in self.config.server_to_platform_instance - else PlatformDetail.parse_obj({}) + else PlatformDetail.model_validate({}) ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py index 9ba30c5d191a17..c3b9b98b004896 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi/powerbi.py @@ -1316,7 +1316,7 @@ def test_connection(config_dict: dict) -> TestConnectionReport: @classmethod def create(cls, config_dict, ctx): - config = PowerBiDashboardSourceConfig.parse_obj(config_dict) + config = PowerBiDashboardSourceConfig.model_validate(config_dict) return cls(config, ctx) def get_allowed_workspaces(self) -> List[powerbi_data_classes.Workspace]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server.py index 224cc70083faf0..bc65321fb0c346 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server.py @@ -213,7 +213,7 @@ def get_all_reports(self) -> List[Any]: if response_dict.get("value"): reports.extend( - report_types_mapping[report_type].parse_obj(report) + report_types_mapping[report_type].model_validate(report) for report in response_dict.get("value") ) @@ -517,7 +517,7 @@ def __init__( @classmethod def create(cls, config_dict, ctx): - config = PowerBiReportServerDashboardSourceConfig.parse_obj(config_dict) + config = PowerBiReportServerDashboardSourceConfig.model_validate(config_dict) return cls(config, ctx) def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server_domain.py b/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server_domain.py index f7260a194f8360..9ab5f3a2ba41a0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server_domain.py +++ b/metadata-ingestion/src/datahub/ingestion/source/powerbi_report_server/report_server_domain.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, model_validator from datahub.ingestion.source.powerbi_report_server.constants import ( RelationshipDirection, @@ -30,11 +30,13 @@ class CatalogItem(BaseModel): has_data_sources: bool = Field(False, alias="HasDataSources") data_sources: Optional[List["DataSource"]] = Field(None, alias="DataSources") - @validator("display_name", always=True) - def validate_diplay_name(cls, value, values): - if values["created_by"]: - return values["created_by"].split("\\")[-1] - return "" + @model_validator(mode="after") + def validate_diplay_name(self): + if self.created_by: + self.display_name = self.created_by.split("\\")[-1] + else: + self.display_name = "" + return self def get_urn_part(self): return f"reports.{self.id}" diff --git a/metadata-ingestion/src/datahub/ingestion/source/preset.py b/metadata-ingestion/src/datahub/ingestion/source/preset.py index a4e73932a99a06..3738e050b8dbb4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/preset.py +++ b/metadata-ingestion/src/datahub/ingestion/source/preset.py @@ -2,7 +2,7 @@ from typing import Dict, Optional import requests -from pydantic import root_validator, validator +from pydantic import field_validator, model_validator from pydantic.fields import Field from datahub.emitter.mce_builder import DEFAULT_ENV @@ -55,16 +55,16 @@ class PresetConfig(SupersetConfig): description="Can be used to change mapping for database names in superset to what you have in datahub", ) - @validator("connect_uri", "display_uri") + @field_validator("connect_uri", "display_uri", mode="after") + @classmethod def remove_trailing_slash(cls, v): return config_clean.remove_trailing_slashes(v) - @root_validator(skip_on_failure=True) - def default_display_uri_to_connect_uri(cls, values): - base = values.get("display_uri") - if base is None: - values["display_uri"] = values.get("connect_uri") - return values + @model_validator(mode="after") + def default_display_uri_to_connect_uri(self) -> "PresetConfig": + if self.display_uri is None: + self.display_uri = self.connect_uri + return self @platform_name("Preset") diff --git a/metadata-ingestion/src/datahub/ingestion/source/pulsar.py b/metadata-ingestion/src/datahub/ingestion/source/pulsar.py index e9ef6fa8fa1af2..0cd5d3c36b243e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/pulsar.py +++ b/metadata-ingestion/src/datahub/ingestion/source/pulsar.py @@ -235,7 +235,7 @@ def _get_pulsar_metadata(self, url): @classmethod def create(cls, config_dict, ctx): - config = PulsarSourceConfig.parse_obj(config_dict) + config = PulsarSourceConfig.model_validate(config_dict) # Do not include each individual partition for partitioned topics, if config.exclude_individual_partitions: diff --git a/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/data_classes.py b/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/data_classes.py index ad48f446bbef7c..8865a758ab8052 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/data_classes.py +++ b/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/data_classes.py @@ -3,7 +3,7 @@ from enum import Enum from typing import Dict, List, Optional, Type, Union -from pydantic import BaseModel, ConfigDict, Field, root_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from datahub.emitter.mcp_builder import ContainerKey from datahub.ingestion.source.qlik_sense.config import QLIK_DATETIME_FORMAT, Constant @@ -92,7 +92,8 @@ class Space(_QlikBaseModel): updatedAt: datetime ownerId: Optional[str] = None - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def update_values(cls, values: Dict) -> Dict: # Create a copy to avoid modifying the input dictionary, preventing state contamination in tests values = deepcopy(values) @@ -121,7 +122,8 @@ class SchemaField(_QlikBaseModel): primaryKey: Optional[bool] = None nullable: Optional[bool] = None - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def update_values(cls, values: Dict) -> Dict: # Create a copy to avoid modifying the input dictionary, preventing state contamination in tests values = deepcopy(values) @@ -138,7 +140,8 @@ class QlikDataset(Item): itemId: str datasetSchema: List[SchemaField] - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def update_values(cls, values: Dict) -> Dict: # Create a copy to avoid modifying the input dictionary, preventing state contamination in tests values = deepcopy(values) @@ -174,7 +177,8 @@ class Chart(_QlikBaseModel): qDimension: List[AxisProperty] qMeasure: List[AxisProperty] - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def update_values(cls, values: Dict) -> Dict: # Create a copy to avoid modifying the input dictionary, preventing state contamination in tests values = deepcopy(values) @@ -193,7 +197,8 @@ class Sheet(_QlikBaseModel): updatedAt: datetime charts: List[Chart] = [] - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def update_values(cls, values: Dict) -> Dict: # Create a copy to avoid modifying the input dictionary, preventing state contamination in tests values = deepcopy(values) @@ -220,7 +225,8 @@ class QlikTable(_QlikBaseModel): databaseName: Optional[str] = None schemaName: Optional[str] = None - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def update_values(cls, values: Dict) -> Dict: # Create a copy to avoid modifying the input dictionary, preventing state contamination in tests values = deepcopy(values) @@ -239,7 +245,8 @@ class App(Item): sheets: List[Sheet] = [] tables: List[QlikTable] = [] - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def update_values(cls, values: Dict) -> Dict: # Create a copy to avoid modifying the input dictionary, preventing state contamination in tests values = deepcopy(values) diff --git a/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/qlik_api.py b/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/qlik_api.py index 10b062c98c147f..3e9842a1ae8bf5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/qlik_api.py +++ b/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/qlik_api.py @@ -56,7 +56,7 @@ def get_spaces(self) -> List[Space]: response.raise_for_status() response_dict = response.json() for space_dict in response_dict[Constant.DATA]: - space = Space.parse_obj(space_dict) + space = Space.model_validate(space_dict) spaces.append(space) self.spaces[space.id] = space.name if Constant.NEXT in response_dict[Constant.LINKS]: @@ -64,7 +64,7 @@ def get_spaces(self) -> List[Space]: else: break # Add personal space entity - spaces.append(Space.parse_obj(PERSONAL_SPACE_DICT)) + spaces.append(Space.model_validate(PERSONAL_SPACE_DICT)) self.spaces[PERSONAL_SPACE_DICT[Constant.ID]] = PERSONAL_SPACE_DICT[ Constant.NAME ] @@ -78,7 +78,7 @@ def _get_dataset(self, dataset_id: str, item_id: str) -> Optional[QlikDataset]: response.raise_for_status() response_dict = response.json() response_dict[Constant.ITEMID] = item_id - return QlikDataset.parse_obj(response_dict) + return QlikDataset.model_validate(response_dict) except Exception as e: self._log_http_error( message=f"Unable to fetch dataset with id {dataset_id}. Exception: {e}" @@ -119,7 +119,7 @@ def _get_chart( f"Chart with id {chart_id} of sheet {sheet_id} does not have hypercube. q_layout: {q_layout}" ) return None - return Chart.parse_obj(q_layout) + return Chart.model_validate(q_layout) except Exception as e: self._log_http_error( message=f"Unable to fetch chart {chart_id} of sheet {sheet_id}. Exception: {e}" @@ -140,7 +140,7 @@ def _get_sheet( if Constant.OWNERID not in sheet_dict[Constant.QMETA]: # That means sheet is private sheet return None - sheet = Sheet.parse_obj(sheet_dict[Constant.QMETA]) + sheet = Sheet.model_validate(sheet_dict[Constant.QMETA]) if Constant.QCHILDLIST not in sheet_dict: logger.warning( f"Sheet {sheet.title} with id {sheet_id} does not have any charts. sheet_dict: {sheet_dict}" @@ -222,7 +222,7 @@ def _get_app_used_tables( return [] response = websocket_connection.websocket_send_request(method="GetLayout") for table_dict in response[Constant.QLAYOUT][Constant.TABLES]: - tables.append(QlikTable.parse_obj(table_dict)) + tables.append(QlikTable.model_validate(table_dict)) websocket_connection.handle.pop() self._add_qri_of_tables(tables, app_id) except Exception as e: @@ -270,7 +270,7 @@ def _get_app(self, app_id: str) -> Optional[App]: response = websocket_connection.websocket_send_request( method="GetAppLayout" ) - app = App.parse_obj(response[Constant.QLAYOUT]) + app = App.model_validate(response[Constant.QLAYOUT]) app.sheets = self._get_app_sheets(websocket_connection, app_id) app.tables = self._get_app_used_tables(websocket_connection, app_id) websocket_connection.close_websocket() diff --git a/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/qlik_sense.py b/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/qlik_sense.py index 7c9b07793032d3..97374567c04164 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/qlik_sense.py +++ b/metadata-ingestion/src/datahub/ingestion/source/qlik_sense/qlik_sense.py @@ -148,7 +148,7 @@ def test_connection(config_dict: dict) -> TestConnectionReport: @classmethod def create(cls, config_dict, ctx): - config = QlikSourceConfig.parse_obj(config_dict) + config = QlikSourceConfig.model_validate(config_dict) return cls(config, ctx) def _gen_space_key(self, space_id: str) -> SpaceKey: diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py index 41e56017b752e8..44ed9c2e6bfdb9 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/config.py @@ -3,7 +3,7 @@ from enum import Enum from typing import Any, Dict, List, Optional -from pydantic import root_validator +from pydantic import model_validator from pydantic.fields import Field from datahub.configuration import ConfigModel @@ -182,7 +182,8 @@ class RedshiftConfig( description="Whether to skip EXTERNAL tables.", ) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def check_email_is_set_on_usage(cls, values): if values.get("include_usage_statistics"): assert "email_domain" in values and values["email_domain"], ( @@ -190,31 +191,28 @@ def check_email_is_set_on_usage(cls, values): ) return values - @root_validator(skip_on_failure=True) - def check_database_is_set(cls, values): - assert values.get("database"), "database must be set" - return values - - @root_validator(skip_on_failure=True) - def backward_compatibility_configs_set(cls, values: Dict) -> Dict: - match_fully_qualified_names = values.get("match_fully_qualified_names") - - schema_pattern: Optional[AllowDenyPattern] = values.get("schema_pattern") + @model_validator(mode="after") + def check_database_is_set(self) -> "RedshiftConfig": + assert self.database, "database must be set" + return self + @model_validator(mode="after") + def backward_compatibility_configs_set(self) -> "RedshiftConfig": if ( - schema_pattern is not None - and schema_pattern != AllowDenyPattern.allow_all() - and match_fully_qualified_names is not None - and not match_fully_qualified_names + self.schema_pattern is not None + and self.schema_pattern != AllowDenyPattern.allow_all() + and self.match_fully_qualified_names is not None + and not self.match_fully_qualified_names ): logger.warning( "Please update `schema_pattern` to match against fully qualified schema name `.` and set config `match_fully_qualified_names : True`." "Current default `match_fully_qualified_names: False` is only to maintain backward compatibility. " "The config option `match_fully_qualified_names` will be deprecated in future and the default behavior will assume `match_fully_qualified_names: True`." ) - return values + return self - @root_validator(skip_on_failure=True) + @model_validator(mode="before") + @classmethod def connection_config_compatibility_set(cls, values: Dict) -> Dict: # Create a copy to avoid modifying the input dictionary, preventing state contamination in tests values = deepcopy(values) @@ -231,8 +229,8 @@ def connection_config_compatibility_set(cls, values: Dict) -> Dict: if "options" in values and "connect_args" in values["options"]: values["extra_client_options"] = values["options"]["connect_args"] - if values["extra_client_options"]: - if values["options"]: + if values.get("extra_client_options"): + if values.get("options"): values["options"]["connect_args"] = values["extra_client_options"] else: values["options"] = {"connect_args": values["extra_client_options"]} diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py index 7c0f6a39e927ed..3a4ec6470b90c5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py @@ -236,7 +236,7 @@ def test_connection(config_dict: dict) -> TestConnectionReport: RedshiftConfig.Config.extra = ( pydantic.Extra.allow ) # we are okay with extra fields during this stage - config = RedshiftConfig.parse_obj(config_dict) + config = RedshiftConfig.model_validate(config_dict) # source = RedshiftSource(config, report) connection: redshift_connector.Connection = ( RedshiftSource.get_redshift_connection(config) @@ -316,7 +316,7 @@ def __init__(self, config: RedshiftConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict, ctx): - config = RedshiftConfig.parse_obj(config_dict) + config = RedshiftConfig.model_validate(config_dict) return cls(config, ctx) @staticmethod diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/usage.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/usage.py index 5900276ed59cfe..963e4f276813b0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/usage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/usage.py @@ -1,12 +1,12 @@ import collections import logging import time -from datetime import datetime +from datetime import datetime, timezone from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import cachetools -import pydantic.error_wrappers import redshift_connector +from pydantic import ValidationError, field_validator from pydantic.fields import Field from pydantic.main import BaseModel @@ -64,6 +64,26 @@ class RedshiftAccessEvent(BaseModel): starttime: datetime endtime: datetime + @field_validator("starttime", "endtime", mode="before") + @classmethod + def ensure_utc_datetime(cls, v): + """Ensure datetime fields are treated as UTC for consistency with Pydantic V1 behavior. + + Pydantic V2 assumes local timezone for naive datetime strings, whereas Pydantic V1 assumed UTC. + This validator restores V1 behavior to maintain timestamp consistency. + """ + if isinstance(v, str): + # Parse as naive datetime, then assume UTC (matching V1 behavior) + dt = datetime.fromisoformat(v) + if dt.tzinfo is None: + # Treat naive datetime as UTC (this was the V1 behavior) + dt = dt.replace(tzinfo=timezone.utc) + return dt + elif isinstance(v, datetime) and v.tzinfo is None: + # If we get a naive datetime object, assume UTC + return v.replace(tzinfo=timezone.utc) + return v + class RedshiftUsageExtractor: """ @@ -291,7 +311,7 @@ def _gen_access_events_from_history_query( else None ), ) - except pydantic.error_wrappers.ValidationError as e: + except ValidationError as e: logging.warning( f"Validation error on access event creation from row {row}. The error was: {e} Skipping ...." ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/s3/config.py b/metadata-ingestion/src/datahub/ingestion/source/s3/config.py index eac93c5059459f..4420a71d9427ad 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/s3/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/s3/config.py @@ -1,7 +1,7 @@ import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Union -import pydantic +from pydantic import ValidationInfo, field_validator, model_validator from pydantic.fields import Field from datahub.configuration.common import AllowDenyPattern @@ -12,7 +12,6 @@ from datahub.configuration.validate_field_rename import pydantic_renamed_field from datahub.ingestion.source.aws.aws_common import AwsConnectionConfig from datahub.ingestion.source.data_lake_common.config import PathSpecsConfigMixin -from datahub.ingestion.source.data_lake_common.path_spec import PathSpec from datahub.ingestion.source.s3.datalake_profiler_config import DataLakeProfilerConfig from datahub.ingestion.source.state.stale_entity_removal_handler import ( StatefulStaleMetadataRemovalConfig, @@ -117,69 +116,91 @@ def is_profiling_enabled(self) -> bool: self.profiling.operation_config ) - @pydantic.validator("path_specs", always=True) - def check_path_specs_and_infer_platform( - cls, path_specs: List[PathSpec], values: Dict - ) -> List[PathSpec]: + @field_validator("path_specs", mode="before") + @classmethod + def check_path_specs(cls, path_specs: Any, info: ValidationInfo) -> Any: if len(path_specs) == 0: raise ValueError("path_specs must not be empty") - # Check that all path specs have the same platform. - guessed_platforms = { - "s3" if path_spec.is_s3 else "file" for path_spec in path_specs - } - if len(guessed_platforms) > 1: - raise ValueError( - f"Cannot have multiple platforms in path_specs: {guessed_platforms}" - ) - guessed_platform = guessed_platforms.pop() - - # Ensure s3 configs aren't used for file sources. - if guessed_platform != "s3" and ( - values.get("use_s3_object_tags") or values.get("use_s3_bucket_tags") - ): - raise ValueError( - "Cannot grab s3 object/bucket tags when platform is not s3. Remove the flag or use s3." - ) - - # Infer platform if not specified. - if values.get("platform") and values["platform"] != guessed_platform: - raise ValueError( - f"All path_specs belong to {guessed_platform} platform, but platform is set to {values['platform']}" - ) - else: - logger.debug(f'Setting config "platform": {guessed_platform}') - values["platform"] = guessed_platform + # Basic validation - path specs consistency and S3 config validation is now handled in model_validator return path_specs - @pydantic.validator("platform", always=True) - def platform_valid(cls, platform: Any, values: dict) -> str: - inferred_platform = values.get("platform") # we may have inferred it above - platform = platform or inferred_platform - if not platform: - raise ValueError("platform must not be empty") - - if platform != "s3" and values.get("use_s3_bucket_tags"): - raise ValueError( - "Cannot grab s3 bucket tags when platform is not s3. Remove the flag or ingest from s3." - ) - if platform != "s3" and values.get("use_s3_object_tags"): - raise ValueError( - "Cannot grab s3 object tags when platform is not s3. Remove the flag or ingest from s3." - ) - if platform != "s3" and values.get("use_s3_content_type"): - raise ValueError( - "Cannot grab s3 object content type when platform is not s3. Remove the flag or ingest from s3." - ) - - return platform - - @pydantic.root_validator(skip_on_failure=True) - def ensure_profiling_pattern_is_passed_to_profiling( - cls, values: Dict[str, Any] - ) -> Dict[str, Any]: - profiling: Optional[DataLakeProfilerConfig] = values.get("profiling") + @model_validator(mode="after") + def ensure_profiling_pattern_is_passed_to_profiling(self) -> "DataLakeSourceConfig": + profiling = self.profiling if profiling is not None and profiling.enabled: - profiling._allow_deny_patterns = values["profile_patterns"] - return values + profiling._allow_deny_patterns = self.profile_patterns + return self + + @model_validator(mode="after") + def validate_platform_and_config_consistency(self) -> "DataLakeSourceConfig": + """Infer platform from path_specs and validate config consistency.""" + # Track whether platform was explicitly provided + platform_was_explicit = bool(self.platform) + + # Infer platform from path_specs if not explicitly set + if not self.platform and self.path_specs: + guessed_platforms = set() + for path_spec in self.path_specs: + if ( + hasattr(path_spec, "include") + and path_spec.include + and path_spec.include.startswith("s3://") + ): + guessed_platforms.add("s3") + else: + guessed_platforms.add("file") + + # Ensure all path specs belong to the same platform + if len(guessed_platforms) > 1: + raise ValueError( + f"Cannot have multiple platforms in path_specs: {guessed_platforms}" + ) + + if guessed_platforms: + guessed_platform = guessed_platforms.pop() + logger.debug(f"Inferred platform: {guessed_platform}") + self.platform = guessed_platform + else: + self.platform = "file" + elif not self.platform: + self.platform = "file" + + # Validate platform consistency only when platform was inferred (not explicitly set) + # This allows sources like GCS to set platform="gcs" with s3:// URIs for correct container subtypes + if not platform_was_explicit and self.platform and self.path_specs: + expected_platforms = set() + for path_spec in self.path_specs: + if ( + hasattr(path_spec, "include") + and path_spec.include + and path_spec.include.startswith("s3://") + ): + expected_platforms.add("s3") + else: + expected_platforms.add("file") + + if len(expected_platforms) == 1: + expected_platform = expected_platforms.pop() + if self.platform != expected_platform: + raise ValueError( + f"All path_specs belong to {expected_platform} platform, but platform was inferred as {self.platform}" + ) + + # Validate S3-specific configurations + if self.platform != "s3": + if self.use_s3_bucket_tags: + raise ValueError( + "Cannot grab s3 bucket tags when platform is not s3. Remove the flag or ingest from s3." + ) + if self.use_s3_object_tags: + raise ValueError( + "Cannot grab s3 object tags when platform is not s3. Remove the flag or ingest from s3." + ) + if self.use_s3_content_type: + raise ValueError( + "Cannot grab s3 object content type when platform is not s3. Remove the flag or ingest from s3." + ) + + return self diff --git a/metadata-ingestion/src/datahub/ingestion/source/s3/datalake_profiler_config.py b/metadata-ingestion/src/datahub/ingestion/source/s3/datalake_profiler_config.py index 58e930eb6e809c..b1f050b51d25c1 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/s3/datalake_profiler_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/s3/datalake_profiler_config.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, Optional +from typing import Optional import pydantic +from pydantic import model_validator from pydantic.fields import Field from datahub.configuration import ConfigModel @@ -72,21 +73,18 @@ class DataLakeProfilerConfig(ConfigModel): description="Whether to profile for the sample values for all columns.", ) - @pydantic.root_validator(skip_on_failure=True) - def ensure_field_level_settings_are_normalized( - cls: "DataLakeProfilerConfig", values: Dict[str, Any] - ) -> Dict[str, Any]: - max_num_fields_to_profile_key = "max_number_of_fields_to_profile" - max_num_fields_to_profile = values.get(max_num_fields_to_profile_key) + @model_validator(mode="after") + def ensure_field_level_settings_are_normalized(self) -> "DataLakeProfilerConfig": + max_num_fields_to_profile = self.max_number_of_fields_to_profile # Disable all field-level metrics. - if values.get("profile_table_level_only"): - for field_level_metric in cls.__fields__: - if field_level_metric.startswith("include_field_"): - values.setdefault(field_level_metric, False) + if self.profile_table_level_only: + for field_name in self.__fields__: + if field_name.startswith("include_field_"): + setattr(self, field_name, False) assert max_num_fields_to_profile is None, ( - f"{max_num_fields_to_profile_key} should be set to None" + "max_number_of_fields_to_profile should be set to None" ) - return values + return self diff --git a/metadata-ingestion/src/datahub/ingestion/source/s3/source.py b/metadata-ingestion/src/datahub/ingestion/source/s3/source.py index c5314d624b7286..ea55fa2e311cd9 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/s3/source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/s3/source.py @@ -53,8 +53,11 @@ from datahub.ingestion.source.data_lake_common.object_store import ( create_object_store_adapter, ) -from datahub.ingestion.source.data_lake_common.path_spec import FolderTraversalMethod -from datahub.ingestion.source.s3.config import DataLakeSourceConfig, PathSpec +from datahub.ingestion.source.data_lake_common.path_spec import ( + FolderTraversalMethod, + PathSpec, +) +from datahub.ingestion.source.s3.config import DataLakeSourceConfig from datahub.ingestion.source.s3.report import DataLakeSourceReport from datahub.ingestion.source.schema_inference import avro, csv_tsv, json, parquet from datahub.ingestion.source.schema_inference.base import SchemaInferenceBase @@ -261,7 +264,7 @@ def __init__(self, config: DataLakeSourceConfig, ctx: PipelineContext): ) config_report = { - config_option: config.dict().get(config_option) + config_option: config.model_dump().get(config_option) for config_option in config_options_to_report } config_report = { @@ -278,7 +281,7 @@ def __init__(self, config: DataLakeSourceConfig, ctx: PipelineContext): telemetry.telemetry_instance.ping( "data_lake_profiling_config", { - config_flag: config.profiling.dict().get(config_flag) + config_flag: config.profiling.model_dump().get(config_flag) for config_flag in profiling_flags_to_report }, ) @@ -370,7 +373,7 @@ def init_spark(self): @classmethod def create(cls, config_dict, ctx): - config = DataLakeSourceConfig.parse_obj(config_dict) + config = DataLakeSourceConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sac/sac.py b/metadata-ingestion/src/datahub/ingestion/source/sac/sac.py index efa422a8ffcf28..c45a75035916c6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sac/sac.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sac/sac.py @@ -8,7 +8,7 @@ import pyodata.v2.model import pyodata.v2.service from authlib.integrations.requests_client import OAuth2Session -from pydantic import Field, SecretStr, validator +from pydantic import Field, SecretStr, field_validator from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry @@ -159,7 +159,8 @@ class SACSourceConfig( description="Template for generating dataset urns of consumed queries, the placeholder {query} can be used within the template for inserting the name of the query", ) - @validator("tenant_url", "token_url") + @field_validator("tenant_url", "token_url", mode="after") + @classmethod def remove_trailing_slash(cls, v): return config_clean.remove_trailing_slashes(v) @@ -209,7 +210,7 @@ def close(self) -> None: @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "SACSource": - config = SACSourceConfig.parse_obj(config_dict) + config = SACSourceConfig.model_validate(config_dict) return cls(config, ctx) @staticmethod @@ -217,7 +218,7 @@ def test_connection(config_dict: dict) -> TestConnectionReport: test_report = TestConnectionReport() try: - config = SACSourceConfig.parse_obj(config_dict) + config = SACSourceConfig.model_validate(config_dict) # when creating the pyodata.Client, the metadata is automatically parsed and validated session, _ = SACSource.get_sac_connection(config) diff --git a/metadata-ingestion/src/datahub/ingestion/source/salesforce.py b/metadata-ingestion/src/datahub/ingestion/source/salesforce.py index 0d3bef914c562e..70fabde774ac63 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/salesforce.py +++ b/metadata-ingestion/src/datahub/ingestion/source/salesforce.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Iterable, List, Literal, Optional, TypedDict import requests -from pydantic import Field, validator +from pydantic import Field, field_validator from simple_salesforce import Salesforce from simple_salesforce.exceptions import SalesforceAuthenticationFailed @@ -172,7 +172,8 @@ def is_profiling_enabled(self) -> bool: self.profiling.operation_config ) - @validator("instance_url") + @field_validator("instance_url", mode="after") + @classmethod def remove_trailing_slash(cls, v): return config_clean.remove_trailing_slashes(v) diff --git a/metadata-ingestion/src/datahub/ingestion/source/schema/json_schema.py b/metadata-ingestion/src/datahub/ingestion/source/schema/json_schema.py index 300d914eb048c7..711919263c207c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/schema/json_schema.py +++ b/metadata-ingestion/src/datahub/ingestion/source/schema/json_schema.py @@ -12,7 +12,7 @@ import jsonref import requests -from pydantic import AnyHttpUrl, DirectoryPath, FilePath, validator +from pydantic import AnyHttpUrl, DirectoryPath, FilePath, field_validator from pydantic.fields import Field import datahub.metadata.schema_classes as models @@ -90,7 +90,7 @@ class JsonSchemaSourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMix description="Use this if URI-s need to be modified during reference resolution. Simple string match - replace capabilities are supported.", ) - @validator("path") + @field_validator("path", mode="after") def download_http_url_to_temp_file(cls, v): if isinstance(v, AnyHttpUrl): try: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sigma/data_classes.py b/metadata-ingestion/src/datahub/ingestion/source/sigma/data_classes.py index 01a57b165c007e..8e44567e87e782 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sigma/data_classes.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sigma/data_classes.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Dict, List, Optional -from pydantic import BaseModel, root_validator +from pydantic import BaseModel, model_validator from datahub.emitter.mcp_builder import ContainerKey @@ -22,7 +22,8 @@ class Workspace(BaseModel): createdAt: datetime updatedAt: datetime - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def update_values(cls, values: Dict) -> Dict: # Create a copy to avoid modifying the input dictionary, preventing state contamination in tests values = deepcopy(values) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sigma/sigma.py b/metadata-ingestion/src/datahub/ingestion/source/sigma/sigma.py index e24ea23515c358..1d9b9cff0276e3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sigma/sigma.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sigma/sigma.py @@ -150,7 +150,7 @@ def test_connection(config_dict: dict) -> TestConnectionReport: @classmethod def create(cls, config_dict, ctx): - config = SigmaSourceConfig.parse_obj(config_dict) + config = SigmaSourceConfig.model_validate(config_dict) return cls(config, ctx) def _gen_workbook_key(self, workbook_id: str) -> WorkbookKey: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sigma/sigma_api.py b/metadata-ingestion/src/datahub/ingestion/source/sigma/sigma_api.py index 935ef5fddafb1a..c08bc9202fee85 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sigma/sigma_api.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sigma/sigma_api.py @@ -108,7 +108,7 @@ def get_workspace(self, workspace_id: str) -> Optional[Workspace]: self.report.non_accessible_workspaces_count += 1 return None response.raise_for_status() - workspace = Workspace.parse_obj(response.json()) + workspace = Workspace.model_validate(response.json()) self.workspaces[workspace.workspaceId] = workspace return workspace except Exception as e: @@ -127,7 +127,7 @@ def fill_workspaces(self) -> None: response_dict = response.json() for workspace_dict in response_dict[Constant.ENTRIES]: self.workspaces[workspace_dict[Constant.WORKSPACEID]] = ( - Workspace.parse_obj(workspace_dict) + Workspace.model_validate(workspace_dict) ) if response_dict[Constant.NEXTPAGE]: url = f"{workspace_url}&page={response_dict[Constant.NEXTPAGE]}" @@ -197,7 +197,7 @@ def _get_files_metadata(self, file_type: str) -> Dict[str, File]: response.raise_for_status() response_dict = response.json() for file_dict in response_dict[Constant.ENTRIES]: - file = File.parse_obj(file_dict) + file = File.model_validate(file_dict) file.workspaceId = self.get_workspace_id_from_file_path( file.parentId, file.path ) @@ -225,7 +225,7 @@ def get_sigma_datasets(self) -> List[SigmaDataset]: response.raise_for_status() response_dict = response.json() for dataset_dict in response_dict[Constant.ENTRIES]: - dataset = SigmaDataset.parse_obj(dataset_dict) + dataset = SigmaDataset.model_validate(dataset_dict) if dataset.datasetId not in dataset_files_metadata: self.report.datasets.dropped( @@ -354,7 +354,7 @@ def get_page_elements(self, workbook: Workbook, page: Page) -> List[Element]: element_dict[Constant.URL] = ( f"{workbook.url}?:nodeId={element_dict[Constant.ELEMENTID]}&:fullScreen=true" ) - element = Element.parse_obj(element_dict) + element = Element.model_validate(element_dict) if ( self.config.extract_lineage and self.config.workbook_lineage_pattern.allowed(workbook.name) @@ -379,7 +379,7 @@ def get_workbook_pages(self, workbook: Workbook) -> List[Page]: ) response.raise_for_status() for page_dict in response.json()[Constant.ENTRIES]: - page = Page.parse_obj(page_dict) + page = Page.model_validate(page_dict) page.elements = self.get_page_elements(workbook, page) pages.append(page) return pages @@ -400,7 +400,7 @@ def get_sigma_workbooks(self) -> List[Workbook]: response.raise_for_status() response_dict = response.json() for workbook_dict in response_dict[Constant.ENTRIES]: - workbook = Workbook.parse_obj(workbook_dict) + workbook = Workbook.model_validate(workbook_dict) if workbook.workbookId not in workbook_files_metadata: # Due to a bug in the Sigma API, it seems like the /files endpoint does not diff --git a/metadata-ingestion/src/datahub/ingestion/source/slack/slack.py b/metadata-ingestion/src/datahub/ingestion/source/slack/slack.py index 24c591a890e828..4c013f873a10ea 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/slack/slack.py +++ b/metadata-ingestion/src/datahub/ingestion/source/slack/slack.py @@ -260,7 +260,7 @@ def __init__(self, ctx: PipelineContext, config: SlackSourceConfig): @classmethod def create(cls, config_dict, ctx): - config = SlackSourceConfig.parse_obj(config_dict) + config = SlackSourceConfig.model_validate(config_dict) return cls(ctx, config) def get_slack_client(self) -> WebClient: diff --git a/metadata-ingestion/src/datahub/ingestion/source/snaplogic/snaplogic.py b/metadata-ingestion/src/datahub/ingestion/source/snaplogic/snaplogic.py index 59166776eff09a..c8e1f5d6b841e1 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snaplogic/snaplogic.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snaplogic/snaplogic.py @@ -351,5 +351,5 @@ def close(self) -> None: @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "SnaplogicSource": - config = SnaplogicConfig.parse_obj(config_dict) + config = SnaplogicConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_assertion.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_assertion.py index a7c008d932a713..9b61bba73e0915 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_assertion.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_assertion.py @@ -91,7 +91,7 @@ def _process_result_row( self, result_row: dict, discovered_datasets: List[str] ) -> Optional[MetadataChangeProposalWrapper]: try: - result = DataQualityMonitoringResult.parse_obj(result_row) + result = DataQualityMonitoringResult.model_validate(result_row) assertion_guid = result.METRIC_NAME.split("__")[-1].lower() status = bool(result.VALUE) # 1 if PASS, 0 if FAIL assertee = self.identifiers.get_dataset_identifier( diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index 03ef7afbce713e..35c1f43b001494 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Set import pydantic -from pydantic import Field, root_validator, validator +from pydantic import Field, ValidationInfo, field_validator, model_validator from datahub.configuration.common import AllowDenyPattern, ConfigModel, HiddenFromDocs from datahub.configuration.pattern_utils import UUID_REGEX @@ -122,10 +122,10 @@ class SnowflakeFilterConfig(SQLFilterConfig): description="Whether `schema_pattern` is matched against fully qualified schema name `.`.", ) - @root_validator(pre=False, skip_on_failure=True) - def validate_legacy_schema_pattern(cls, values: Dict) -> Dict: - schema_pattern: Optional[AllowDenyPattern] = values.get("schema_pattern") - match_fully_qualified_names = values.get("match_fully_qualified_names") + @model_validator(mode="after") + def validate_legacy_schema_pattern(self) -> "SnowflakeFilterConfig": + schema_pattern: Optional[AllowDenyPattern] = self.schema_pattern + match_fully_qualified_names = self.match_fully_qualified_names if ( schema_pattern is not None @@ -145,7 +145,7 @@ def validate_legacy_schema_pattern(cls, values: Dict) -> Dict: assert isinstance(schema_pattern, AllowDenyPattern) schema_pattern.deny.append(r".*INFORMATION_SCHEMA$") - return values + return self class SnowflakeIdentifierConfig( @@ -391,7 +391,8 @@ class SnowflakeV2Config( "This may be required in the case of _eg_ temporary tables being created in a different database than the ones in the database_name patterns.", ) - @validator("convert_urns_to_lowercase") + @field_validator("convert_urns_to_lowercase", mode="after") + @classmethod def validate_convert_urns_to_lowercase(cls, v): if not v: add_global_warning( @@ -400,30 +401,31 @@ def validate_convert_urns_to_lowercase(cls, v): return v - @validator("include_column_lineage") - def validate_include_column_lineage(cls, v, values): - if not values.get("include_table_lineage") and v: + @field_validator("include_column_lineage", mode="after") + @classmethod + def validate_include_column_lineage(cls, v, info): + if not info.data.get("include_table_lineage") and v: raise ValueError( "include_table_lineage must be True for include_column_lineage to be set." ) return v - @root_validator(pre=False, skip_on_failure=True) - def validate_unsupported_configs(cls, values: Dict) -> Dict: - value = values.get("include_read_operational_stats") - if value is not None and value: + @model_validator(mode="after") + def validate_unsupported_configs(self) -> "SnowflakeV2Config": + if ( + hasattr(self, "include_read_operational_stats") + and self.include_read_operational_stats + ): raise ValueError( "include_read_operational_stats is not supported. Set `include_read_operational_stats` to False.", ) - include_technical_schema = values.get("include_technical_schema") - include_profiles = ( - values.get("profiling") is not None and values["profiling"].enabled - ) + include_technical_schema = self.include_technical_schema + include_profiles = self.profiling is not None and self.profiling.enabled delete_detection_enabled = ( - values.get("stateful_ingestion") is not None - and values["stateful_ingestion"].enabled - and values["stateful_ingestion"].remove_stale_metadata + self.stateful_ingestion is not None + and self.stateful_ingestion.enabled + and self.stateful_ingestion.remove_stale_metadata ) # TODO: Allow profiling irrespective of basic schema extraction, @@ -435,13 +437,14 @@ def validate_unsupported_configs(cls, values: Dict) -> Dict: "Cannot perform Deletion Detection or Profiling without extracting snowflake technical schema. Set `include_technical_schema` to True or disable Deletion Detection and Profiling." ) - return values + return self - @validator("shares") + @field_validator("shares", mode="after") + @classmethod def validate_shares( - cls, shares: Optional[Dict[str, SnowflakeShareConfig]], values: Dict + cls, shares: Optional[Dict[str, SnowflakeShareConfig]], info: ValidationInfo ) -> Optional[Dict[str, SnowflakeShareConfig]]: - current_platform_instance = values.get("platform_instance") + current_platform_instance = info.data.get("platform_instance") if shares: # Check: platform_instance should be present @@ -479,11 +482,12 @@ def validate_shares( return shares - @root_validator(pre=False, skip_on_failure=True) - def validate_queries_v2_stateful_ingestion(cls, values: Dict) -> Dict: - if values.get("use_queries_v2"): - if values.get("enable_stateful_lineage_ingestion") or values.get( - "enable_stateful_usage_ingestion" + @model_validator(mode="after") + def validate_queries_v2_stateful_ingestion(self) -> "SnowflakeV2Config": + if self.use_queries_v2: + if ( + self.enable_stateful_lineage_ingestion + or self.enable_stateful_usage_ingestion ): logger.warning( "enable_stateful_lineage_ingestion and enable_stateful_usage_ingestion are deprecated " @@ -491,7 +495,7 @@ def validate_queries_v2_stateful_ingestion(cls, values: Dict) -> Dict: "For queries v2, use enable_stateful_time_window instead to enable stateful ingestion " "for the unified time window extraction (lineage + usage + operations + queries)." ) - return values + return self def outbounds(self) -> Dict[str, Set[DatabaseId]]: """ diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_connection.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_connection.py index a4cc15dd34583e..bfaf3e9ddecce5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_connection.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_connection.py @@ -6,6 +6,7 @@ import snowflake.connector from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization +from pydantic import field_validator, model_validator from snowflake.connector import SnowflakeConnection as NativeSnowflakeConnection from snowflake.connector.cursor import DictCursor from snowflake.connector.network import ( @@ -125,26 +126,28 @@ def get_account(self) -> str: rename_host_port_to_account_id = pydantic_renamed_field("host_port", "account_id") # type: ignore[pydantic-field] - @pydantic.validator("account_id") - def validate_account_id(cls, account_id: str, values: Dict) -> str: + @field_validator("account_id", mode="after") + @classmethod + def validate_account_id(cls, account_id: str, info: pydantic.ValidationInfo) -> str: account_id = remove_protocol(account_id) account_id = remove_trailing_slashes(account_id) # Get the domain from config, fallback to default - domain = values.get("snowflake_domain", DEFAULT_SNOWFLAKE_DOMAIN) + domain = info.data.get("snowflake_domain", DEFAULT_SNOWFLAKE_DOMAIN) snowflake_host_suffix = f".{domain}" account_id = remove_suffix(account_id, snowflake_host_suffix) return account_id - @pydantic.validator("authentication_type", always=True) - def authenticator_type_is_valid(cls, v, values): + @field_validator("authentication_type", mode="before") + @classmethod + def authenticator_type_is_valid(cls, v: Any, info: pydantic.ValidationInfo) -> Any: if v not in _VALID_AUTH_TYPES: raise ValueError( f"unsupported authenticator type '{v}' was provided," f" use one of {list(_VALID_AUTH_TYPES.keys())}" ) if ( - values.get("private_key") is not None - or values.get("private_key_path") is not None + info.data.get("private_key") is not None + or info.data.get("private_key_path") is not None ) and v != "KEY_PAIR_AUTHENTICATOR": raise ValueError( f"Either `private_key` and `private_key_path` is set but `authentication_type` is {v}. " @@ -153,21 +156,22 @@ def authenticator_type_is_valid(cls, v, values): if v == "KEY_PAIR_AUTHENTICATOR": # If we are using key pair auth, we need the private key path and password to be set if ( - values.get("private_key") is None - and values.get("private_key_path") is None + info.data.get("private_key") is None + and info.data.get("private_key_path") is None ): raise ValueError( f"Both `private_key` and `private_key_path` are none. " f"At least one should be set when using {v} authentication" ) elif v == "OAUTH_AUTHENTICATOR": - cls._check_oauth_config(values.get("oauth_config")) + cls._check_oauth_config(info.data.get("oauth_config")) logger.info(f"using authenticator type '{v}'") return v - @pydantic.validator("token", always=True) - def validate_token_oauth_config(cls, v, values): - auth_type = values.get("authentication_type") + @field_validator("token", mode="before") + @classmethod + def validate_token_oauth_config(cls, v: Any, info: pydantic.ValidationInfo) -> Any: + auth_type = info.data.get("authentication_type") if auth_type == "OAUTH_AUTHENTICATOR_TOKEN": if not v: raise ValueError("Token required for OAUTH_AUTHENTICATOR_TOKEN.") @@ -177,6 +181,24 @@ def validate_token_oauth_config(cls, v, values): ) return v + @model_validator(mode="after") + def validate_authentication_config(self): + """Validate authentication configuration consistency.""" + # Check token requirement for OAUTH_AUTHENTICATOR_TOKEN + if self.authentication_type == "OAUTH_AUTHENTICATOR_TOKEN": + if not self.token: + raise ValueError("Token required for OAUTH_AUTHENTICATOR_TOKEN.") + + # Check private key authentication consistency + if self.private_key is not None or self.private_key_path is not None: + if self.authentication_type != "KEY_PAIR_AUTHENTICATOR": + raise ValueError( + f"Either `private_key` and `private_key_path` is set but `authentication_type` is {self.authentication_type}. " + f"Should be set to 'KEY_PAIR_AUTHENTICATOR' when using key pair authentication" + ) + + return self + @staticmethod def _check_oauth_config(oauth_config: Optional[OAuthConfiguration]) -> None: if oauth_config is None: diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py index 651607d0ba6868..f7a8b22473a4be 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py @@ -14,7 +14,7 @@ Type, ) -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator from datahub.configuration.datetimes import parse_absolute_time from datahub.ingestion.api.closeable import Closeable @@ -70,7 +70,7 @@ def _parse_from_json(cls: Type, v: Any) -> dict: return json.loads(v) return v - return validator(field, pre=True, allow_reuse=True)(_parse_from_json) + return field_validator(field, mode="before")(_parse_from_json) class UpstreamColumnNode(BaseModel): @@ -379,7 +379,7 @@ def _process_upstream_lineage_row( # To avoid that causing a pydantic error we are setting it to an empty list # instead of a list with an empty object db_row["QUERIES"] = "[]" - return UpstreamLineageEdge.parse_obj(db_row) + return UpstreamLineageEdge.model_validate(db_row) except Exception as e: self.report.num_upstream_lineage_edge_parsing_failed += 1 upstream_tables = db_row.get("UPSTREAM_TABLES") diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index 76bcac49395db7..a7babecf048582 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -806,7 +806,7 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeQueriesSourceConfig): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> Self: - config = SnowflakeQueriesSourceConfig.parse_obj(config_dict) + config = SnowflakeQueriesSourceConfig.model_validate(config_dict) return cls(ctx, config) def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py index d10685ed7b2f7b..8423e8f843546a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py @@ -386,7 +386,7 @@ def __init__(self, config, ctx): @classmethod def create(cls, config_dict, ctx): - config = AthenaConfig.parse_obj(config_dict) + config = AthenaConfig.model_validate(config_dict) return cls(config, ctx) # overwrite this method to allow to specify the usage of a custom dialect diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/clickhouse.py b/metadata-ingestion/src/datahub/ingestion/source/sql/clickhouse.py index 07d84a9dc3b735..ff5f57fc3dc2f1 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/clickhouse.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/clickhouse.py @@ -10,6 +10,7 @@ import pydantic from clickhouse_sqlalchemy.drivers import base from clickhouse_sqlalchemy.drivers.base import ClickHouseDialect +from pydantic import model_validator from pydantic.fields import Field from sqlalchemy import create_engine, text from sqlalchemy.engine import reflection @@ -175,7 +176,8 @@ def get_sql_alchemy_url( return str(url) # pre = True because we want to take some decision before pydantic initialize the configuration to default values - @pydantic.root_validator(pre=True) + @model_validator(mode="before") + @classmethod def projects_backward_compatibility(cls, values: Dict) -> Dict: secure = values.get("secure") protocol = values.get("protocol") @@ -423,7 +425,7 @@ def __init__(self, config, ctx): @classmethod def create(cls, config_dict, ctx): - config = ClickHouseConfig.parse_obj(config_dict) + config = ClickHouseConfig.model_validate(config_dict) return cls(config, ctx) def get_workunits_internal(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/cockroachdb.py b/metadata-ingestion/src/datahub/ingestion/source/sql/cockroachdb.py index 1a12417e4995ac..7c5f6f237a0c05 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/cockroachdb.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/cockroachdb.py @@ -39,5 +39,5 @@ def get_platform(self): @classmethod def create(cls, config_dict, ctx): - config = CockroachDBConfig.parse_obj(config_dict) + config = CockroachDBConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/druid.py b/metadata-ingestion/src/datahub/ingestion/source/sql/druid.py index 281237c0535a90..a86639fd02f07d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/druid.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/druid.py @@ -77,5 +77,5 @@ def __init__(self, config, ctx): @classmethod def create(cls, config_dict, ctx): - config = DruidConfig.parse_obj(config_dict) + config = DruidConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/hana.py b/metadata-ingestion/src/datahub/ingestion/source/sql/hana.py index 9d7cc022ace6e8..4e0c87a9dd6ae1 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/hana.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/hana.py @@ -36,5 +36,5 @@ def __init__(self, config: HanaConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: Dict, ctx: PipelineContext) -> "HanaSource": - config = HanaConfig.parse_obj(config_dict) + config = HanaConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/hive.py b/metadata-ingestion/src/datahub/ingestion/source/sql/hive.py index 6a615c15058d71..cf1fed641509ce 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/hive.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/hive.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from urllib.parse import urlparse -from pydantic import validator +from pydantic import field_validator from pydantic.fields import Field # This import verifies that the dependencies are available. @@ -674,11 +674,13 @@ class HiveConfig(TwoTierSQLAlchemyConfig): description="Platform instance for the storage system", ) - @validator("host_port") - def clean_host_port(cls, v): + @field_validator("host_port", mode="after") + @classmethod + def clean_host_port(cls, v: str) -> str: return config_clean.remove_protocol(v) - @validator("hive_storage_lineage_direction") + @field_validator("hive_storage_lineage_direction", mode="after") + @classmethod def _validate_direction(cls, v: str) -> str: """Validate the lineage direction.""" if v.lower() not in ["upstream", "downstream"]: @@ -725,7 +727,7 @@ def __init__(self, config, ctx): @classmethod def create(cls, config_dict, ctx): - config = HiveConfig.parse_obj(config_dict) + config = HiveConfig.model_validate(config_dict) return cls(config, ctx) def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/hive_metastore.py b/metadata-ingestion/src/datahub/ingestion/source/sql/hive_metastore.py index d4d4e28546771f..dee453aef25379 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/hive_metastore.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/hive_metastore.py @@ -351,7 +351,7 @@ def get_db_name(self, inspector: Inspector) -> str: @classmethod def create(cls, config_dict, ctx): - config = HiveMetastore.parse_obj(config_dict) + config = HiveMetastore.model_validate(config_dict) return cls(config, ctx) def gen_database_containers( diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/mssql/source.py b/metadata-ingestion/src/datahub/ingestion/source/sql/mssql/source.py index 4913617febf53b..51d1a6997198ab 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/mssql/source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/mssql/source.py @@ -3,8 +3,8 @@ import urllib.parse from typing import Any, Dict, Iterable, List, Optional, Tuple -import pydantic import sqlalchemy.dialects.mssql +from pydantic import ValidationInfo, field_validator from pydantic.fields import Field from sqlalchemy import create_engine, inspect from sqlalchemy.engine.base import Connection @@ -140,11 +140,18 @@ class SQLServerConfig(BasicSQLAlchemyConfig): description="Indicates if the SQL Server instance is running on AWS RDS. When None (default), automatic detection will be attempted using server name analysis.", ) - @pydantic.validator("uri_args") - def passwords_match(cls, v, values, **kwargs): - if values["use_odbc"] and not values["sqlalchemy_uri"] and "driver" not in v: + @field_validator("uri_args", mode="after") + @classmethod + def passwords_match( + cls, v: Dict[str, Any], info: ValidationInfo, **kwargs: Any + ) -> Dict[str, Any]: + if ( + info.data["use_odbc"] + and not info.data["sqlalchemy_uri"] + and "driver" not in v + ): raise ValueError("uri_args must contain a 'driver' option") - elif not values["use_odbc"] and v: + elif not info.data["use_odbc"] and v: raise ValueError("uri_args is not supported when ODBC is disabled") return v @@ -314,7 +321,7 @@ def _populate_column_descriptions(self, conn: Connection, db_name: str) -> None: @classmethod def create(cls, config_dict: Dict, ctx: PipelineContext) -> "SQLServerSource": - config = SQLServerConfig.parse_obj(config_dict) + config = SQLServerConfig.model_validate(config_dict) return cls(config, ctx) # override to get table descriptions diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/mysql.py b/metadata-ingestion/src/datahub/ingestion/source/sql/mysql.py index 3b4aea1d40cc78..f9954cf06eb9fb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/mysql.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/mysql.py @@ -150,7 +150,7 @@ def get_platform(self): @classmethod def create(cls, config_dict, ctx): - config = MySQLConfig.parse_obj(config_dict) + config = MySQLConfig.model_validate(config_dict) return cls(config, ctx) def _setup_rds_iam_event_listener( diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/oracle.py b/metadata-ingestion/src/datahub/ingestion/source/sql/oracle.py index 455ac23fe04162..79f5acc3f883e0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/oracle.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/oracle.py @@ -10,8 +10,8 @@ from unittest.mock import patch import oracledb -import pydantic import sqlalchemy.engine +from pydantic import ValidationInfo, field_validator from pydantic.fields import Field from sqlalchemy import event, sql from sqlalchemy.dialects.oracle.base import ischema_names @@ -101,25 +101,32 @@ class OracleConfig(BasicSQLAlchemyConfig): "On Linux, this value is ignored, as ldconfig or LD_LIBRARY_PATH will define the location.", ) - @pydantic.validator("service_name") - def check_service_name(cls, v, values): - if values.get("database") and v: + @field_validator("service_name", mode="after") + @classmethod + def check_service_name( + cls, v: Optional[str], info: ValidationInfo + ) -> Optional[str]: + if info.data.get("database") and v: raise ValueError( "specify one of 'database' and 'service_name', but not both" ) return v - @pydantic.validator("data_dictionary_mode") - def check_data_dictionary_mode(cls, value): + @field_validator("data_dictionary_mode", mode="after") + @classmethod + def check_data_dictionary_mode(cls, value: str) -> str: if value not in ("ALL", "DBA"): raise ValueError("Specify one of data dictionary views mode: 'ALL', 'DBA'.") return value - @pydantic.validator("thick_mode_lib_dir", always=True) - def check_thick_mode_lib_dir(cls, v, values): + @field_validator("thick_mode_lib_dir", mode="before") + @classmethod + def check_thick_mode_lib_dir( + cls, v: Optional[str], info: ValidationInfo + ) -> Optional[str]: if ( v is None - and values.get("enable_thick_mode") + and info.data.get("enable_thick_mode") and (platform.system() == "Darwin" or platform.system() == "Windows") ): raise ValueError( @@ -659,7 +666,7 @@ def __init__(self, config, ctx): @classmethod def create(cls, config_dict, ctx): - config = OracleConfig.parse_obj(config_dict) + config = OracleConfig.model_validate(config_dict) return cls(config, ctx) def get_db_name(self, inspector: Inspector) -> str: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py b/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py index aaf62b31a02acd..3c04e78c936a09 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/postgres.py @@ -212,7 +212,7 @@ def get_platform(self): @classmethod def create(cls, config_dict, ctx): - config = PostgresConfig.parse_obj(config_dict) + config = PostgresConfig.model_validate(config_dict) return cls(config, ctx) def _setup_rds_iam_event_listener( @@ -288,7 +288,7 @@ def _get_view_lineage_elements( return {} for row in results: - data.append(ViewLineageEntry.parse_obj(row)) + data.append(ViewLineageEntry.model_validate(row)) lineage_elements: Dict[Tuple[str, str], List[str]] = defaultdict(list) # Loop over the lineages in the JSON data. diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/presto.py b/metadata-ingestion/src/datahub/ingestion/source/sql/presto.py index 7582a051dec67c..401271ab75e1f0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/presto.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/presto.py @@ -115,7 +115,7 @@ def __init__(self, config: PrestoConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict, ctx): - config = PrestoConfig.parse_obj(config_dict) + config = PrestoConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_config.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_config.py index 691ae26d5465cf..be2190a20e9f63 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_config.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Optional import pydantic -from pydantic import Field +from pydantic import Field, model_validator from datahub.configuration.common import AllowDenyPattern, ConfigModel from datahub.configuration.source_common import ( @@ -49,7 +49,8 @@ class SQLFilterConfig(ConfigModel): description="Regex patterns for views to filter in ingestion. Note: Defaults to table_pattern if not specified. Specify regex to match the entire view name in database.schema.view format. e.g. to match all views starting with customer in Customer database and public schema, use the regex 'Customer.public.customer.*'", ) - @pydantic.root_validator(pre=True) + @model_validator(mode="before") + @classmethod def view_pattern_is_table_pattern_unless_specified( cls, values: Dict[str, Any] ) -> Dict[str, Any]: @@ -120,11 +121,9 @@ def is_profiling_enabled(self) -> bool: self.profiling.operation_config ) - @pydantic.root_validator(skip_on_failure=True) - def ensure_profiling_pattern_is_passed_to_profiling( - cls, values: Dict[str, Any] - ) -> Dict[str, Any]: - profiling: Optional[GEProfilingConfig] = values.get("profiling") + @model_validator(mode="after") + def ensure_profiling_pattern_is_passed_to_profiling(self): + profiling = self.profiling # Note: isinstance() check is required here as unity-catalog source reuses # SQLCommonConfig with different profiling config than GEProfilingConfig if ( @@ -132,8 +131,8 @@ def ensure_profiling_pattern_is_passed_to_profiling( and isinstance(profiling, GEProfilingConfig) and profiling.enabled ): - profiling._allow_deny_patterns = values["profile_pattern"] - return values + profiling._allow_deny_patterns = self.profile_pattern + return self @abstractmethod def get_sql_alchemy_url(self): diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_generic.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_generic.py index 78b0dcf9b7be82..d1e4b483180e44 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_generic.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_generic.py @@ -85,5 +85,5 @@ def __init__(self, config: SQLAlchemyGenericConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict, ctx): - config = SQLAlchemyGenericConfig.parse_obj(config_dict) + config = SQLAlchemyGenericConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/teradata.py b/metadata-ingestion/src/datahub/ingestion/source/sql/teradata.py index 2f6d9e04dd0ceb..96e3830ba9a1ef 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/teradata.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/teradata.py @@ -860,7 +860,7 @@ def _add_default_options(self, sql_config: SQLCommonConfig) -> None: @classmethod def create(cls, config_dict, ctx): - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) return cls(config, ctx) def _init_schema_resolver(self) -> SchemaResolver: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/trino.py b/metadata-ingestion/src/datahub/ingestion/source/sql/trino.py index bae22d49ac8cb9..fa5032a5a2f5a9 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/trino.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/trino.py @@ -413,7 +413,7 @@ def _process_view( @classmethod def create(cls, config_dict, ctx): - config = TrinoConfig.parse_obj(config_dict) + config = TrinoConfig.model_validate(config_dict) return cls(config, ctx) def get_schema_fields_for_column( diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py b/metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py index 7581b5b3667946..e276a04b5fe8dc 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/vertica.py @@ -5,7 +5,7 @@ import pydantic import pytest -from pydantic import validator +from pydantic import field_validator from vertica_sqlalchemy_dialect.base import VerticaInspector from datahub.configuration.common import AllowDenyPattern @@ -105,8 +105,9 @@ class VerticaConfig(BasicSQLAlchemyConfig): # defaults scheme: str = pydantic.Field(default="vertica+vertica_python") - @validator("host_port") - def clean_host_port(cls, v): + @field_validator("host_port", mode="after") + @classmethod + def clean_host_port(cls, v: str) -> str: return config_clean.remove_protocol(v) @@ -138,7 +139,7 @@ def __init__(self, config: VerticaConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: Dict, ctx: PipelineContext) -> "VerticaSource": - config = VerticaConfig.parse_obj(config_dict) + config = VerticaConfig.model_validate(config_dict) return cls(config, ctx) def get_workunits_internal(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py b/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py index bab840fd9e2869..332129872c4787 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql_queries.py @@ -5,10 +5,10 @@ from dataclasses import dataclass, field from datetime import datetime from functools import partial -from typing import ClassVar, Iterable, List, Optional, Union, cast +from typing import Any, ClassVar, Iterable, List, Optional, Union, cast import smart_open -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator from datahub.configuration.common import HiddenFromDocs from datahub.configuration.datetimes import parse_user_datetime @@ -450,19 +450,22 @@ class QueryEntry(BaseModel): class Config: arbitrary_types_allowed = True - @validator("timestamp", pre=True) - def parse_timestamp(cls, v): + @field_validator("timestamp", mode="before") + @classmethod + def parse_timestamp(cls, v: Any) -> Any: return None if v is None else parse_user_datetime(str(v)) - @validator("user", pre=True) - def parse_user(cls, v): + @field_validator("user", mode="before") + @classmethod + def parse_user(cls, v: Any) -> Any: if v is None: return None return v if isinstance(v, CorpUserUrn) else CorpUserUrn(v) - @validator("downstream_tables", "upstream_tables", pre=True) - def parse_tables(cls, v): + @field_validator("downstream_tables", "upstream_tables", mode="before") + @classmethod + def parse_tables(cls, v: Any) -> Any: if not v: return [] diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py b/metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py index 27cb100fb4841a..ff79ca0f4996f7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py @@ -163,7 +163,7 @@ def _from_utf8_bytes( ) state_as_dict["version"] = checkpoint_aspect.state.formatVersion state_as_dict["serde"] = checkpoint_aspect.state.serde - return state_class.parse_obj(state_as_dict) + return state_class.model_validate(state_as_dict) @staticmethod def _from_base85_json_bytes( @@ -179,7 +179,7 @@ def _from_base85_json_bytes( state_as_dict = json.loads(state_uncompressed.decode("utf-8")) state_as_dict["version"] = checkpoint_aspect.state.formatVersion state_as_dict["serde"] = checkpoint_aspect.state.serde - return state_class.parse_obj(state_as_dict) + return state_class.model_validate(state_as_dict) def to_checkpoint_aspect( self, max_allowed_state_size: int diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/entity_removal_state.py b/metadata-ingestion/src/datahub/ingestion/source/state/entity_removal_state.py index 67f68df1b4fffc..37a1241991140f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/entity_removal_state.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/entity_removal_state.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, Type import pydantic +from pydantic import model_validator from datahub.emitter.mce_builder import make_assertion_urn, make_container_urn from datahub.ingestion.source.state.checkpoint import CheckpointStateBase @@ -59,7 +60,7 @@ def _validate_field_rename(cls: Type, values: dict) -> dict: return values - return pydantic.root_validator(pre=True, allow_reuse=True)(_validate_field_rename) + return model_validator(mode="before")(_validate_field_rename) class GenericCheckpointState(CheckpointStateBase): diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py b/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py index 2d31231eb53217..15de4fbe56ecca 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Generic, Optional, Type, TypeVar import pydantic -from pydantic import root_validator +from pydantic import model_validator from pydantic.fields import Field from datahub.configuration.common import ( @@ -73,14 +73,14 @@ class StatefulIngestionConfig(ConfigModel): description="If set to True, ignores the current checkpoint state.", ) - @pydantic.root_validator(skip_on_failure=True) - def validate_config(cls, values: Dict[str, Any]) -> Dict[str, Any]: - if values.get("enabled"): - if values.get("state_provider") is None: - values["state_provider"] = DynamicTypedStateProviderConfig( + @model_validator(mode="after") + def validate_config(self) -> "StatefulIngestionConfig": + if self.enabled: + if self.state_provider is None: + self.state_provider = DynamicTypedStateProviderConfig( type="datahub", config={} ) - return values + return self CustomConfig = TypeVar("CustomConfig", bound=StatefulIngestionConfig) @@ -110,17 +110,19 @@ class StatefulLineageConfigMixin(ConfigModel): "store_last_lineage_extraction_timestamp", "enable_stateful_lineage_ingestion" ) - @root_validator(skip_on_failure=True) - def lineage_stateful_option_validator(cls, values: Dict) -> Dict: - sti = values.get("stateful_ingestion") - if not sti or not sti.enabled: - if values.get("enable_stateful_lineage_ingestion"): - logger.warning( - "Stateful ingestion is disabled, disabling enable_stateful_lineage_ingestion config option as well" - ) - values["enable_stateful_lineage_ingestion"] = False - - return values + @model_validator(mode="after") + def lineage_stateful_option_validator(self) -> "StatefulLineageConfigMixin": + try: + sti = getattr(self, "stateful_ingestion", None) + if not sti or not getattr(sti, "enabled", False): + if getattr(self, "enable_stateful_lineage_ingestion", False): + logger.warning( + "Stateful ingestion is disabled, disabling enable_stateful_lineage_ingestion config option as well" + ) + self.enable_stateful_lineage_ingestion = False + except (AttributeError, RecursionError) as e: + logger.debug(f"Skipping stateful lineage validation due to: {e}") + return self class StatefulProfilingConfigMixin(ConfigModel): @@ -135,16 +137,19 @@ class StatefulProfilingConfigMixin(ConfigModel): "store_last_profiling_timestamps", "enable_stateful_profiling" ) - @root_validator(skip_on_failure=True) - def profiling_stateful_option_validator(cls, values: Dict) -> Dict: - sti = values.get("stateful_ingestion") - if not sti or not sti.enabled: - if values.get("enable_stateful_profiling"): - logger.warning( - "Stateful ingestion is disabled, disabling enable_stateful_profiling config option as well" - ) - values["enable_stateful_profiling"] = False - return values + @model_validator(mode="after") + def profiling_stateful_option_validator(self) -> "StatefulProfilingConfigMixin": + try: + sti = getattr(self, "stateful_ingestion", None) + if not sti or not getattr(sti, "enabled", False): + if getattr(self, "enable_stateful_profiling", False): + logger.warning( + "Stateful ingestion is disabled, disabling enable_stateful_profiling config option as well" + ) + self.enable_stateful_profiling = False + except (AttributeError, RecursionError) as e: + logger.debug(f"Skipping stateful profiling validation due to: {e}") + return self class StatefulUsageConfigMixin(BaseTimeWindowConfig): @@ -161,16 +166,21 @@ class StatefulUsageConfigMixin(BaseTimeWindowConfig): "store_last_usage_extraction_timestamp", "enable_stateful_usage_ingestion" ) - @root_validator(skip_on_failure=True) - def last_usage_extraction_stateful_option_validator(cls, values: Dict) -> Dict: - sti = values.get("stateful_ingestion") - if not sti or not sti.enabled: - if values.get("enable_stateful_usage_ingestion"): - logger.warning( - "Stateful ingestion is disabled, disabling enable_stateful_usage_ingestion config option as well" - ) - values["enable_stateful_usage_ingestion"] = False - return values + @model_validator(mode="after") + def last_usage_extraction_stateful_option_validator( + self, + ) -> "StatefulUsageConfigMixin": + try: + sti = getattr(self, "stateful_ingestion", None) + if not sti or not getattr(sti, "enabled", False): + if getattr(self, "enable_stateful_usage_ingestion", False): + logger.warning( + "Stateful ingestion is disabled, disabling enable_stateful_usage_ingestion config option as well" + ) + self.enable_stateful_usage_ingestion = False + except (AttributeError, RecursionError) as e: + logger.debug(f"Skipping stateful usage validation due to: {e}") + return self class StatefulTimeWindowConfigMixin(BaseTimeWindowConfig): @@ -185,16 +195,16 @@ class StatefulTimeWindowConfigMixin(BaseTimeWindowConfig): "and queries together from a single audit log and uses a unified time window.", ) - @root_validator(skip_on_failure=True) - def time_window_stateful_option_validator(cls, values: Dict) -> Dict: - sti = values.get("stateful_ingestion") - if not sti or not sti.enabled: - if values.get("enable_stateful_time_window"): + @model_validator(mode="after") + def time_window_stateful_option_validator(self) -> "StatefulTimeWindowConfigMixin": + sti = getattr(self, "stateful_ingestion", None) + if not sti or not getattr(sti, "enabled", False): + if getattr(self, "enable_stateful_time_window", False): logger.warning( "Stateful ingestion is disabled, disabling enable_stateful_time_window config option as well" ) - values["enable_stateful_time_window"] = False - return values + self.enable_stateful_time_window = False + return self @dataclass diff --git a/metadata-ingestion/src/datahub/ingestion/source/state_provider/datahub_ingestion_checkpointing_provider.py b/metadata-ingestion/src/datahub/ingestion/source/state_provider/datahub_ingestion_checkpointing_provider.py index 53b4880e51ff32..0fe29a15a439c3 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state_provider/datahub_ingestion_checkpointing_provider.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state_provider/datahub_ingestion_checkpointing_provider.py @@ -40,7 +40,7 @@ def __init__( def create( cls, config_dict: Dict[str, Any], ctx: PipelineContext ) -> "DatahubIngestionCheckpointingProvider": - config = DatahubIngestionStateProviderConfig.parse_obj(config_dict) + config = DatahubIngestionStateProviderConfig.model_validate(config_dict) if config.datahub_api is not None: return cls(DataHubGraph(config.datahub_api)) elif ctx.graph: diff --git a/metadata-ingestion/src/datahub/ingestion/source/state_provider/file_ingestion_checkpointing_provider.py b/metadata-ingestion/src/datahub/ingestion/source/state_provider/file_ingestion_checkpointing_provider.py index 55f0903b9c91c7..2096f11009bbf1 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state_provider/file_ingestion_checkpointing_provider.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state_provider/file_ingestion_checkpointing_provider.py @@ -32,7 +32,7 @@ def __init__(self, config: FileIngestionStateProviderConfig): def create( cls, config_dict: Dict[str, Any], ctx: PipelineContext ) -> "FileIngestionCheckpointingProvider": - config = FileIngestionStateProviderConfig.parse_obj(config_dict) + config = FileIngestionStateProviderConfig.model_validate(config_dict) return cls(config) def get_latest_checkpoint( diff --git a/metadata-ingestion/src/datahub/ingestion/source/superset.py b/metadata-ingestion/src/datahub/ingestion/source/superset.py index d12ce91570b97c..d7a90e7a645ca5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/superset.py +++ b/metadata-ingestion/src/datahub/ingestion/source/superset.py @@ -9,7 +9,7 @@ import dateutil.parser as dp import requests import sqlglot -from pydantic import BaseModel, root_validator, validator +from pydantic import BaseModel, field_validator, model_validator from pydantic.fields import Field from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry @@ -246,16 +246,16 @@ class Config: # This is required to allow preset configs to get parsed extra = "allow" - @validator("connect_uri", "display_uri") - def remove_trailing_slash(cls, v): + @field_validator("connect_uri", "display_uri", mode="after") + @classmethod + def remove_trailing_slash(cls, v: str) -> str: return config_clean.remove_trailing_slashes(v) - @root_validator(skip_on_failure=True) - def default_display_uri_to_connect_uri(cls, values): - base = values.get("display_uri") - if base is None: - values["display_uri"] = values.get("connect_uri") - return values + @model_validator(mode="after") + def default_display_uri_to_connect_uri(self) -> "SupersetConfig": + if self.display_uri is None: + self.display_uri = self.connect_uri + return self def get_metric_name(metric): diff --git a/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau.py b/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau.py index 0bb6e3e19358f9..5ead74b3ddb203 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau.py +++ b/metadata-ingestion/src/datahub/ingestion/source/tableau/tableau.py @@ -25,7 +25,7 @@ import dateutil.parser as dp import tableauserverclient as TSC -from pydantic import root_validator, validator +from pydantic import field_validator, model_validator from pydantic.fields import Field from requests.adapters import HTTPAdapter from tableauserverclient import ( @@ -257,8 +257,9 @@ class TableauConnectionConfig(ConfigModel): description="When enabled, extracts column-level lineage from Tableau Datasources", ) - @validator("connect_uri") - def remove_trailing_slash(cls, v): + @field_validator("connect_uri", mode="after") + @classmethod + def remove_trailing_slash(cls, v: str) -> str: return config_clean.remove_trailing_slashes(v) def get_tableau_auth( @@ -652,8 +653,9 @@ class TableauConfig( "fetch_size", ) - # pre = True because we want to take some decision before pydantic initialize the configuration to default values - @root_validator(pre=True) + # mode = "before" because we want to take some decision before pydantic initialize the configuration to default values + @model_validator(mode="before") + @classmethod def projects_backward_compatibility(cls, values: Dict) -> Dict: # In-place update of the input dict would cause state contamination. This was discovered through test failures # in test_hex.py where the same dict is reused. @@ -683,27 +685,23 @@ def projects_backward_compatibility(cls, values: Dict) -> Dict: return values - @root_validator(skip_on_failure=True) - def validate_config_values(cls, values: Dict) -> Dict: - tags_for_hidden_assets = values.get("tags_for_hidden_assets") - ingest_tags = values.get("ingest_tags") + @model_validator(mode="after") + def validate_config_values(self) -> "TableauConfig": if ( - not ingest_tags - and tags_for_hidden_assets - and len(tags_for_hidden_assets) > 0 + not self.ingest_tags + and self.tags_for_hidden_assets + and len(self.tags_for_hidden_assets) > 0 ): raise ValueError( "tags_for_hidden_assets is only allowed with ingest_tags enabled. Be aware that this will overwrite tags entered from the UI." ) - use_email_as_username = values.get("use_email_as_username") - ingest_owner = values.get("ingest_owner") - if use_email_as_username and not ingest_owner: + if self.use_email_as_username and not self.ingest_owner: raise ValueError( "use_email_as_username requires ingest_owner to be enabled." ) - return values + return self class WorkbookKey(ContainerKey): diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/config.py b/metadata-ingestion/src/datahub/ingestion/source/unity/config.py index c57517b31a71e8..17e0929d36b0eb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/config.py @@ -1,10 +1,10 @@ import logging import os from datetime import datetime, timedelta, timezone -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import pydantic -from pydantic import Field +from pydantic import Field, field_validator, model_validator from typing_extensions import Literal from datahub.configuration.common import ( @@ -397,13 +397,15 @@ def is_ge_profiling(self) -> bool: default=None, description="Unity Catalog Stateful Ingestion Config." ) - @pydantic.validator("start_time") + @field_validator("start_time", mode="after") + @classmethod def within_thirty_days(cls, v: datetime) -> datetime: if (datetime.now(timezone.utc) - v).days > 30: raise ValueError("Query history is only maintained for 30 days.") return v - @pydantic.validator("workspace_url") + @field_validator("workspace_url", mode="after") + @classmethod def workspace_url_should_start_with_http_scheme(cls, workspace_url: str) -> str: if not workspace_url.lower().startswith(("http://", "https://")): raise ValueError( @@ -411,7 +413,8 @@ def workspace_url_should_start_with_http_scheme(cls, workspace_url: str) -> str: ) return workspace_url - @pydantic.validator("include_metastore") + @field_validator("include_metastore", mode="after") + @classmethod def include_metastore_warning(cls, v: bool) -> bool: if v: msg = ( @@ -424,60 +427,56 @@ def include_metastore_warning(cls, v: bool) -> bool: add_global_warning(msg) return v - @pydantic.root_validator(skip_on_failure=True) - def set_warehouse_id_from_profiling(cls, values: Dict[str, Any]) -> Dict[str, Any]: - profiling: Optional[ - Union[UnityCatalogGEProfilerConfig, UnityCatalogAnalyzeProfilerConfig] - ] = values.get("profiling") - if not values.get("warehouse_id") and profiling and profiling.warehouse_id: - values["warehouse_id"] = profiling.warehouse_id + @model_validator(mode="after") + def set_warehouse_id_from_profiling(self): + profiling = self.profiling + if not self.warehouse_id and profiling and profiling.warehouse_id: + self.warehouse_id = profiling.warehouse_id if ( - values.get("warehouse_id") + self.warehouse_id and profiling and profiling.warehouse_id - and values["warehouse_id"] != profiling.warehouse_id + and self.warehouse_id != profiling.warehouse_id ): raise ValueError( "When `warehouse_id` is set, it must match the `warehouse_id` in `profiling`." ) - if values.get("warehouse_id") and profiling and not profiling.warehouse_id: - profiling.warehouse_id = values["warehouse_id"] + if self.warehouse_id and profiling and not profiling.warehouse_id: + profiling.warehouse_id = self.warehouse_id if profiling and profiling.enabled and not profiling.warehouse_id: raise ValueError("warehouse_id must be set when profiling is enabled.") - return values + return self - @pydantic.root_validator(skip_on_failure=True) - def validate_lineage_data_source_with_warehouse( - cls, values: Dict[str, Any] - ) -> Dict[str, Any]: - lineage_data_source = values.get("lineage_data_source", LineageDataSource.AUTO) - warehouse_id = values.get("warehouse_id") + @model_validator(mode="after") + def validate_lineage_data_source_with_warehouse(self): + lineage_data_source = self.lineage_data_source or LineageDataSource.AUTO - if lineage_data_source == LineageDataSource.SYSTEM_TABLES and not warehouse_id: + if ( + lineage_data_source == LineageDataSource.SYSTEM_TABLES + and not self.warehouse_id + ): raise ValueError( f"lineage_data_source='{LineageDataSource.SYSTEM_TABLES.value}' requires warehouse_id to be set" ) - return values + return self - @pydantic.root_validator(skip_on_failure=True) - def validate_usage_data_source_with_warehouse( - cls, values: Dict[str, Any] - ) -> Dict[str, Any]: - usage_data_source = values.get("usage_data_source", UsageDataSource.AUTO) - warehouse_id = values.get("warehouse_id") + @model_validator(mode="after") + def validate_usage_data_source_with_warehouse(self): + usage_data_source = self.usage_data_source or UsageDataSource.AUTO - if usage_data_source == UsageDataSource.SYSTEM_TABLES and not warehouse_id: + if usage_data_source == UsageDataSource.SYSTEM_TABLES and not self.warehouse_id: raise ValueError( f"usage_data_source='{UsageDataSource.SYSTEM_TABLES.value}' requires warehouse_id to be set" ) - return values + return self - @pydantic.validator("schema_pattern", always=True) + @field_validator("schema_pattern", mode="before") + @classmethod def schema_pattern_should__always_deny_information_schema( cls, v: AllowDenyPattern ) -> AllowDenyPattern: diff --git a/metadata-ingestion/src/datahub/ingestion/source/unity/source.py b/metadata-ingestion/src/datahub/ingestion/source/unity/source.py index 5224dcb9b7edb9..505f7e77351c57 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/unity/source.py +++ b/metadata-ingestion/src/datahub/ingestion/source/unity/source.py @@ -319,7 +319,7 @@ def test_connection(config_dict: dict) -> TestConnectionReport: @classmethod def create(cls, config_dict, ctx): - config = UnityCatalogSourceConfig.parse_obj(config_dict) + config = UnityCatalogSourceConfig.model_validate(config_dict) return cls(ctx=ctx, config=config) def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/usage/clickhouse_usage.py b/metadata-ingestion/src/datahub/ingestion/source/usage/clickhouse_usage.py index d35b067983e26c..b67393e175d6dd 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/usage/clickhouse_usage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/usage/clickhouse_usage.py @@ -115,7 +115,7 @@ class ClickHouseUsageSource(Source): @classmethod def create(cls, config_dict, ctx): - config = ClickHouseUsageConfig.parse_obj(config_dict) + config = ClickHouseUsageConfig.model_validate(config_dict) return cls(ctx, config) def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/usage/starburst_trino_usage.py b/metadata-ingestion/src/datahub/ingestion/source/usage/starburst_trino_usage.py index 53139e1590b926..c01de8a90e3673 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/usage/starburst_trino_usage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/usage/starburst_trino_usage.py @@ -133,7 +133,7 @@ class TrinoUsageSource(Source): @classmethod def create(cls, config_dict, ctx): - config = TrinoUsageConfig.parse_obj(config_dict) + config = TrinoUsageConfig.model_validate(config_dict) return cls(ctx, config) def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: diff --git a/metadata-ingestion/src/datahub/ingestion/source/usage/usage_common.py b/metadata-ingestion/src/datahub/ingestion/source/usage/usage_common.py index dcb831d3ce31f5..4f94c483b292d0 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/usage/usage_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/usage/usage_common.py @@ -15,6 +15,7 @@ ) import pydantic +from pydantic import ValidationInfo, field_validator from pydantic.fields import Field import datahub.emitter.mce_builder as builder @@ -226,10 +227,11 @@ class BaseUsageConfig(BaseTimeWindowConfig): default=True, description="Whether to ingest the top_n_queries." ) - @pydantic.validator("top_n_queries") - def ensure_top_n_queries_is_not_too_big(cls, v: int, values: dict) -> int: + @field_validator("top_n_queries", mode="after") + @classmethod + def ensure_top_n_queries_is_not_too_big(cls, v: int, info: ValidationInfo) -> int: minimum_query_size = 20 - + values = info.data max_queries = int(values["queries_character_limit"] / minimum_query_size) if v > max_queries: raise ValueError( diff --git a/metadata-ingestion/src/datahub/ingestion/source_config/csv_enricher.py b/metadata-ingestion/src/datahub/ingestion/source_config/csv_enricher.py index f0f0ab95ca8119..d5661743244d1a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source_config/csv_enricher.py +++ b/metadata-ingestion/src/datahub/ingestion/source_config/csv_enricher.py @@ -1,6 +1,5 @@ -from typing import Any, Dict - import pydantic +from pydantic import field_validator from datahub.configuration.common import ConfigModel @@ -21,7 +20,8 @@ class CSVEnricherConfig(ConfigModel): description="Delimiter to use when parsing array fields (tags, terms and owners)", ) - @pydantic.validator("write_semantics") + @field_validator("write_semantics", mode="after") + @classmethod def validate_write_semantics(cls, write_semantics: str) -> str: if write_semantics.lower() not in {"patch", "override"}: raise ValueError( @@ -31,9 +31,10 @@ def validate_write_semantics(cls, write_semantics: str) -> str: ) return write_semantics - @pydantic.validator("array_delimiter") - def validator_diff(cls, array_delimiter: str, values: Dict[str, Any]) -> str: - if array_delimiter == values["delimiter"]: + @field_validator("array_delimiter", mode="after") + @classmethod + def validator_diff(cls, array_delimiter: str, info: pydantic.ValidationInfo) -> str: + if array_delimiter == info.data["delimiter"]: raise ValueError( "array_delimiter and delimiter are the same. Please choose different delimiters." ) diff --git a/metadata-ingestion/src/datahub/ingestion/source_config/operation_config.py b/metadata-ingestion/src/datahub/ingestion/source_config/operation_config.py index 1846dcb4fdd3d0..2957a1323e9d14 100644 --- a/metadata-ingestion/src/datahub/ingestion/source_config/operation_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source_config/operation_config.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Optional import cachetools -import pydantic +from pydantic import field_validator, model_validator from pydantic.fields import Field from datahub.configuration.common import ConfigModel @@ -26,7 +26,8 @@ class OperationConfig(ConfigModel): description="Number between 1 to 31 for date of month (both inclusive). If not specified, defaults to Nothing and this field does not take affect.", ) - @pydantic.root_validator(pre=True) + @model_validator(mode="before") + @classmethod def lower_freq_configs_are_set(cls, values: Dict[str, Any]) -> Dict[str, Any]: lower_freq_profile_enabled = values.get("lower_freq_profile_enabled") profile_day_of_week = values.get("profile_day_of_week") @@ -41,7 +42,8 @@ def lower_freq_configs_are_set(cls, values: Dict[str, Any]) -> Dict[str, Any]: ) return values - @pydantic.validator("profile_day_of_week") + @field_validator("profile_day_of_week", mode="after") + @classmethod def validate_profile_day_of_week(cls, v: Optional[int]) -> Optional[int]: profile_day_of_week = v if profile_day_of_week is None: @@ -52,7 +54,8 @@ def validate_profile_day_of_week(cls, v: Optional[int]) -> Optional[int]: ) return profile_day_of_week - @pydantic.validator("profile_date_of_month") + @field_validator("profile_date_of_month", mode="after") + @classmethod def validate_profile_date_of_month(cls, v: Optional[int]) -> Optional[int]: profile_date_of_month = v if profile_date_of_month is None: diff --git a/metadata-ingestion/src/datahub/ingestion/source_config/pulsar.py b/metadata-ingestion/src/datahub/ingestion/source_config/pulsar.py index bbeb9ee556aa94..380ab09618bb11 100644 --- a/metadata-ingestion/src/datahub/ingestion/source_config/pulsar.py +++ b/metadata-ingestion/src/datahub/ingestion/source_config/pulsar.py @@ -3,7 +3,7 @@ from urllib.parse import urlparse import pydantic -from pydantic import Field, validator +from pydantic import Field, model_validator from datahub.configuration.common import AllowDenyPattern from datahub.configuration.source_common import ( @@ -100,27 +100,23 @@ class PulsarSourceConfig( default_factory=dict, description="Placeholder for OpenId discovery document" ) - @validator("token") - def ensure_only_issuer_or_token( - cls, token: Optional[str], values: Dict[str, Optional[str]] - ) -> Optional[str]: - if token is not None and values.get("issuer_url") is not None: + @model_validator(mode="after") + def ensure_only_issuer_or_token(self) -> "PulsarSourceConfig": + if self.token is not None and self.issuer_url is not None: raise ValueError( "Expected only one authentication method, either issuer_url or token." ) - return token - - @validator("client_secret", always=True) - def ensure_client_id_and_secret_for_issuer_url( - cls, client_secret: Optional[str], values: Dict[str, Optional[str]] - ) -> Optional[str]: - if values.get("issuer_url") is not None and ( - client_secret is None or values.get("client_id") is None + return self + + @model_validator(mode="after") + def ensure_client_id_and_secret_for_issuer_url(self) -> "PulsarSourceConfig": + if self.issuer_url is not None and ( + self.client_secret is None or self.client_id is None ): raise ValueError( "Missing configuration: client_id and client_secret are mandatory when issuer_url is set." ) - return client_secret + return self @pydantic.field_validator("web_service_url", mode="after") @classmethod diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_browse_path.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_browse_path.py index 55989cf17f2691..a497417a813e8c 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_browse_path.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_browse_path.py @@ -32,7 +32,7 @@ def __init__(self, config: AddDatasetBrowsePathConfig, ctx: PipelineContext): def create( cls, config_dict: dict, ctx: PipelineContext ) -> "AddDatasetBrowsePathTransformer": - config = AddDatasetBrowsePathConfig.parse_obj(config_dict) + config = AddDatasetBrowsePathConfig.model_validate(config_dict) return cls(config, ctx) @staticmethod diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_dataproduct.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_dataproduct.py index e18aa063878b7b..6182fe8b9b8c23 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_dataproduct.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_dataproduct.py @@ -1,7 +1,7 @@ import logging from typing import Callable, Dict, List, Optional, Union -import pydantic +from pydantic import model_validator from datahub.configuration.common import ConfigModel, KeyValuePattern from datahub.configuration.import_resolver import pydantic_resolve_key @@ -39,7 +39,7 @@ def __init__(self, config: AddDatasetDataProductConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "AddDatasetDataProduct": - config = AddDatasetDataProductConfig.parse_obj(config_dict) + config = AddDatasetDataProductConfig.model_validate(config_dict) return cls(config, ctx) def transform_aspect( @@ -116,7 +116,7 @@ def __init__(self, config: SimpleDatasetDataProductConfig, ctx: PipelineContext) def create( cls, config_dict: dict, ctx: PipelineContext ) -> "SimpleAddDatasetDataProduct": - config = SimpleDatasetDataProductConfig.parse_obj(config_dict) + config = SimpleDatasetDataProductConfig.model_validate(config_dict) return cls(config, ctx) @@ -124,7 +124,8 @@ class PatternDatasetDataProductConfig(ConfigModel): dataset_to_data_product_urns_pattern: KeyValuePattern = KeyValuePattern.all() is_container: bool = False - @pydantic.root_validator(pre=True) + @model_validator(mode="before") + @classmethod def validate_pattern_value(cls, values: Dict) -> Dict: rules = values["dataset_to_data_product_urns_pattern"]["rules"] for key, value in rules.items(): @@ -156,5 +157,5 @@ def __init__(self, config: PatternDatasetDataProductConfig, ctx: PipelineContext def create( cls, config_dict: dict, ctx: PipelineContext ) -> "PatternAddDatasetDataProduct": - config = PatternDatasetDataProductConfig.parse_obj(config_dict) + config = PatternDatasetDataProductConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_ownership.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_ownership.py index c89a7301bd0363..ab3d610856a429 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_ownership.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_ownership.py @@ -55,7 +55,7 @@ def __init__(self, config: AddDatasetOwnershipConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "AddDatasetOwnership": - config = AddDatasetOwnershipConfig.parse_obj(config_dict) + config = AddDatasetOwnershipConfig.model_validate(config_dict) return cls(config, ctx) @staticmethod @@ -209,7 +209,7 @@ def __init__(self, config: SimpleDatasetOwnershipConfig, ctx: PipelineContext): def create( cls, config_dict: dict, ctx: PipelineContext ) -> "SimpleAddDatasetOwnership": - config = SimpleDatasetOwnershipConfig.parse_obj(config_dict) + config = SimpleDatasetOwnershipConfig.model_validate(config_dict) return cls(config, ctx) @@ -247,5 +247,5 @@ def __init__(self, config: PatternDatasetOwnershipConfig, ctx: PipelineContext): def create( cls, config_dict: dict, ctx: PipelineContext ) -> "PatternAddDatasetOwnership": - config = PatternDatasetOwnershipConfig.parse_obj(config_dict) + config = PatternDatasetOwnershipConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_properties.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_properties.py index 4b9b4c9e6f5da6..0e406e0d061ee8 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_properties.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_properties.py @@ -50,7 +50,7 @@ def __init__( @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "AddDatasetProperties": - config = AddDatasetPropertiesConfig.parse_obj(config_dict) + config = AddDatasetPropertiesConfig.model_validate(config_dict) return cls(config, ctx) @staticmethod @@ -144,5 +144,5 @@ def __init__(self, config: SimpleAddDatasetPropertiesConfig, ctx: PipelineContex def create( cls, config_dict: dict, ctx: PipelineContext ) -> "SimpleAddDatasetProperties": - config = SimpleAddDatasetPropertiesConfig.parse_obj(config_dict) + config = SimpleAddDatasetPropertiesConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_schema_tags.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_schema_tags.py index d2687ebc5e76f6..43be3d8ef0ec86 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_schema_tags.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_schema_tags.py @@ -38,7 +38,7 @@ def __init__(self, config: AddDatasetSchemaTagsConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "AddDatasetSchemaTags": - config = AddDatasetSchemaTagsConfig.parse_obj(config_dict) + config = AddDatasetSchemaTagsConfig.model_validate(config_dict) return cls(config, ctx) def extend_field( @@ -142,5 +142,5 @@ def __init__(self, config: PatternDatasetTagsConfig, ctx: PipelineContext): def create( cls, config_dict: dict, ctx: PipelineContext ) -> "PatternAddDatasetSchemaTags": - config = PatternDatasetTagsConfig.parse_obj(config_dict) + config = PatternDatasetTagsConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_schema_terms.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_schema_terms.py index d17a39bee6cfbf..32ec90dbd6d9ba 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_schema_terms.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_schema_terms.py @@ -39,7 +39,7 @@ def __init__(self, config: AddDatasetSchemaTermsConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "AddDatasetSchemaTerms": - config = AddDatasetSchemaTermsConfig.parse_obj(config_dict) + config = AddDatasetSchemaTermsConfig.model_validate(config_dict) return cls(config, ctx) def extend_field( @@ -162,5 +162,5 @@ def __init__(self, config: PatternDatasetTermsConfig, ctx: PipelineContext): def create( cls, config_dict: dict, ctx: PipelineContext ) -> "PatternAddDatasetSchemaTerms": - config = PatternDatasetTermsConfig.parse_obj(config_dict) + config = PatternDatasetTermsConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_tags.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_tags.py index 355ca7a373653f..7772e075d756de 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_tags.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_tags.py @@ -41,7 +41,7 @@ def __init__(self, config: AddDatasetTagsConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "AddDatasetTags": - config = AddDatasetTagsConfig.parse_obj(config_dict) + config = AddDatasetTagsConfig.model_validate(config_dict) return cls(config, ctx) def transform_aspect( @@ -104,7 +104,7 @@ def __init__(self, config: SimpleDatasetTagConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "SimpleAddDatasetTags": - config = SimpleDatasetTagConfig.parse_obj(config_dict) + config = SimpleDatasetTagConfig.model_validate(config_dict) return cls(config, ctx) @@ -128,5 +128,5 @@ def __init__(self, config: PatternDatasetTagsConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "PatternAddDatasetTags": - config = PatternDatasetTagsConfig.parse_obj(config_dict) + config = PatternDatasetTagsConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_terms.py b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_terms.py index 3daf52e32ed4bb..a65cfe57c40387 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_terms.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/add_dataset_terms.py @@ -39,7 +39,7 @@ def __init__(self, config: AddDatasetTermsConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "AddDatasetTerms": - config = AddDatasetTermsConfig.parse_obj(config_dict) + config = AddDatasetTermsConfig.model_validate(config_dict) return cls(config, ctx) @staticmethod @@ -120,7 +120,7 @@ def __init__(self, config: SimpleDatasetTermsConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "SimpleAddDatasetTerms": - config = SimpleDatasetTermsConfig.parse_obj(config_dict) + config = SimpleDatasetTermsConfig.model_validate(config_dict) return cls(config, ctx) @@ -147,5 +147,5 @@ def __init__(self, config: PatternDatasetTermsConfig, ctx: PipelineContext): def create( cls, config_dict: dict, ctx: PipelineContext ) -> "PatternAddDatasetTerms": - config = PatternDatasetTermsConfig.parse_obj(config_dict) + config = PatternDatasetTermsConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain.py b/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain.py index c02fd2c520184d..108185bb2a25ce 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain.py @@ -67,7 +67,7 @@ def __init__(self, config: AddDatasetDomainSemanticsConfig, ctx: PipelineContext @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "AddDatasetDomain": - config = AddDatasetDomainSemanticsConfig.parse_obj(config_dict) + config = AddDatasetDomainSemanticsConfig.model_validate(config_dict) return cls(config, ctx) @staticmethod @@ -208,7 +208,7 @@ def __init__( def create( cls, config_dict: dict, ctx: PipelineContext ) -> "SimpleAddDatasetDomain": - config = SimpleDatasetDomainSemanticsConfig.parse_obj(config_dict) + config = SimpleDatasetDomainSemanticsConfig.model_validate(config_dict) return cls(config, ctx) @@ -238,5 +238,5 @@ def resolve_domain(domain_urn: str) -> DomainsClass: def create( cls, config_dict: dict, ctx: PipelineContext ) -> "PatternAddDatasetDomain": - config = PatternDatasetDomainSemanticsConfig.parse_obj(config_dict) + config = PatternDatasetDomainSemanticsConfig.model_validate(config_dict) return cls(config, ctx) diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain_based_on_tags.py b/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain_based_on_tags.py index bb2f318dcac8b8..8bc1e0be1e2a5d 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain_based_on_tags.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/dataset_domain_based_on_tags.py @@ -27,7 +27,7 @@ def __init__(self, config: DatasetTagDomainMapperConfig, ctx: PipelineContext): def create( cls, config_dict: dict, ctx: PipelineContext ) -> "DatasetTagDomainMapper": - config = DatasetTagDomainMapperConfig.parse_obj(config_dict) + config = DatasetTagDomainMapperConfig.model_validate(config_dict) return cls(config, ctx) def transform_aspect( diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/extract_dataset_tags.py b/metadata-ingestion/src/datahub/ingestion/transformer/extract_dataset_tags.py index 4b64d38a9b42fa..3dcf0d3251f436 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/extract_dataset_tags.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/extract_dataset_tags.py @@ -29,7 +29,7 @@ def __init__(self, config: ExtractDatasetTagsConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "ExtractDatasetTags": - config = ExtractDatasetTagsConfig.parse_obj(config_dict) + config = ExtractDatasetTagsConfig.model_validate(config_dict) return cls(config, ctx) def _get_tags_to_add(self, entity_urn: str) -> List[TagAssociationClass]: diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/extract_ownership_from_tags.py b/metadata-ingestion/src/datahub/ingestion/transformer/extract_ownership_from_tags.py index 32707dcd3a372f..1b59357ad300a9 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/extract_ownership_from_tags.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/extract_ownership_from_tags.py @@ -62,7 +62,7 @@ def __init__(self, config: ExtractOwnersFromTagsConfig, ctx: PipelineContext): def create( cls, config_dict: dict, ctx: PipelineContext ) -> "ExtractOwnersFromTagsTransformer": - config = ExtractOwnersFromTagsConfig.parse_obj(config_dict) + config = ExtractOwnersFromTagsConfig.model_validate(config_dict) return cls(config, ctx) def get_owner_urn(self, owner_str: str) -> str: diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/mark_dataset_status.py b/metadata-ingestion/src/datahub/ingestion/transformer/mark_dataset_status.py index 00ef29183a0c9a..95e4a494137a45 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/mark_dataset_status.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/mark_dataset_status.py @@ -24,7 +24,7 @@ def __init__(self, config: MarkDatasetStatusConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "MarkDatasetStatus": - config = MarkDatasetStatusConfig.parse_obj(config_dict) + config = MarkDatasetStatusConfig.model_validate(config_dict) return cls(config, ctx) def transform_aspect( diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/pattern_cleanup_dataset_usage_user.py b/metadata-ingestion/src/datahub/ingestion/transformer/pattern_cleanup_dataset_usage_user.py index a3d41c8e91ec52..9df647a8942adf 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/pattern_cleanup_dataset_usage_user.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/pattern_cleanup_dataset_usage_user.py @@ -38,7 +38,7 @@ def __init__( def create( cls, config_dict: dict, ctx: PipelineContext ) -> "PatternCleanupDatasetUsageUser": - config = PatternCleanupDatasetUsageUserConfig.parse_obj(config_dict) + config = PatternCleanupDatasetUsageUserConfig.model_validate(config_dict) return cls(config, ctx) def transform_aspect( diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/pattern_cleanup_ownership.py b/metadata-ingestion/src/datahub/ingestion/transformer/pattern_cleanup_ownership.py index 3b4491290516c2..46827d4a2fee73 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/pattern_cleanup_ownership.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/pattern_cleanup_ownership.py @@ -37,7 +37,7 @@ def __init__(self, config: PatternCleanUpOwnershipConfig, ctx: PipelineContext): def create( cls, config_dict: dict, ctx: PipelineContext ) -> "PatternCleanUpOwnership": - config = PatternCleanUpOwnershipConfig.parse_obj(config_dict) + config = PatternCleanUpOwnershipConfig.model_validate(config_dict) return cls(config, ctx) def _get_current_owner_urns(self, entity_urn: str) -> Set[str]: diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/remove_dataset_ownership.py b/metadata-ingestion/src/datahub/ingestion/transformer/remove_dataset_ownership.py index 934e2a13d56314..dae09d8d13958e 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/remove_dataset_ownership.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/remove_dataset_ownership.py @@ -21,7 +21,7 @@ def __init__(self, config: ClearDatasetOwnershipConfig, ctx: PipelineContext): def create( cls, config_dict: dict, ctx: PipelineContext ) -> "SimpleRemoveDatasetOwnership": - config = ClearDatasetOwnershipConfig.parse_obj(config_dict) + config = ClearDatasetOwnershipConfig.model_validate(config_dict) return cls(config, ctx) def transform_aspect( diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/replace_external_url.py b/metadata-ingestion/src/datahub/ingestion/transformer/replace_external_url.py index f6847f234aefe6..7a4f9dd119a0d7 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/replace_external_url.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/replace_external_url.py @@ -47,7 +47,7 @@ def __init__( def create( cls, config_dict: dict, ctx: PipelineContext ) -> "ReplaceExternalUrlDataset": - config = ReplaceExternalUrlConfig.parse_obj(config_dict) + config = ReplaceExternalUrlConfig.model_validate(config_dict) return cls(config, ctx) def transform_aspect( @@ -97,7 +97,7 @@ def __init__( def create( cls, config_dict: dict, ctx: PipelineContext ) -> "ReplaceExternalUrlContainer": - config = ReplaceExternalUrlConfig.parse_obj(config_dict) + config = ReplaceExternalUrlConfig.model_validate(config_dict) return cls(config, ctx) def transform_aspect( diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/set_browse_path.py b/metadata-ingestion/src/datahub/ingestion/transformer/set_browse_path.py index 06e4869084cb54..026a7b3fb16699 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/set_browse_path.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/set_browse_path.py @@ -42,7 +42,7 @@ def entity_types(self) -> List[str]: def create( cls, config_dict: dict, ctx: PipelineContext ) -> "SetBrowsePathTransformer": - config = SetBrowsePathTransformerConfig.parse_obj(config_dict) + config = SetBrowsePathTransformerConfig.model_validate(config_dict) return cls(config, ctx) @staticmethod diff --git a/metadata-ingestion/src/datahub/ingestion/transformer/tags_to_terms.py b/metadata-ingestion/src/datahub/ingestion/transformer/tags_to_terms.py index 65cf2ac3614ae0..0220d5cccc1cbc 100644 --- a/metadata-ingestion/src/datahub/ingestion/transformer/tags_to_terms.py +++ b/metadata-ingestion/src/datahub/ingestion/transformer/tags_to_terms.py @@ -32,7 +32,7 @@ def __init__(self, config: TagsToTermMapperConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "TagsToTermMapper": - config = TagsToTermMapperConfig.parse_obj(config_dict) + config = TagsToTermMapperConfig.model_validate(config_dict) return cls(config, ctx) @staticmethod diff --git a/metadata-ingestion/src/datahub/lite/duckdb_lite.py b/metadata-ingestion/src/datahub/lite/duckdb_lite.py index c0d74df89f9517..7930b8ad779820 100644 --- a/metadata-ingestion/src/datahub/lite/duckdb_lite.py +++ b/metadata-ingestion/src/datahub/lite/duckdb_lite.py @@ -42,7 +42,7 @@ class DuckDBLite(DataHubLiteLocal[DuckDBLiteConfig]): @classmethod def create(cls, config_dict: dict) -> "DuckDBLite": - config: DuckDBLiteConfig = DuckDBLiteConfig.parse_obj(config_dict) + config: DuckDBLiteConfig = DuckDBLiteConfig.model_validate(config_dict) return DuckDBLite(config) def __init__(self, config: DuckDBLiteConfig) -> None: diff --git a/metadata-ingestion/src/datahub/lite/lite_util.py b/metadata-ingestion/src/datahub/lite/lite_util.py index 251aab340d1c78..531a9066c70eba 100644 --- a/metadata-ingestion/src/datahub/lite/lite_util.py +++ b/metadata-ingestion/src/datahub/lite/lite_util.py @@ -92,7 +92,7 @@ def reindex(self) -> None: def get_datahub_lite(config_dict: dict, read_only: bool = False) -> "DataHubLiteLocal": - lite_local_config = LiteLocalConfig.parse_obj(config_dict) + lite_local_config = LiteLocalConfig.model_validate(config_dict) lite_type = lite_local_config.type try: @@ -102,7 +102,7 @@ def get_datahub_lite(config_dict: dict, read_only: bool = False) -> "DataHubLite f"Failed to find a registered lite implementation for {lite_type}. Valid values are {[k for k in lite_registry.mapping]}" ) from e - lite_specific_config = lite_class.get_config_class().parse_obj( + lite_specific_config = lite_class.get_config_class().model_validate( lite_local_config.config ) lite = lite_class(lite_specific_config) diff --git a/metadata-ingestion/src/datahub/sdk/search_filters.py b/metadata-ingestion/src/datahub/sdk/search_filters.py index 632b692072e54c..da69f7667e4d1d 100644 --- a/metadata-ingestion/src/datahub/sdk/search_filters.py +++ b/metadata-ingestion/src/datahub/sdk/search_filters.py @@ -16,6 +16,7 @@ ) import pydantic +from pydantic import field_validator from datahub.configuration.common import ConfigModel from datahub.configuration.pydantic_migration_helpers import ( @@ -102,7 +103,8 @@ class _EntitySubtypeFilter(_BaseFilter): description="The entity subtype to filter on. Can be 'Table', 'View', 'Source', etc. depending on the native platform's concepts.", ) - @pydantic.validator("entity_subtype", pre=True) + @field_validator("entity_subtype", mode="before") + @classmethod def validate_entity_subtype(cls, v: str) -> List[str]: return [v] if not isinstance(v, list) else v @@ -141,10 +143,13 @@ class _PlatformFilter(_BaseFilter): platform: List[str] # TODO: Add validator to convert string -> list of strings - @pydantic.validator("platform", each_item=True) - def validate_platform(cls, v: str) -> str: + @field_validator("platform", mode="before") + @classmethod + def validate_platform(cls, v): # Subtle - we use the constructor instead of the from_string method # because coercion is acceptable here. + if isinstance(v, list): + return [str(DataPlatformUrn(item)) for item in v] return str(DataPlatformUrn(v)) def _build_rule(self) -> SearchFilterRule: @@ -161,8 +166,11 @@ def compile(self) -> _OrFilters: class _DomainFilter(_BaseFilter): domain: List[str] - @pydantic.validator("domain", each_item=True) - def validate_domain(cls, v: str) -> str: + @field_validator("domain", mode="before") + @classmethod + def validate_domain(cls, v): + if isinstance(v, list): + return [str(DomainUrn.from_string(item)) for item in v] return str(DomainUrn.from_string(v)) def _build_rule(self) -> SearchFilterRule: @@ -183,8 +191,11 @@ class _ContainerFilter(_BaseFilter): description="If true, only entities that are direct descendants of the container will be returned.", ) - @pydantic.validator("container", each_item=True) - def validate_container(cls, v: str) -> str: + @field_validator("container", mode="before") + @classmethod + def validate_container(cls, v): + if isinstance(v, list): + return [str(ContainerUrn.from_string(item)) for item in v] return str(ContainerUrn.from_string(v)) @classmethod @@ -249,17 +260,25 @@ class _OwnerFilter(_BaseFilter): description="The owner to filter on. Should be user or group URNs.", ) - @pydantic.validator("owner", each_item=True) - def validate_owner(cls, v: str) -> str: - if not v.startswith("urn:li:"): - raise ValueError(f"Owner must be a valid User or Group URN, got: {v}") - _type = guess_entity_type(v) - if _type == CorpUserUrn.ENTITY_TYPE: - return str(CorpUserUrn.from_string(v)) - elif _type == CorpGroupUrn.ENTITY_TYPE: - return str(CorpGroupUrn.from_string(v)) - else: - raise ValueError(f"Owner must be a valid User or Group URN, got: {v}") + @field_validator("owner", mode="before") + @classmethod + def validate_owner(cls, v): + validated = [] + for owner in v: + if not owner.startswith("urn:li:"): + raise ValueError( + f"Owner must be a valid User or Group URN, got: {owner}" + ) + _type = guess_entity_type(owner) + if _type == CorpUserUrn.ENTITY_TYPE: + validated.append(str(CorpUserUrn.from_string(owner))) + elif _type == CorpGroupUrn.ENTITY_TYPE: + validated.append(str(CorpGroupUrn.from_string(owner))) + else: + raise ValueError( + f"Owner must be a valid User or Group URN, got: {owner}" + ) + return validated def _build_rule(self) -> SearchFilterRule: return SearchFilterRule( @@ -279,17 +298,21 @@ class _GlossaryTermFilter(_BaseFilter): description="The glossary term to filter on. Should be glossary term URNs.", ) - @pydantic.validator("glossary_term", each_item=True) - def validate_glossary_term(cls, v: str) -> str: - if not v.startswith("urn:li:"): - raise ValueError(f"Glossary term must be a valid URN, got: {v}") - # Validate that it's a glossary term URN - _type = guess_entity_type(v) - if _type != "glossaryTerm": - raise ValueError( - f"Glossary term must be a valid glossary term URN, got: {v}" - ) - return v + @field_validator("glossary_term", mode="before") + @classmethod + def validate_glossary_term(cls, v): + validated = [] + for term in v: + if not term.startswith("urn:li:"): + raise ValueError(f"Glossary term must be a valid URN, got: {term}") + # Validate that it's a glossary term URN + _type = guess_entity_type(term) + if _type != "glossaryTerm": + raise ValueError( + f"Glossary term must be a valid glossary term URN, got: {term}" + ) + validated.append(term) + return validated def _build_rule(self) -> SearchFilterRule: return SearchFilterRule( @@ -309,15 +332,19 @@ class _TagFilter(_BaseFilter): description="The tag to filter on. Should be tag URNs.", ) - @pydantic.validator("tag", each_item=True) - def validate_tag(cls, v: str) -> str: - if not v.startswith("urn:li:"): - raise ValueError(f"Tag must be a valid URN, got: {v}") - # Validate that it's a tag URN - _type = guess_entity_type(v) - if _type != "tag": - raise ValueError(f"Tag must be a valid tag URN, got: {v}") - return v + @field_validator("tag", mode="before") + @classmethod + def validate_tag(cls, v): + validated = [] + for tag in v: + if not tag.startswith("urn:li:"): + raise ValueError(f"Tag must be a valid URN, got: {tag}") + # Validate that it's a tag URN + _type = guess_entity_type(tag) + if _type != "tag": + raise ValueError(f"Tag must be a valid tag URN, got: {tag}") + validated.append(tag) + return validated def _build_rule(self) -> SearchFilterRule: return SearchFilterRule( @@ -426,7 +453,8 @@ class _Not(_BaseFilter): not_: "Filter" = pydantic.Field(alias="not") - @pydantic.validator("not_", pre=False) + @field_validator("not_", mode="after") + @classmethod def validate_not(cls, v: "Filter") -> "Filter": inner_filter = v.compile() if len(inner_filter) != 1: @@ -571,7 +599,7 @@ def load_filters(obj: Any) -> Filter: if PYDANTIC_VERSION_2: return pydantic.TypeAdapter(Filter).validate_python(obj) # type: ignore else: - return pydantic.parse_obj_as(Filter, obj) # type: ignore + return pydantic.TypeAdapter(Filter).validate_python(obj) # type: ignore # We need FilterDsl for two reasons: diff --git a/metadata-ingestion/src/datahub/secret/datahub_secret_store.py b/metadata-ingestion/src/datahub/secret/datahub_secret_store.py index 0ae413b46c6d17..63a704a498d205 100644 --- a/metadata-ingestion/src/datahub/secret/datahub_secret_store.py +++ b/metadata-ingestion/src/datahub/secret/datahub_secret_store.py @@ -1,7 +1,7 @@ import logging from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, validator +from pydantic import BaseModel, field_validator from datahub.ingestion.graph.client import DataHubGraph from datahub.ingestion.graph.config import DatahubClientConfig @@ -18,8 +18,11 @@ class DataHubSecretStoreConfig(BaseModel): class Config: arbitrary_types_allowed = True - @validator("graph_client") - def check_graph_connection(cls, v: DataHubGraph) -> DataHubGraph: + @field_validator("graph_client", mode="after") + @classmethod + def check_graph_connection( + cls, v: Optional[DataHubGraph] + ) -> Optional[DataHubGraph]: if v is not None: v.test_connection() return v @@ -63,7 +66,7 @@ def get_id(self) -> str: @classmethod def create(cls, config: Any) -> "DataHubSecretStore": - config = DataHubSecretStoreConfig.parse_obj(config) + config = DataHubSecretStoreConfig.model_validate(config) return cls(config) def close(self) -> None: diff --git a/metadata-ingestion/src/datahub/secret/file_secret_store.py b/metadata-ingestion/src/datahub/secret/file_secret_store.py index 4ce2e5a2f546c0..55be7103f9fb8a 100644 --- a/metadata-ingestion/src/datahub/secret/file_secret_store.py +++ b/metadata-ingestion/src/datahub/secret/file_secret_store.py @@ -45,5 +45,5 @@ def close(self) -> None: @classmethod def create(cls, config: Any) -> "FileSecretStore": - config = FileSecretStoreConfig.parse_obj(config) + config = FileSecretStoreConfig.model_validate(config) return cls(config) diff --git a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py index 59914e8802e825..9722410c77c47a 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py @@ -28,6 +28,7 @@ import sqlglot.optimizer.qualify import sqlglot.optimizer.qualify_columns import sqlglot.optimizer.unnest_subqueries +from pydantic import field_validator from datahub.cli.env_utils import get_boolean_env_variable from datahub.ingestion.graph.client import DataHubGraph @@ -141,7 +142,8 @@ class DownstreamColumnRef(_ParserBaseModel): column_type: Optional[SchemaFieldDataTypeClass] = None native_column_type: Optional[str] = None - @pydantic.validator("column_type", pre=True) + @field_validator("column_type", mode="before") + @classmethod def _load_column_type( cls, v: Optional[Union[dict, SchemaFieldDataTypeClass]] ) -> Optional[SchemaFieldDataTypeClass]: @@ -215,7 +217,8 @@ class SqlParsingDebugInfo(_ParserBaseModel): def error(self) -> Optional[Exception]: return self.table_error or self.column_error - @pydantic.validator("table_error", "column_error") + @field_validator("table_error", "column_error", mode="before") + @classmethod def remove_variables_from_error(cls, v: Optional[Exception]) -> Optional[Exception]: if v and v.__traceback__: # Remove local variables from the traceback to avoid memory leaks. diff --git a/metadata-ingestion/src/datahub/testing/check_sql_parser_result.py b/metadata-ingestion/src/datahub/testing/check_sql_parser_result.py index 43a1e525180be4..dfc1fd9c49c464 100644 --- a/metadata-ingestion/src/datahub/testing/check_sql_parser_result.py +++ b/metadata-ingestion/src/datahub/testing/check_sql_parser_result.py @@ -60,8 +60,8 @@ def assert_sql_result_with_resolver( expected = SqlParsingResult.parse_raw(expected_file.read_text()) full_diff = deepdiff.DeepDiff( - expected.dict(), - res.dict(), + expected.model_dump(), + res.model_dump(), exclude_regex_paths=[ r"root.column_lineage\[\d+\].logic", ], diff --git a/metadata-ingestion/src/datahub/utilities/ingest_utils.py b/metadata-ingestion/src/datahub/utilities/ingest_utils.py index 363c82df9c82ef..7a9eca4a84ce72 100644 --- a/metadata-ingestion/src/datahub/utilities/ingest_utils.py +++ b/metadata-ingestion/src/datahub/utilities/ingest_utils.py @@ -48,7 +48,7 @@ def deploy_source_vars( deploy_options_raw = pipeline_config.pop("deployment", None) if deploy_options_raw is not None: - deploy_options = DeployOptions.parse_obj(deploy_options_raw) + deploy_options = DeployOptions.model_validate(deploy_options_raw) if name: logger.info(f"Overriding deployment name {deploy_options.name} with {name}") diff --git a/metadata-ingestion/tests/integration/azure_ad/test_azure_ad.py b/metadata-ingestion/tests/integration/azure_ad/test_azure_ad.py index 2a8beafb27f42f..d356fac0c352de 100644 --- a/metadata-ingestion/tests/integration/azure_ad/test_azure_ad.py +++ b/metadata-ingestion/tests/integration/azure_ad/test_azure_ad.py @@ -194,7 +194,7 @@ def overwrite_group_in_mocked_data( def test_azure_ad_config(): - config = AzureADConfig.parse_obj( + config = AzureADConfig.model_validate( dict( client_id="00000000-0000-0000-0000-000000000000", tenant_id="00000000-0000-0000-0000-000000000000", diff --git a/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery.py b/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery.py index f4ffa02c287eca..77786d0746c2ca 100644 --- a/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery.py +++ b/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery.py @@ -84,7 +84,7 @@ def recipe(mcp_output_path: str, source_config_override: Optional[dict] = None) ) ], max_workers=1, - ).dict(), + ).model_dump(), **source_config_override, }, }, diff --git a/metadata-ingestion/tests/integration/csv-enricher/test_csv_enricher.py b/metadata-ingestion/tests/integration/csv-enricher/test_csv_enricher.py index 79bfec3c234dfd..96ba5bb94c9388 100644 --- a/metadata-ingestion/tests/integration/csv-enricher/test_csv_enricher.py +++ b/metadata-ingestion/tests/integration/csv-enricher/test_csv_enricher.py @@ -10,7 +10,7 @@ def test_csv_enricher_config(): - config = CSVEnricherConfig.parse_obj( + config = CSVEnricherConfig.model_validate( dict( filename="../integration/csv_enricher/csv_enricher_test_data.csv", write_semantics="OVERRIDE", diff --git a/metadata-ingestion/tests/integration/fivetran/test_fivetran.py b/metadata-ingestion/tests/integration/fivetran/test_fivetran.py index 94b36f2e0bd57f..4f5a8b7a98a56c 100644 --- a/metadata-ingestion/tests/integration/fivetran/test_fivetran.py +++ b/metadata-ingestion/tests/integration/fivetran/test_fivetran.py @@ -525,7 +525,7 @@ def test_rename_destination_config(): ConfigurationWarning, match="destination_config is deprecated, please use snowflake_destination_config instead.", ): - config = FivetranSourceConfig.parse_obj(config_dict) + config = FivetranSourceConfig.model_validate(config_dict) assert config.fivetran_log_config.snowflake_destination_config is not None assert ( config.fivetran_log_config.snowflake_destination_config.account_id @@ -556,7 +556,7 @@ def test_compat_sources_to_database() -> None: ConfigurationWarning, match=r"sources_to_database.*deprecated", ): - config = FivetranSourceConfig.parse_obj(config_dict) + config = FivetranSourceConfig.model_validate(config_dict) assert config.sources_to_platform_instance == { "calendar_elected": PlatformDetail(env="DEV", database="my_db"), diff --git a/metadata-ingestion/tests/integration/git/test_git_clone.py b/metadata-ingestion/tests/integration/git/test_git_clone.py index 01e075930998a4..67ececbc2a221c 100644 --- a/metadata-ingestion/tests/integration/git/test_git_clone.py +++ b/metadata-ingestion/tests/integration/git/test_git_clone.py @@ -62,7 +62,7 @@ def test_base_url_guessing() -> None: # Deprecated: base_url. with pytest.warns(ConfigurationWarning, match="base_url is deprecated"): - config = GitInfo.parse_obj( + config = GitInfo.model_validate( dict( repo="https://github.com/datahub-project/datahub", branch="master", diff --git a/metadata-ingestion/tests/integration/kafka/test_kafka.py b/metadata-ingestion/tests/integration/kafka/test_kafka.py index 8ac216e6fa8376..915386a9ab7157 100644 --- a/metadata-ingestion/tests/integration/kafka/test_kafka.py +++ b/metadata-ingestion/tests/integration/kafka/test_kafka.py @@ -161,7 +161,7 @@ def test_kafka_source_oauth_cb_signature(): ConfigurationError, match=("oauth_cb function must accept single positional argument."), ): - KafkaSourceConfig.parse_obj( + KafkaSourceConfig.model_validate( { "connection": { "bootstrap": "foobar:9092", @@ -174,7 +174,7 @@ def test_kafka_source_oauth_cb_signature(): ConfigurationError, match=("oauth_cb function must accept single positional argument."), ): - KafkaSourceConfig.parse_obj( + KafkaSourceConfig.model_validate( { "connection": { "bootstrap": "foobar:9092", diff --git a/metadata-ingestion/tests/integration/lookml/test_lookml.py b/metadata-ingestion/tests/integration/lookml/test_lookml.py index 02d0b86ac1c0c2..6bcfcb62066256 100644 --- a/metadata-ingestion/tests/integration/lookml/test_lookml.py +++ b/metadata-ingestion/tests/integration/lookml/test_lookml.py @@ -185,7 +185,7 @@ def test_lookml_explore_refinement(pytestconfig, tmp_path, mock_time): looker_model=looker_model, looker_viewfile_loader=None, # type: ignore reporter=None, # type: ignore - source_config=LookMLSourceConfig.parse_obj( + source_config=LookMLSourceConfig.model_validate( { "process_refinements": "True", "base_folder": ".", @@ -769,7 +769,7 @@ def test_lookml_base_folder(): "client_secret": "this-is-also-fake", } - LookMLSourceConfig.parse_obj( + LookMLSourceConfig.model_validate( { "git_info": { "repo": "acryldata/long-tail-companions-looker", @@ -782,7 +782,7 @@ def test_lookml_base_folder(): with pytest.raises( pydantic.ValidationError, match=r"base_folder.+nor.+git_info.+provided" ): - LookMLSourceConfig.parse_obj({"api": fake_api}) + LookMLSourceConfig.model_validate({"api": fake_api}) @freeze_time(FROZEN_TIME) @@ -1263,7 +1263,7 @@ def test_unreachable_views(pytestconfig): } source = LookMLSource( - LookMLSourceConfig.parse_obj(config), + LookMLSourceConfig.model_validate(config), ctx=PipelineContext(run_id="lookml-source-test"), ) workunits: List[Union[MetadataWorkUnit, Entity]] = [ diff --git a/metadata-ingestion/tests/integration/okta/test_okta.py b/metadata-ingestion/tests/integration/okta/test_okta.py index d606a1e22cb8ad..d4762266a18f4d 100644 --- a/metadata-ingestion/tests/integration/okta/test_okta.py +++ b/metadata-ingestion/tests/integration/okta/test_okta.py @@ -77,7 +77,7 @@ def run_ingest( def test_okta_config(): - config = OktaConfig.parse_obj( + config = OktaConfig.model_validate( dict(okta_domain="test.okta.com", okta_api_token="test-token") ) diff --git a/metadata-ingestion/tests/integration/oracle/common.py b/metadata-ingestion/tests/integration/oracle/common.py index 17c1a43894fcc5..7fee637b1ac896 100644 --- a/metadata-ingestion/tests/integration/oracle/common.py +++ b/metadata-ingestion/tests/integration/oracle/common.py @@ -159,7 +159,7 @@ def get_recipe_source(self) -> dict: "source": { "type": "oracle", "config": { - **self.get_default_recipe_config().dict(), + **self.get_default_recipe_config().model_dump(), }, } } diff --git a/metadata-ingestion/tests/integration/powerbi/test_m_parser.py b/metadata-ingestion/tests/integration/powerbi/test_m_parser.py index 380b72763a2037..8bfa540f4d1276 100644 --- a/metadata-ingestion/tests/integration/powerbi/test_m_parser.py +++ b/metadata-ingestion/tests/integration/powerbi/test_m_parser.py @@ -101,7 +101,7 @@ def get_default_instances( ]: if override_config is None: override_config = {} - config: PowerBiDashboardSourceConfig = PowerBiDashboardSourceConfig.parse_obj( + config: PowerBiDashboardSourceConfig = PowerBiDashboardSourceConfig.model_validate( { "tenant_id": "fake", "client_id": "foo", diff --git a/metadata-ingestion/tests/integration/redshift-usage/test_redshift_usage.py b/metadata-ingestion/tests/integration/redshift-usage/test_redshift_usage.py index 3f6b8f57506334..80c0f26f365f0c 100644 --- a/metadata-ingestion/tests/integration/redshift-usage/test_redshift_usage.py +++ b/metadata-ingestion/tests/integration/redshift-usage/test_redshift_usage.py @@ -27,7 +27,7 @@ def test_redshift_usage_config(): - config = RedshiftConfig.parse_obj( + config = RedshiftConfig.model_validate( dict( host_port="xxxxx", database="xxxxx", diff --git a/metadata-ingestion/tests/integration/salesforce/test_salesforce.py b/metadata-ingestion/tests/integration/salesforce/test_salesforce.py index 5fcdc06fec054e..464b3d14424c9e 100644 --- a/metadata-ingestion/tests/integration/salesforce/test_salesforce.py +++ b/metadata-ingestion/tests/integration/salesforce/test_salesforce.py @@ -78,7 +78,7 @@ def test_latest_version(mock_sdk): mock_sf._call_salesforce = mocked_call mock_sdk.return_value = mock_sf - config = SalesforceConfig.parse_obj( + config = SalesforceConfig.model_validate( { "auth": "DIRECT_ACCESS_TOKEN", "instance_url": "https://mydomain.my.salesforce.com/", @@ -122,7 +122,7 @@ def test_custom_version(mock_sdk): mock_sf._call_salesforce = mocked_call mock_sdk.return_value = mock_sf - config = SalesforceConfig.parse_obj( + config = SalesforceConfig.model_validate( { "auth": "DIRECT_ACCESS_TOKEN", "api_version": "46.0", diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake_stateful.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake_stateful.py index df5fba0adbb740..13132216f77bf9 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake_stateful.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake_stateful.py @@ -33,7 +33,7 @@ def stateful_pipeline_config(include_tables: bool) -> PipelineConfig: include_tables=include_tables, incremental_lineage=False, use_queries_v2=False, - stateful_ingestion=StatefulStaleMetadataRemovalConfig.parse_obj( + stateful_ingestion=StatefulStaleMetadataRemovalConfig.model_validate( { "enabled": True, "remove_stale_metadata": True, diff --git a/metadata-ingestion/tests/integration/starburst-trino-usage/test_starburst_trino_usage.py b/metadata-ingestion/tests/integration/starburst-trino-usage/test_starburst_trino_usage.py index d1123107942bf5..4c65cac8ddc2f3 100644 --- a/metadata-ingestion/tests/integration/starburst-trino-usage/test_starburst_trino_usage.py +++ b/metadata-ingestion/tests/integration/starburst-trino-usage/test_starburst_trino_usage.py @@ -12,7 +12,7 @@ def test_trino_usage_config(): - config = TrinoUsageConfig.parse_obj( + config = TrinoUsageConfig.model_validate( dict( host_port="xxxxx", database="testcatalog", diff --git a/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py b/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py index 90420a5037b8ba..8129333351ec33 100644 --- a/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py +++ b/metadata-ingestion/tests/integration/tableau/test_tableau_ingest.py @@ -1022,7 +1022,7 @@ def test_hidden_assets_without_ingest_tags(pytestconfig, tmp_path, mock_datahub_ ValidationError, match=r".*tags_for_hidden_assets is only allowed with ingest_tags enabled.*", ): - TableauConfig.parse_obj(new_config) + TableauConfig.model_validate(new_config) @freeze_time(FROZEN_TIME) diff --git a/metadata-ingestion/tests/integration/trino/test_trino.py b/metadata-ingestion/tests/integration/trino/test_trino.py index 785b631299369e..e70da281d51ed9 100644 --- a/metadata-ingestion/tests/integration/trino/test_trino.py +++ b/metadata-ingestion/tests/integration/trino/test_trino.py @@ -111,11 +111,11 @@ def test_trino_ingest( platform_instance="local_server", ) }, - ).dict(), + ).model_dump(), }, "sink": { "type": "file", - "config": FileSinkConfig(filename=str(events_file)).dict(), + "config": FileSinkConfig(filename=str(events_file)).model_dump(), }, } @@ -161,11 +161,11 @@ def test_trino_hive_ingest( ], max_workers=1, ), - ).dict(), + ).model_dump(), }, "sink": { "type": "file", - "config": FileSinkConfig(filename=str(events_file)).dict(), + "config": FileSinkConfig(filename=str(events_file)).model_dump(), }, } @@ -221,11 +221,11 @@ def test_trino_instance_ingest( platform_instance="local_server", ) }, - ).dict(), + ).model_dump(), }, "sink": { "type": "file", - "config": FileSinkConfig(filename=str(events_file)).dict(), + "config": FileSinkConfig(filename=str(events_file)).model_dump(), }, } diff --git a/metadata-ingestion/tests/unit/api/entities/platformresource/test_platform_resource.py b/metadata-ingestion/tests/unit/api/entities/platformresource/test_platform_resource.py index 474bc7d1906992..cac955dfd14db3 100644 --- a/metadata-ingestion/tests/unit/api/entities/platformresource/test_platform_resource.py +++ b/metadata-ingestion/tests/unit/api/entities/platformresource/test_platform_resource.py @@ -178,7 +178,7 @@ class TestModel(BaseModel): == models.SerializedValueContentTypeClass.JSON ) assert platform_resource_info_mcp.aspect.value.blob == json.dumps( - test_model.dict(), sort_keys=True + test_model.model_dump(), sort_keys=True ).encode("utf-8") assert platform_resource_info_mcp.aspect.value.schemaType == "JSON" assert platform_resource_info_mcp.aspect.value.schemaRef == TestModel.__name__ diff --git a/metadata-ingestion/tests/unit/bigquery/test_bigquery_source.py b/metadata-ingestion/tests/unit/bigquery/test_bigquery_source.py index 2df57eeff8db5a..1bf4df43ae2959 100644 --- a/metadata-ingestion/tests/unit/bigquery/test_bigquery_source.py +++ b/metadata-ingestion/tests/unit/bigquery/test_bigquery_source.py @@ -62,7 +62,7 @@ def test_bigquery_uri(): - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "project_id": "test-project", } @@ -71,14 +71,14 @@ def test_bigquery_uri(): def test_bigquery_uri_on_behalf(): - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( {"project_id": "test-project", "project_on_behalf": "test-project-on-behalf"} ) assert config.get_sql_alchemy_url() == "bigquery://test-project-on-behalf" def test_bigquery_dataset_pattern(): - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "dataset_pattern": { "allow": [ @@ -103,7 +103,7 @@ def test_bigquery_dataset_pattern(): r"project\.second_dataset", ] - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "dataset_pattern": { "allow": [ @@ -144,7 +144,7 @@ def test_bigquery_uri_with_credential(): "type": "service_account", } - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "project_id": "test-project", "credential": { @@ -183,7 +183,7 @@ def test_get_projects_with_project_ids( ): client_mock = MagicMock() get_bq_client_mock.return_value = client_mock - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "project_ids": ["test-1", "test-2"], } @@ -199,7 +199,7 @@ def test_get_projects_with_project_ids( ] assert client_mock.list_projects.call_count == 0 - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( {"project_ids": ["test-1", "test-2"], "project_id": "test-3"} ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test2")) @@ -220,7 +220,7 @@ def test_get_projects_with_project_ids_overrides_project_id_pattern( get_projects_client, get_bigquery_client, ): - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "project_ids": ["test-project", "test-project-2"], "project_id_pattern": {"deny": ["^test-project$"]}, @@ -239,12 +239,12 @@ def test_get_projects_with_project_ids_overrides_project_id_pattern( def test_platform_instance_config_always_none(): - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( {"include_data_platform_instance": True, "platform_instance": "something"} ) assert config.platform_instance is None - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( dict(platform_instance="something", project_id="project_id") ) assert config.project_ids == ["project_id"] @@ -262,7 +262,7 @@ def test_get_dataplatform_instance_aspect_returns_project_id( f"urn:li:dataPlatformInstance:(urn:li:dataPlatform:bigquery,{project_id})" ) - config = BigQueryV2Config.parse_obj({"include_data_platform_instance": True}) + config = BigQueryV2Config.model_validate({"include_data_platform_instance": True}) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) schema_gen = source.bq_schema_extractor @@ -282,7 +282,7 @@ def test_get_dataplatform_instance_default_no_instance( get_projects_client, get_bq_client_mock, ): - config = BigQueryV2Config.parse_obj({}) + config = BigQueryV2Config.model_validate({}) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) schema_gen = source.bq_schema_extractor @@ -304,7 +304,7 @@ def test_get_projects_with_single_project_id( ): client_mock = MagicMock() get_bq_client_mock.return_value = client_mock - config = BigQueryV2Config.parse_obj({"project_id": "test-3"}) + config = BigQueryV2Config.model_validate({"project_id": "test-3"}) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test1")) assert get_projects( source.bq_schema_extractor.schema_api, @@ -342,7 +342,7 @@ def test_get_projects_by_list(get_projects_client, get_bigquery_client): client_mock.list_projects.side_effect = [first_page, second_page] - config = BigQueryV2Config.parse_obj({}) + config = BigQueryV2Config.model_validate({}) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test1")) assert get_projects( source.bq_schema_extractor.schema_api, @@ -368,7 +368,7 @@ def test_get_projects_filter_by_pattern( BigqueryProject("test-project-2", "Test Project 2"), ] - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( {"project_id_pattern": {"deny": ["^test-project$"]}} ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) @@ -390,7 +390,7 @@ def test_get_projects_list_empty( ): get_projects_mock.return_value = [] - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( {"project_id_pattern": {"deny": ["^test-project$"]}} ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) @@ -415,7 +415,7 @@ def test_get_projects_list_failure( get_bq_client_mock.return_value = bq_client_mock bq_client_mock.list_projects.side_effect = GoogleAPICallError(error_str) - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( {"project_id_pattern": {"deny": ["^test-project$"]}} ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) @@ -440,7 +440,7 @@ def test_get_projects_list_fully_filtered( ): get_projects_mock.return_value = [BigqueryProject("test-project", "Test Project")] - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( {"project_id_pattern": {"deny": ["^test-project$"]}} ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) @@ -480,7 +480,7 @@ def test_gen_table_dataset_workunits( ): project_id = "test-project" dataset_name = "test-dataset" - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "project_id": project_id, "capture_table_label_as_tag": True, @@ -620,7 +620,7 @@ def test_get_datasets_for_project_id_with_timestamps( }[ref] # Create BigQuerySchemaApi instance - config = BigQueryV2Config.parse_obj({"project_id": project_id}) + config = BigQueryV2Config.model_validate({"project_id": project_id}) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test")) schema_api = source.bq_schema_extractor.schema_api @@ -671,7 +671,7 @@ def test_simple_upstream_table_generation(get_bq_client_mock, get_projects_clien ) ) - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "project_id": "test-project", } @@ -707,7 +707,7 @@ def test_upstream_table_generation_with_temporary_table_without_temp_upstream( ) ) - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "project_id": "test-project", } @@ -748,7 +748,7 @@ def test_upstream_table_column_lineage_with_temp_table( ) ) - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "project_id": "test-project", } @@ -834,7 +834,7 @@ def test_upstream_table_generation_with_temporary_table_with_multiple_temp_upstr ) ) - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "project_id": "test-project", } @@ -875,7 +875,7 @@ def test_table_processing_logic( ): client_mock = MagicMock() get_bq_client_mock.return_value = client_mock - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "project_id": "test-project", } @@ -951,7 +951,7 @@ def test_table_processing_logic_date_named_tables( client_mock = MagicMock() get_bq_client_mock.return_value = client_mock # test that tables with date names are processed correctly - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "project_id": "test-project", } @@ -1113,7 +1113,7 @@ def test_gen_view_dataset_workunits( ): project_id = "test-project" dataset_name = "test-dataset" - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "project_id": project_id, } @@ -1214,7 +1214,7 @@ def test_gen_snapshot_dataset_workunits( ): project_id = "test-project" dataset_name = "test-dataset" - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "project_id": project_id, } @@ -1342,9 +1342,9 @@ def test_get_table_name(full_table_name: str, datahub_full_table_name: str) -> N def test_default_config_for_excluding_projects_and_datasets(): - config = BigQueryV2Config.parse_obj({}) + config = BigQueryV2Config.model_validate({}) assert config.exclude_empty_projects is False - config = BigQueryV2Config.parse_obj({"exclude_empty_projects": True}) + config = BigQueryV2Config.model_validate({"exclude_empty_projects": True}) assert config.exclude_empty_projects @@ -1378,11 +1378,13 @@ def get_datasets_for_project_id_side_effect( "include_table_lineage": False, } - config = BigQueryV2Config.parse_obj(base_config) + config = BigQueryV2Config.model_validate(base_config) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test-1")) assert len({wu.metadata.entityUrn for wu in source.get_workunits()}) == 2 # type: ignore - config = BigQueryV2Config.parse_obj({**base_config, "exclude_empty_projects": True}) + config = BigQueryV2Config.model_validate( + {**base_config, "exclude_empty_projects": True} + ) source = BigqueryV2Source(config=config, ctx=PipelineContext(run_id="test-2")) assert len({wu.metadata.entityUrn for wu in source.get_workunits()}) == 1 # type: ignore @@ -1393,21 +1395,21 @@ def test_bigquery_config_deprecated_schema_pattern(): "include_table_lineage": False, } - config = BigQueryV2Config.parse_obj(base_config) + config = BigQueryV2Config.model_validate(base_config) assert config.dataset_pattern == AllowDenyPattern(allow=[".*"]) # default config_with_schema_pattern = { **base_config, "schema_pattern": AllowDenyPattern(deny=[".*"]), } - config = BigQueryV2Config.parse_obj(config_with_schema_pattern) + config = BigQueryV2Config.model_validate(config_with_schema_pattern) assert config.dataset_pattern == AllowDenyPattern(deny=[".*"]) # schema_pattern config_with_dataset_pattern = { **base_config, "dataset_pattern": AllowDenyPattern(deny=["temp.*"]), } - config = BigQueryV2Config.parse_obj(config_with_dataset_pattern) + config = BigQueryV2Config.model_validate(config_with_dataset_pattern) assert config.dataset_pattern == AllowDenyPattern( deny=["temp.*"] ) # dataset_pattern @@ -1428,7 +1430,7 @@ def test_get_projects_with_project_labels( SimpleNamespace(project_id="qa", display_name="qa_project"), ] - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "project_labels": ["environment:dev", "environment:qa"], } diff --git a/metadata-ingestion/tests/unit/bigquery/test_bigqueryv2_usage_source.py b/metadata-ingestion/tests/unit/bigquery/test_bigqueryv2_usage_source.py index e6109e9292267a..961e7c512c1a40 100644 --- a/metadata-ingestion/tests/unit/bigquery/test_bigqueryv2_usage_source.py +++ b/metadata-ingestion/tests/unit/bigquery/test_bigqueryv2_usage_source.py @@ -33,7 +33,7 @@ def test_bigqueryv2_uri_with_credential(): "type": "service_account", } - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "project_id": "test-project", "stateful_ingestion": {"enabled": False}, @@ -66,7 +66,7 @@ def test_bigqueryv2_uri_with_credential(): @freeze_time(FROZEN_TIME) def test_bigqueryv2_filters(): - config = BigQueryV2Config.parse_obj( + config = BigQueryV2Config.model_validate( { "project_id": "test-project", "credential": { diff --git a/metadata-ingestion/tests/unit/cli/docker/test_quickstart_version_mapping.py b/metadata-ingestion/tests/unit/cli/docker/test_quickstart_version_mapping.py index f32104389f7318..94f8fddc97d129 100644 --- a/metadata-ingestion/tests/unit/cli/docker/test_quickstart_version_mapping.py +++ b/metadata-ingestion/tests/unit/cli/docker/test_quickstart_version_mapping.py @@ -6,7 +6,7 @@ QuickstartVersionMappingConfig, ) -example_version_mapper = QuickstartVersionMappingConfig.parse_obj( +example_version_mapper = QuickstartVersionMappingConfig.model_validate( { "quickstart_version_map": { "default": { @@ -112,7 +112,7 @@ def test_quickstart_version_older_than_v1_2_0_uses_commit_hash(): This exercises line 168 in quickstart_versioning.py. """ # Create a version mapping with a version older than v1.2.0 - version_mapper = QuickstartVersionMappingConfig.parse_obj( + version_mapper = QuickstartVersionMappingConfig.model_validate( { "quickstart_version_map": { "v1.1.0": { diff --git a/metadata-ingestion/tests/unit/config/test_config_enum.py b/metadata-ingestion/tests/unit/config/test_config_enum.py index f2ff6467bf0f02..04e0d05121c6d2 100644 --- a/metadata-ingestion/tests/unit/config/test_config_enum.py +++ b/metadata-ingestion/tests/unit/config/test_config_enum.py @@ -20,11 +20,11 @@ class FruitConfig(ConfigModel): assert Fruit.ORANGE.value == "ORANGE" # Check that config loading works. - assert FruitConfig.parse_obj({}).fruit == Fruit.APPLE - assert FruitConfig.parse_obj({"fruit": "PEAR"}).fruit == Fruit.PEAR - assert FruitConfig.parse_obj({"fruit": "pear"}).fruit == Fruit.PEAR - assert FruitConfig.parse_obj({"fruit": "Orange"}).fruit == Fruit.ORANGE + assert FruitConfig.model_validate({}).fruit == Fruit.APPLE + assert FruitConfig.model_validate({"fruit": "PEAR"}).fruit == Fruit.PEAR + assert FruitConfig.model_validate({"fruit": "pear"}).fruit == Fruit.PEAR + assert FruitConfig.model_validate({"fruit": "Orange"}).fruit == Fruit.ORANGE # Check that errors are thrown. with pytest.raises(pydantic.ValidationError): - FruitConfig.parse_obj({"fruit": "banana"}) + FruitConfig.model_validate({"fruit": "banana"}) diff --git a/metadata-ingestion/tests/unit/config/test_config_model.py b/metadata-ingestion/tests/unit/config/test_config_model.py index 9a45ccb5c9859d..70285763c7c209 100644 --- a/metadata-ingestion/tests/unit/config/test_config_model.py +++ b/metadata-ingestion/tests/unit/config/test_config_model.py @@ -15,11 +15,11 @@ class MyConfig(ConfigModel): required: str optional: str = "bar" - MyConfig.parse_obj({"required": "foo"}) - MyConfig.parse_obj({"required": "foo", "optional": "baz"}) + MyConfig.model_validate({"required": "foo"}) + MyConfig.model_validate({"required": "foo", "optional": "baz"}) with pytest.raises(pydantic.ValidationError): - MyConfig.parse_obj({"required": "foo", "extra": "extra"}) + MyConfig.model_validate({"required": "foo", "extra": "extra"}) def test_extras_allowed(): diff --git a/metadata-ingestion/tests/unit/config/test_connection_resolver.py b/metadata-ingestion/tests/unit/config/test_connection_resolver.py index 0ff23f7cc842c7..f4811b6e0b049c 100644 --- a/metadata-ingestion/tests/unit/config/test_connection_resolver.py +++ b/metadata-ingestion/tests/unit/config/test_connection_resolver.py @@ -17,7 +17,7 @@ class MyConnectionType(ConfigModel): def test_auto_connection_resolver(): # Test a normal config. - config = MyConnectionType.parse_obj( + config = MyConnectionType.model_validate( {"username": "test_user", "password": "test_password"} ) assert config.username == "test_user" @@ -25,7 +25,7 @@ def test_auto_connection_resolver(): # No graph context -> should raise an error. with pytest.raises(pydantic.ValidationError, match=r"requires a .*graph"): - config = MyConnectionType.parse_obj( + config = MyConnectionType.model_validate( { "connection": "test_connection", } @@ -38,7 +38,7 @@ def test_auto_connection_resolver(): set_graph_context(fake_graph), pytest.raises(pydantic.ValidationError, match=r"not found"), ): - config = MyConnectionType.parse_obj( + config = MyConnectionType.model_validate( { "connection": "urn:li:dataHubConnection:missing-connection", } @@ -47,7 +47,7 @@ def test_auto_connection_resolver(): # Bad connection config -> should raise an error. fake_graph.get_connection_json.return_value = {"bad_key": "bad_value"} with set_graph_context(fake_graph), pytest.raises(pydantic.ValidationError): - config = MyConnectionType.parse_obj( + config = MyConnectionType.model_validate( { "connection": "urn:li:dataHubConnection:bad-connection", } @@ -59,7 +59,7 @@ def test_auto_connection_resolver(): "password": "test_password", } with set_graph_context(fake_graph): - config = MyConnectionType.parse_obj( + config = MyConnectionType.model_validate( { "connection": "urn:li:dataHubConnection:good-connection", "username": "override_user", diff --git a/metadata-ingestion/tests/unit/config/test_pydantic_validators.py b/metadata-ingestion/tests/unit/config/test_pydantic_validators.py index f687a2776f6e2d..1822f65ecc055f 100644 --- a/metadata-ingestion/tests/unit/config/test_pydantic_validators.py +++ b/metadata-ingestion/tests/unit/config/test_pydantic_validators.py @@ -21,18 +21,18 @@ class TestModel(ConfigModel): _validate_rename = pydantic_renamed_field("a", "b") - v = TestModel.parse_obj({"b": "original"}) + v = TestModel.model_validate({"b": "original"}) assert v.b == "original" with pytest.warns(ConfigurationWarning, match="a is deprecated"): - v = TestModel.parse_obj({"a": "renamed"}) + v = TestModel.model_validate({"a": "renamed"}) assert v.b == "renamed" with pytest.raises(ValidationError): - TestModel.parse_obj({"a": "foo", "b": "bar"}) + TestModel.model_validate({"a": "foo", "b": "bar"}) with pytest.raises(ValidationError): - TestModel.parse_obj({}) + TestModel.model_validate({}) def test_field_multiple_fields_rename(): @@ -43,29 +43,29 @@ class TestModel(ConfigModel): _validate_deprecated = pydantic_renamed_field("a", "b") _validate_deprecated1 = pydantic_renamed_field("a1", "b1") - v = TestModel.parse_obj({"b": "original", "b1": "original"}) + v = TestModel.model_validate({"b": "original", "b1": "original"}) assert v.b == "original" assert v.b1 == "original" with pytest.warns(ConfigurationWarning, match=r"a.* is deprecated"): - v = TestModel.parse_obj({"a": "renamed", "a1": "renamed"}) + v = TestModel.model_validate({"a": "renamed", "a1": "renamed"}) assert v.b == "renamed" assert v.b1 == "renamed" with pytest.raises(ValidationError): - TestModel.parse_obj({"a": "foo", "b": "bar", "b1": "ok"}) + TestModel.model_validate({"a": "foo", "b": "bar", "b1": "ok"}) with pytest.raises(ValidationError): - TestModel.parse_obj({"a1": "foo", "b1": "bar", "b": "ok"}) + TestModel.model_validate({"a1": "foo", "b1": "bar", "b": "ok"}) with pytest.raises(ValidationError): - TestModel.parse_obj({"b": "foo"}) + TestModel.model_validate({"b": "foo"}) with pytest.raises(ValidationError): - TestModel.parse_obj({"b1": "foo"}) + TestModel.model_validate({"b1": "foo"}) with pytest.raises(ValidationError): - TestModel.parse_obj({}) + TestModel.model_validate({}) def test_field_remove(): @@ -75,11 +75,13 @@ class TestModel(ConfigModel): _validate_removed_r1 = pydantic_removed_field("r1") _validate_removed_r2 = pydantic_removed_field("r2") - v = TestModel.parse_obj({"b": "original"}) + v = TestModel.model_validate({"b": "original"}) assert v.b == "original" with pytest.warns(ConfigurationWarning, match=r"r\d was removed"): - v = TestModel.parse_obj({"b": "original", "r1": "removed", "r2": "removed"}) + v = TestModel.model_validate( + {"b": "original", "r1": "removed", "r2": "removed"} + ) assert v.b == "original" @@ -94,11 +96,11 @@ class TestModel(ConfigModel): _validate_deprecated_d1 = pydantic_field_deprecated("d1") _validate_deprecated_d2 = pydantic_field_deprecated("d2") - v = TestModel.parse_obj({"b": "original"}) + v = TestModel.model_validate({"b": "original"}) assert v.b == "original" with pytest.warns(ConfigurationWarning, match=r"d\d.+ deprecated"): - v = TestModel.parse_obj( + v = TestModel.model_validate( {"b": "original", "d1": "deprecated", "d2": "deprecated"} ) assert v.b == "original" @@ -118,17 +120,17 @@ class TestModel(ConfigModel): _validate_s = pydantic_multiline_string("s") _validate_m = pydantic_multiline_string("m") - v = TestModel.parse_obj({"s": "foo\nbar"}) + v = TestModel.model_validate({"s": "foo\nbar"}) assert v.s == "foo\nbar" - v = TestModel.parse_obj({"s": "foo\\nbar"}) + v = TestModel.model_validate({"s": "foo\\nbar"}) assert v.s == "foo\nbar" - v = TestModel.parse_obj({"s": "normal", "m": "foo\\nbar"}) + v = TestModel.model_validate({"s": "normal", "m": "foo\\nbar"}) assert v.s == "normal" assert v.m assert v.m.get_secret_value() == "foo\nbar" - v = TestModel.parse_obj({"s": "normal", "m": pydantic.SecretStr("foo\\nbar")}) + v = TestModel.model_validate({"s": "normal", "m": pydantic.SecretStr("foo\\nbar")}) assert v.m assert v.m.get_secret_value() == "foo\nbar" diff --git a/metadata-ingestion/tests/unit/config/test_time_window_config.py b/metadata-ingestion/tests/unit/config/test_time_window_config.py index 847bda2511a0ce..e36865880f371d 100644 --- a/metadata-ingestion/tests/unit/config/test_time_window_config.py +++ b/metadata-ingestion/tests/unit/config/test_time_window_config.py @@ -11,35 +11,35 @@ @freeze_time(FROZEN_TIME) def test_default_start_end_time(): - config = BaseTimeWindowConfig.parse_obj({}) + config = BaseTimeWindowConfig.model_validate({}) assert config.start_time == datetime(2023, 8, 2, 0, tzinfo=timezone.utc) assert config.end_time == datetime(2023, 8, 3, 9, tzinfo=timezone.utc) @freeze_time(FROZEN_TIME2) def test_default_start_end_time_hour_bucket_duration(): - config = BaseTimeWindowConfig.parse_obj({"bucket_duration": "HOUR"}) + config = BaseTimeWindowConfig.model_validate({"bucket_duration": "HOUR"}) assert config.start_time == datetime(2023, 8, 3, 8, tzinfo=timezone.utc) assert config.end_time == datetime(2023, 8, 3, 9, 10, tzinfo=timezone.utc) @freeze_time(FROZEN_TIME) def test_relative_start_time(): - config = BaseTimeWindowConfig.parse_obj({"start_time": "-2 days"}) + config = BaseTimeWindowConfig.model_validate({"start_time": "-2 days"}) assert config.start_time == datetime(2023, 8, 1, 0, tzinfo=timezone.utc) assert config.end_time == datetime(2023, 8, 3, 9, tzinfo=timezone.utc) - config = BaseTimeWindowConfig.parse_obj({"start_time": "-2d"}) + config = BaseTimeWindowConfig.model_validate({"start_time": "-2d"}) assert config.start_time == datetime(2023, 8, 1, 0, tzinfo=timezone.utc) assert config.end_time == datetime(2023, 8, 3, 9, tzinfo=timezone.utc) - config = BaseTimeWindowConfig.parse_obj( + config = BaseTimeWindowConfig.model_validate( {"start_time": "-2 days", "end_time": "2023-07-07T09:00:00Z"} ) assert config.start_time == datetime(2023, 7, 5, 0, tzinfo=timezone.utc) assert config.end_time == datetime(2023, 7, 7, 9, tzinfo=timezone.utc) - config = BaseTimeWindowConfig.parse_obj( + config = BaseTimeWindowConfig.model_validate( {"start_time": "-2 days", "end_time": "2023-07-07T09:00:00Z"} ) assert config.start_time == datetime(2023, 7, 5, 0, tzinfo=timezone.utc) @@ -48,11 +48,11 @@ def test_relative_start_time(): @freeze_time(FROZEN_TIME) def test_absolute_start_time(): - config = BaseTimeWindowConfig.parse_obj({"start_time": "2023-07-01T00:00:00Z"}) + config = BaseTimeWindowConfig.model_validate({"start_time": "2023-07-01T00:00:00Z"}) assert config.start_time == datetime(2023, 7, 1, 0, tzinfo=timezone.utc) assert config.end_time == datetime(2023, 8, 3, 9, tzinfo=timezone.utc) - config = BaseTimeWindowConfig.parse_obj({"start_time": "2023-07-01T09:00:00Z"}) + config = BaseTimeWindowConfig.model_validate({"start_time": "2023-07-01T09:00:00Z"}) assert config.start_time == datetime(2023, 7, 1, 9, tzinfo=timezone.utc) assert config.end_time == datetime(2023, 8, 3, 9, tzinfo=timezone.utc) @@ -60,21 +60,21 @@ def test_absolute_start_time(): @freeze_time(FROZEN_TIME) def test_invalid_relative_start_time(): with pytest.raises(ValueError, match="Unknown string format"): - BaseTimeWindowConfig.parse_obj({"start_time": "-2 das"}) + BaseTimeWindowConfig.model_validate({"start_time": "-2 das"}) with pytest.raises( ValueError, match="Relative start time should be in terms of configured bucket duration", ): - BaseTimeWindowConfig.parse_obj({"start_time": "-2"}) + BaseTimeWindowConfig.model_validate({"start_time": "-2"}) with pytest.raises( ValueError, match="Relative start time should start with minus sign" ): - BaseTimeWindowConfig.parse_obj({"start_time": "2d"}) + BaseTimeWindowConfig.model_validate({"start_time": "2d"}) with pytest.raises( ValueError, match="Relative start time should be in terms of configured bucket duration", ): - BaseTimeWindowConfig.parse_obj({"start_time": "-2m"}) + BaseTimeWindowConfig.model_validate({"start_time": "-2m"}) diff --git a/metadata-ingestion/tests/unit/data_lake/test_path_spec.py b/metadata-ingestion/tests/unit/data_lake/test_path_spec.py index 9b8d512102b761..d89cecfe6983d8 100644 --- a/metadata-ingestion/tests/unit/data_lake/test_path_spec.py +++ b/metadata-ingestion/tests/unit/data_lake/test_path_spec.py @@ -1,6 +1,7 @@ from unittest.mock import patch import pytest +from pydantic import ValidationError from datahub.configuration.common import AllowDenyPattern from datahub.ingestion.source.data_lake_common.path_spec import ( @@ -594,14 +595,16 @@ def test_table_name_custom_invalid() -> None: def test_validate_path_spec_missing_fields() -> None: """Test path_spec validation with missing required fields.""" - with patch( - "datahub.ingestion.source.data_lake_common.path_spec.logger" - ) as mock_logger: - # This should not raise an error but log debug messages - values = {"file_types": ["csv"]} # Missing include and default_extension - result = PathSpec.validate_path_spec(values) - assert result == values - mock_logger.debug.assert_called() + # This should raises a validation error since include is required + # In Pydantic v1 this was just logged as a debug message + # In Pydantic v2, keeping same behavior would require making include Optional with a default of None, whihc + # would be a breaking change in the model; also, along the code there are many places that assume include is always set. + values = {"file_types": ["csv"]} # Missing include (now required) + + with pytest.raises(ValidationError) as exc_info: + PathSpec(**values) + # Verify that the validation error mentions the missing include field + assert "include" in str(exc_info.value) def test_validate_path_spec_autodetect_partitions() -> None: @@ -612,8 +615,8 @@ def test_validate_path_spec_autodetect_partitions() -> None: "file_types": ["csv"], "default_extension": None, } - result = PathSpec.validate_path_spec(values) - assert result["include"] == "s3://bucket/{table}/**" + result = PathSpec(**values) + assert result.include == "s3://bucket/{table}/**" def test_validate_path_spec_autodetect_partitions_with_slash() -> None: @@ -624,8 +627,8 @@ def test_validate_path_spec_autodetect_partitions_with_slash() -> None: "file_types": ["csv"], "default_extension": None, } - result = PathSpec.validate_path_spec(values) - assert result["include"] == "s3://bucket/{table}/**" + result = PathSpec(**values) + assert result.include == "s3://bucket/{table}/**" def test_validate_path_spec_invalid_extension() -> None: diff --git a/metadata-ingestion/tests/unit/dbt/test_dbt_source.py b/metadata-ingestion/tests/unit/dbt/test_dbt_source.py index f8731f91ec7e7c..5aee41d44ff57f 100644 --- a/metadata-ingestion/tests/unit/dbt/test_dbt_source.py +++ b/metadata-ingestion/tests/unit/dbt/test_dbt_source.py @@ -198,7 +198,7 @@ def test_dbt_entity_emission_configuration(): ValidationError, match="Cannot have more than 1 type of entity emission set to ONLY", ): - DBTCoreConfig.parse_obj(config_dict) + DBTCoreConfig.model_validate(config_dict) # valid config config_dict = { @@ -207,7 +207,7 @@ def test_dbt_entity_emission_configuration(): "target_platform": "dummy_platform", "entities_enabled": {"models": "Yes", "seeds": "Only"}, } - DBTCoreConfig.parse_obj(config_dict) + DBTCoreConfig.model_validate(config_dict) def test_dbt_config_skip_sources_in_lineage(): @@ -221,7 +221,7 @@ def test_dbt_config_skip_sources_in_lineage(): "target_platform": "dummy_platform", "skip_sources_in_lineage": True, } - config = DBTCoreConfig.parse_obj(config_dict) + config = DBTCoreConfig.model_validate(config_dict) config_dict = { "manifest_path": "dummy_path", @@ -230,7 +230,7 @@ def test_dbt_config_skip_sources_in_lineage(): "skip_sources_in_lineage": True, "entities_enabled": {"sources": "NO"}, } - config = DBTCoreConfig.parse_obj(config_dict) + config = DBTCoreConfig.model_validate(config_dict) assert config.skip_sources_in_lineage is True @@ -245,7 +245,7 @@ def test_dbt_config_prefer_sql_parser_lineage(): "target_platform": "dummy_platform", "prefer_sql_parser_lineage": True, } - config = DBTCoreConfig.parse_obj(config_dict) + config = DBTCoreConfig.model_validate(config_dict) config_dict = { "manifest_path": "dummy_path", @@ -254,14 +254,14 @@ def test_dbt_config_prefer_sql_parser_lineage(): "skip_sources_in_lineage": True, "prefer_sql_parser_lineage": True, } - config = DBTCoreConfig.parse_obj(config_dict) + config = DBTCoreConfig.model_validate(config_dict) assert config.skip_sources_in_lineage is True assert config.prefer_sql_parser_lineage is True def test_dbt_prefer_sql_parser_lineage_no_self_reference(): ctx = PipelineContext(run_id="test-run-id") - config = DBTCoreConfig.parse_obj( + config = DBTCoreConfig.model_validate( { **create_base_dbt_config(), "skip_sources_in_lineage": True, @@ -302,7 +302,7 @@ def test_dbt_prefer_sql_parser_lineage_no_self_reference(): def test_dbt_cll_skip_python_model() -> None: ctx = PipelineContext(run_id="test-run-id") - config = DBTCoreConfig.parse_obj(create_base_dbt_config()) + config = DBTCoreConfig.model_validate(create_base_dbt_config()) source: DBTCoreSource = DBTCoreSource(config, ctx) all_nodes_map = { "model1": DBTNode( @@ -341,7 +341,7 @@ def test_dbt_s3_config(): "target_platform": "dummy_platform", } with pytest.raises(ValidationError, match="provide aws_connection"): - DBTCoreConfig.parse_obj(config_dict) + DBTCoreConfig.model_validate(config_dict) # valid config config_dict = { @@ -350,7 +350,7 @@ def test_dbt_s3_config(): "target_platform": "dummy_platform", "aws_connection": {}, } - DBTCoreConfig.parse_obj(config_dict) + DBTCoreConfig.model_validate(config_dict) def test_default_convert_column_urns_to_lowercase(): @@ -361,14 +361,16 @@ def test_default_convert_column_urns_to_lowercase(): "entities_enabled": {"models": "Yes", "seeds": "Only"}, } - config = DBTCoreConfig.parse_obj({**config_dict}) + config = DBTCoreConfig.model_validate({**config_dict}) assert config.convert_column_urns_to_lowercase is False - config = DBTCoreConfig.parse_obj({**config_dict, "target_platform": "snowflake"}) + config = DBTCoreConfig.model_validate( + {**config_dict, "target_platform": "snowflake"} + ) assert config.convert_column_urns_to_lowercase is True # Check that we respect the user's setting if provided. - config = DBTCoreConfig.parse_obj( + config = DBTCoreConfig.model_validate( { **config_dict, "convert_column_urns_to_lowercase": False, @@ -387,7 +389,7 @@ def test_dbt_entity_emission_configuration_helpers(): "models": "Only", }, } - config = DBTCoreConfig.parse_obj(config_dict) + config = DBTCoreConfig.model_validate(config_dict) assert config.entities_enabled.can_emit_node_type("model") assert not config.entities_enabled.can_emit_node_type("source") assert not config.entities_enabled.can_emit_node_type("test") @@ -400,7 +402,7 @@ def test_dbt_entity_emission_configuration_helpers(): "catalog_path": "dummy_path", "target_platform": "dummy_platform", } - config = DBTCoreConfig.parse_obj(config_dict) + config = DBTCoreConfig.model_validate(config_dict) assert config.entities_enabled.can_emit_node_type("model") assert config.entities_enabled.can_emit_node_type("source") assert config.entities_enabled.can_emit_node_type("test") @@ -416,7 +418,7 @@ def test_dbt_entity_emission_configuration_helpers(): "test_results": "Only", }, } - config = DBTCoreConfig.parse_obj(config_dict) + config = DBTCoreConfig.model_validate(config_dict) assert not config.entities_enabled.can_emit_node_type("model") assert not config.entities_enabled.can_emit_node_type("source") assert not config.entities_enabled.can_emit_node_type("test") @@ -436,7 +438,7 @@ def test_dbt_entity_emission_configuration_helpers(): "sources": "No", }, } - config = DBTCoreConfig.parse_obj(config_dict) + config = DBTCoreConfig.model_validate(config_dict) assert not config.entities_enabled.can_emit_node_type("model") assert not config.entities_enabled.can_emit_node_type("source") assert config.entities_enabled.can_emit_node_type("test") @@ -455,7 +457,7 @@ def test_dbt_cloud_config_access_url(): "run_id": "123456789", "target_platform": "dummy_platform", } - config = DBTCloudConfig.parse_obj(config_dict) + config = DBTCloudConfig.model_validate(config_dict) assert config.access_url == "https://emea.getdbt.com" assert config.metadata_endpoint == "https://metadata.emea.getdbt.com/graphql" @@ -471,7 +473,7 @@ def test_dbt_cloud_config_with_defined_metadata_endpoint(): "target_platform": "dummy_platform", "metadata_endpoint": "https://my-metadata-endpoint.my-dbt-cloud.dbt.com/graphql", } - config = DBTCloudConfig.parse_obj(config_dict) + config = DBTCloudConfig.model_validate(config_dict) assert config.access_url == "https://my-dbt-cloud.dbt.com" assert ( config.metadata_endpoint @@ -534,7 +536,7 @@ def test_include_database_name_default(): "catalog_path": "dummy_path", "target_platform": "dummy_platform", } - config = DBTCoreConfig.parse_obj({**config_dict}) + config = DBTCoreConfig.model_validate({**config_dict}) assert config.include_database_name is True @@ -548,7 +550,7 @@ def test_include_database_name(include_database_name: str, expected: bool) -> No "target_platform": "dummy_platform", } config_dict.update({"include_database_name": include_database_name}) - config = DBTCoreConfig.parse_obj({**config_dict}) + config = DBTCoreConfig.model_validate({**config_dict}) assert config.include_database_name is expected @@ -636,7 +638,7 @@ class SharedDBTNodeFields(TypedDict, total=False): # Test the method ctx = PipelineContext(run_id="test-run-id", pipeline_name="dbt-source") - config = DBTCoreConfig.parse_obj(create_base_dbt_config()) + config = DBTCoreConfig.model_validate(create_base_dbt_config()) source: DBTCoreSource = DBTCoreSource(config, ctx) result_nodes: List[DBTNode] = source._drop_duplicate_sources(original_nodes) diff --git a/metadata-ingestion/tests/unit/glue/test_aws_common.py b/metadata-ingestion/tests/unit/glue/test_aws_common.py index bfe9207e3a78ef..4b6daac95d54b9 100644 --- a/metadata-ingestion/tests/unit/glue/test_aws_common.py +++ b/metadata-ingestion/tests/unit/glue/test_aws_common.py @@ -74,7 +74,7 @@ def test_dict_method(self) -> None: RoleArn="arn:aws:iam::123456789012:role/TestRole", ExternalId="external-id-123", ) - config_dict = config.dict() + config_dict = config.model_dump() assert config_dict["RoleArn"] == "arn:aws:iam::123456789012:role/TestRole" assert config_dict["ExternalId"] == "external-id-123" diff --git a/metadata-ingestion/tests/unit/grafana/test_grafana_models.py b/metadata-ingestion/tests/unit/grafana/test_grafana_models.py index 9633713ff13f70..20f69d75136d33 100644 --- a/metadata-ingestion/tests/unit/grafana/test_grafana_models.py +++ b/metadata-ingestion/tests/unit/grafana/test_grafana_models.py @@ -19,7 +19,7 @@ def test_panel_basic(): "transformations": [], } - panel = Panel.parse_obj(panel_data) + panel = Panel.model_validate(panel_data) assert panel.id == "1" assert panel.title == "Test Panel" assert panel.description == "Test Description" @@ -34,7 +34,7 @@ def test_panel_with_datasource(): "datasource": {"type": "postgres", "uid": "abc123"}, } - panel = Panel.parse_obj(panel_data) + panel = Panel.model_validate(panel_data) assert panel.datasource_ref is not None assert panel.datasource_ref.type == "postgres" assert panel.datasource_ref.uid == "abc123" @@ -55,7 +55,7 @@ def test_dashboard_basic(): } } - dashboard = Dashboard.parse_obj(dashboard_data) + dashboard = Dashboard.model_validate(dashboard_data) assert dashboard.uid == "dash1" assert dashboard.title == "Test Dashboard" assert dashboard.version == "1" @@ -82,7 +82,7 @@ def test_dashboard_nested_panels(): } } - dashboard = Dashboard.parse_obj(dashboard_data) + dashboard = Dashboard.model_validate(dashboard_data) assert len(dashboard.panels) == 2 assert dashboard.panels[0].title == "Nested Panel" assert dashboard.panels[1].title == "Top Level Panel" @@ -91,7 +91,7 @@ def test_dashboard_nested_panels(): def test_folder(): folder_data = {"id": "1", "title": "Test Folder", "description": "Test Description"} - folder = Folder.parse_obj(folder_data) + folder = Folder.model_validate(folder_data) assert folder.id == "1" assert folder.title == "Test Folder" assert folder.description == "Test Description" diff --git a/metadata-ingestion/tests/unit/hex/test_api.py b/metadata-ingestion/tests/unit/hex/test_api.py index 3d705171c529be..5b83b01fc2d1be 100644 --- a/metadata-ingestion/tests/unit/hex/test_api.py +++ b/metadata-ingestion/tests/unit/hex/test_api.py @@ -119,7 +119,7 @@ def test_map_data_project(self): base_url=self.base_url, ) - hex_api_project = HexApiProjectApiResource.parse_obj(project_data) + hex_api_project = HexApiProjectApiResource.model_validate(project_data) result = hex_api._map_data_from_model(hex_api_project) # Verify the result @@ -182,7 +182,7 @@ def test_map_data_component(self): base_url=self.base_url, ) - hex_api_component = HexApiProjectApiResource.parse_obj(component_data) + hex_api_component = HexApiProjectApiResource.model_validate(component_data) result = hex_api._map_data_from_model(hex_api_component) # Verify the result @@ -247,13 +247,13 @@ def test_fetch_projects_failure_http_error(self): ) assert failures[0].context - @patch("datahub.ingestion.source.hex.api.HexApiProjectsListResponse.parse_obj") + @patch("datahub.ingestion.source.hex.api.HexApiProjectsListResponse.model_validate") def test_fetch_projects_failure_response_validation(self, mock_parse_obj): # Create a dummy http response mock_response = MagicMock() mock_response.json.return_value = {"whatever": "json"} # and simulate ValidationError when parsing the response - mock_parse_obj.side_effect = lambda _: HexApiProjectApiResource.parse_obj( + mock_parse_obj.side_effect = lambda _: HexApiProjectApiResource.model_validate( {} ) # will raise ValidationError @@ -285,7 +285,7 @@ def test_fetch_projects_failure_response_validation(self, mock_parse_obj): ) assert failures[0].context - @patch("datahub.ingestion.source.hex.api.HexApiProjectsListResponse.parse_obj") + @patch("datahub.ingestion.source.hex.api.HexApiProjectsListResponse.model_validate") @patch("datahub.ingestion.source.hex.api.HexApi._map_data_from_model") def test_fetch_projects_warning_model_mapping( self, mock_map_data_from_model, mock_parse_obj diff --git a/metadata-ingestion/tests/unit/hex/test_hex.py b/metadata-ingestion/tests/unit/hex/test_hex.py index ded2b078b0c9e6..f3a03d69d6873d 100644 --- a/metadata-ingestion/tests/unit/hex/test_hex.py +++ b/metadata-ingestion/tests/unit/hex/test_hex.py @@ -27,30 +27,30 @@ def test_required_fields(self): with self.assertRaises(ValueError): input_config = {**self.minimum_input_config} del input_config["workspace_name"] - HexSourceConfig.parse_obj(input_config) + HexSourceConfig.model_validate(input_config) with self.assertRaises(ValueError): input_config = {**self.minimum_input_config} del input_config["token"] - HexSourceConfig.parse_obj(input_config) + HexSourceConfig.model_validate(input_config) def test_minimum_config(self): - config = HexSourceConfig.parse_obj(self.minimum_input_config) + config = HexSourceConfig.model_validate(self.minimum_input_config) assert config assert config.workspace_name == "test-workspace" assert config.token.get_secret_value() == "test-token" def test_lineage_config(self): - config = HexSourceConfig.parse_obj(self.minimum_input_config) + config = HexSourceConfig.model_validate(self.minimum_input_config) assert config and config.include_lineage input_config = {**self.minimum_input_config, "include_lineage": False} - config = HexSourceConfig.parse_obj(input_config) + config = HexSourceConfig.model_validate(input_config) assert config and not config.include_lineage # default values for lineage_start_time and lineage_end_time - config = HexSourceConfig.parse_obj(self.minimum_input_config) + config = HexSourceConfig.model_validate(self.minimum_input_config) assert ( config.lineage_start_time and isinstance(config.lineage_start_time, datetime) @@ -72,7 +72,7 @@ def test_lineage_config(self): "lineage_start_time": "2025-03-24 12:00:00", "lineage_end_time": "2025-03-25 12:00:00", } - config = HexSourceConfig.parse_obj(input_config) + config = HexSourceConfig.model_validate(input_config) assert ( config.lineage_start_time and isinstance(config.lineage_start_time, datetime) @@ -94,7 +94,7 @@ def test_lineage_config(self): **self.minimum_input_config, "lineage_end_time": "2025-03-25 12:00:00", } - config = HexSourceConfig.parse_obj(input_config) + config = HexSourceConfig.model_validate(input_config) assert ( config.lineage_start_time and isinstance(config.lineage_start_time, datetime) @@ -117,7 +117,7 @@ def test_lineage_config(self): **self.minimum_input_config, "lineage_start_time": "2025-03-25 12:00:00", } - config = HexSourceConfig.parse_obj(input_config) + config = HexSourceConfig.model_validate(input_config) assert ( config.lineage_start_time and isinstance(config.lineage_start_time, datetime) @@ -139,7 +139,7 @@ def test_lineage_config(self): "lineage_start_time": "-3day", "lineage_end_time": "now", } - config = HexSourceConfig.parse_obj(input_config) + config = HexSourceConfig.model_validate(input_config) assert ( config.lineage_start_time and isinstance(config.lineage_start_time, datetime) diff --git a/metadata-ingestion/tests/unit/powerbi/test_config.py b/metadata-ingestion/tests/unit/powerbi/test_config.py index 31bb3a6f9f9e38..0263384b9e7521 100644 --- a/metadata-ingestion/tests/unit/powerbi/test_config.py +++ b/metadata-ingestion/tests/unit/powerbi/test_config.py @@ -113,7 +113,7 @@ def test_dsn_to_database_schema_config( expected_config_or_exception, Exception ): with pytest.raises(expected_config_or_exception): - PowerBiDashboardSourceConfig.parse_obj(config_dict) + PowerBiDashboardSourceConfig.model_validate(config_dict) else: - config = PowerBiDashboardSourceConfig.parse_obj(config_dict) + config = PowerBiDashboardSourceConfig.model_validate(config_dict) assert config == expected_config_or_exception diff --git a/metadata-ingestion/tests/unit/reporting/test_datahub_ingestion_reporter.py b/metadata-ingestion/tests/unit/reporting/test_datahub_ingestion_reporter.py index 2ab6208e2dcc68..d23f498e22f5f7 100644 --- a/metadata-ingestion/tests/unit/reporting/test_datahub_ingestion_reporter.py +++ b/metadata-ingestion/tests/unit/reporting/test_datahub_ingestion_reporter.py @@ -49,7 +49,7 @@ def test_unique_key_gen(pipeline_config, expected_key): def test_default_config(): - typed_config = DatahubIngestionRunSummaryProviderConfig.parse_obj({}) + typed_config = DatahubIngestionRunSummaryProviderConfig.model_validate({}) assert typed_config.sink is None assert typed_config.report_recipe is True diff --git a/metadata-ingestion/tests/unit/sdk/test_kafka_emitter.py b/metadata-ingestion/tests/unit/sdk/test_kafka_emitter.py index 8154f179c36b8e..bd9c06c303bf82 100644 --- a/metadata-ingestion/tests/unit/sdk/test_kafka_emitter.py +++ b/metadata-ingestion/tests/unit/sdk/test_kafka_emitter.py @@ -14,7 +14,7 @@ class KafkaEmitterTest(unittest.TestCase): def test_kafka_emitter_config(self): - emitter_config = KafkaEmitterConfig.parse_obj( + emitter_config = KafkaEmitterConfig.model_validate( {"connection": {"bootstrap": "foobar:9092"}} ) assert emitter_config.topic_routes[MCE_KEY] == DEFAULT_MCE_KAFKA_TOPIC @@ -26,7 +26,7 @@ def test_kafka_emitter_config(self): def test_kafka_emitter_config_old_and_new(self): with pytest.raises(pydantic.ValidationError): - KafkaEmitterConfig.parse_obj( + KafkaEmitterConfig.model_validate( { "connection": {"bootstrap": "foobar:9092"}, "topic": "NewTopic", @@ -39,7 +39,7 @@ def test_kafka_emitter_config_old_and_new(self): """ def test_kafka_emitter_config_topic_upgrade(self): - emitter_config = KafkaEmitterConfig.parse_obj( + emitter_config = KafkaEmitterConfig.model_validate( {"connection": {"bootstrap": "foobar:9092"}, "topic": "NewTopic"} ) assert emitter_config.topic_routes[MCE_KEY] == "NewTopic" # MCE topic upgraded diff --git a/metadata-ingestion/tests/unit/sdk_v2/test_entity_client.py b/metadata-ingestion/tests/unit/sdk_v2/test_entity_client.py index ec903df0139a60..4f7c8f76e1e30f 100644 --- a/metadata-ingestion/tests/unit/sdk_v2/test_entity_client.py +++ b/metadata-ingestion/tests/unit/sdk_v2/test_entity_client.py @@ -44,7 +44,7 @@ def assert_client_golden(client: DataHubClient, golden_path: pathlib.Path) -> No def test_container_creation_flow(client: DataHubClient, mock_graph: Mock) -> None: # Create database and schema containers db = DatabaseKey(platform="snowflake", database="test_db") - schema = SchemaKey(**db.dict(), schema="test_schema") + schema = SchemaKey(**db.model_dump(), schema="test_schema") db_container = Container(db, display_name="test_db", subtype="Database") schema_container = Container(schema, display_name="test_schema", subtype="Schema") diff --git a/metadata-ingestion/tests/unit/snowflake/test_snowflake_dynamic_tables.py b/metadata-ingestion/tests/unit/snowflake/test_snowflake_dynamic_tables.py index 2e405a474447f1..2198fc7e75498a 100644 --- a/metadata-ingestion/tests/unit/snowflake/test_snowflake_dynamic_tables.py +++ b/metadata-ingestion/tests/unit/snowflake/test_snowflake_dynamic_tables.py @@ -216,7 +216,7 @@ def test_dynamic_table_lineage_extraction(mock_extractor_class): UpstreamLineageEdge, ) - result = UpstreamLineageEdge.parse_obj(mock_cursor.__iter__.return_value[0]) + result = UpstreamLineageEdge.model_validate(mock_cursor.__iter__.return_value[0]) # Verify the lineage information assert result.DOWNSTREAM_TABLE_NAME == "TEST_DB.PUBLIC.DYNAMIC_TABLE1" diff --git a/metadata-ingestion/tests/unit/snowflake/test_snowflake_source.py b/metadata-ingestion/tests/unit/snowflake/test_snowflake_source.py index e951427463f2c6..d78602ca1090a1 100644 --- a/metadata-ingestion/tests/unit/snowflake/test_snowflake_source.py +++ b/metadata-ingestion/tests/unit/snowflake/test_snowflake_source.py @@ -58,7 +58,7 @@ def test_snowflake_source_throws_error_on_account_id_missing(): with pytest.raises( ValidationError, match=re.compile(r"account_id.*Field required", re.DOTALL) ): - SnowflakeV2Config.parse_obj( + SnowflakeV2Config.model_validate( { "username": "user", "password": "password", @@ -72,19 +72,19 @@ def test_no_client_id_invalid_oauth_config(): with pytest.raises( ValueError, match=re.compile(r"client_id.*Field required", re.DOTALL) ): - OAuthConfiguration.parse_obj(oauth_dict) + OAuthConfiguration.model_validate(oauth_dict) def test_snowflake_throws_error_on_client_secret_missing_if_use_certificate_is_false(): oauth_dict = default_oauth_dict.copy() del oauth_dict["client_secret"] - OAuthConfiguration.parse_obj(oauth_dict) + OAuthConfiguration.model_validate(oauth_dict) with pytest.raises( ValueError, match="'oauth_config.client_secret' was none but should be set when using use_certificate false for oauth_config", ): - SnowflakeV2Config.parse_obj( + SnowflakeV2Config.model_validate( { "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR", @@ -96,12 +96,12 @@ def test_snowflake_throws_error_on_client_secret_missing_if_use_certificate_is_f def test_snowflake_throws_error_on_encoded_oauth_private_key_missing_if_use_certificate_is_true(): oauth_dict = default_oauth_dict.copy() oauth_dict["use_certificate"] = True - OAuthConfiguration.parse_obj(oauth_dict) + OAuthConfiguration.model_validate(oauth_dict) with pytest.raises( ValueError, match="'base64_encoded_oauth_private_key' was none but should be set when using certificate for oauth_config", ): - SnowflakeV2Config.parse_obj( + SnowflakeV2Config.model_validate( { "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR", @@ -114,11 +114,11 @@ def test_snowflake_oauth_okta_does_not_support_certificate(): oauth_dict = default_oauth_dict.copy() oauth_dict["use_certificate"] = True oauth_dict["provider"] = "okta" - OAuthConfiguration.parse_obj(oauth_dict) + OAuthConfiguration.model_validate(oauth_dict) with pytest.raises( ValueError, match="Certificate authentication is not supported for Okta." ): - SnowflakeV2Config.parse_obj( + SnowflakeV2Config.model_validate( { "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR", @@ -130,7 +130,7 @@ def test_snowflake_oauth_okta_does_not_support_certificate(): def test_snowflake_oauth_happy_paths(): oauth_dict = default_oauth_dict.copy() oauth_dict["provider"] = "okta" - assert SnowflakeV2Config.parse_obj( + assert SnowflakeV2Config.model_validate( { "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR", @@ -141,7 +141,7 @@ def test_snowflake_oauth_happy_paths(): oauth_dict["provider"] = "microsoft" oauth_dict["encoded_oauth_public_key"] = "publickey" oauth_dict["encoded_oauth_private_key"] = "privatekey" - assert SnowflakeV2Config.parse_obj( + assert SnowflakeV2Config.model_validate( { "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR", @@ -151,7 +151,7 @@ def test_snowflake_oauth_happy_paths(): def test_snowflake_oauth_token_happy_path(): - assert SnowflakeV2Config.parse_obj( + assert SnowflakeV2Config.model_validate( { "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR_TOKEN", @@ -166,7 +166,7 @@ def test_snowflake_oauth_token_without_token(): with pytest.raises( ValidationError, match="Token required for OAUTH_AUTHENTICATOR_TOKEN." ): - SnowflakeV2Config.parse_obj( + SnowflakeV2Config.model_validate( { "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR_TOKEN", @@ -180,7 +180,7 @@ def test_snowflake_oauth_token_with_wrong_auth_type(): ValueError, match="Token can only be provided when using OAUTH_AUTHENTICATOR_TOKEN.", ): - SnowflakeV2Config.parse_obj( + SnowflakeV2Config.model_validate( { "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR", @@ -194,7 +194,7 @@ def test_snowflake_oauth_token_with_empty_token(): with pytest.raises( ValidationError, match="Token required for OAUTH_AUTHENTICATOR_TOKEN." ): - SnowflakeV2Config.parse_obj( + SnowflakeV2Config.model_validate( { "account_id": "test", "authentication_type": "OAUTH_AUTHENTICATOR_TOKEN", @@ -212,17 +212,17 @@ def test_config_fetch_views_from_information_schema(): "username": "test_user", "password": "test_pass", } - config = SnowflakeV2Config.parse_obj(config_dict) + config = SnowflakeV2Config.model_validate(config_dict) assert config.fetch_views_from_information_schema is False # Test explicitly set to True config_dict_true = {**config_dict, "fetch_views_from_information_schema": True} - config = SnowflakeV2Config.parse_obj(config_dict_true) + config = SnowflakeV2Config.model_validate(config_dict_true) assert config.fetch_views_from_information_schema is True # Test explicitly set to False config_dict_false = {**config_dict, "fetch_views_from_information_schema": False} - config = SnowflakeV2Config.parse_obj(config_dict_false) + config = SnowflakeV2Config.model_validate(config_dict_false) assert config.fetch_views_from_information_schema is False @@ -239,17 +239,17 @@ def test_account_id_is_added_when_host_port_is_present(): config_dict = default_config_dict.copy() del config_dict["account_id"] config_dict["host_port"] = "acctname" - config = SnowflakeV2Config.parse_obj(config_dict) + config = SnowflakeV2Config.model_validate(config_dict) assert config.account_id == "acctname" def test_account_id_with_snowflake_host_suffix(): - config = SnowflakeV2Config.parse_obj(default_config_dict) + config = SnowflakeV2Config.model_validate(default_config_dict) assert config.account_id == "acctname" def test_snowflake_uri_default_authentication(): - config = SnowflakeV2Config.parse_obj(default_config_dict) + config = SnowflakeV2Config.model_validate(default_config_dict) assert config.get_sql_alchemy_url() == ( "snowflake://user:password@acctname" "?application=acryl_datahub" @@ -263,7 +263,7 @@ def test_snowflake_uri_external_browser_authentication(): config_dict = default_config_dict.copy() del config_dict["password"] config_dict["authentication_type"] = "EXTERNAL_BROWSER_AUTHENTICATOR" - config = SnowflakeV2Config.parse_obj(config_dict) + config = SnowflakeV2Config.model_validate(config_dict) assert config.get_sql_alchemy_url() == ( "snowflake://user@acctname" "?application=acryl_datahub" @@ -279,7 +279,7 @@ def test_snowflake_uri_key_pair_authentication(): config_dict["authentication_type"] = "KEY_PAIR_AUTHENTICATOR" config_dict["private_key_path"] = "/a/random/path" config_dict["private_key_password"] = "a_random_password" - config = SnowflakeV2Config.parse_obj(config_dict) + config = SnowflakeV2Config.model_validate(config_dict) assert config.get_sql_alchemy_url() == ( "snowflake://user@acctname" @@ -291,7 +291,7 @@ def test_snowflake_uri_key_pair_authentication(): def test_options_contain_connect_args(): - config = SnowflakeV2Config.parse_obj(default_config_dict) + config = SnowflakeV2Config.model_validate(default_config_dict) connect_args = config.get_options().get("connect_args") assert connect_args is not None @@ -302,7 +302,7 @@ def test_options_contain_connect_args(): def test_snowflake_connection_with_default_domain(mock_connect): """Test that connection uses default .com domain when not specified""" config_dict = default_config_dict.copy() - config = SnowflakeV2Config.parse_obj(config_dict) + config = SnowflakeV2Config.model_validate(config_dict) mock_connect.return_value = MagicMock() try: @@ -323,7 +323,7 @@ def test_snowflake_connection_with_china_domain(mock_connect): config_dict = default_config_dict.copy() config_dict["account_id"] = "test-account_cn" config_dict["snowflake_domain"] = "snowflakecomputing.cn" - config = SnowflakeV2Config.parse_obj(config_dict) + config = SnowflakeV2Config.model_validate(config_dict) mock_connect.return_value = MagicMock() try: @@ -344,11 +344,11 @@ def test_snowflake_config_with_column_lineage_no_table_lineage_throws_error(): ValidationError, match="include_table_lineage must be True for include_column_lineage to be set", ): - SnowflakeV2Config.parse_obj(config_dict) + SnowflakeV2Config.model_validate(config_dict) def test_snowflake_config_with_no_connect_args_returns_base_connect_args(): - config: SnowflakeV2Config = SnowflakeV2Config.parse_obj(default_config_dict) + config: SnowflakeV2Config = SnowflakeV2Config.model_validate(default_config_dict) assert config.get_options()["connect_args"] is not None assert config.get_options()["connect_args"] == { CLIENT_PREFETCH_THREADS: 10, @@ -361,7 +361,7 @@ def test_private_key_set_but_auth_not_changed(): ValidationError, match="Either `private_key` and `private_key_path` is set but `authentication_type` is DEFAULT_AUTHENTICATOR. Should be set to 'KEY_PAIR_AUTHENTICATOR' when using key pair authentication", ): - SnowflakeV2Config.parse_obj( + SnowflakeV2Config.model_validate( { "account_id": "acctname", "private_key_path": "/a/random/path", @@ -374,7 +374,7 @@ def test_snowflake_config_with_connect_args_overrides_base_connect_args(): config_dict["connect_args"] = { CLIENT_PREFETCH_THREADS: 5, } - config: SnowflakeV2Config = SnowflakeV2Config.parse_obj(config_dict) + config: SnowflakeV2Config = SnowflakeV2Config.model_validate(config_dict) assert config.get_options()["connect_args"] is not None assert config.get_options()["connect_args"][CLIENT_PREFETCH_THREADS] == 5 assert config.get_options()["connect_args"][CLIENT_SESSION_KEEP_ALIVE] is True @@ -665,7 +665,7 @@ def test_snowflake_query_create_deny_regex_sql(): def test_snowflake_temporary_patterns_config_rename(): - conf = SnowflakeV2Config.parse_obj( + conf = SnowflakeV2Config.model_validate( { "account_id": "test", "username": "user", @@ -792,7 +792,7 @@ def test_snowflake_utils() -> None: def test_using_removed_fields_causes_no_error() -> None: - assert SnowflakeV2Config.parse_obj( + assert SnowflakeV2Config.model_validate( { "account_id": "test", "username": "snowflake", @@ -823,7 +823,7 @@ def test_snowflake_query_result_parsing(): } ], } - assert UpstreamLineageEdge.parse_obj(db_row) + assert UpstreamLineageEdge.model_validate(db_row) class TestDDLProcessing: @@ -1055,7 +1055,7 @@ def test_process_upstream_lineage_row_dynamic_table_moved(mock_extractor_class): def side_effect(self, row): # Create a new UpstreamLineageEdge with the updated table name - result = UpstreamLineageEdge.parse_obj(row) + result = UpstreamLineageEdge.model_validate(row) result.DOWNSTREAM_TABLE_NAME = "NEW_DB.NEW_SCHEMA.DYNAMIC_TABLE" return result diff --git a/metadata-ingestion/tests/unit/sql_queries/test_sql_queries.py b/metadata-ingestion/tests/unit/sql_queries/test_sql_queries.py index d6f1423d97d9d9..9e90a1925a4dbb 100644 --- a/metadata-ingestion/tests/unit/sql_queries/test_sql_queries.py +++ b/metadata-ingestion/tests/unit/sql_queries/test_sql_queries.py @@ -447,7 +447,7 @@ class TestSqlQueriesSourceConfig: def test_incremental_lineage_default(self): """Test that incremental_lineage defaults to False.""" config_dict = {"query_file": "test.jsonl", "platform": "snowflake"} - config = SqlQueriesSourceConfig.parse_obj(config_dict) + config = SqlQueriesSourceConfig.model_validate(config_dict) assert config.incremental_lineage is False def test_incremental_lineage_enabled(self): @@ -457,7 +457,7 @@ def test_incremental_lineage_enabled(self): "platform": "snowflake", "incremental_lineage": True, } - config = SqlQueriesSourceConfig.parse_obj(config_dict) + config = SqlQueriesSourceConfig.model_validate(config_dict) assert config.incremental_lineage is True def test_incremental_lineage_disabled_explicitly(self): @@ -467,7 +467,7 @@ def test_incremental_lineage_disabled_explicitly(self): "platform": "snowflake", "incremental_lineage": False, } - config = SqlQueriesSourceConfig.parse_obj(config_dict) + config = SqlQueriesSourceConfig.model_validate(config_dict) assert config.incremental_lineage is False @@ -612,7 +612,7 @@ def test_backward_compatibility(self, pipeline_context, temp_query_file): "usage": {"bucket_duration": "DAY"}, } - config = SqlQueriesSourceConfig.parse_obj(config_dict) + config = SqlQueriesSourceConfig.model_validate(config_dict) source = SqlQueriesSource(pipeline_context, config) # Should default to False diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py index 8b15fb9b70be49..24e6e6519a864a 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py @@ -1,8 +1,8 @@ from datetime import datetime, timezone from typing import Dict, List -import pydantic import pytest +from pydantic import model_validator from datahub.ingestion.source.state.checkpoint import Checkpoint, CheckpointStateBase from datahub.ingestion.source.state.sql_common_state import ( @@ -188,7 +188,8 @@ class PrevState(CheckpointStateBase): class NextState(CheckpointStateBase): list_stuff: List[str] - @pydantic.root_validator(pre=True, allow_reuse=True) + @model_validator(mode="before") + @classmethod def _migrate(cls, values: dict) -> dict: values.setdefault("list_stuff", []) values["list_stuff"] += values.pop("list_a", []) diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_sql_common_state.py b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_sql_common_state.py index 3d5b97a8ef63bf..51e978ad135a17 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_sql_common_state.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_sql_common_state.py @@ -36,7 +36,7 @@ def test_sql_common_state() -> None: def test_state_backward_compat() -> None: - state = BaseSQLAlchemyCheckpointState.parse_obj( + state = BaseSQLAlchemyCheckpointState.model_validate( dict( encoded_table_urns=["mysql||db1.t1||PROD"], encoded_view_urns=["mysql||db1.v1||PROD"], diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stateful_ingestion.py b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stateful_ingestion.py index 2c1fc79d27934c..ee7a4d720c0b89 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stateful_ingestion.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_stateful_ingestion.py @@ -101,7 +101,7 @@ def __init__(self, config: DummySourceConfig, ctx: PipelineContext): @classmethod def create(cls, config_dict, ctx): - config = DummySourceConfig.parse_obj(config_dict) + config = DummySourceConfig.model_validate(config_dict) return cls(config, ctx) def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/test_configs.py b/metadata-ingestion/tests/unit/stateful_ingestion/test_configs.py index db8eab70ca2bac..56ff9eccf0cb68 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/test_configs.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/test_configs.py @@ -48,7 +48,7 @@ timeout_sec=10, extra_headers={}, max_threads=10, - ) + ), ), ), False, diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/test_kafka_state.py b/metadata-ingestion/tests/unit/stateful_ingestion/test_kafka_state.py index 3b0e4e31d4b4a2..5eb4cfd7de48ac 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/test_kafka_state.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/test_kafka_state.py @@ -16,7 +16,7 @@ def test_kafka_common_state() -> None: def test_kafka_state_migration() -> None: - state = GenericCheckpointState.parse_obj( + state = GenericCheckpointState.model_validate( { "encoded_topic_urns": [ "kafka||test_topic1||test", diff --git a/metadata-ingestion/tests/unit/tableau/test_tableau_config.py b/metadata-ingestion/tests/unit/tableau/test_tableau_config.py index 741d89ef2130b8..7f845a9fbc8a79 100644 --- a/metadata-ingestion/tests/unit/tableau/test_tableau_config.py +++ b/metadata-ingestion/tests/unit/tableau/test_tableau_config.py @@ -63,7 +63,7 @@ def test_value_error_projects_and_project_pattern( ValidationError, match=r".*projects is deprecated. Please use project_path_pattern only.*", ): - TableauConfig.parse_obj(new_config) + TableauConfig.model_validate(new_config) def test_project_pattern_deprecation(pytestconfig, tmp_path, mock_datahub_graph): @@ -76,27 +76,27 @@ def test_project_pattern_deprecation(pytestconfig, tmp_path, mock_datahub_graph) ValidationError, match=r".*project_pattern is deprecated. Please use project_path_pattern only*", ): - TableauConfig.parse_obj(new_config) + TableauConfig.model_validate(new_config) def test_ingest_hidden_assets_bool(): config_dict = deepcopy(default_config) config_dict["ingest_hidden_assets"] = False - config = TableauConfig.parse_obj(config_dict) + config = TableauConfig.model_validate(config_dict) assert config.ingest_hidden_assets is False def test_ingest_hidden_assets_list(): config_dict = deepcopy(default_config) config_dict["ingest_hidden_assets"] = ["dashboard"] - config = TableauConfig.parse_obj(config_dict) + config = TableauConfig.model_validate(config_dict) assert config.ingest_hidden_assets == ["dashboard"] def test_ingest_hidden_assets_multiple(): config_dict = deepcopy(default_config) config_dict["ingest_hidden_assets"] = ["dashboard", "worksheet"] - config = TableauConfig.parse_obj(config_dict) + config = TableauConfig.model_validate(config_dict) assert config.ingest_hidden_assets == ["dashboard", "worksheet"] @@ -107,7 +107,7 @@ def test_ingest_hidden_assets_invalid(): ValidationError, match=re.compile(r"ingest_hidden_assets.*input_value='invalid'", re.DOTALL), ): - TableauConfig.parse_obj(config) + TableauConfig.model_validate(config) @freeze_time(FROZEN_TIME) @@ -133,7 +133,7 @@ def test_extract_project_hierarchy(extract_project_hierarchy, allowed_projects): config_dict["extract_project_hierarchy"] = extract_project_hierarchy - config = TableauConfig.parse_obj(config_dict) + config = TableauConfig.model_validate(config_dict) site_source = TableauSiteSource( config=config, @@ -196,7 +196,7 @@ def test_use_email_as_username_requires_ingest_owner(): ValidationError, match=r".*use_email_as_username requires ingest_owner to be enabled.*", ): - TableauConfig.parse_obj(config_dict) + TableauConfig.model_validate(config_dict) def test_use_email_as_username_valid_config(): @@ -205,7 +205,7 @@ def test_use_email_as_username_valid_config(): config_dict["ingest_owner"] = True config_dict["use_email_as_username"] = True - config = TableauConfig.parse_obj(config_dict) + config = TableauConfig.model_validate(config_dict) assert config.ingest_owner is True assert config.use_email_as_username is True @@ -213,5 +213,5 @@ def test_use_email_as_username_valid_config(): def test_use_email_as_username_default_false(): """Test that use_email_as_username defaults to False.""" config_dict = default_config.copy() - config = TableauConfig.parse_obj(config_dict) + config = TableauConfig.model_validate(config_dict) assert config.use_email_as_username is False diff --git a/metadata-ingestion/tests/unit/tableau/test_tableau_source.py b/metadata-ingestion/tests/unit/tableau/test_tableau_source.py index 9a3690d62e8138..ca3af82a5d9da3 100644 --- a/metadata-ingestion/tests/unit/tableau/test_tableau_source.py +++ b/metadata-ingestion/tests/unit/tableau/test_tableau_source.py @@ -178,7 +178,7 @@ def test_tableau_unsupported_csql(): context = PipelineContext(run_id="0", pipeline_name="test_tableau") config_dict = default_config.copy() del config_dict["stateful_ingestion"] - config = TableauConfig.parse_obj(config_dict) + config = TableauConfig.model_validate(config_dict) config.extract_lineage_from_unsupported_custom_sql_queries = True config.lineage_overrides = TableauLineageOverrides( database_override_map={"production database": "prod"} @@ -717,7 +717,7 @@ def test_get_owner_identifier_username(): """Test owner identifier extraction using username.""" config_dict = default_config.copy() config_dict["use_email_as_username"] = False - config = TableauConfig.parse_obj(config_dict) + config = TableauConfig.model_validate(config_dict) context = PipelineContext(run_id="test", pipeline_name="test") site_source = TableauSiteSource( @@ -738,7 +738,7 @@ def test_get_owner_identifier_email(): """Test owner identifier extraction using email.""" config_dict = default_config.copy() config_dict["use_email_as_username"] = True - config = TableauConfig.parse_obj(config_dict) + config = TableauConfig.model_validate(config_dict) context = PipelineContext(run_id="test", pipeline_name="test") site_source = TableauSiteSource( @@ -759,7 +759,7 @@ def test_get_owner_identifier_email_fallback(): """Test owner identifier extraction falls back to username when email is not available.""" config_dict = default_config.copy() config_dict["use_email_as_username"] = True - config = TableauConfig.parse_obj(config_dict) + config = TableauConfig.model_validate(config_dict) context = PipelineContext(run_id="test", pipeline_name="test") site_source = TableauSiteSource( @@ -780,7 +780,7 @@ def test_get_owner_identifier_empty_dict(): """Test owner identifier extraction with empty owner dict.""" config_dict = default_config.copy() config_dict["use_email_as_username"] = True - config = TableauConfig.parse_obj(config_dict) + config = TableauConfig.model_validate(config_dict) context = PipelineContext(run_id="test", pipeline_name="test") site_source = TableauSiteSource( diff --git a/metadata-ingestion/tests/unit/test_athena_source.py b/metadata-ingestion/tests/unit/test_athena_source.py index 9d2c3a0a87a0bb..24a0f1aea1d9f4 100644 --- a/metadata-ingestion/tests/unit/test_athena_source.py +++ b/metadata-ingestion/tests/unit/test_athena_source.py @@ -33,7 +33,7 @@ def test_athena_config_query_location_old_plus_new_value_not_allowed(): from datahub.ingestion.source.sql.athena import AthenaConfig with pytest.raises(ValueError): - AthenaConfig.parse_obj( + AthenaConfig.model_validate( { "aws_region": "us-west-1", "s3_staging_dir": "s3://sample-staging-dir/", @@ -46,7 +46,7 @@ def test_athena_config_query_location_old_plus_new_value_not_allowed(): def test_athena_config_staging_dir_is_set_as_query_result(): from datahub.ingestion.source.sql.athena import AthenaConfig - config = AthenaConfig.parse_obj( + config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "s3_staging_dir": "s3://sample-staging-dir/", @@ -54,7 +54,7 @@ def test_athena_config_staging_dir_is_set_as_query_result(): } ) - expected_config = AthenaConfig.parse_obj( + expected_config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "query_result_location": "s3://sample-staging-dir/", @@ -68,7 +68,7 @@ def test_athena_config_staging_dir_is_set_as_query_result(): def test_athena_uri(): from datahub.ingestion.source.sql.athena import AthenaConfig - config = AthenaConfig.parse_obj( + config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "query_result_location": "s3://query-result-location/", @@ -91,7 +91,7 @@ def test_athena_get_table_properties(): from datahub.ingestion.source.sql.athena import AthenaConfig, AthenaSource - config = AthenaConfig.parse_obj( + config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "s3_staging_dir": "s3://sample-staging-dir/", @@ -295,7 +295,7 @@ def test_convert_simple_field_paths_to_v1_enabled(): """Test that emit_schema_fieldpaths_as_v1 correctly converts simple field paths when enabled""" # Test config with emit_schema_fieldpaths_as_v1 enabled - config = AthenaConfig.parse_obj( + config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "query_result_location": "s3://query-result-location/", @@ -359,7 +359,7 @@ def test_convert_simple_field_paths_to_v1_disabled(): """Test that emit_schema_fieldpaths_as_v1 keeps v2 field paths when disabled""" # Test config with emit_schema_fieldpaths_as_v1 disabled (default) - config = AthenaConfig.parse_obj( + config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "query_result_location": "s3://query-result-location/", @@ -397,7 +397,7 @@ def test_convert_simple_field_paths_to_v1_complex_types_ignored(): """Test that complex types (arrays, maps, structs) are not affected by emit_schema_fieldpaths_as_v1""" # Test config with emit_schema_fieldpaths_as_v1 enabled - config = AthenaConfig.parse_obj( + config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "query_result_location": "s3://query-result-location/", @@ -457,7 +457,7 @@ def test_convert_simple_field_paths_to_v1_with_partition_keys(): """Test that emit_schema_fieldpaths_as_v1 works correctly with partition keys""" # Test config with emit_schema_fieldpaths_as_v1 enabled - config = AthenaConfig.parse_obj( + config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "query_result_location": "s3://query-result-location/", @@ -497,7 +497,7 @@ def test_convert_simple_field_paths_to_v1_default_behavior(): from datahub.ingestion.source.sql.athena import AthenaConfig # Test config without specifying emit_schema_fieldpaths_as_v1 - config = AthenaConfig.parse_obj( + config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "query_result_location": "s3://query-result-location/", @@ -510,7 +510,7 @@ def test_convert_simple_field_paths_to_v1_default_behavior(): def test_get_partitions_returns_none_when_extract_partitions_disabled(): """Test that get_partitions returns None when extract_partitions is False""" - config = AthenaConfig.parse_obj( + config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "query_result_location": "s3://query-result-location/", @@ -539,7 +539,7 @@ def test_get_partitions_returns_none_when_extract_partitions_disabled(): def test_get_partitions_attempts_extraction_when_extract_partitions_enabled(): """Test that get_partitions attempts partition extraction when extract_partitions is True""" - config = AthenaConfig.parse_obj( + config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "query_result_location": "s3://query-result-location/", @@ -581,7 +581,7 @@ def test_get_partitions_attempts_extraction_when_extract_partitions_enabled(): def test_partition_profiling_sql_generation_single_key(): """Test that partition profiling generates valid SQL for single partition key and can be parsed by SQLGlot.""" - config = AthenaConfig.parse_obj( + config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "query_result_location": "s3://query-result-location/", @@ -639,7 +639,7 @@ def test_partition_profiling_sql_generation_single_key(): def test_partition_profiling_sql_generation_multiple_keys(): """Test that partition profiling generates valid SQL for multiple partition keys and can be parsed by SQLGlot.""" - config = AthenaConfig.parse_obj( + config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "query_result_location": "s3://query-result-location/", @@ -706,7 +706,7 @@ def test_partition_profiling_sql_generation_multiple_keys(): def test_partition_profiling_sql_generation_complex_schema_table_names(): """Test that partition profiling handles complex schema/table names correctly and generates valid SQL.""" - config = AthenaConfig.parse_obj( + config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "query_result_location": "s3://query-result-location/", @@ -866,7 +866,7 @@ def test_build_max_partition_query(): def test_partition_profiling_disabled_no_sql_generation(): """Test that when partition profiling is disabled, no complex SQL is generated.""" - config = AthenaConfig.parse_obj( + config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "query_result_location": "s3://query-result-location/", @@ -1062,7 +1062,7 @@ def test_sanitize_identifier_integration_with_build_max_partition_query(): def test_sanitize_identifier_error_handling_in_get_partitions(): """Test that ValueError from _sanitize_identifier is handled gracefully in get_partitions method.""" - config = AthenaConfig.parse_obj( + config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "query_result_location": "s3://query-result-location/", @@ -1103,7 +1103,7 @@ def test_sanitize_identifier_error_handling_in_generate_partition_profiler_query """Test that ValueError from _sanitize_identifier is handled gracefully in generate_partition_profiler_query.""" import logging - config = AthenaConfig.parse_obj( + config = AthenaConfig.model_validate( { "aws_region": "us-west-1", "query_result_location": "s3://query-result-location/", diff --git a/metadata-ingestion/tests/unit/test_cassandra_source.py b/metadata-ingestion/tests/unit/test_cassandra_source.py index 3232764b4f2057..9feae9bd2e2d35 100644 --- a/metadata-ingestion/tests/unit/test_cassandra_source.py +++ b/metadata-ingestion/tests/unit/test_cassandra_source.py @@ -99,7 +99,7 @@ def _get_base_config_dict() -> dict: def test_authenticate_no_ssl(): config_dict = _get_base_config_dict() - config = CassandraSourceConfig.parse_obj(config_dict) + config = CassandraSourceConfig.model_validate(config_dict) report = MagicMock(spec=SourceReport) api = CassandraAPI(config, report) @@ -116,7 +116,7 @@ def test_authenticate_no_ssl(): def test_authenticate_ssl_ca_certs(): config_dict = _get_base_config_dict() config_dict["ssl_ca_certs"] = "ca.pem" - config = CassandraSourceConfig.parse_obj(config_dict) + config = CassandraSourceConfig.model_validate(config_dict) report = MagicMock(spec=SourceReport) api = CassandraAPI(config, report) @@ -148,7 +148,7 @@ def test_authenticate_ssl_all_certs(): config_dict["ssl_ca_certs"] = "ca.pem" config_dict["ssl_certfile"] = "client.crt" config_dict["ssl_keyfile"] = "client.key" - config = CassandraSourceConfig.parse_obj(config_dict) + config = CassandraSourceConfig.model_validate(config_dict) report = MagicMock(spec=SourceReport) api = CassandraAPI(config, report) diff --git a/metadata-ingestion/tests/unit/test_classification.py b/metadata-ingestion/tests/unit/test_classification.py index c79ae5808b2a69..c2bf0d589c21eb 100644 --- a/metadata-ingestion/tests/unit/test_classification.py +++ b/metadata-ingestion/tests/unit/test_classification.py @@ -112,7 +112,7 @@ def test_incorrect_custom_info_type_config(): with pytest.raises( ValidationError, match="Missing Configuration for Prediction Factor" ): - DataHubClassifierConfig.parse_obj( + DataHubClassifierConfig.model_validate( { "confidence_level_threshold": 0.7, "info_types_config": { @@ -135,7 +135,7 @@ def test_incorrect_custom_info_type_config(): ) with pytest.raises(ValidationError, match="Invalid Prediction Type"): - DataHubClassifierConfig.parse_obj( + DataHubClassifierConfig.model_validate( { "confidence_level_threshold": 0.7, "info_types_config": { diff --git a/metadata-ingestion/tests/unit/test_clickhouse_source.py b/metadata-ingestion/tests/unit/test_clickhouse_source.py index 1b2ffb70c8d190..9182b8bc71cbac 100644 --- a/metadata-ingestion/tests/unit/test_clickhouse_source.py +++ b/metadata-ingestion/tests/unit/test_clickhouse_source.py @@ -2,7 +2,7 @@ def test_clickhouse_uri_https(): - config = ClickHouseConfig.parse_obj( + config = ClickHouseConfig.model_validate( { "username": "user", "password": "password", @@ -18,7 +18,7 @@ def test_clickhouse_uri_https(): def test_clickhouse_uri_native(): - config = ClickHouseConfig.parse_obj( + config = ClickHouseConfig.model_validate( { "username": "user", "password": "password", @@ -30,7 +30,7 @@ def test_clickhouse_uri_native(): def test_clickhouse_uri_native_secure(): - config = ClickHouseConfig.parse_obj( + config = ClickHouseConfig.model_validate( { "username": "user", "password": "password", @@ -47,7 +47,7 @@ def test_clickhouse_uri_native_secure(): def test_clickhouse_uri_default_password(): - config = ClickHouseConfig.parse_obj( + config = ClickHouseConfig.model_validate( { "username": "user", "host_port": "host:1111", @@ -59,7 +59,7 @@ def test_clickhouse_uri_default_password(): def test_clickhouse_uri_native_secure_backward_compatibility(): - config = ClickHouseConfig.parse_obj( + config = ClickHouseConfig.model_validate( { "username": "user", "password": "password", @@ -76,7 +76,7 @@ def test_clickhouse_uri_native_secure_backward_compatibility(): def test_clickhouse_uri_https_backward_compatibility(): - config = ClickHouseConfig.parse_obj( + config = ClickHouseConfig.model_validate( { "username": "user", "password": "password", diff --git a/metadata-ingestion/tests/unit/test_cockroach_source.py b/metadata-ingestion/tests/unit/test_cockroach_source.py index 113a62ff61975e..f91c2cf0d923f4 100644 --- a/metadata-ingestion/tests/unit/test_cockroach_source.py +++ b/metadata-ingestion/tests/unit/test_cockroach_source.py @@ -13,7 +13,7 @@ def _base_config(): def test_platform_correctly_set_cockroachdb(): source = CockroachDBSource( ctx=PipelineContext(run_id="cockroachdb-source-test"), - config=CockroachDBConfig.parse_obj(_base_config()), + config=CockroachDBConfig.model_validate(_base_config()), ) assert source.platform == "cockroachdb" @@ -21,6 +21,6 @@ def test_platform_correctly_set_cockroachdb(): def test_platform_correctly_set_postgres(): source = PostgresSource( ctx=PipelineContext(run_id="postgres-source-test"), - config=PostgresConfig.parse_obj(_base_config()), + config=PostgresConfig.model_validate(_base_config()), ) assert source.platform == "postgres" diff --git a/metadata-ingestion/tests/unit/test_confluent_schema_registry.py b/metadata-ingestion/tests/unit/test_confluent_schema_registry.py index c5c14b1136ca58..b73f8ee6eca6a5 100644 --- a/metadata-ingestion/tests/unit/test_confluent_schema_registry.py +++ b/metadata-ingestion/tests/unit/test_confluent_schema_registry.py @@ -60,7 +60,7 @@ def test_get_schema_str_replace_confluent_ref_avro(self): """ ) - kafka_source_config = KafkaSourceConfig.parse_obj( + kafka_source_config = KafkaSourceConfig.model_validate( { "connection": { "bootstrap": "localhost:9092", diff --git a/metadata-ingestion/tests/unit/test_druid_source.py b/metadata-ingestion/tests/unit/test_druid_source.py index 504fb13700a78c..c70a52d5d72666 100644 --- a/metadata-ingestion/tests/unit/test_druid_source.py +++ b/metadata-ingestion/tests/unit/test_druid_source.py @@ -2,12 +2,12 @@ def test_druid_uri(): - config = DruidConfig.parse_obj({"host_port": "localhost:8082"}) + config = DruidConfig.model_validate({"host_port": "localhost:8082"}) assert config.get_sql_alchemy_url() == "druid://localhost:8082/druid/v2/sql/" def test_druid_get_identifier(): - config = DruidConfig.parse_obj({"host_port": "localhost:8082"}) + config = DruidConfig.model_validate({"host_port": "localhost:8082"}) assert config.get_identifier("schema", "table") == "table" diff --git a/metadata-ingestion/tests/unit/test_elasticsearch_source.py b/metadata-ingestion/tests/unit/test_elasticsearch_source.py index 872dd882673b8a..f9a58da5c4ecff 100644 --- a/metadata-ingestion/tests/unit/test_elasticsearch_source.py +++ b/metadata-ingestion/tests/unit/test_elasticsearch_source.py @@ -19,7 +19,7 @@ def test_elasticsearch_throws_error_wrong_operation_config(): with pytest.raises(pydantic.ValidationError): - ElasticsearchSourceConfig.parse_obj( + ElasticsearchSourceConfig.model_validate( { "profiling": { "enabled": True, @@ -2478,14 +2478,14 @@ def test_host_port_parsing() -> None: bad_examples = ["localhost:abcd", "htttp://localhost:1234", "localhost:9200//"] for example in examples: config_dict = {"host": example} - config = ElasticsearchSourceConfig.parse_obj(config_dict) + config = ElasticsearchSourceConfig.model_validate(config_dict) assert config.host == example for bad_example in bad_examples: config_dict = {"host": bad_example} with pytest.raises(pydantic.ValidationError): - ElasticsearchSourceConfig.parse_obj(config_dict) + ElasticsearchSourceConfig.model_validate(config_dict) def test_collapse_urns() -> None: diff --git a/metadata-ingestion/tests/unit/test_file_lineage_source.py b/metadata-ingestion/tests/unit/test_file_lineage_source.py index ea161ac89f1413..a2d1eb1118d85e 100644 --- a/metadata-ingestion/tests/unit/test_file_lineage_source.py +++ b/metadata-ingestion/tests/unit/test_file_lineage_source.py @@ -49,7 +49,7 @@ def basic_mcp(): transformOperation: func1 """ config = yaml.safe_load(sample_lineage) - lineage_config: LineageConfig = LineageConfig.parse_obj(config) + lineage_config: LineageConfig = LineageConfig.model_validate(config) return _get_lineage_mcp(lineage_config.lineage[0], False) @@ -90,7 +90,7 @@ def unsupported_entity_type_mcp(): platform: kafka """ config = yaml.safe_load(sample_lineage) - return LineageConfig.parse_obj(config) + return LineageConfig.model_validate(config) def unsupported_upstream_entity_type_mcp(): @@ -114,7 +114,7 @@ def unsupported_upstream_entity_type_mcp(): platform: kafka """ config = yaml.safe_load(sample_lineage) - return LineageConfig.parse_obj(config) + return LineageConfig.model_validate(config) def unsupported_entity_env_mcp(): @@ -154,7 +154,7 @@ def unsupported_entity_env_mcp(): platform: kafka """ config = yaml.safe_load(sample_lineage) - return LineageConfig.parse_obj(config) + return LineageConfig.model_validate(config) def test_basic_lineage_entity_root_node_urn(basic_mcp): diff --git a/metadata-ingestion/tests/unit/test_ge_profiling_config.py b/metadata-ingestion/tests/unit/test_ge_profiling_config.py index f4d73a6ffe1e4e..d736a18e735657 100644 --- a/metadata-ingestion/tests/unit/test_ge_profiling_config.py +++ b/metadata-ingestion/tests/unit/test_ge_profiling_config.py @@ -4,12 +4,12 @@ def test_profile_table_level_only(): - config = GEProfilingConfig.parse_obj( + config = GEProfilingConfig.model_validate( {"enabled": True, "profile_table_level_only": True} ) assert config.any_field_level_metrics_enabled() is False - config = GEProfilingConfig.parse_obj( + config = GEProfilingConfig.model_validate( { "enabled": True, "profile_table_level_only": True, @@ -24,7 +24,7 @@ def test_profile_table_level_only_fails_with_field_metric_enabled(): ValueError, match="Cannot enable field-level metrics if profile_table_level_only is set", ): - GEProfilingConfig.parse_obj( + GEProfilingConfig.model_validate( { "enabled": True, "profile_table_level_only": True, diff --git a/metadata-ingestion/tests/unit/test_hana_source.py b/metadata-ingestion/tests/unit/test_hana_source.py index aa9d37069092e3..2538525af192eb 100644 --- a/metadata-ingestion/tests/unit/test_hana_source.py +++ b/metadata-ingestion/tests/unit/test_hana_source.py @@ -23,7 +23,7 @@ def test_platform_correctly_set_hana(): reason="The hdbcli dependency is not available for aarch64", ) def test_hana_uri_native(): - config = HanaConfig.parse_obj( + config = HanaConfig.model_validate( { "username": "user", "password": "password", @@ -39,7 +39,7 @@ def test_hana_uri_native(): reason="The hdbcli dependency is not available for aarch64", ) def test_hana_uri_native_db(): - config = HanaConfig.parse_obj( + config = HanaConfig.model_validate( { "username": "user", "password": "password", diff --git a/metadata-ingestion/tests/unit/test_hive_source.py b/metadata-ingestion/tests/unit/test_hive_source.py index 2eeebdc8cd1f09..1b3baed301999a 100644 --- a/metadata-ingestion/tests/unit/test_hive_source.py +++ b/metadata-ingestion/tests/unit/test_hive_source.py @@ -15,7 +15,7 @@ def test_hive_configuration_get_identifier_with_database(): "database": test_db_name, "scheme": "hive+https", } - hive_config = HiveConfig.parse_obj(config_dict) + hive_config = HiveConfig.model_validate(config_dict) expected_output = f"{test_db_name}" ctx = PipelineContext(run_id="test") hive_source = HiveSource(hive_config, ctx) diff --git a/metadata-ingestion/tests/unit/test_kafka_source.py b/metadata-ingestion/tests/unit/test_kafka_source.py index ea2979db934659..e88942219ff603 100644 --- a/metadata-ingestion/tests/unit/test_kafka_source.py +++ b/metadata-ingestion/tests/unit/test_kafka_source.py @@ -49,7 +49,7 @@ def mock_admin_client(): def test_kafka_source_configuration(mock_kafka): ctx = PipelineContext(run_id="test") kafka_source = KafkaSource( - KafkaSourceConfig.parse_obj({"connection": {"bootstrap": "foobar:9092"}}), + KafkaSourceConfig.model_validate({"connection": {"bootstrap": "foobar:9092"}}), ctx, ) kafka_source.close() @@ -65,7 +65,9 @@ def test_kafka_source_workunits_wildcard_topic(mock_kafka, mock_admin_client): ctx = PipelineContext(run_id="test") kafka_source = KafkaSource( - KafkaSourceConfig.parse_obj({"connection": {"bootstrap": "localhost:9092"}}), + KafkaSourceConfig.model_validate( + {"connection": {"bootstrap": "localhost:9092"}} + ), ctx, ) workunits = list(kafka_source.get_workunits()) @@ -807,7 +809,7 @@ def test_kafka_source_oauth_cb_configuration(): "in the format :." ), ): - KafkaSourceConfig.parse_obj( + KafkaSourceConfig.model_validate( { "connection": { "bootstrap": "foobar:9092", diff --git a/metadata-ingestion/tests/unit/test_metabase_source.py b/metadata-ingestion/tests/unit/test_metabase_source.py index 096936e2b184ec..065cda0f8402a4 100644 --- a/metadata-ingestion/tests/unit/test_metabase_source.py +++ b/metadata-ingestion/tests/unit/test_metabase_source.py @@ -52,7 +52,7 @@ def test_get_platform_instance(): def test_set_display_uri(): display_uri = "some_host:1234" - config = MetabaseConfig.parse_obj({"display_uri": display_uri}) + config = MetabaseConfig.model_validate({"display_uri": display_uri}) assert config.connect_uri == "localhost:3000" assert config.display_uri == display_uri diff --git a/metadata-ingestion/tests/unit/test_nifi_source.py b/metadata-ingestion/tests/unit/test_nifi_source.py index d7e04956408620..9244fe41d5afe3 100644 --- a/metadata-ingestion/tests/unit/test_nifi_source.py +++ b/metadata-ingestion/tests/unit/test_nifi_source.py @@ -19,7 +19,7 @@ @typing.no_type_check def test_nifi_s3_provenance_event(): config_dict = {"site_url": "http://localhost:8080", "incremental_lineage": False} - nifi_config = NifiSourceConfig.parse_obj(config_dict) + nifi_config = NifiSourceConfig.model_validate(config_dict) ctx = PipelineContext(run_id="test") with ( @@ -290,7 +290,7 @@ def test_auth_without_password(auth): with pytest.raises( ValueError, match=f"`username` and `password` is required for {auth} auth" ): - NifiSourceConfig.parse_obj( + NifiSourceConfig.model_validate( { "site_url": "https://localhost:8443", "auth": auth, @@ -304,7 +304,7 @@ def test_auth_without_username_and_password(auth): with pytest.raises( ValueError, match=f"`username` and `password` is required for {auth} auth" ): - NifiSourceConfig.parse_obj( + NifiSourceConfig.model_validate( { "site_url": "https://localhost:8443", "auth": auth, @@ -316,7 +316,7 @@ def test_client_cert_auth_without_client_cert_file(): with pytest.raises( ValueError, match="`client_cert_file` is required for CLIENT_CERT auth" ): - NifiSourceConfig.parse_obj( + NifiSourceConfig.model_validate( { "site_url": "https://localhost:8443", "auth": "CLIENT_CERT", diff --git a/metadata-ingestion/tests/unit/test_oracle_source.py b/metadata-ingestion/tests/unit/test_oracle_source.py index 0477044354576b..c240c6e2a3906c 100644 --- a/metadata-ingestion/tests/unit/test_oracle_source.py +++ b/metadata-ingestion/tests/unit/test_oracle_source.py @@ -13,7 +13,7 @@ def test_oracle_config(): "host_port": "host:1521", } - config = OracleConfig.parse_obj( + config = OracleConfig.model_validate( { **base_config, "service_name": "svc01", @@ -25,7 +25,7 @@ def test_oracle_config(): ) with pytest.raises(ValueError): - config = OracleConfig.parse_obj( + config = OracleConfig.model_validate( { **base_config, "database": "db", diff --git a/metadata-ingestion/tests/unit/test_postgres_source.py b/metadata-ingestion/tests/unit/test_postgres_source.py index 25140cf1b997f8..c05b4140b11ee6 100644 --- a/metadata-ingestion/tests/unit/test_postgres_source.py +++ b/metadata-ingestion/tests/unit/test_postgres_source.py @@ -11,7 +11,7 @@ def _base_config(): @patch("datahub.ingestion.source.sql.postgres.create_engine") def test_initial_database(create_engine_mock): - config = PostgresConfig.parse_obj(_base_config()) + config = PostgresConfig.model_validate(_base_config()) assert config.initial_database == "postgres" source = PostgresSource(config, PipelineContext(run_id="test")) _ = list(source.get_inspectors()) @@ -24,7 +24,9 @@ def test_get_inspectors_multiple_databases(create_engine_mock): execute_mock = create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}] - config = PostgresConfig.parse_obj({**_base_config(), "initial_database": "db0"}) + config = PostgresConfig.model_validate( + {**_base_config(), "initial_database": "db0"} + ) source = PostgresSource(config, PipelineContext(run_id="test")) _ = list(source.get_inspectors()) assert create_engine_mock.call_count == 3 @@ -38,7 +40,7 @@ def tests_get_inspectors_with_database_provided(create_engine_mock): execute_mock = create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}] - config = PostgresConfig.parse_obj({**_base_config(), "database": "custom_db"}) + config = PostgresConfig.model_validate({**_base_config(), "database": "custom_db"}) source = PostgresSource(config, PipelineContext(run_id="test")) _ = list(source.get_inspectors()) assert create_engine_mock.call_count == 1 @@ -50,7 +52,7 @@ def tests_get_inspectors_with_sqlalchemy_uri_provided(create_engine_mock): execute_mock = create_engine_mock.return_value.connect.return_value.__enter__.return_value.execute execute_mock.return_value = [{"datname": "db1"}, {"datname": "db2"}] - config = PostgresConfig.parse_obj( + config = PostgresConfig.model_validate( {**_base_config(), "sqlalchemy_uri": "custom_url"} ) source = PostgresSource(config, PipelineContext(run_id="test")) @@ -60,7 +62,7 @@ def tests_get_inspectors_with_sqlalchemy_uri_provided(create_engine_mock): def test_database_in_identifier(): - config = PostgresConfig.parse_obj({**_base_config(), "database": "postgres"}) + config = PostgresConfig.model_validate({**_base_config(), "database": "postgres"}) mock_inspector = mock.MagicMock() assert ( PostgresSource(config, PipelineContext(run_id="test")).get_identifier( @@ -71,7 +73,7 @@ def test_database_in_identifier(): def test_current_sqlalchemy_database_in_identifier(): - config = PostgresConfig.parse_obj({**_base_config()}) + config = PostgresConfig.model_validate({**_base_config()}) mock_inspector = mock.MagicMock() mock_inspector.engine.url.database = "current_db" assert ( diff --git a/metadata-ingestion/tests/unit/test_preset_source.py b/metadata-ingestion/tests/unit/test_preset_source.py index 97f46e743ea0ca..849bf966f59d0f 100644 --- a/metadata-ingestion/tests/unit/test_preset_source.py +++ b/metadata-ingestion/tests/unit/test_preset_source.py @@ -3,7 +3,7 @@ def test_default_values(): - config = PresetConfig.parse_obj({}) + config = PresetConfig.model_validate({}) assert config.connect_uri == "" assert config.manager_uri == "https://api.app.preset.io" @@ -20,7 +20,7 @@ def test_default_values(): def test_set_display_uri(): display_uri = "some_host:1234" - config = PresetConfig.parse_obj({"display_uri": display_uri}) + config = PresetConfig.model_validate({"display_uri": display_uri}) assert config.connect_uri == "" assert config.manager_uri == "https://api.app.preset.io" @@ -36,7 +36,7 @@ def test_preset_config_parsing(): } # Tests if SupersetConfig fields are parsed extra fields correctly - config = PresetConfig.parse_obj(preset_config) + config = PresetConfig.model_validate(preset_config) # Test Preset-specific fields assert config.api_key == "dummy_api_key" diff --git a/metadata-ingestion/tests/unit/test_sql_common.py b/metadata-ingestion/tests/unit/test_sql_common.py index 8f149450201731..9f3a1912b1e33c 100644 --- a/metadata-ingestion/tests/unit/test_sql_common.py +++ b/metadata-ingestion/tests/unit/test_sql_common.py @@ -23,7 +23,7 @@ def get_sql_alchemy_url(self): class _TestSQLAlchemySource(SQLAlchemySource): @classmethod def create(cls, config_dict, ctx): - config = _TestSQLAlchemyConfig.parse_obj(config_dict) + config = _TestSQLAlchemyConfig.model_validate(config_dict) return cls(config, ctx, "TEST") diff --git a/metadata-ingestion/tests/unit/test_superset_source.py b/metadata-ingestion/tests/unit/test_superset_source.py index d28d1103572e25..36e693238b2451 100644 --- a/metadata-ingestion/tests/unit/test_superset_source.py +++ b/metadata-ingestion/tests/unit/test_superset_source.py @@ -5,7 +5,7 @@ def test_default_values(): - config = SupersetConfig.parse_obj({}) + config = SupersetConfig.model_validate({}) assert config.connect_uri == "http://localhost:8088" assert config.display_uri == "http://localhost:8088" @@ -22,7 +22,7 @@ def test_default_values(): def test_set_display_uri(): display_uri = "some_host:1234" - config = SupersetConfig.parse_obj({"display_uri": display_uri}) + config = SupersetConfig.model_validate({"display_uri": display_uri}) assert config.connect_uri == "http://localhost:8088" assert config.display_uri == display_uri diff --git a/metadata-ingestion/tests/unit/test_teradata_integration.py b/metadata-ingestion/tests/unit/test_teradata_integration.py index b5d19c52c02ee2..f17eb8f05cb354 100644 --- a/metadata-ingestion/tests/unit/test_teradata_integration.py +++ b/metadata-ingestion/tests/unit/test_teradata_integration.py @@ -37,7 +37,7 @@ class TestEndToEndWorkflow: def test_complete_metadata_extraction_workflow(self): """Test complete metadata extraction from initialization to work units.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -81,7 +81,7 @@ def test_lineage_extraction_with_historical_data(self): "start_time": "2024-01-01T00:00:00Z", "end_time": "2024-01-02T00:00:00Z", } - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -154,7 +154,7 @@ def test_view_processing_with_threading(self): **_base_config(), "max_workers": 3, } - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -222,7 +222,7 @@ def test_error_handling_in_complex_workflow(self): **_base_config(), "max_workers": 1, # Use single-threaded processing for this test } - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -301,7 +301,7 @@ def test_database_filtering_integration(self): "deny": ["allowed_db2"], # Deny one of the explicitly allowed }, } - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -347,7 +347,7 @@ def test_lineage_and_usage_statistics_integration(self): "include_usage_statistics": True, "include_queries": True, } - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -376,7 +376,7 @@ def test_profiling_configuration_integration(self): "profile_table_level_only": True, }, } - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -411,7 +411,7 @@ def test_large_query_result_processing(self): **_base_config(), "use_server_side_cursors": True, } - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -483,7 +483,7 @@ def mock_fetchmany(size): def test_query_with_special_characters(self): """Test processing queries with special characters and encoding.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -518,7 +518,7 @@ def test_query_with_special_characters(self): def test_multi_line_query_processing(self): """Test processing of multi-line queries.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -562,7 +562,7 @@ class TestResourceManagement: def test_engine_disposal_on_close(self): """Test that engines are properly disposed on close.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -598,7 +598,7 @@ def test_engine_disposal_on_close(self): def test_aggregator_cleanup_on_close(self): """Test that aggregator is properly cleaned up on close.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -625,7 +625,7 @@ def test_aggregator_cleanup_on_close(self): def test_connection_cleanup_in_error_scenarios(self): """Test that connections are cleaned up even when errors occur.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" diff --git a/metadata-ingestion/tests/unit/test_teradata_performance.py b/metadata-ingestion/tests/unit/test_teradata_performance.py index e99824d1d20d48..b2751541c5b68f 100644 --- a/metadata-ingestion/tests/unit/test_teradata_performance.py +++ b/metadata-ingestion/tests/unit/test_teradata_performance.py @@ -155,7 +155,7 @@ class TestMemoryOptimizations: def test_tables_cache_memory_efficiency(self): """Test that tables cache is memory efficient.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -183,7 +183,7 @@ def test_streaming_query_processing(self): **_base_config(), "use_server_side_cursors": True, } - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -212,7 +212,7 @@ def test_streaming_query_processing(self): def test_chunked_processing_batch_size(self): """Test that chunked processing uses appropriate batch sizes.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -288,7 +288,7 @@ def test_column_extraction_timing(self): def test_view_processing_timing_metrics(self): """Test that view processing timing is properly tracked.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -337,7 +337,7 @@ def test_connection_pool_metrics(self): **_base_config(), "max_workers": 2, } - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -359,7 +359,7 @@ def test_connection_pool_metrics(self): def test_database_level_metrics_tracking(self): """Test that database-level metrics are properly tracked.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -425,7 +425,7 @@ class TestThreadSafetyOptimizations: def test_tables_cache_lock_usage(self): """Test that tables cache uses locks for thread safety.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -445,7 +445,7 @@ def test_tables_cache_lock_usage(self): def test_report_lock_usage(self): """Test that report operations use locks for thread safety.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -465,7 +465,7 @@ def test_report_lock_usage(self): def test_pooled_engine_lock_usage(self): """Test that pooled engine creation uses locks.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" @@ -512,7 +512,7 @@ def test_parameterized_query_usage(self): def test_query_structure_optimization(self): """Test that queries are structured for optimal performance.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.ingestion.source.sql.teradata.SqlParsingAggregator" diff --git a/metadata-ingestion/tests/unit/test_teradata_source.py b/metadata-ingestion/tests/unit/test_teradata_source.py index d9d2108fe4f4b6..f140ada0cc042a 100644 --- a/metadata-ingestion/tests/unit/test_teradata_source.py +++ b/metadata-ingestion/tests/unit/test_teradata_source.py @@ -32,7 +32,7 @@ class TestTeradataConfig: def test_valid_config(self): """Test that valid configuration is accepted.""" config_dict = _base_config() - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) assert config.host_port == "localhost:1025" assert config.include_table_lineage is True @@ -45,13 +45,13 @@ def test_max_workers_validation_valid(self): **_base_config(), "max_workers": 8, } - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) assert config.max_workers == 8 def test_max_workers_default(self): """Test max_workers defaults to 10.""" config_dict = _base_config() - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) assert config.max_workers == 10 def test_max_workers_custom_value(self): @@ -60,13 +60,13 @@ def test_max_workers_custom_value(self): **_base_config(), "max_workers": 5, } - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) assert config.max_workers == 5 def test_include_queries_default(self): """Test include_queries defaults to True.""" config_dict = _base_config() - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) assert config.include_queries is True def test_time_window_defaults_applied(self): @@ -79,7 +79,7 @@ def test_time_window_defaults_applied(self): "include_usage_statistics": True, } - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) assert config.start_time is not None assert config.end_time is not None @@ -98,7 +98,7 @@ def test_incremental_lineage_config_support(self): "incremental_lineage": True, } - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) assert hasattr(config, "incremental_lineage") assert config.incremental_lineage is True @@ -107,10 +107,10 @@ def test_incremental_lineage_config_support(self): **_base_config(), "incremental_lineage": False, } - config_false = TeradataConfig.parse_obj(config_dict_false) + config_false = TeradataConfig.model_validate(config_dict_false) assert config_false.incremental_lineage is False - config_default = TeradataConfig.parse_obj(_base_config()) + config_default = TeradataConfig.model_validate(_base_config()) assert config_default.incremental_lineage is False def test_config_inheritance_chain(self): @@ -122,7 +122,7 @@ def test_config_inheritance_chain(self): "incremental_lineage": True, } - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) # Verify inheritance from BaseTimeWindowConfig assert hasattr(config, "start_time") @@ -147,7 +147,7 @@ def test_user_original_recipe_compatibility(self): "stateful_ingestion": {"enabled": True, "fail_safe_threshold": 90}, } - config = TeradataConfig.parse_obj(user_recipe_config) + config = TeradataConfig.model_validate(user_recipe_config) assert config.host_port == "vmvantage1720:1025" assert config.username == "dbc" @@ -168,7 +168,7 @@ class TestTeradataSource: @patch("datahub.ingestion.source.sql.teradata.create_engine") def test_source_initialization(self, mock_create_engine): """Test source initializes correctly.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) ctx = PipelineContext(run_id="test") # Mock the engine creation @@ -207,7 +207,7 @@ def test_get_inspectors(self, mock_inspect, mock_create_engine): mock_engine.connect.return_value.__enter__.return_value = mock_connection mock_create_engine.return_value = mock_engine - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -231,7 +231,7 @@ def test_get_inspectors(self, mock_inspect, mock_create_engine): def test_cache_tables_and_views_thread_safety(self): """Test that cache operations are thread-safe.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -274,7 +274,7 @@ def test_cache_tables_and_views_thread_safety(self): def test_convert_entry_to_observed_query(self): """Test conversion of database entries to ObservedQuery objects.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -312,7 +312,7 @@ def test_convert_entry_to_observed_query(self): def test_convert_entry_to_observed_query_with_none_user(self): """Test ObservedQuery conversion handles None user correctly.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -339,7 +339,7 @@ def test_convert_entry_to_observed_query_with_none_user(self): def test_check_historical_table_exists_success(self): """Test historical table check when table exists.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -366,7 +366,7 @@ def test_check_historical_table_exists_success(self): def test_check_historical_table_exists_failure(self): """Test historical table check when table doesn't exist.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -394,7 +394,7 @@ def test_check_historical_table_exists_failure(self): def test_close_cleanup(self): """Test that close() properly cleans up resources.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -429,7 +429,7 @@ def test_make_lineage_queries_with_time_defaults(self): "include_usage_statistics": True, } - config = TeradataConfig.parse_obj(config_dict) + config = TeradataConfig.model_validate(config_dict) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -502,7 +502,7 @@ class TestMemoryEfficiency: def test_fetch_lineage_entries_chunked_streaming(self): """Test that lineage entries are processed in streaming fashion.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -547,7 +547,7 @@ class TestConcurrencySupport: def test_tables_cache_thread_safety(self): """Test that tables cache operations are thread-safe.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -570,7 +570,7 @@ def test_tables_cache_thread_safety(self): def test_cached_loop_tables_safe_access(self): """Test cached_loop_tables uses safe cache access.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -619,7 +619,7 @@ class TestStageTracking: def test_stage_tracking_in_cache_operation(self): """Test that table caching uses stage tracking.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) # Create source without mocking to test the actual stage tracking during init with ( @@ -635,7 +635,7 @@ def test_stage_tracking_in_cache_operation(self): def test_stage_tracking_in_aggregator_processing(self): """Test that aggregator processing uses stage tracking.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -673,7 +673,7 @@ class TestErrorHandling: def test_empty_lineage_entries(self): """Test handling of empty lineage entries.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -696,7 +696,7 @@ def test_empty_lineage_entries(self): def test_malformed_query_entry(self): """Test handling of malformed query entries.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -732,7 +732,7 @@ class TestLineageQuerySeparation: def test_make_lineage_queries_current_only(self): """Test that only current query is returned when historical lineage is disabled.""" - config = TeradataConfig.parse_obj( + config = TeradataConfig.model_validate( { **_base_config(), "include_historical_lineage": False, @@ -762,7 +762,7 @@ def test_make_lineage_queries_current_only(self): def test_make_lineage_queries_with_historical_available(self): """Test that UNION query is returned when historical lineage is enabled and table exists.""" - config = TeradataConfig.parse_obj( + config = TeradataConfig.model_validate( { **_base_config(), "include_historical_lineage": True, @@ -802,7 +802,7 @@ def test_make_lineage_queries_with_historical_available(self): def test_make_lineage_queries_with_historical_unavailable(self): """Test that only current query is returned when historical lineage is enabled but table doesn't exist.""" - config = TeradataConfig.parse_obj( + config = TeradataConfig.model_validate( { **_base_config(), "include_historical_lineage": True, @@ -833,7 +833,7 @@ def test_make_lineage_queries_with_historical_unavailable(self): def test_make_lineage_queries_with_database_filter(self): """Test that database filters are correctly applied to UNION query.""" - config = TeradataConfig.parse_obj( + config = TeradataConfig.model_validate( { **_base_config(), "include_historical_lineage": True, @@ -868,7 +868,7 @@ def test_make_lineage_queries_with_database_filter(self): def test_fetch_lineage_entries_chunked_multiple_queries(self): """Test that _fetch_lineage_entries_chunked handles multiple queries correctly.""" - config = TeradataConfig.parse_obj( + config = TeradataConfig.model_validate( { **_base_config(), "include_historical_lineage": True, @@ -931,7 +931,7 @@ def test_fetch_lineage_entries_chunked_multiple_queries(self): def test_fetch_lineage_entries_chunked_single_query(self): """Test that _fetch_lineage_entries_chunked handles single query correctly.""" - config = TeradataConfig.parse_obj( + config = TeradataConfig.model_validate( { **_base_config(), "include_historical_lineage": False, @@ -984,7 +984,7 @@ def test_fetch_lineage_entries_chunked_single_query(self): def test_fetch_lineage_entries_chunked_batch_processing(self): """Test that batch processing works correctly with configurable batch size.""" - config = TeradataConfig.parse_obj( + config = TeradataConfig.model_validate( { **_base_config(), "include_historical_lineage": False, @@ -1045,7 +1045,7 @@ def test_fetch_lineage_entries_chunked_batch_processing(self): def test_end_to_end_separate_queries_integration(self): """Test end-to-end integration of separate queries in the aggregator flow.""" - config = TeradataConfig.parse_obj( + config = TeradataConfig.model_validate( { **_base_config(), "include_historical_lineage": True, @@ -1112,7 +1112,7 @@ def mock_fetch_generator(): def test_query_logging_and_progress_tracking(self): """Test that proper logging occurs when processing multiple queries.""" - config = TeradataConfig.parse_obj( + config = TeradataConfig.model_validate( { **_base_config(), "include_historical_lineage": True, @@ -1186,7 +1186,7 @@ class TestQueryConstruction: def test_current_query_construction(self): """Test that the current query is constructed correctly.""" - config = TeradataConfig.parse_obj( + config = TeradataConfig.model_validate( { **_base_config(), "start_time": "2024-01-01T00:00:00Z", @@ -1218,7 +1218,7 @@ def test_current_query_construction(self): def test_historical_query_construction(self): """Test that the UNION query contains historical data correctly.""" - config = TeradataConfig.parse_obj( + config = TeradataConfig.model_validate( { **_base_config(), "include_historical_lineage": True, @@ -1260,7 +1260,7 @@ class TestStreamingQueryReconstruction: def test_reconstruct_queries_streaming_single_row_queries(self): """Test streaming reconstruction with single-row queries.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -1301,7 +1301,7 @@ def test_reconstruct_queries_streaming_single_row_queries(self): def test_reconstruct_queries_streaming_multi_row_queries(self): """Test streaming reconstruction with multi-row queries.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -1353,7 +1353,7 @@ def test_reconstruct_queries_streaming_multi_row_queries(self): def test_reconstruct_queries_streaming_mixed_queries(self): """Test streaming reconstruction with mixed single and multi-row queries.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -1411,7 +1411,7 @@ def test_reconstruct_queries_streaming_mixed_queries(self): def test_reconstruct_queries_streaming_empty_entries(self): """Test streaming reconstruction with empty entries.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -1431,7 +1431,7 @@ def test_reconstruct_queries_streaming_empty_entries(self): def test_reconstruct_queries_streaming_teradata_specific_transformations(self): """Test that Teradata-specific transformations are applied.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -1463,7 +1463,7 @@ def test_reconstruct_queries_streaming_teradata_specific_transformations(self): def test_reconstruct_queries_streaming_metadata_preservation(self): """Test that all metadata fields are preserved correctly.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -1506,7 +1506,7 @@ def test_reconstruct_queries_streaming_metadata_preservation(self): def test_reconstruct_queries_streaming_with_none_user(self): """Test streaming reconstruction handles None user correctly.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -1534,7 +1534,7 @@ def test_reconstruct_queries_streaming_with_none_user(self): def test_reconstruct_queries_streaming_empty_query_text(self): """Test streaming reconstruction handles empty query text correctly.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" @@ -1564,7 +1564,7 @@ def test_reconstruct_queries_streaming_empty_query_text(self): def test_reconstruct_queries_streaming_space_joining_behavior(self): """Test that query parts are joined directly without adding spaces.""" - config = TeradataConfig.parse_obj(_base_config()) + config = TeradataConfig.model_validate(_base_config()) with patch( "datahub.sql_parsing.sql_parsing_aggregator.SqlParsingAggregator" diff --git a/metadata-ingestion/tests/unit/test_unity_catalog_config.py b/metadata-ingestion/tests/unit/test_unity_catalog_config.py index a1f6b4f9516e04..8c0b4e60d937a6 100644 --- a/metadata-ingestion/tests/unit/test_unity_catalog_config.py +++ b/metadata-ingestion/tests/unit/test_unity_catalog_config.py @@ -11,7 +11,7 @@ @freeze_time(FROZEN_TIME) def test_within_thirty_days(): - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://workspace_url", @@ -26,7 +26,7 @@ def test_within_thirty_days(): with pytest.raises( ValueError, match="Query history is only maintained for 30 days." ): - UnityCatalogSourceConfig.parse_obj( + UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://workspace_url", @@ -37,7 +37,7 @@ def test_within_thirty_days(): def test_profiling_requires_warehouses_id(): - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://workspace_url", @@ -51,7 +51,7 @@ def test_profiling_requires_warehouses_id(): ) assert config.profiling.enabled is True - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://workspace_url", @@ -63,7 +63,7 @@ def test_profiling_requires_warehouses_id(): assert config.profiling.enabled is False with pytest.raises(ValueError): - UnityCatalogSourceConfig.parse_obj( + UnityCatalogSourceConfig.model_validate( { "token": "token", "include_hive_metastore": False, @@ -76,7 +76,7 @@ def test_profiling_requires_warehouses_id(): @freeze_time(FROZEN_TIME) def test_workspace_url_should_start_with_https(): with pytest.raises(ValueError, match="Workspace URL must start with http scheme"): - UnityCatalogSourceConfig.parse_obj( + UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "workspace_url", @@ -86,7 +86,7 @@ def test_workspace_url_should_start_with_https(): def test_global_warehouse_id_is_set_from_profiling(): - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://XXXXXXXXXXXXXXXXXXXXX", @@ -106,7 +106,7 @@ def test_set_different_warehouse_id_from_profiling(): ValueError, match="When `warehouse_id` is set, it must match the `warehouse_id` in `profiling`.", ): - UnityCatalogSourceConfig.parse_obj( + UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://XXXXXXXXXXXXXXXXXXXXX", @@ -122,7 +122,7 @@ def test_set_different_warehouse_id_from_profiling(): def test_warehouse_id_must_be_set_if_include_hive_metastore_is_true(): """Test that include_hive_metastore is auto-disabled when warehouse_id is missing.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://XXXXXXXXXXXXXXXXXXXXX", @@ -150,7 +150,7 @@ def test_warehouse_id_must_be_present_test_connection(): def test_set_profiling_warehouse_id_from_global(): - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://XXXXXXXXXXXXXXXXXXXXX", @@ -166,7 +166,7 @@ def test_set_profiling_warehouse_id_from_global(): def test_warehouse_id_auto_disables_tags_when_missing(): """Test that include_tags is automatically disabled when warehouse_id is missing.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -181,7 +181,7 @@ def test_warehouse_id_auto_disables_tags_when_missing(): def test_warehouse_id_not_required_when_tags_disabled(): """Test that warehouse_id is not required when include_tags=False.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -196,7 +196,7 @@ def test_warehouse_id_not_required_when_tags_disabled(): def test_warehouse_id_explicit_true_auto_disables(): """Test that explicitly setting include_tags=True gets auto-disabled when warehouse_id is missing.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -212,7 +212,7 @@ def test_warehouse_id_explicit_true_auto_disables(): def test_warehouse_id_with_tags_enabled_succeeds(): """Test that providing warehouse_id with include_tags=True succeeds.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -227,7 +227,7 @@ def test_warehouse_id_with_tags_enabled_succeeds(): def test_warehouse_id_validation_with_hive_metastore_precedence(): """Test that both hive_metastore and tags are auto-disabled when warehouse_id is missing.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -244,7 +244,7 @@ def test_warehouse_id_validation_with_hive_metastore_precedence(): def test_databricks_api_page_size_default(): """Test that databricks_api_page_size defaults to 0.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -257,7 +257,7 @@ def test_databricks_api_page_size_default(): def test_databricks_api_page_size_valid_values(): """Test that databricks_api_page_size accepts valid positive integers.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -268,7 +268,7 @@ def test_databricks_api_page_size_valid_values(): ) assert config.databricks_api_page_size == 100 - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -282,7 +282,7 @@ def test_databricks_api_page_size_valid_values(): def test_databricks_api_page_size_zero_allowed(): """Test that databricks_api_page_size allows zero (default behavior).""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -297,7 +297,7 @@ def test_databricks_api_page_size_zero_allowed(): def test_databricks_api_page_size_negative_invalid(): """Test that databricks_api_page_size rejects negative values.""" with pytest.raises(ValueError, match="Input should be greater than or equal to 0"): - UnityCatalogSourceConfig.parse_obj( + UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -308,7 +308,7 @@ def test_databricks_api_page_size_negative_invalid(): ) with pytest.raises(ValueError, match="Input should be greater than or equal to 0"): - UnityCatalogSourceConfig.parse_obj( + UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -321,7 +321,7 @@ def test_databricks_api_page_size_negative_invalid(): def test_include_ml_model_default(): """Test that include_ml_model_aliases defaults to False.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -334,7 +334,7 @@ def test_include_ml_model_default(): def test_include_ml_model_aliases_explicit_true(): """Test that include_ml_model_aliases can be set to True.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -347,7 +347,7 @@ def test_include_ml_model_aliases_explicit_true(): def test_ml_model_max_results_valid_values(): """Test that ml_model_max_results accepts valid positive integers.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -357,7 +357,7 @@ def test_ml_model_max_results_valid_values(): ) assert config.ml_model_max_results == 2000 - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -371,7 +371,7 @@ def test_ml_model_max_results_valid_values(): def test_ml_model_max_results_negative_invalid(): """Test that ml_model_max_results rejects negative values.""" with pytest.raises(ValueError): - UnityCatalogSourceConfig.parse_obj( + UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -383,7 +383,7 @@ def test_ml_model_max_results_negative_invalid(): def test_lineage_data_source_default(): """Test that lineage_data_source defaults to AUTO.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -402,7 +402,7 @@ def test_lineage_data_source_system_tables_requires_warehouse_id(): ValueError, match="lineage_data_source='SYSTEM_TABLES' requires warehouse_id to be set", ): - UnityCatalogSourceConfig.parse_obj( + UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -415,7 +415,7 @@ def test_lineage_data_source_system_tables_requires_warehouse_id(): def test_lineage_data_source_api_does_not_require_warehouse(): """Test that lineage_data_source=API does not require warehouse_id.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -432,7 +432,7 @@ def test_lineage_data_source_api_does_not_require_warehouse(): def test_usage_data_source_default(): """Test that usage_data_source defaults to AUTO.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -451,7 +451,7 @@ def test_usage_data_source_system_tables_requires_warehouse_id(): ValueError, match="usage_data_source='SYSTEM_TABLES' requires warehouse_id to be set", ): - UnityCatalogSourceConfig.parse_obj( + UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -464,7 +464,7 @@ def test_usage_data_source_system_tables_requires_warehouse_id(): def test_usage_data_source_api_does_not_require_warehouse(): """Test that usage_data_source=API does not require warehouse_id.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", @@ -481,7 +481,7 @@ def test_usage_data_source_api_does_not_require_warehouse(): def test_usage_data_source_can_be_set_with_warehouse(): """Test that usage_data_source can be set to SYSTEM_TABLES with warehouse_id.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "token", "workspace_url": "https://test.databricks.com", diff --git a/metadata-ingestion/tests/unit/test_unity_catalog_source.py b/metadata-ingestion/tests/unit/test_unity_catalog_source.py index 23056154756865..a44e74780b7358 100644 --- a/metadata-ingestion/tests/unit/test_unity_catalog_source.py +++ b/metadata-ingestion/tests/unit/test_unity_catalog_source.py @@ -12,7 +12,7 @@ class TestUnityCatalogSource: @pytest.fixture def minimal_config(self): """Create a minimal config for testing.""" - return UnityCatalogSourceConfig.parse_obj( + return UnityCatalogSourceConfig.model_validate( { "token": "test_token", "workspace_url": "https://test.databricks.com", @@ -24,7 +24,7 @@ def minimal_config(self): @pytest.fixture def config_with_page_size(self): """Create a config with custom page size.""" - return UnityCatalogSourceConfig.parse_obj( + return UnityCatalogSourceConfig.model_validate( { "token": "test_token", "workspace_url": "https://test.databricks.com", @@ -37,7 +37,7 @@ def config_with_page_size(self): @pytest.fixture def config_with_ml_model_settings(self): """Create a config with ML model settings.""" - return UnityCatalogSourceConfig.parse_obj( + return UnityCatalogSourceConfig.model_validate( { "token": "test_token", "workspace_url": "https://test.databricks.com", @@ -110,7 +110,7 @@ def test_source_with_hive_metastore_disabled( self, mock_hive_proxy, mock_unity_proxy ): """Test that UnityCatalogSource works with hive metastore disabled.""" - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "test_token", "workspace_url": "https://test.databricks.com", @@ -174,7 +174,7 @@ def test_source_report_includes_ml_model_stats( mock_unity_instance.catalogs.return_value = [] mock_unity_instance.check_basic_connectivity.return_value = True - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "test_token", "workspace_url": "https://test.databricks.com", @@ -242,7 +242,7 @@ def test_process_ml_model_generates_workunits( Schema, ) - config = UnityCatalogSourceConfig.parse_obj( + config = UnityCatalogSourceConfig.model_validate( { "token": "test_token", "workspace_url": "https://test.databricks.com", diff --git a/metadata-ingestion/tests/unit/test_vertica_source.py b/metadata-ingestion/tests/unit/test_vertica_source.py index de888ddb559242..bc9335ed925433 100644 --- a/metadata-ingestion/tests/unit/test_vertica_source.py +++ b/metadata-ingestion/tests/unit/test_vertica_source.py @@ -2,7 +2,7 @@ def test_vertica_uri_https(): - config = VerticaConfig.parse_obj( + config = VerticaConfig.model_validate( { "username": "user", "password": "password",