Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,16 +271,16 @@ impl PyMultiHostUrl {
// string representation of the URL, with punycode decoded when appropriate
pub fn unicode_string(&self) -> String {
if let Some(extra_urls) = &self.extra_urls {
let schema = self.ref_url.lib_url.scheme();
let host_offset = schema.len() + 3;
let scheme = self.ref_url.lib_url.scheme();
let host_offset = scheme.len() + 3;

let mut full_url = self.ref_url.unicode_string();
full_url.insert(host_offset, ',');

// special urls will have had a trailing slash added, non-special urls will not
// hence we need to remove the last char if the schema is special
// hence we need to remove the last char if the scheme is special
#[allow(clippy::bool_to_int_with_if)]
let sub = if schema_is_special(schema) { 1 } else { 0 };
let sub = if scheme_is_special(scheme) { 1 } else { 0 };

let hosts = extra_urls
.iter()
Expand All @@ -299,16 +299,16 @@ impl PyMultiHostUrl {

pub fn __str__(&self) -> String {
if let Some(extra_urls) = &self.extra_urls {
let schema = self.ref_url.lib_url.scheme();
let host_offset = schema.len() + 3;
let scheme = self.ref_url.lib_url.scheme();
let host_offset = scheme.len() + 3;

let mut full_url = self.ref_url.lib_url.to_string();
full_url.insert(host_offset, ',');

// special urls will have had a trailing slash added, non-special urls will not
// hence we need to remove the last char if the schema is special
// hence we need to remove the last char if the scheme is special
#[allow(clippy::bool_to_int_with_if)]
let sub = if schema_is_special(schema) { 1 } else { 0 };
let sub = if scheme_is_special(scheme) { 1 } else { 0 };

let hosts = extra_urls
.iter()
Expand Down Expand Up @@ -510,10 +510,10 @@ fn decode_punycode(domain: &str) -> Option<String> {
static PUNYCODE_PREFIX: &str = "xn--";

fn is_punnycode_domain(lib_url: &Url, domain: &str) -> bool {
schema_is_special(lib_url.scheme()) && domain.split('.').any(|part| part.starts_with(PUNYCODE_PREFIX))
scheme_is_special(lib_url.scheme()) && domain.split('.').any(|part| part.starts_with(PUNYCODE_PREFIX))
}

// based on https://github.com/servo/rust-url/blob/1c1e406874b3d2aa6f36c5d2f3a5c2ea74af9efb/url/src/parser.rs#L161-L167
pub fn schema_is_special(schema: &str) -> bool {
matches!(schema, "http" | "https" | "ws" | "wss" | "ftp" | "file")
pub fn scheme_is_special(scheme: &str) -> bool {
matches!(scheme, "http" | "https" | "ws" | "wss" | "ftp" | "file")
}
121 changes: 54 additions & 67 deletions src/validators/url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,21 @@ use crate::errors::ToErrorValue;
use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult};
use crate::input::downcast_python_input;
use crate::input::Input;
use crate::input::ValidationMatch;
use crate::tools::SchemaDict;
use crate::url::{schema_is_special, PyMultiHostUrl, PyUrl};
use crate::url::{scheme_is_special, PyMultiHostUrl, PyUrl};

use super::literal::expected_repr_name;
use super::Exactness;
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};

type AllowedSchemas = Option<(AHashSet<String>, String)>;
type AllowedSchemes = Option<(AHashSet<String>, String)>;

#[derive(Debug, Clone)]
pub struct UrlValidator {
strict: bool,
max_length: Option<usize>,
allowed_schemes: AllowedSchemas,
allowed_schemes: AllowedSchemes,
host_required: bool,
default_host: Option<String>,
default_port: Option<u16>,
Expand All @@ -44,7 +45,7 @@ impl BuildValidator for UrlValidator {
config: Option<&Bound<'_, PyDict>>,
_definitions: &mut DefinitionsBuilder<CombinedValidator>,
) -> PyResult<CombinedValidator> {
let (allowed_schemes, name) = get_allowed_schemas(schema, Self::EXPECTED_TYPE)?;
let (allowed_schemes, name) = get_allowed_schemes(schema, Self::EXPECTED_TYPE)?;

Ok(Self {
strict: is_strict(schema, config)?,
Expand Down Expand Up @@ -107,31 +108,23 @@ impl Validator for UrlValidator {

impl UrlValidator {
fn get_url<'py>(&self, input: &(impl Input<'py> + ?Sized), strict: bool) -> ValResult<EitherUrl<'py>> {
match input.validate_str(strict, false) {
Ok(val_match) => {
let either_str = val_match.into_inner();
let cow = either_str.as_cow()?;
let url_str = cow.as_ref();

self.check_length(input, url_str)?;

parse_url(url_str, input, strict).map(EitherUrl::Rust)
}
Err(_) => {
// we don't need to worry about whether the url was parsed in strict mode before,
// even if it was, any syntax errors would have been fixed by the first validation
if let Some(py_url) = downcast_python_input::<PyUrl>(input) {
self.check_length(input, py_url.get().url().as_str())?;
Ok(EitherUrl::Py(py_url.clone()))
} else if let Some(multi_host_url) = downcast_python_input::<PyMultiHostUrl>(input) {
let url_str = multi_host_url.get().__str__();
self.check_length(input, &url_str)?;

parse_url(&url_str, input, strict).map(EitherUrl::Rust)
} else {
Err(ValError::new(ErrorTypeDefaults::UrlType, input))
}
}
if let Some(py_url) = downcast_python_input::<PyUrl>(input) {
// we don't need to worry about whether the url was parsed in strict mode before,
// even if it was, any syntax errors would have been fixed by the first validation
self.check_length(input, py_url.get().url().as_str())?;
Ok(EitherUrl::Py(py_url.clone()))
} else if let Some(multi_host_url) = downcast_python_input::<PyMultiHostUrl>(input) {
let url_str = multi_host_url.get().__str__();
self.check_length(input, &url_str)?;
parse_url(&url_str, input, strict).map(EitherUrl::Rust)
} else if let Ok(either_str) = input.validate_str(strict, false).map(ValidationMatch::into_inner) {
let cow = either_str.as_cow()?;
let url_str = cow.as_ref();

self.check_length(input, url_str)?;
parse_url(url_str, input, strict).map(EitherUrl::Rust)
} else {
Err(ValError::new(ErrorTypeDefaults::UrlType, input))
}
}

Expand Down Expand Up @@ -192,7 +185,7 @@ impl CopyFromPyUrl for EitherUrl<'_> {
pub struct MultiHostUrlValidator {
strict: bool,
max_length: Option<usize>,
allowed_schemes: AllowedSchemas,
allowed_schemes: AllowedSchemes,
host_required: bool,
default_host: Option<String>,
default_port: Option<u16>,
Expand All @@ -208,7 +201,7 @@ impl BuildValidator for MultiHostUrlValidator {
config: Option<&Bound<'_, PyDict>>,
_definitions: &mut DefinitionsBuilder<CombinedValidator>,
) -> PyResult<CombinedValidator> {
let (allowed_schemes, name) = get_allowed_schemas(schema, Self::EXPECTED_TYPE)?;
let (allowed_schemes, name) = get_allowed_schemes(schema, Self::EXPECTED_TYPE)?;

let default_host: Option<String> = schema.get_as(intern!(schema.py(), "default_host"))?;
if let Some(ref default_host) = default_host {
Expand Down Expand Up @@ -276,32 +269,26 @@ impl Validator for MultiHostUrlValidator {

impl MultiHostUrlValidator {
fn get_url<'py>(&self, input: &(impl Input<'py> + ?Sized), strict: bool) -> ValResult<EitherMultiHostUrl<'py>> {
match input.validate_str(strict, false) {
Ok(val_match) => {
let either_str = val_match.into_inner();
let cow = either_str.as_cow()?;
let url_str = cow.as_ref();

self.check_length(input, || url_str.len())?;

parse_multihost_url(url_str, input, strict).map(EitherMultiHostUrl::Rust)
}
Err(_) => {
// we don't need to worry about whether the url was parsed in strict mode before,
// even if it was, any syntax errors would have been fixed by the first validation
if let Some(multi_url) = downcast_python_input::<PyMultiHostUrl>(input) {
self.check_length(input, || multi_url.get().__str__().len())?;
Ok(EitherMultiHostUrl::Py(multi_url.clone()))
} else if let Some(py_url) = downcast_python_input::<PyUrl>(input) {
self.check_length(input, || py_url.get().url().as_str().len())?;
Ok(EitherMultiHostUrl::Rust(PyMultiHostUrl::new(
py_url.get().url().clone(),
None,
)))
} else {
Err(ValError::new(ErrorTypeDefaults::UrlType, input))
}
}
// we don't need to worry about whether the url was parsed in strict mode before,
// even if it was, any syntax errors would have been fixed by the first validation
if let Some(multi_url) = downcast_python_input::<PyMultiHostUrl>(input) {
self.check_length(input, || multi_url.get().__str__().len())?;
Ok(EitherMultiHostUrl::Py(multi_url.clone()))
} else if let Some(py_url) = downcast_python_input::<PyUrl>(input) {
self.check_length(input, || py_url.get().url().as_str().len())?;
Ok(EitherMultiHostUrl::Rust(PyMultiHostUrl::new(
py_url.get().url().clone(),
None,
)))
} else if let Ok(either_str) = input.validate_str(strict, false).map(ValidationMatch::into_inner) {
let cow = either_str.as_cow()?;
let url_str = cow.as_ref();

self.check_length(input, || url_str.len())?;

parse_multihost_url(url_str, input, strict).map(EitherMultiHostUrl::Rust)
} else {
Err(ValError::new(ErrorTypeDefaults::UrlType, input))
}
}

Expand Down Expand Up @@ -399,24 +386,24 @@ fn parse_multihost_url<'py>(
}
}

// consume the url schema, some logic from `parse_scheme`
// consume the url scheme, some logic from `parse_scheme`
// https://github.com/servo/rust-url/blob/v2.3.1/url/src/parser.rs#L387-L411
let schema_start = chars.position;
let schema_end = loop {
let scheme_start = chars.position;
let scheme_end = loop {
match chars.next() {
Some('a'..='z' | 'A'..='Z' | '0'..='9' | '+' | '-' | '.') => continue,
Some(':') => {
// require the schema to be non-empty
let schema_end = chars.position - ':'.len_utf8();
if schema_end > schema_start {
break schema_end;
// require the scheme to be non-empty
let scheme_end = chars.position - ':'.len_utf8();
if scheme_end > scheme_start {
break scheme_end;
}
}
_ => {}
}
return parsing_err!(ParseError::RelativeUrlWithoutBase);
};
let schema = url_str[schema_start..schema_end].to_ascii_lowercase();
let scheme = url_str[scheme_start..scheme_end].to_ascii_lowercase();

// consume the double slash, or any number of slashes, including backslashes, taken from `parse_with_scheme`
// https://github.com/servo/rust-url/blob/v2.3.1/url/src/parser.rs#L413-L456
Expand All @@ -437,7 +424,7 @@ fn parse_multihost_url<'py>(
let mut start = chars.position;
while let Some(c) = chars.next() {
match c {
'\\' if schema_is_special(&schema) => break,
'\\' if scheme_is_special(&scheme) => break,
'/' | '?' | '#' => break,
',' => {
// minus 1 because we know that the last char was a `,` with length 1
Expand Down Expand Up @@ -587,7 +574,7 @@ trait CopyFromPyUrl {
fn url_mut(&mut self) -> &mut Url;
}

fn get_allowed_schemas(schema: &Bound<'_, PyDict>, name: &'static str) -> PyResult<(AllowedSchemas, String)> {
fn get_allowed_schemes(schema: &Bound<'_, PyDict>, name: &'static str) -> PyResult<(AllowedSchemes, String)> {
match schema.get_as::<Bound<'_, PyList>>(intern!(schema.py(), "allowed_schemes"))? {
Some(list) => {
if list.is_empty() {
Expand Down
14 changes: 14 additions & 0 deletions tests/benchmarks/test_micro_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,3 +1475,17 @@ def test_enum_str_core(benchmark):
assert v.validate_python('apple') is FooStr.a

benchmark(v.validate_python, 'apple')


@pytest.mark.benchmark(group='url')
def test_url_core(benchmark):
v = SchemaValidator(core_schema.url_schema())

benchmark(v.validate_python, 'https://example.com/some/path?query=string#fragment')


@pytest.mark.benchmark(group='url')
def test_multi_host_url_core(benchmark):
v = SchemaValidator(core_schema.multi_host_url_schema())

benchmark(v.validate_python, 'https://example.com,b:[email protected]:777/some/path?query=string#fragment')
6 changes: 6 additions & 0 deletions tests/validators/test_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ def url_test_case_helper(
('http://example.com:65535', 'http://example.com:65535/'),
('http:\\\\example.com', 'http://example.com/'),
('http:example.com', 'http://example.com/'),
('http:example.com/path', 'http://example.com/path'),
('http:example.com/path/', 'http://example.com/path/'),
('http:example.com?query=nopath', 'http://example.com/?query=nopath'),
('http:example.com/?query=haspath', 'http://example.com/?query=haspath'),
('http:example.com#nopath', 'http://example.com/#nopath'),
('http:example.com/#haspath', 'http://example.com/#haspath'),
('http://example.com:65536', Err('invalid port number')),
('http://1...1', Err('invalid IPv4 address')),
('https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334[', Err('invalid IPv6 address')),
Expand Down
Loading