diff --git a/pydantic_extra_types/mongo_objectId.py b/pydantic_extra_types/mongo_objectId.py new file mode 100644 index 00000000..758859a5 --- /dev/null +++ b/pydantic_extra_types/mongo_objectId.py @@ -0,0 +1,40 @@ +from typing import Any + +try: + from bson import ObjectId +except ModuleNotFoundError: # pragma: no cover + raise RuntimeError( + 'The `ObjectIdField` module requires "bson" to be installed. You can install it with "pip install ' + 'bson".' + ) +from pydantic_core import core_schema +from pydantic_core import PydanticCustomError + + +class ObjectIdField(str): + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: Any + ) -> core_schema.CoreSchema: + object_id_schema = core_schema.chain_schema( + [ + core_schema.str_schema(), + core_schema.no_info_plain_validator_function(cls.validate), + ] + ) + return core_schema.json_or_python_schema( + json_schema=object_id_schema, + python_schema=core_schema.union_schema( + [core_schema.is_instance_schema(ObjectId), object_id_schema] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda x: str(x) + ), + ) + + @classmethod + def validate(cls, value): + try: + return ObjectId(value) + except bson.errors.InvalidId as invalid_id: + raise PydanticCustomError('value_error', 'invalid format for MongoDB object identifier') from invalid_id \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d713e4d1..c4adf2ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,8 @@ all = [ 'pycountry>=23', 'python-ulid>=1,<2; python_version<"3.9"', 'python-ulid>=1,<3; python_version>="3.9"', - 'pendulum>=3.0.0,<4.0.0' + 'pendulum>=3.0.0,<4.0.0', + 'bson>=0.5; python_version>="3.9"', ] [project.urls] diff --git a/requirements/pyproject.txt b/requirements/pyproject.txt index 5b774721..e7f7514b 100644 --- a/requirements/pyproject.txt +++ b/requirements/pyproject.txt @@ -10,6 +10,8 @@ pendulum==3.0.0 # via pydantic-extra-types (pyproject.toml) phonenumbers==8.13.31 # via pydantic-extra-types (pyproject.toml) +bson==0.5.10 +# via pydantic-extra-types (pyproject.toml) pycountry==23.12.11 # via pydantic-extra-types (pyproject.toml) pydantic==2.6.3 diff --git a/tests/test_json_schema.py b/tests/test_json_schema.py index 5daa2460..a18f9c49 100644 --- a/tests/test_json_schema.py +++ b/tests/test_json_schema.py @@ -18,6 +18,7 @@ from pydantic_extra_types.payment import PaymentCardNumber from pydantic_extra_types.pendulum_dt import DateTime from pydantic_extra_types.ulid import ULID +from pydantic_extra_types.mongo_objectId import ObjectIdField languages = [lang.alpha_3 for lang in pycountry.languages] language_families = [lang.alpha_3 for lang in pycountry.language_families] @@ -287,6 +288,20 @@ 'type': 'object', }, ), + ( + ObjectIdField, + { + 'properties': { + 'x': { + 'title': 'X', + 'type': 'string', + } + }, + 'required': ['x'], + 'title': 'Model', + 'type': 'object', + }, + ), ], ) def test_json_schema(cls, expected): diff --git a/tests/test_pyobjectid.py b/tests/test_pyobjectid.py new file mode 100644 index 00000000..f4976745 --- /dev/null +++ b/tests/test_pyobjectid.py @@ -0,0 +1,48 @@ +import pytest +from pydantic import BaseModel, ValidationError + +from pydantic_extra_types.mongo_objectId import ObjectIdField + + +class Something(BaseModel): + object_id: ObjectIdField + + +@pytest.mark.parametrize( + "object_id, result, valid", + [ + # Valid ObjectId for str format + ("611827f2878b88b49ebb69fc", "611827f2878b88b49ebb69fc", True), + ("611827f2878b88b49ebb69fd", "611827f2878b88b49ebb69fd", True), + # Invalid ObjectId for str format + ("611827f2878b88b49ebb69f", None, False), # Invalid ObjectId (short length) + ("611827f2878b88b49ebb69fca", None, False), # Invalid ObjectId (long length) + # Valid ObjectId for bytes format + ], +) +def test_format_for_object_id(object_id, result, valid): + if valid: + assert str(Something(object_id=object_id).object_id) == result + else: + with pytest.raises(ValidationError): + Something(object_id=object_id) + + +def test_json_schema(): + assert Something.model_json_schema(mode="validation") == { + "properties": {"object_id": {"title": "Object Id", "type": "string"}}, + "required": ["object_id"], + "title": "Something", + "type": "object", + } + assert Something.model_json_schema(mode="serialization") == { + "properties": { + "object_id": { + "anyOf": [{"type": "string"}], + "title": "Object Id", + } + }, + "required": ["object_id"], + "title": "Something", + "type": "object", + }