diff --git a/docs/examples/model_coverage/__init__.py b/docs/examples/model_coverage/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/docs/examples/model_coverage/test_example_1.py b/docs/examples/model_coverage/test_example_1.py new file mode 100644 index 00000000..ced96f4a --- /dev/null +++ b/docs/examples/model_coverage/test_example_1.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +from polyfactory.factories.dataclass_factory import DataclassFactory + + +@dataclass +class Car: + model: str + + +@dataclass +class Boat: + can_float: bool + + +@dataclass +class Profile: + age: int + favourite_color: Literal["red", "green", "blue"] + vehicle: Car | Boat + + +class ProfileFactory(DataclassFactory[Profile]): + __model__ = Profile + + +def test_profile_coverage() -> None: + profiles = list(ProfileFactory.coverage()) + + assert profiles[0].favourite_color == "red" + assert isinstance(profiles[0].vehicle, Car) + assert profiles[1].favourite_color == "green" + assert isinstance(profiles[1].vehicle, Boat) + assert profiles[2].favourite_color == "blue" + assert isinstance(profiles[2].vehicle, Car) diff --git a/docs/examples/model_coverage/test_example_2.py b/docs/examples/model_coverage/test_example_2.py new file mode 100644 index 00000000..bf67f959 --- /dev/null +++ b/docs/examples/model_coverage/test_example_2.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + +from polyfactory.factories.dataclass_factory import DataclassFactory + + +@dataclass +class Car: + model: str + + +@dataclass +class Boat: + can_float: bool + + +@dataclass +class Profile: + age: int + favourite_color: Literal["red", "green", "blue"] + vehicle: Car | Boat + + +@dataclass +class SocialGroup: + members: list[Profile] + + +class SocialGroupFactory(DataclassFactory[SocialGroup]): + __model__ = SocialGroup + + +def test_social_group_coverage() -> None: + groups = list(SocialGroupFactory.coverage()) + assert len(groups) == 3 + + for group in groups: + assert len(group.members) == 1 + + assert groups[0].members[0].favourite_color == "red" + assert isinstance(groups[0].members[0].vehicle, Car) + assert groups[1].members[0].favourite_color == "green" + assert isinstance(groups[1].members[0].vehicle, Boat) + assert groups[2].members[0].favourite_color == "blue" + assert isinstance(groups[2].members[0].vehicle, Car) diff --git a/docs/usage/index.rst b/docs/usage/index.rst index 0e7f922a..891d11b7 100644 --- a/docs/usage/index.rst +++ b/docs/usage/index.rst @@ -12,3 +12,4 @@ Usage Guide decorators fixtures handling_custom_types + model_coverage diff --git a/docs/usage/model_coverage.rst b/docs/usage/model_coverage.rst new file mode 100644 index 00000000..753dedde --- /dev/null +++ b/docs/usage/model_coverage.rst @@ -0,0 +1,29 @@ +Model coverage generation +========================= + +The ``BaseFactory.coverage()`` function is an alternative approach to ``BaseFactory.batch()``, where the examples that are generated attempt to provide full coverage of all the forms a model can take with the minimum number of instances. For example: + +.. literalinclude:: /examples/model_coverage/test_example_1.py + :caption: Defining a factory and generating examples with coverage + :language: python + +As you can see in the above example, the ``Profile`` model has 3 options for ``favourite_color``, and 2 options for ``vehicle``. In the output you can expect to see instances of ``Profile`` that have each of these options. The largest variance dictates the length of the output, in this case ``favourite_color`` has the most, at 3 options, so expect to see 3 ``Profile`` instances. + + +.. note:: + Notice that the same ``Car`` instance is used in the first and final generated example. When the coverage examples for a field are exhausted before another field, values for that field are re-used. + +Notes on collection types +------------------------- + +When generating coverage for models with fields that are collections, in particular collections that contain sub-models, the contents of the collection will be the all coverage examples for that sub-model. For example: + +.. literalinclude:: /examples/model_coverage/test_example_2.py + :caption: Coverage output for the SocialGroup model + :language: python + +Known Limitations +----------------- + +- Recursive models will cause an error: ``RecursionError: maximum recursion depth exceeded``. +- ``__min_collection_length__`` and ``__max_collection_length__`` are currently ignored in coverage generation. diff --git a/polyfactory/factories/base.py b/polyfactory/factories/base.py index 3b8a28c5..07a6c97d 100644 --- a/polyfactory/factories/base.py +++ b/polyfactory/factories/base.py @@ -1,5 +1,6 @@ from __future__ import annotations +import typing from abc import ABC, abstractmethod from collections import Counter, abc, deque from contextlib import suppress @@ -22,6 +23,12 @@ from os.path import realpath from pathlib import Path from random import Random + +try: + from types import NoneType +except ImportError: + NoneType = type(None) # type: ignore[misc,assignment] + from typing import ( TYPE_CHECKING, Any, @@ -46,13 +53,16 @@ MIN_COLLECTION_LENGTH, RANDOMIZE_COLLECTION_LENGTH, ) -from polyfactory.exceptions import ( - ConfigurationException, - MissingBuildKwargException, - ParameterException, -) +from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use -from polyfactory.utils.helpers import get_collection_type, unwrap_annotation, unwrap_args, unwrap_optional +from polyfactory.utils.helpers import ( + flatten_annotation, + get_collection_type, + unwrap_annotation, + unwrap_args, + unwrap_optional, +) +from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage from polyfactory.utils.predicates import ( get_type_origin, is_any, @@ -61,7 +71,7 @@ is_safe_subclass, is_union, ) -from polyfactory.value_generators.complex_types import handle_collection_type +from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage from polyfactory.value_generators.constrained_collections import ( handle_constrained_collection, handle_constrained_mapping, @@ -263,6 +273,32 @@ def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | N return field_value() if callable(field_value) else field_value + @classmethod + def _handle_factory_field_coverage(cls, field_value: Any, field_build_parameters: Any | None = None) -> Any: + """Handle a value defined on the factory class itself. + + :param field_value: A value defined as an attribute on the factory class. + :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + + :returns: An arbitrary value correlating with the given field_meta value. + """ + if is_safe_subclass(field_value, BaseFactory): + if isinstance(field_build_parameters, Mapping): + return CoverageContainer(field_value.coverage(**field_build_parameters)) + + if isinstance(field_build_parameters, Sequence): + return [CoverageContainer(field_value.coverage(**parameter)) for parameter in field_build_parameters] + + return CoverageContainer(field_value.coverage()) + + if isinstance(field_value, Use): + return field_value.to_value() + + if isinstance(field_value, Fixture): + return CoverageContainerCallable(field_value.to_value) + + return CoverageContainerCallable(field_value) if callable(field_value) else field_value + @classmethod def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]: """Get a factory from registered factories or generate a factory dynamically. @@ -635,6 +671,66 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912 msg, ) + @classmethod + def get_field_value_coverage( # noqa: C901 + cls, + field_meta: FieldMeta, + field_build_parameters: Any | None = None, + ) -> typing.Iterable[Any]: + """Return a field value on the subclass if existing, otherwise returns a mock value. + + :param field_meta: FieldMeta instance. + :param field_build_parameters: Any build parameters passed to the factory as kwarg values. + + :returns: An iterable of values. + + """ + if cls.is_ignored_type(field_meta.annotation): + return [None] + + for unwrapped_annotation in flatten_annotation(field_meta.annotation): + if unwrapped_annotation in (None, NoneType): + yield None + + elif is_literal(annotation=unwrapped_annotation) and (literal_args := get_args(unwrapped_annotation)): + yield CoverageContainer(literal_args) + + elif isinstance(unwrapped_annotation, EnumMeta): + yield CoverageContainer(list(unwrapped_annotation)) + + elif field_meta.constraints: + yield CoverageContainerCallable( + cls.get_constrained_field_value, + annotation=unwrapped_annotation, + field_meta=field_meta, + ) + + elif BaseFactory.is_factory_type(annotation=unwrapped_annotation): + yield CoverageContainer( + cls._get_or_create_factory(model=unwrapped_annotation).coverage( + **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}), + ), + ) + + elif (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection): + yield handle_collection_type_coverage(field_meta, origin, cls) + + elif is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar): + yield create_random_string(cls.__random__, min_length=1, max_length=10) + + elif provider := cls.get_provider_map().get(unwrapped_annotation): + yield CoverageContainerCallable(provider) + + elif callable(unwrapped_annotation): + # if value is a callable we can try to naively call it. + # this will work for callables that do not require any parameters passed + yield CoverageContainerCallable(unwrapped_annotation) + else: + msg = f"Unsupported type: {unwrapped_annotation!r}\n\nEither extend the providers map or add a factory function for this type." + raise ParameterException( + msg, + ) + @classmethod def should_set_none_value(cls, field_meta: FieldMeta) -> bool: """Determine whether a given model field_meta should be set to None. @@ -752,6 +848,50 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]: return result + @classmethod + def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]: + """Process the given kwargs and generate values for the factory's model. + + :param kwargs: Any build kwargs. + + :returns: A dictionary of build results. + + """ + result: dict[str, Any] = {**kwargs} + generate_post: dict[str, PostGenerated] = {} + + for field_meta in cls.get_model_fields(): + field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs) + + if cls.should_set_field_value(field_meta, **kwargs): + if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name): + field_value = getattr(cls, field_meta.name) + if isinstance(field_value, Ignore): + continue + + if isinstance(field_value, Require) and field_meta.name not in kwargs: + msg = f"Require kwarg {field_meta.name} is missing" + raise MissingBuildKwargException(msg) + + if isinstance(field_value, PostGenerated): + generate_post[field_meta.name] = field_value + continue + + result[field_meta.name] = cls._handle_factory_field_coverage( + field_value=field_value, + field_build_parameters=field_build_parameters, + ) + continue + + result[field_meta.name] = CoverageContainer( + cls.get_field_value_coverage(field_meta, field_build_parameters=field_build_parameters), + ) + + for resolved in resolve_kwargs_coverage(result): + for field_name, post_generator in generate_post.items(): + resolved[field_name] = post_generator.to_value(field_name, resolved) + yield resolved + @classmethod def build(cls, **kwargs: Any) -> T: """Build an instance of the factory's __model__ @@ -776,6 +916,19 @@ def batch(cls, size: int, **kwargs: Any) -> list[T]: """ return [cls.build(**kwargs) for _ in range(size)] + @classmethod + def coverage(cls, **kwargs: Any) -> abc.Iterator[T]: + """Build a batch of the factory's Meta.model will full coverage of the sub-types of the model. + + :param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used. + + :returns: A iterator of instances of type T. + + """ + for data in cls.process_kwargs_coverage(**kwargs): + instance = cls.__model__(**data) + yield cast("T", instance) + @classmethod def create_sync(cls, **kwargs: Any) -> T: """Build and persists synchronously a single model instance. diff --git a/polyfactory/utils/helpers.py b/polyfactory/utils/helpers.py index 4cafc7de..da34a8b1 100644 --- a/polyfactory/utils/helpers.py +++ b/polyfactory/utils/helpers.py @@ -3,6 +3,11 @@ import sys from typing import TYPE_CHECKING, Any, Mapping +try: + from types import NoneType +except ImportError: + NoneType = type(None) # type: ignore[misc,assignment] + from typing_extensions import get_args, get_origin from polyfactory.constants import TYPE_MAPPING @@ -52,7 +57,7 @@ def unwrap_optional(annotation: Any) -> Any: :returns: A type annotation """ while is_optional(annotation): - annotation = next(arg for arg in get_args(annotation) if arg not in (type(None), None)) + annotation = next(arg for arg in get_args(annotation) if arg not in (NoneType, None)) return annotation @@ -77,6 +82,30 @@ def unwrap_annotation(annotation: Any, random: Random) -> Any: return annotation +def flatten_annotation(annotation: Any) -> list[Any]: + """Flattens an annotation. + + :param annotation: A type annotation. + + :returns: The flattened annotations. + """ + flat = [] + if is_new_type(annotation): + flat.extend(flatten_annotation(unwrap_new_type(annotation))) + elif is_optional(annotation): + flat.append(NoneType) + flat.extend(flatten_annotation(arg) for arg in get_args(annotation) if arg not in (NoneType, None)) + elif is_annotated(annotation): + flat.extend(flatten_annotation(get_args(annotation)[0])) + elif is_union(annotation): + for a in get_args(annotation): + flat.extend(flatten_annotation(a)) + else: + flat.append(annotation) + + return flat + + def unwrap_args(annotation: Any, random: Random) -> tuple[Any, ...]: """Unwrap the annotation and return any type args. diff --git a/polyfactory/utils/model_coverage.py b/polyfactory/utils/model_coverage.py new file mode 100644 index 00000000..6fc39714 --- /dev/null +++ b/polyfactory/utils/model_coverage.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Iterator, Mapping, MutableSequence +from typing import AbstractSet, Any, Generic, Set, TypeVar, cast + +from typing_extensions import ParamSpec + +from polyfactory.exceptions import ParameterException + + +class CoverageContainerBase(ABC): + """Base class for coverage container implementations. + + A coverage container is a wrapper providing values for a particular field. Coverage containers return field values and + track a "done" state to indicate that all coverage examples have been generated. + """ + + @abstractmethod + def next_value(self) -> Any: + """Provide the next value""" + ... + + @abstractmethod + def is_done(self) -> bool: + """Indicate if this container has provided every coverage example it has""" + ... + + +T = TypeVar("T") + + +class CoverageContainer(CoverageContainerBase, Generic[T]): + """A coverage container that wraps a collection of values. + + When calling ``next_value()`` a greater number of times than the length of the given collection will cause duplicate + examples to be returned (wraps around). + + If there are any coverage containers within the given collection, the values from those containers are essentially merged + into the parent container. + """ + + def __init__(self, instances: Iterable[T]) -> None: + self._pos = 0 + self._instances = list(instances) + if not self._instances: + msg = "CoverageContainer must have at least one instance" + raise ValueError(msg) + + def next_value(self) -> T: + value = self._instances[self._pos % len(self._instances)] + if isinstance(value, CoverageContainerBase): + result = value.next_value() + if value.is_done(): + # Only move onto the next instance if the sub-container is done + self._pos += 1 + return cast(T, result) + + self._pos += 1 + return value + + def is_done(self) -> bool: + return self._pos >= len(self._instances) + + def __repr__(self) -> str: + return f"CoverageContainer(instances={self._instances}, is_done={self.is_done()})" + + +P = ParamSpec("P") + + +class CoverageContainerCallable(CoverageContainerBase, Generic[T]): + """A coverage container that wraps a callable. + + When calling ``next_value()`` the wrapped callable is called to provide a value. + """ + + def __init__(self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> None: + self._func = func + self._args = args + self._kwargs = kwargs + + def next_value(self) -> T: + try: + return self._func(*self._args, **self._kwargs) + except Exception as e: # noqa: BLE001 + msg = f"Unsupported type: {self._func!r}\n\nEither extend the providers map or add a factory function for this type." + raise ParameterException(msg) from e + + def is_done(self) -> bool: + return True + + +def _resolve_next(unresolved: Any) -> tuple[Any, bool]: # noqa: C901 + if isinstance(unresolved, CoverageContainerBase): + result, done = _resolve_next(unresolved.next_value()) + return result, unresolved.is_done() and done + + if isinstance(unresolved, Mapping): + result = {} + done_status = True + for key, value in unresolved.items(): + val_resolved, val_done = _resolve_next(value) + key_resolved, key_done = _resolve_next(key) + result[key_resolved] = val_resolved + done_status = done_status and val_done and key_done + return result, done_status + + if isinstance(unresolved, (tuple, MutableSequence)): + result = [] + done_status = True + for value in unresolved: + resolved, done = _resolve_next(value) + result.append(resolved) + done_status = done_status and done + if isinstance(unresolved, tuple): + result = tuple(result) + return result, done_status + + if isinstance(unresolved, Set): + result = type(unresolved)() + done_status = True + for value in unresolved: + resolved, done = _resolve_next(value) + result.add(resolved) + done_status = done_status and done + return result, done_status + + if issubclass(type(unresolved), AbstractSet): + result = type(unresolved)() + done_status = True + resolved_values = [] + for value in unresolved: + resolved, done = _resolve_next(value) + resolved_values.append(resolved) + done_status = done_status and done + return result.union(resolved_values), done_status + + return unresolved, True + + +def resolve_kwargs_coverage(kwargs: dict[str, Any]) -> Iterator[dict[str, Any]]: + done = False + while not done: + resolved, done = _resolve_next(kwargs) + yield resolved diff --git a/polyfactory/value_generators/complex_types.py b/polyfactory/value_generators/complex_types.py index 46d39df3..29def9f9 100644 --- a/polyfactory/value_generators/complex_types.py +++ b/polyfactory/value_generators/complex_types.py @@ -1,20 +1,11 @@ from __future__ import annotations -from typing import ( - TYPE_CHECKING, - AbstractSet, - Any, - Iterable, - MutableMapping, - MutableSequence, - Set, - Tuple, - cast, -) +from typing import TYPE_CHECKING, AbstractSet, Any, Iterable, MutableMapping, MutableSequence, Set, Tuple, cast from typing_extensions import is_typeddict from polyfactory.field_meta import FieldMeta +from polyfactory.utils.model_coverage import CoverageContainer if TYPE_CHECKING: from polyfactory.factories.base import BaseFactory @@ -60,3 +51,56 @@ def handle_collection_type(field_meta: FieldMeta, container_type: type, factory: msg = f"Unsupported container type: {container_type}" raise NotImplementedError(msg) + + +def handle_collection_type_coverage( + field_meta: FieldMeta, + container_type: type, + factory: type[BaseFactory[Any]], +) -> Any: + """Handle coverage generation of container types recursively. + + :param container_type: A type that can accept type arguments. + :param factory: A factory. + :param field_meta: A field meta instance. + + :returns: An unresolved built result. + """ + container = container_type() + if not field_meta.children: + return container + + if issubclass(container_type, MutableMapping) or is_typeddict(container_type): + for key_field_meta, value_field_meta in cast( + Iterable[Tuple[FieldMeta, FieldMeta]], + zip(field_meta.children[::2], field_meta.children[1::2]), + ): + key = CoverageContainer(factory.get_field_value_coverage(key_field_meta)) + value = CoverageContainer(factory.get_field_value_coverage(value_field_meta)) + container[key] = value + return container + + if issubclass(container_type, MutableSequence): + container_instance = container_type() + for subfield_meta in field_meta.children: + container_instance.extend(factory.get_field_value_coverage(subfield_meta)) + + return container_instance + + if issubclass(container_type, Set): + set_instance = container_type() + for subfield_meta in field_meta.children: + set_instance = set_instance.union(factory.get_field_value_coverage(subfield_meta)) + + return set_instance + + if issubclass(container_type, AbstractSet): + return container.union(handle_collection_type_coverage(field_meta, set, factory)) + + if issubclass(container_type, tuple): + return container_type( + CoverageContainer(factory.get_field_value_coverage(subfield_meta)) for subfield_meta in field_meta.children + ) + + msg = f"Unsupported container type: {container_type}" + raise NotImplementedError(msg) diff --git a/tests/test_type_coverage_generation.py b/tests/test_type_coverage_generation.py new file mode 100644 index 00000000..047c9954 --- /dev/null +++ b/tests/test_type_coverage_generation.py @@ -0,0 +1,214 @@ +# ruff: noqa: UP007, UP006 +from __future__ import annotations + +from dataclasses import dataclass, make_dataclass +from datetime import date +from typing import Dict, FrozenSet, List, Literal, Set, Tuple, Union +from uuid import UUID + +import pytest +from typing_extensions import TypedDict + +from polyfactory.decorators import post_generated +from polyfactory.exceptions import ParameterException +from polyfactory.factories.dataclass_factory import DataclassFactory +from polyfactory.factories.typed_dict_factory import TypedDictFactory + + +def test_coverage_count() -> None: + @dataclass + class Profile: + name: str + high_score: Union[int, float] + dob: date + data: Union[str, date, int, float] + + class ProfileFactory(DataclassFactory[Profile]): + __model__ = Profile + + results = list(ProfileFactory.coverage()) + + assert len(results) == 4 + + for result in results: + assert isinstance(result, Profile) + + +def test_coverage_tuple() -> None: + @dataclass + class Pair: + tuple_: Tuple[Union[int, str], Tuple[Union[int, float], int]] + + class TupleFactory(DataclassFactory[Pair]): + __model__ = Pair + + results = list(TupleFactory.coverage()) + + assert len(results) == 2 + + a0, (b0, c0) = results[0].tuple_ + a1, (b1, c1) = results[1].tuple_ + + assert isinstance(a0, int) and isinstance(b0, int) and isinstance(c0, int) + assert isinstance(a1, str) and isinstance(b1, float) and isinstance(c1, int) + + +@pytest.mark.parametrize( + "collection_annotation", + (Set[Union[int, str]], List[Union[int, str]], FrozenSet[Union[int, str]]), +) +def test_coverage_collection(collection_annotation: type) -> None: + Collective = make_dataclass("Collective", [("collection", collection_annotation)]) + + class CollectiveFactory(DataclassFactory[Collective]): # type: ignore + __model__ = Collective + + results = list(CollectiveFactory.coverage()) + + assert len(results) == 1 + + result = results[0] + + collection = result.collection # type: ignore + + assert len(collection) == 2 + + v0, v1 = collection + assert {type(v0), type(v1)} == {int, str} + + +def test_coverage_literal() -> None: + @dataclass + class Literally: + literal: Literal["a", "b", 1, 2] + + class LiterallyFactory(DataclassFactory[Literally]): + __model__ = Literally + + results = list(LiterallyFactory.coverage()) + + assert len(results) == 4 + + assert results[0].literal == "a" + assert results[1].literal == "b" + assert results[2].literal == 1 + assert results[3].literal == 2 + + +def test_coverage_dict() -> None: + @dataclass + class Thesaurus: + dict_simple: Dict[str, int] + dict_more_key_types: Dict[Union[str, int, float], Union[int, str]] + dict_more_value_types: Dict[str, Union[int, str]] + + class ThesaurusFactory(DataclassFactory[Thesaurus]): + __model__ = Thesaurus + + results = list(ThesaurusFactory.coverage()) + + assert len(results) == 3 + + +@pytest.mark.skip(reason="Does not support recursive types yet.") +def test_coverage_recursive() -> None: + @dataclass + class Recursive: + r: Union[Recursive, None] + + class RecursiveFactory(DataclassFactory[Recursive]): + __model__ = Recursive + + results = list(RecursiveFactory.coverage()) + assert len(results) == 2 + + +def test_coverage_typed_dict() -> None: + class TypedThesaurus(TypedDict): + number: int + string: str + union: Union[int, str] + collection: List[Union[int, str]] + + class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): + __model__ = TypedThesaurus + + results = list(TypedThesaurusFactory.coverage()) + + assert len(results) == 2 + + example = TypedThesaurusFactory.build() + for result in results: + assert result.keys() == example.keys() + + +def test_coverage_typed_dict_field() -> None: + class TypedThesaurus(TypedDict): + number: int + string: str + union: Union[int, str] + collection: List[Union[int, str]] + + class TypedThesaurusFactory(TypedDictFactory[TypedThesaurus]): + __model__ = TypedThesaurus + + results = list(TypedThesaurusFactory.coverage()) + + assert len(results) == 2 + + example = TypedThesaurusFactory.build() + + for result in results: + assert result.keys() == example.keys() + + +def test_coverage_values_unique() -> None: + @dataclass + class Unique: + uuid: UUID + data: Union[int, str] + + class UniqueFactory(DataclassFactory[Unique]): + __model__ = Unique + + results = list(UniqueFactory.coverage()) + + assert len(results) == 2 + assert results[0].uuid != results[1].uuid + + +def test_coverage_post_generated() -> None: + @dataclass + class Model: + i: int + j: int + + class Factory(DataclassFactory[Model]): + __model__ = Model + + @post_generated + @classmethod + def i(cls, j: int) -> int: + return j + 10 + + results = list(Factory.coverage()) + assert len(results) == 1 + + assert results[0].i == results[0].j + 10 + + +class CustomInt: + def __init__(self, value: int) -> None: + self.value = value + + +def test_coverage_parameter_exception() -> None: + @dataclass + class Model: + i: CustomInt + + class Factory(DataclassFactory[Model]): + __model__ = Model + + with pytest.raises(ParameterException): + list(Factory.coverage())