Skip to content

Commit 1919da1

Browse files
authored
simplify recursive reference schemas, fix #60 (#130)
1 parent c0f6de0 commit 1919da1

File tree

7 files changed

+200
-195
lines changed

7 files changed

+200
-195
lines changed

pydantic_core/_types.py

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
import sys
44
from datetime import date, datetime, time
5-
from typing import Any, Callable, Dict, List, Sequence, Union
5+
from typing import Any, Callable, Dict, List, Union
66

77
if sys.version_info < (3, 11):
88
from typing_extensions import NotRequired, Required
99
else:
10-
from typing import NotRequired
10+
from typing import NotRequired, Required
1111

1212
if sys.version_info < (3, 8):
1313
from typing_extensions import Literal, TypedDict
@@ -19,9 +19,10 @@ class AnySchema(TypedDict):
1919
type: Literal['any']
2020

2121

22-
class BoolSchema(TypedDict):
23-
type: Literal['bool']
24-
strict: NotRequired[bool]
22+
class BoolSchema(TypedDict, total=False):
23+
type: Required[Literal['bool']]
24+
strict: bool
25+
ref: str
2526

2627

2728
class ConfigSchema(TypedDict, total=False):
@@ -39,6 +40,7 @@ class DictSchema(TypedDict, total=False):
3940
min_items: int
4041
max_items: int
4142
strict: bool
43+
ref: str
4244

4345

4446
class FloatSchema(TypedDict, total=False):
@@ -49,20 +51,22 @@ class FloatSchema(TypedDict, total=False):
4951
lt: float
5052
gt: float
5153
strict: bool
52-
default: float
54+
ref: str
5355

5456

5557
class FunctionSchema(TypedDict):
5658
type: Literal['function']
5759
mode: Literal['before', 'after', 'wrap']
5860
function: Callable[..., Any]
5961
schema: Schema
62+
ref: NotRequired[str]
6063

6164

6265
class FunctionPlainSchema(TypedDict):
6366
type: Literal['function']
6467
mode: Literal['plain']
6568
function: Callable[..., Any]
69+
ref: NotRequired[str]
6670

6771

6872
class IntSchema(TypedDict, total=False):
@@ -73,6 +77,7 @@ class IntSchema(TypedDict, total=False):
7377
lt: int
7478
gt: int
7579
strict: bool
80+
ref: str
7681

7782

7883
class ListSchema(TypedDict, total=False):
@@ -81,17 +86,20 @@ class ListSchema(TypedDict, total=False):
8186
min_items: int
8287
max_items: int
8388
strict: bool
89+
ref: str
8490

8591

8692
class LiteralSchema(TypedDict):
8793
type: Literal['literal']
88-
expected: Sequence[Any]
94+
expected: List[Any]
95+
ref: NotRequired[str]
8996

9097

9198
class ModelClassSchema(TypedDict):
9299
type: Literal['model-class']
93100
class_type: type
94101
schema: TypedDictSchema
102+
ref: NotRequired[str]
95103

96104

97105
class TypedDictField(TypedDict, total=False):
@@ -102,33 +110,30 @@ class TypedDictField(TypedDict, total=False):
102110
aliases: List[List[Union[str, int]]]
103111

104112

105-
class TypedDictSchema(TypedDict):
106-
type: Literal['typed-dict']
107-
fields: Dict[str, TypedDictField]
108-
extra_validator: NotRequired[Schema]
109-
config: NotRequired[ConfigSchema]
110-
return_fields_set: NotRequired[bool]
113+
class TypedDictSchema(TypedDict, total=False):
114+
type: Required[Literal['typed-dict']]
115+
fields: Required[Dict[str, TypedDictField]]
116+
extra_validator: Schema
117+
config: ConfigSchema
118+
return_fields_set: bool
119+
ref: str
111120

112121

113122
class NoneSchema(TypedDict):
114123
type: Literal['none']
124+
ref: NotRequired[str]
115125

116126

117-
class NullableSchema(TypedDict):
118-
type: Literal['nullable']
119-
schema: Schema
120-
strict: NotRequired[bool]
127+
class NullableSchema(TypedDict, total=False):
128+
type: Required[Literal['nullable']]
129+
schema: Required[Schema]
130+
strict: bool
131+
ref: str
121132

122133

123134
class RecursiveReferenceSchema(TypedDict):
124135
type: Literal['recursive-ref']
125-
name: str
126-
127-
128-
class RecursiveContainerSchema(TypedDict):
129-
type: Literal['recursive-container']
130-
name: str
131-
schema: Schema
136+
schema_ref: str
132137

133138

134139
class SetSchema(TypedDict, total=False):
@@ -137,6 +142,7 @@ class SetSchema(TypedDict, total=False):
137142
min_items: int
138143
max_items: int
139144
strict: bool
145+
ref: str
140146

141147

142148
class FrozenSetSchema(TypedDict, total=False):
@@ -145,6 +151,7 @@ class FrozenSetSchema(TypedDict, total=False):
145151
min_items: int
146152
max_items: int
147153
strict: bool
154+
ref: str
148155

149156

150157
class StringSchema(TypedDict, total=False):
@@ -156,20 +163,22 @@ class StringSchema(TypedDict, total=False):
156163
to_lower: bool
157164
to_upper: bool
158165
strict: bool
166+
ref: str
159167

160168

161-
class UnionSchema(TypedDict):
162-
type: Literal['union']
163-
choices: List[Schema]
164-
strict: NotRequired[bool]
165-
default: NotRequired[Any]
169+
class UnionSchema(TypedDict, total=False):
170+
type: Required[Literal['union']]
171+
choices: Required[List[Schema]]
172+
strict: bool
173+
ref: str
166174

167175

168176
class BytesSchema(TypedDict, total=False):
169177
type: Required[Literal['bytes']]
170178
max_length: int
171179
min_length: int
172180
strict: bool
181+
ref: str
173182

174183

175184
class DateSchema(TypedDict, total=False):
@@ -179,7 +188,7 @@ class DateSchema(TypedDict, total=False):
179188
ge: date
180189
lt: date
181190
gt: date
182-
default: date
191+
ref: str
183192

184193

185194
class TimeSchema(TypedDict, total=False):
@@ -189,7 +198,7 @@ class TimeSchema(TypedDict, total=False):
189198
ge: time
190199
lt: time
191200
gt: time
192-
default: time
201+
ref: str
193202

194203

195204
class DatetimeSchema(TypedDict, total=False):
@@ -199,13 +208,14 @@ class DatetimeSchema(TypedDict, total=False):
199208
ge: datetime
200209
lt: datetime
201210
gt: datetime
202-
default: datetime
211+
ref: str
203212

204213

205-
class TupleFixLenSchema(TypedDict):
206-
type: Literal['tuple-fix-len']
207-
items_schema: List[Schema]
208-
strict: NotRequired[bool]
214+
class TupleFixLenSchema(TypedDict, total=False):
215+
type: Required[Literal['tuple-fix-len']]
216+
items_schema: Required[List[Schema]]
217+
strict: bool
218+
ref: str
209219

210220

211221
class TupleVarLenSchema(TypedDict, total=False):
@@ -214,6 +224,7 @@ class TupleVarLenSchema(TypedDict, total=False):
214224
min_items: int
215225
max_items: int
216226
strict: bool
227+
ref: str
217228

218229

219230
# pydantic allows types to be defined via a simple string instead of dict with just `type`, e.g.
@@ -256,7 +267,6 @@ class TupleVarLenSchema(TypedDict, total=False):
256267
ModelClassSchema,
257268
NoneSchema,
258269
NullableSchema,
259-
RecursiveContainerSchema,
260270
RecursiveReferenceSchema,
261271
SetSchema,
262272
FrozenSetSchema,

src/validators/mod.rs

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -151,19 +151,35 @@ pub trait BuildValidator: Sized {
151151
) -> PyResult<CombinedValidator>;
152152
}
153153

154+
fn build_single_validator<'a, T: BuildValidator>(
155+
val_type: &str,
156+
schema_dict: &'a PyDict,
157+
config: Option<&'a PyDict>,
158+
build_context: &mut BuildContext,
159+
) -> PyResult<(CombinedValidator, &'a PyDict)> {
160+
build_context.incr_check_depth()?;
161+
162+
let val: CombinedValidator = if let Some(schema_ref) = schema_dict.get_as::<String>("ref")? {
163+
let slot_id = build_context.prepare_slot(schema_ref)?;
164+
let inner_val = T::build(schema_dict, config, build_context)
165+
.map_err(|err| SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)))?;
166+
build_context.complete_slot(slot_id, inner_val);
167+
recursive::RecursiveContainerValidator::create(slot_id)
168+
} else {
169+
T::build(schema_dict, config, build_context)
170+
.map_err(|err| SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)))?
171+
};
172+
173+
build_context.decr_depth();
174+
Ok((val, schema_dict))
175+
}
176+
154177
// macro to build the match statement for validator selection
155178
macro_rules! validator_match {
156179
($type:ident, $dict:ident, $config:ident, $build_context:ident, $($validator:path,)+) => {
157180
match $type {
158181
$(
159-
<$validator>::EXPECTED_TYPE => {
160-
$build_context.incr_check_depth()?;
161-
let val = <$validator>::build($dict, $config, $build_context).map_err(|err| {
162-
SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", $type, err))
163-
})?;
164-
$build_context.decr_depth();
165-
Ok((val, $dict))
166-
},
182+
<$validator>::EXPECTED_TYPE => build_single_validator::<$validator>($type, $dict, $config, $build_context),
167183
)+
168184
_ => {
169185
return py_error!(r#"Unknown schema type: "{}""#, $type)
@@ -221,7 +237,6 @@ pub fn build_validator<'a>(
221237
// functions - before, after, plain & wrap
222238
function::FunctionBuilder,
223239
// recursive (self-referencing) models
224-
recursive::RecursiveValidator,
225240
recursive::RecursiveRefValidator,
226241
// literals
227242
literal::LiteralBuilder,
@@ -294,7 +309,7 @@ pub enum CombinedValidator {
294309
FunctionPlain(function::FunctionPlainValidator),
295310
FunctionWrap(function::FunctionWrapValidator),
296311
// recursive (self-referencing) models
297-
Recursive(recursive::RecursiveValidator),
312+
Recursive(recursive::RecursiveContainerValidator),
298313
RecursiveRef(recursive::RecursiveRefValidator),
299314
// literals
300315
LiteralSingleString(literal::LiteralSingleStringValidator),
@@ -348,29 +363,26 @@ pub trait Validator: Send + Sync + Clone + Debug {
348363
fn get_name(&self, py: Python) -> String;
349364
}
350365

366+
#[derive(Default)]
351367
pub struct BuildContext {
352368
named_slots: Vec<(Option<String>, Option<CombinedValidator>)>,
353369
depth: usize,
354370
}
355371

356372
const MAX_DEPTH: usize = 100;
357373

358-
impl Default for BuildContext {
359-
fn default() -> Self {
360-
let named_slots: Vec<(Option<String>, Option<CombinedValidator>)> = Vec::new();
361-
BuildContext { named_slots, depth: 0 }
362-
}
363-
}
364-
365374
impl BuildContext {
366-
pub fn add_named_slot(&mut self, name: String, schema: &PyAny, config: Option<&PyDict>) -> PyResult<usize> {
375+
pub fn prepare_slot(&mut self, slot_ref: String) -> PyResult<usize> {
367376
let id = self.named_slots.len();
368-
self.named_slots.push((Some(name), None));
369-
let validator = build_validator(schema, config, self)?.0;
370-
self.named_slots[id] = (None, Some(validator));
377+
self.named_slots.push((Some(slot_ref), None));
371378
Ok(id)
372379
}
373380

381+
pub fn complete_slot(&mut self, slot_id: usize, validator: CombinedValidator) {
382+
let (name, _) = self.named_slots.get(slot_id).unwrap();
383+
self.named_slots[slot_id] = (name.clone(), Some(validator));
384+
}
385+
374386
pub fn incr_check_depth(&mut self) -> PyResult<()> {
375387
self.depth += 1;
376388
if self.depth > MAX_DEPTH {
@@ -384,14 +396,14 @@ impl BuildContext {
384396
self.depth -= 1;
385397
}
386398

387-
pub fn find_id(&self, name: &str) -> PyResult<usize> {
399+
pub fn find_slot_id(&self, slot_ref: &str) -> PyResult<usize> {
388400
let is_match = |(n, _): &(Option<String>, Option<CombinedValidator>)| match n {
389-
Some(n) => n == name,
401+
Some(n) => n == slot_ref,
390402
None => false,
391403
};
392404
match self.named_slots.iter().position(is_match) {
393405
Some(id) => Ok(id),
394-
None => py_error!("Recursive reference error: ref '{}' not found", name),
406+
None => py_error!("Recursive reference error: ref '{}' not found", slot_ref),
395407
}
396408
}
397409

0 commit comments

Comments
 (0)