Skip to content

Commit c40a138

Browse files
Fix strict behavior for unions (#1638)
1 parent 21eef8d commit c40a138

File tree

6 files changed

+54
-37
lines changed

6 files changed

+54
-37
lines changed

python/pydantic_core/core_schema.py

-3
Original file line numberDiff line numberDiff line change
@@ -2525,7 +2525,6 @@ def union_schema(
25252525
custom_error_message: str | None = None,
25262526
custom_error_context: dict[str, str | int] | None = None,
25272527
mode: Literal['smart', 'left_to_right'] | None = None,
2528-
strict: bool | None = None,
25292528
ref: str | None = None,
25302529
metadata: dict[str, Any] | None = None,
25312530
serialization: SerSchema | None = None,
@@ -2551,7 +2550,6 @@ def union_schema(
25512550
mode: How to select which choice to return
25522551
* `smart` (default) will try to return the choice which is the closest match to the input value
25532552
* `left_to_right` will return the first choice in `choices` which succeeds validation
2554-
strict: Whether the underlying schemas should be validated with strict mode
25552553
ref: optional unique identifier of the schema, used to reference the schema in other places
25562554
metadata: Any other information you want to include with the schema, not used by pydantic-core
25572555
serialization: Custom serialization schema
@@ -2564,7 +2562,6 @@ def union_schema(
25642562
custom_error_message=custom_error_message,
25652563
custom_error_context=custom_error_context,
25662564
mode=mode,
2567-
strict=strict,
25682565
ref=ref,
25692566
metadata=metadata,
25702567
serialization=serialization,

src/validators/union.rs

+2-20
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use pyo3::{intern, PyTraverseError, PyVisit};
88
use smallvec::SmallVec;
99

1010
use crate::build_tools::py_schema_err;
11-
use crate::build_tools::{is_strict, schema_or_config};
11+
use crate::build_tools::schema_or_config;
1212
use crate::common::union::{Discriminator, SMALL_UNION_THRESHOLD};
1313
use crate::errors::{ErrorType, ToErrorValue, ValError, ValLineError, ValResult};
1414
use crate::input::{BorrowInput, Input, ValidatedDict};
@@ -43,7 +43,6 @@ pub struct UnionValidator {
4343
mode: UnionMode,
4444
choices: Vec<(CombinedValidator, Option<String>)>,
4545
custom_error: Option<CustomError>,
46-
strict: bool,
4746
name: String,
4847
}
4948

@@ -91,7 +90,6 @@ impl BuildValidator for UnionValidator {
9190
mode,
9291
choices,
9392
custom_error: CustomError::build(schema, config, definitions)?,
94-
strict: is_strict(schema, config)?,
9593
name: format!("{}[{descr}]", Self::EXPECTED_TYPE),
9694
}
9795
.into())
@@ -110,17 +108,11 @@ impl UnionValidator {
110108
let old_exactness = state.exactness;
111109
let old_fields_set_count = state.fields_set_count;
112110

113-
let strict = state.strict_or(self.strict);
114111
let mut errors = MaybeErrors::new(self.custom_error.as_ref());
115112

116113
let mut best_match: Option<(Py<PyAny>, Exactness, Option<usize>)> = None;
117114

118115
for (choice, label) in &self.choices {
119-
let state = &mut state.rebind_extra(|extra| {
120-
if strict {
121-
extra.strict = Some(strict);
122-
}
123-
});
124116
state.exactness = Some(Exactness::Exact);
125117
state.fields_set_count = None;
126118
let result = choice.validate(py, input, state);
@@ -197,14 +189,6 @@ impl UnionValidator {
197189
) -> ValResult<PyObject> {
198190
let mut errors = MaybeErrors::new(self.custom_error.as_ref());
199191

200-
let mut rebound_state;
201-
let state = if state.strict_or(self.strict) {
202-
rebound_state = state.rebind_extra(|extra| extra.strict = Some(true));
203-
&mut rebound_state
204-
} else {
205-
state
206-
};
207-
208192
for (validator, label) in &self.choices {
209193
match validator.validate(py, input, state) {
210194
Err(ValError::LineErrors(lines)) => errors.push(validator, label.as_deref(), lines),
@@ -300,7 +284,6 @@ pub struct TaggedUnionValidator {
300284
discriminator: Discriminator,
301285
lookup: LiteralLookup<CombinedValidator>,
302286
from_attributes: bool,
303-
strict: bool,
304287
custom_error: Option<CustomError>,
305288
tags_repr: String,
306289
discriminator_repr: String,
@@ -349,7 +332,6 @@ impl BuildValidator for TaggedUnionValidator {
349332
discriminator,
350333
lookup,
351334
from_attributes,
352-
strict: is_strict(schema, config)?,
353335
custom_error: CustomError::build(schema, config, definitions)?,
354336
tags_repr,
355337
discriminator_repr,
@@ -371,7 +353,7 @@ impl Validator for TaggedUnionValidator {
371353
match &self.discriminator {
372354
Discriminator::LookupKey(lookup_key) => {
373355
let from_attributes = state.extra().from_attributes.unwrap_or(self.from_attributes);
374-
let dict = input.validate_model_fields(self.strict, from_attributes)?;
356+
let dict = input.validate_model_fields(state.strict_or(false), from_attributes)?;
375357
// note this methods returns PyResult<Option<(data, data)>>, the outer Err is just for
376358
// errors when getting attributes which should be "raised"
377359
let tag = match dict.get_item(lookup_key)? {

tests/benchmarks/test_micro_benchmarks.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -686,16 +686,18 @@ def test_smart_union_coerce_core(self, benchmark):
686686
def test_strict_union_core(self, benchmark):
687687
v = SchemaValidator(
688688
schema=core_schema.union_schema(
689-
strict=True, choices=[core_schema.bool_schema(), core_schema.int_schema(), core_schema.str_schema()]
690-
)
689+
choices=[core_schema.bool_schema(), core_schema.int_schema(), core_schema.str_schema()]
690+
),
691+
config=CoreConfig(strict=True),
691692
)
692693

693694
benchmark(v.validate_python, 1)
694695

695696
@pytest.mark.benchmark(group='strict-union-error')
696697
def test_strict_union_error_core(self, benchmark):
697698
v = SchemaValidator(
698-
schema=core_schema.union_schema(strict=True, choices=[core_schema.bool_schema(), core_schema.str_schema()])
699+
schema=core_schema.union_schema(choices=[core_schema.bool_schema(), core_schema.str_schema()]),
700+
config=CoreConfig(strict=True),
699701
)
700702

701703
def validate_with_expected_error():

tests/validators/test_bytes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_constrained_bytes(py_and_json: PyAndJson, opts: dict[str, Any], input,
9191

9292

9393
def test_union():
94-
v = SchemaValidator(cs.union_schema(choices=[cs.str_schema(), cs.bytes_schema()], strict=True))
94+
v = SchemaValidator(cs.union_schema(choices=[cs.str_schema(strict=True), cs.bytes_schema(strict=True)]))
9595
assert v.validate_python('oh, a string') == 'oh, a string'
9696
assert v.validate_python(b'oh, bytes') == b'oh, bytes'
9797

tests/validators/test_definitions_recursive.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -611,11 +611,11 @@ def test_union_cycle(strict: bool):
611611
'foobar': core_schema.typed_dict_field(
612612
core_schema.list_schema(core_schema.definition_reference_schema('root-schema'))
613613
)
614-
}
614+
},
615+
strict=strict,
615616
)
616617
],
617618
auto_collapse=False,
618-
strict=strict,
619619
ref='root-schema',
620620
)
621621
],
@@ -700,11 +700,11 @@ def f(input_value, info):
700700
)
701701
],
702702
auto_collapse=False,
703-
strict=strict,
704703
ref='root-schema',
705704
)
706705
],
707-
)
706+
),
707+
config=CoreConfig(strict=strict),
708708
)
709709

710710
with pytest.raises(ValidationError) as exc_info:

tests/validators/test_union.py

+42-6
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99
from dirty_equals import IsFloat, IsInt
1010

11-
from pydantic_core import SchemaError, SchemaValidator, ValidationError, core_schema, validate_core_schema
11+
from pydantic_core import CoreConfig, SchemaError, SchemaValidator, ValidationError, core_schema, validate_core_schema
1212

1313
from ..conftest import plain_repr
1414

@@ -262,16 +262,47 @@ def test_one_choice():
262262
assert v.validate_python('hello') == 'hello'
263263

264264

265-
def test_strict_union():
265+
def test_strict_union_flag() -> None:
266+
v = SchemaValidator(core_schema.union_schema(choices=[core_schema.bool_schema(), core_schema.int_schema()]))
267+
assert v.validate_python(1, strict=True) == 1
268+
assert v.validate_python(123, strict=True) == 123
269+
270+
with pytest.raises(ValidationError) as exc_info:
271+
v.validate_python('123', strict=True)
272+
273+
assert exc_info.value.errors(include_url=False) == [
274+
{'type': 'bool_type', 'loc': ('bool',), 'msg': 'Input should be a valid boolean', 'input': '123'},
275+
{'type': 'int_type', 'loc': ('int',), 'msg': 'Input should be a valid integer', 'input': '123'},
276+
]
277+
278+
279+
def test_strict_union_config_level() -> None:
266280
v = SchemaValidator(
267-
core_schema.union_schema(strict=True, choices=[core_schema.bool_schema(), core_schema.int_schema()])
281+
core_schema.union_schema(choices=[core_schema.bool_schema(), core_schema.int_schema()]),
282+
config=CoreConfig(strict=True),
268283
)
284+
269285
assert v.validate_python(1) == 1
270286
assert v.validate_python(123) == 123
271287

272288
with pytest.raises(ValidationError) as exc_info:
273289
v.validate_python('123')
290+
assert exc_info.value.errors(include_url=False) == [
291+
{'type': 'bool_type', 'loc': ('bool',), 'msg': 'Input should be a valid boolean', 'input': '123'},
292+
{'type': 'int_type', 'loc': ('int',), 'msg': 'Input should be a valid integer', 'input': '123'},
293+
]
274294

295+
296+
def test_strict_union_member_level() -> None:
297+
v = SchemaValidator(
298+
core_schema.union_schema(choices=[core_schema.bool_schema(strict=True), core_schema.int_schema(strict=True)])
299+
)
300+
301+
assert v.validate_python(1) == 1
302+
assert v.validate_python(123) == 123
303+
304+
with pytest.raises(ValidationError) as exc_info:
305+
v.validate_python('123')
275306
assert exc_info.value.errors(include_url=False) == [
276307
{'type': 'bool_type', 'loc': ('bool',), 'msg': 'Input should be a valid boolean', 'input': '123'},
277308
{'type': 'int_type', 'loc': ('int',), 'msg': 'Input should be a valid integer', 'input': '123'},
@@ -469,10 +500,10 @@ def test_left_to_right_union():
469500

470501

471502
def test_left_to_right_union_strict():
472-
choices = [core_schema.int_schema(), core_schema.float_schema()]
503+
choices = [core_schema.int_schema(strict=True), core_schema.float_schema(strict=True)]
473504

474505
# left_to_right union will select not cast if int first (strict int will not accept float)
475-
v = SchemaValidator(core_schema.union_schema(choices, mode='left_to_right', strict=True))
506+
v = SchemaValidator(core_schema.union_schema(choices, mode='left_to_right'))
476507
out = v.validate_python(1)
477508
assert out == 1
478509
assert isinstance(out, int)
@@ -482,7 +513,12 @@ def test_left_to_right_union_strict():
482513
assert isinstance(out, float)
483514

484515
# reversing union will select float always (as strict float will accept int)
485-
v = SchemaValidator(core_schema.union_schema(list(reversed(choices)), mode='left_to_right', strict=True))
516+
v = SchemaValidator(
517+
core_schema.union_schema(
518+
list(reversed(choices)),
519+
mode='left_to_right',
520+
)
521+
)
486522
out = v.validate_python(1.0)
487523
assert out == 1.0
488524
assert isinstance(out, float)

0 commit comments

Comments
 (0)