@@ -2,10 +2,9 @@ use std::fmt::Debug;
2
2
3
3
use enum_dispatch:: enum_dispatch;
4
4
5
- use pyo3:: exceptions:: PyRecursionError ;
5
+ use pyo3:: exceptions:: { PyRecursionError , PyTypeError } ;
6
6
use pyo3:: prelude:: * ;
7
- use pyo3:: types:: { PyAny , PyDict } ;
8
- use serde_json:: from_str as parse_json;
7
+ use pyo3:: types:: { PyAny , PyByteArray , PyBytes , PyDict , PyString } ;
9
8
10
9
use crate :: build_tools:: { py_error, SchemaDict , SchemaError } ;
11
10
use crate :: errors:: { ErrorKind , ValError , ValLineError , ValResult , ValidationError } ;
@@ -99,8 +98,8 @@ impl SchemaValidator {
99
98
}
100
99
}
101
100
102
- pub fn validate_json ( & self , py : Python , input : String ) -> PyResult < PyObject > {
103
- match parse_json :: < JsonInput > ( & input) {
101
+ pub fn validate_json ( & self , py : Python , input : & PyAny ) -> PyResult < PyObject > {
102
+ match parse_json ( input) ? {
104
103
Ok ( input) => {
105
104
let r = self . validator . validate (
106
105
py,
@@ -112,15 +111,15 @@ impl SchemaValidator {
112
111
r. map_err ( |e| self . prepare_validation_err ( py, e) )
113
112
}
114
113
Err ( e) => {
115
- let line_err = ValLineError :: new ( ErrorKind :: InvalidJson { error : e. to_string ( ) } , & input) ;
114
+ let line_err = ValLineError :: new ( ErrorKind :: InvalidJson { error : e. to_string ( ) } , input) ;
116
115
let err = ValError :: LineErrors ( vec ! [ line_err] ) ;
117
116
Err ( self . prepare_validation_err ( py, err) )
118
117
}
119
118
}
120
119
}
121
120
122
- pub fn isinstance_json ( & self , py : Python , input : String ) -> PyResult < bool > {
123
- match parse_json :: < JsonInput > ( & input) {
121
+ pub fn isinstance_json ( & self , py : Python , input : & PyAny ) -> PyResult < bool > {
122
+ match parse_json ( input) ? {
124
123
Ok ( input) => {
125
124
match self . validator . validate (
126
125
py,
@@ -164,6 +163,18 @@ impl SchemaValidator {
164
163
}
165
164
}
166
165
166
+ fn parse_json ( input : & PyAny ) -> PyResult < serde_json:: Result < JsonInput > > {
167
+ if let Ok ( py_bytes) = input. cast_as :: < PyBytes > ( ) {
168
+ Ok ( serde_json:: from_slice ( py_bytes. as_bytes ( ) ) )
169
+ } else if let Ok ( py_str) = input. cast_as :: < PyString > ( ) {
170
+ Ok ( serde_json:: from_str ( & py_str. to_string_lossy ( ) ) )
171
+ } else if let Ok ( py_byte_array) = input. cast_as :: < PyByteArray > ( ) {
172
+ Ok ( serde_json:: from_slice ( unsafe { py_byte_array. as_bytes ( ) } ) )
173
+ } else {
174
+ Err ( PyTypeError :: new_err ( "JSON input must be str, bytes or bytearray" ) )
175
+ }
176
+ }
177
+
167
178
pub trait BuildValidator : Sized {
168
179
const EXPECTED_TYPE : & ' static str ;
169
180
0 commit comments