Skip to content

Commit b1e8b5e

Browse files
sam-orJacobCoffeeadhtruongvkcku
authored
feat(type-coverage-generation): model type coverage batch generation (#390)
Co-authored-by: Jacob Coffee <[email protected]> Co-authored-by: Andrew Truong <[email protected]> Co-authored-by: guacs <[email protected]>
1 parent 70d49fd commit b1e8b5e

File tree

10 files changed

+720
-19
lines changed

10 files changed

+720
-19
lines changed

docs/examples/model_coverage/__init__.py

Whitespace-only changes.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Literal
5+
6+
from polyfactory.factories.dataclass_factory import DataclassFactory
7+
8+
9+
@dataclass
10+
class Car:
11+
model: str
12+
13+
14+
@dataclass
15+
class Boat:
16+
can_float: bool
17+
18+
19+
@dataclass
20+
class Profile:
21+
age: int
22+
favourite_color: Literal["red", "green", "blue"]
23+
vehicle: Car | Boat
24+
25+
26+
class ProfileFactory(DataclassFactory[Profile]):
27+
__model__ = Profile
28+
29+
30+
def test_profile_coverage() -> None:
31+
profiles = list(ProfileFactory.coverage())
32+
33+
assert profiles[0].favourite_color == "red"
34+
assert isinstance(profiles[0].vehicle, Car)
35+
assert profiles[1].favourite_color == "green"
36+
assert isinstance(profiles[1].vehicle, Boat)
37+
assert profiles[2].favourite_color == "blue"
38+
assert isinstance(profiles[2].vehicle, Car)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Literal
5+
6+
from polyfactory.factories.dataclass_factory import DataclassFactory
7+
8+
9+
@dataclass
10+
class Car:
11+
model: str
12+
13+
14+
@dataclass
15+
class Boat:
16+
can_float: bool
17+
18+
19+
@dataclass
20+
class Profile:
21+
age: int
22+
favourite_color: Literal["red", "green", "blue"]
23+
vehicle: Car | Boat
24+
25+
26+
@dataclass
27+
class SocialGroup:
28+
members: list[Profile]
29+
30+
31+
class SocialGroupFactory(DataclassFactory[SocialGroup]):
32+
__model__ = SocialGroup
33+
34+
35+
def test_social_group_coverage() -> None:
36+
groups = list(SocialGroupFactory.coverage())
37+
assert len(groups) == 3
38+
39+
for group in groups:
40+
assert len(group.members) == 1
41+
42+
assert groups[0].members[0].favourite_color == "red"
43+
assert isinstance(groups[0].members[0].vehicle, Car)
44+
assert groups[1].members[0].favourite_color == "green"
45+
assert isinstance(groups[1].members[0].vehicle, Boat)
46+
assert groups[2].members[0].favourite_color == "blue"
47+
assert isinstance(groups[2].members[0].vehicle, Car)

docs/usage/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ Usage Guide
1212
decorators
1313
fixtures
1414
handling_custom_types
15+
model_coverage

docs/usage/model_coverage.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
Model coverage generation
2+
=========================
3+
4+
The ``BaseFactory.coverage()`` function is an alternative approach to ``BaseFactory.batch()``, where the examples that are generated attempt to provide full coverage of all the forms a model can take with the minimum number of instances. For example:
5+
6+
.. literalinclude:: /examples/model_coverage/test_example_1.py
7+
:caption: Defining a factory and generating examples with coverage
8+
:language: python
9+
10+
As you can see in the above example, the ``Profile`` model has 3 options for ``favourite_color``, and 2 options for ``vehicle``. In the output you can expect to see instances of ``Profile`` that have each of these options. The largest variance dictates the length of the output, in this case ``favourite_color`` has the most, at 3 options, so expect to see 3 ``Profile`` instances.
11+
12+
13+
.. note::
14+
Notice that the same ``Car`` instance is used in the first and final generated example. When the coverage examples for a field are exhausted before another field, values for that field are re-used.
15+
16+
Notes on collection types
17+
-------------------------
18+
19+
When generating coverage for models with fields that are collections, in particular collections that contain sub-models, the contents of the collection will be the all coverage examples for that sub-model. For example:
20+
21+
.. literalinclude:: /examples/model_coverage/test_example_2.py
22+
:caption: Coverage output for the SocialGroup model
23+
:language: python
24+
25+
Known Limitations
26+
-----------------
27+
28+
- Recursive models will cause an error: ``RecursionError: maximum recursion depth exceeded``.
29+
- ``__min_collection_length__`` and ``__max_collection_length__`` are currently ignored in coverage generation.

polyfactory/factories/base.py

Lines changed: 160 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import typing
34
from abc import ABC, abstractmethod
45
from collections import Counter, abc, deque
56
from contextlib import suppress
@@ -22,6 +23,12 @@
2223
from os.path import realpath
2324
from pathlib import Path
2425
from random import Random
26+
27+
try:
28+
from types import NoneType
29+
except ImportError:
30+
NoneType = type(None) # type: ignore[misc,assignment]
31+
2532
from typing import (
2633
TYPE_CHECKING,
2734
Any,
@@ -46,13 +53,16 @@
4653
MIN_COLLECTION_LENGTH,
4754
RANDOMIZE_COLLECTION_LENGTH,
4855
)
49-
from polyfactory.exceptions import (
50-
ConfigurationException,
51-
MissingBuildKwargException,
52-
ParameterException,
53-
)
56+
from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException
5457
from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use
55-
from polyfactory.utils.helpers import get_collection_type, unwrap_annotation, unwrap_args, unwrap_optional
58+
from polyfactory.utils.helpers import (
59+
flatten_annotation,
60+
get_collection_type,
61+
unwrap_annotation,
62+
unwrap_args,
63+
unwrap_optional,
64+
)
65+
from polyfactory.utils.model_coverage import CoverageContainer, CoverageContainerCallable, resolve_kwargs_coverage
5666
from polyfactory.utils.predicates import (
5767
get_type_origin,
5868
is_any,
@@ -61,7 +71,7 @@
6171
is_safe_subclass,
6272
is_union,
6373
)
64-
from polyfactory.value_generators.complex_types import handle_collection_type
74+
from polyfactory.value_generators.complex_types import handle_collection_type, handle_collection_type_coverage
6575
from polyfactory.value_generators.constrained_collections import (
6676
handle_constrained_collection,
6777
handle_constrained_mapping,
@@ -263,6 +273,32 @@ def _handle_factory_field(cls, field_value: Any, field_build_parameters: Any | N
263273

264274
return field_value() if callable(field_value) else field_value
265275

276+
@classmethod
277+
def _handle_factory_field_coverage(cls, field_value: Any, field_build_parameters: Any | None = None) -> Any:
278+
"""Handle a value defined on the factory class itself.
279+
280+
:param field_value: A value defined as an attribute on the factory class.
281+
:param field_build_parameters: Any build parameters passed to the factory as kwarg values.
282+
283+
:returns: An arbitrary value correlating with the given field_meta value.
284+
"""
285+
if is_safe_subclass(field_value, BaseFactory):
286+
if isinstance(field_build_parameters, Mapping):
287+
return CoverageContainer(field_value.coverage(**field_build_parameters))
288+
289+
if isinstance(field_build_parameters, Sequence):
290+
return [CoverageContainer(field_value.coverage(**parameter)) for parameter in field_build_parameters]
291+
292+
return CoverageContainer(field_value.coverage())
293+
294+
if isinstance(field_value, Use):
295+
return field_value.to_value()
296+
297+
if isinstance(field_value, Fixture):
298+
return CoverageContainerCallable(field_value.to_value)
299+
300+
return CoverageContainerCallable(field_value) if callable(field_value) else field_value
301+
266302
@classmethod
267303
def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]:
268304
"""Get a factory from registered factories or generate a factory dynamically.
@@ -635,6 +671,66 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
635671
msg,
636672
)
637673

674+
@classmethod
675+
def get_field_value_coverage( # noqa: C901
676+
cls,
677+
field_meta: FieldMeta,
678+
field_build_parameters: Any | None = None,
679+
) -> typing.Iterable[Any]:
680+
"""Return a field value on the subclass if existing, otherwise returns a mock value.
681+
682+
:param field_meta: FieldMeta instance.
683+
:param field_build_parameters: Any build parameters passed to the factory as kwarg values.
684+
685+
:returns: An iterable of values.
686+
687+
"""
688+
if cls.is_ignored_type(field_meta.annotation):
689+
return [None]
690+
691+
for unwrapped_annotation in flatten_annotation(field_meta.annotation):
692+
if unwrapped_annotation in (None, NoneType):
693+
yield None
694+
695+
elif is_literal(annotation=unwrapped_annotation) and (literal_args := get_args(unwrapped_annotation)):
696+
yield CoverageContainer(literal_args)
697+
698+
elif isinstance(unwrapped_annotation, EnumMeta):
699+
yield CoverageContainer(list(unwrapped_annotation))
700+
701+
elif field_meta.constraints:
702+
yield CoverageContainerCallable(
703+
cls.get_constrained_field_value,
704+
annotation=unwrapped_annotation,
705+
field_meta=field_meta,
706+
)
707+
708+
elif BaseFactory.is_factory_type(annotation=unwrapped_annotation):
709+
yield CoverageContainer(
710+
cls._get_or_create_factory(model=unwrapped_annotation).coverage(
711+
**(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}),
712+
),
713+
)
714+
715+
elif (origin := get_type_origin(unwrapped_annotation)) and issubclass(origin, Collection):
716+
yield handle_collection_type_coverage(field_meta, origin, cls)
717+
718+
elif is_any(unwrapped_annotation) or isinstance(unwrapped_annotation, TypeVar):
719+
yield create_random_string(cls.__random__, min_length=1, max_length=10)
720+
721+
elif provider := cls.get_provider_map().get(unwrapped_annotation):
722+
yield CoverageContainerCallable(provider)
723+
724+
elif callable(unwrapped_annotation):
725+
# if value is a callable we can try to naively call it.
726+
# this will work for callables that do not require any parameters passed
727+
yield CoverageContainerCallable(unwrapped_annotation)
728+
else:
729+
msg = f"Unsupported type: {unwrapped_annotation!r}\n\nEither extend the providers map or add a factory function for this type."
730+
raise ParameterException(
731+
msg,
732+
)
733+
638734
@classmethod
639735
def should_set_none_value(cls, field_meta: FieldMeta) -> bool:
640736
"""Determine whether a given model field_meta should be set to None.
@@ -752,6 +848,50 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
752848

753849
return result
754850

851+
@classmethod
852+
def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
853+
"""Process the given kwargs and generate values for the factory's model.
854+
855+
:param kwargs: Any build kwargs.
856+
857+
:returns: A dictionary of build results.
858+
859+
"""
860+
result: dict[str, Any] = {**kwargs}
861+
generate_post: dict[str, PostGenerated] = {}
862+
863+
for field_meta in cls.get_model_fields():
864+
field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs)
865+
866+
if cls.should_set_field_value(field_meta, **kwargs):
867+
if hasattr(cls, field_meta.name) and not hasattr(BaseFactory, field_meta.name):
868+
field_value = getattr(cls, field_meta.name)
869+
if isinstance(field_value, Ignore):
870+
continue
871+
872+
if isinstance(field_value, Require) and field_meta.name not in kwargs:
873+
msg = f"Require kwarg {field_meta.name} is missing"
874+
raise MissingBuildKwargException(msg)
875+
876+
if isinstance(field_value, PostGenerated):
877+
generate_post[field_meta.name] = field_value
878+
continue
879+
880+
result[field_meta.name] = cls._handle_factory_field_coverage(
881+
field_value=field_value,
882+
field_build_parameters=field_build_parameters,
883+
)
884+
continue
885+
886+
result[field_meta.name] = CoverageContainer(
887+
cls.get_field_value_coverage(field_meta, field_build_parameters=field_build_parameters),
888+
)
889+
890+
for resolved in resolve_kwargs_coverage(result):
891+
for field_name, post_generator in generate_post.items():
892+
resolved[field_name] = post_generator.to_value(field_name, resolved)
893+
yield resolved
894+
755895
@classmethod
756896
def build(cls, **kwargs: Any) -> T:
757897
"""Build an instance of the factory's __model__
@@ -776,6 +916,19 @@ def batch(cls, size: int, **kwargs: Any) -> list[T]:
776916
"""
777917
return [cls.build(**kwargs) for _ in range(size)]
778918

919+
@classmethod
920+
def coverage(cls, **kwargs: Any) -> abc.Iterator[T]:
921+
"""Build a batch of the factory's Meta.model will full coverage of the sub-types of the model.
922+
923+
:param kwargs: Any kwargs. If field_meta names are set in kwargs, their values will be used.
924+
925+
:returns: A iterator of instances of type T.
926+
927+
"""
928+
for data in cls.process_kwargs_coverage(**kwargs):
929+
instance = cls.__model__(**data)
930+
yield cast("T", instance)
931+
779932
@classmethod
780933
def create_sync(cls, **kwargs: Any) -> T:
781934
"""Build and persists synchronously a single model instance.

polyfactory/utils/helpers.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
import sys
44
from typing import TYPE_CHECKING, Any, Mapping
55

6+
try:
7+
from types import NoneType
8+
except ImportError:
9+
NoneType = type(None) # type: ignore[misc,assignment]
10+
611
from typing_extensions import get_args, get_origin
712

813
from polyfactory.constants import TYPE_MAPPING
@@ -52,7 +57,7 @@ def unwrap_optional(annotation: Any) -> Any:
5257
:returns: A type annotation
5358
"""
5459
while is_optional(annotation):
55-
annotation = next(arg for arg in get_args(annotation) if arg not in (type(None), None))
60+
annotation = next(arg for arg in get_args(annotation) if arg not in (NoneType, None))
5661
return annotation
5762

5863

@@ -77,6 +82,30 @@ def unwrap_annotation(annotation: Any, random: Random) -> Any:
7782
return annotation
7883

7984

85+
def flatten_annotation(annotation: Any) -> list[Any]:
86+
"""Flattens an annotation.
87+
88+
:param annotation: A type annotation.
89+
90+
:returns: The flattened annotations.
91+
"""
92+
flat = []
93+
if is_new_type(annotation):
94+
flat.extend(flatten_annotation(unwrap_new_type(annotation)))
95+
elif is_optional(annotation):
96+
flat.append(NoneType)
97+
flat.extend(flatten_annotation(arg) for arg in get_args(annotation) if arg not in (NoneType, None))
98+
elif is_annotated(annotation):
99+
flat.extend(flatten_annotation(get_args(annotation)[0]))
100+
elif is_union(annotation):
101+
for a in get_args(annotation):
102+
flat.extend(flatten_annotation(a))
103+
else:
104+
flat.append(annotation)
105+
106+
return flat
107+
108+
80109
def unwrap_args(annotation: Any, random: Random) -> tuple[Any, ...]:
81110
"""Unwrap the annotation and return any type args.
82111

0 commit comments

Comments
 (0)