Skip to content

Commit 274c40a

Browse files
Fixed issue where config with * could not be filled (#53)
1 parent 43c9281 commit 274c40a

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

confection/__init__.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,7 @@ def copy_model_field(field: ModelField, type_: Any) -> ModelField:
704704
default=field.default,
705705
default_factory=field.default_factory,
706706
required=field.required,
707+
alias=field.alias,
707708
)
708709

709710

@@ -912,6 +913,15 @@ def _fill(
912913
# created via config blocks), only use its values
913914
validation[v_key] = list(validation[v_key].values())
914915
final[key] = list(final[key].values())
916+
917+
if ARGS_FIELD_ALIAS in schema.__fields__ and not resolve:
918+
# If we're not resolving the config, make sure that the field
919+
# expecting the promise is typed Any so it doesn't fail
920+
# validation if it doesn't receive the function return value
921+
field = schema.__fields__[ARGS_FIELD_ALIAS]
922+
schema.__fields__[ARGS_FIELD_ALIAS] = copy_model_field(
923+
field, Any
924+
)
915925
else:
916926
filled[key] = value
917927
# Prevent pydantic from consuming generator if part of a union
@@ -936,7 +946,12 @@ def _fill(
936946
# manually because .construct doesn't parse anything
937947
if schema.Config.extra in (Extra.forbid, Extra.ignore):
938948
fields = schema.__fields__.keys()
939-
exclude = [k for k in result.__fields_set__ if k not in fields]
949+
# If we have a reserved field, we need to use its alias
950+
field_set = [
951+
k if k != ARGS_FIELD else ARGS_FIELD_ALIAS
952+
for k in result.__fields_set__
953+
]
954+
exclude = [k for k in field_set if k not in fields]
940955
exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()])
941956
validation.update(result.dict(exclude=exclude_validation))
942957
filled, final = cls._update_from_parsed(validation, filled, final)

confection/tests/test_config.py

+22
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,28 @@ def catsie_567(*args: Optional[str], foo: str = "bar"):
424424
assert my_registry.resolve(cfg)["config"] == "^_^"
425425

426426

427+
def test_fill_config_positional_args_w_promise():
428+
@my_registry.cats("catsie.v568")
429+
def catsie_568(*args: str, foo: str = "bar"):
430+
assert args[0] == "^(*.*)^"
431+
assert foo == "baz"
432+
return args[0]
433+
434+
@my_registry.cats("cat_promise.v568")
435+
def cat_promise() -> str:
436+
return "^(*.*)^"
437+
438+
cfg = {
439+
"config": {
440+
"@cats": "catsie.v568",
441+
"*": {"promise": {"@cats": "cat_promise.v568"}},
442+
}
443+
}
444+
filled = my_registry.fill(cfg, validate=True)
445+
assert filled["config"]["foo"] == "bar"
446+
assert filled["config"]["*"] == {"promise": {"@cats": "cat_promise.v568"}}
447+
448+
427449
def test_make_config_positional_args_complex():
428450
@my_registry.cats("catsie.v890")
429451
def catsie_890(*args: Optional[Union[StrictBool, PositiveInt]]):

0 commit comments

Comments
 (0)