From 51a7396ad3d78d9326ee1537b9ff29ab3919556f Mon Sep 17 00:00:00 2001 From: Jan Verbeek Date: Sun, 3 Dec 2023 12:25:11 +0100 Subject: [PATCH] Move `OsStr::slice_encoded_bytes` validation to platform modules On Windows and UEFI this improves performance and error messaging. On other platforms we optimize the fast path a bit more. This also prepares for later relaxing the checks on certain platforms. --- library/std/src/ffi/mod.rs | 7 +++ library/std/src/ffi/os_str.rs | 43 +++------------ library/std/src/ffi/os_str/tests.rs | 68 +++++++++++++++++++++--- library/std/src/sys/os_str/bytes.rs | 43 +++++++++++++++ library/std/src/sys/os_str/wtf8.rs | 7 ++- library/std/src/sys_common/wtf8.rs | 36 +++++++++++-- library/std/src/sys_common/wtf8/tests.rs | 62 +++++++++++++++++++++ 7 files changed, 219 insertions(+), 47 deletions(-) diff --git a/library/std/src/ffi/mod.rs b/library/std/src/ffi/mod.rs index 97e78d1778677..818571ddaaa16 100644 --- a/library/std/src/ffi/mod.rs +++ b/library/std/src/ffi/mod.rs @@ -127,6 +127,11 @@ //! trait, which provides a [`from_wide`] method to convert a native Windows //! string (without the terminating nul character) to an [`OsString`]. //! +//! ## Other platforms +//! +//! Many other platforms provide their own extension traits in a +//! `std::os::*::ffi` module. +//! //! ## On all platforms //! //! On all platforms, [`OsStr`] consists of a sequence of bytes that is encoded as a superset of @@ -135,6 +140,8 @@ //! For limited, inexpensive conversions from and to bytes, see [`OsStr::as_encoded_bytes`] and //! [`OsStr::from_encoded_bytes_unchecked`]. //! +//! For basic string processing, see [`OsStr::slice_encoded_bytes`]. +//! //! [Unicode scalar value]: https://www.unicode.org/glossary/#unicode_scalar_value //! [Unicode code point]: https://www.unicode.org/glossary/#code_point //! [`env::set_var()`]: crate::env::set_var "env::set_var" diff --git a/library/std/src/ffi/os_str.rs b/library/std/src/ffi/os_str.rs index 81973182148ef..28747ad8f340a 100644 --- a/library/std/src/ffi/os_str.rs +++ b/library/std/src/ffi/os_str.rs @@ -9,7 +9,7 @@ use crate::hash::{Hash, Hasher}; use crate::ops::{self, Range}; use crate::rc::Rc; use crate::slice; -use crate::str::{from_utf8 as str_from_utf8, FromStr}; +use crate::str::FromStr; use crate::sync::Arc; use crate::sys::os_str::{Buf, Slice}; @@ -997,42 +997,15 @@ impl OsStr { /// ``` #[unstable(feature = "os_str_slice", issue = "118485")] pub fn slice_encoded_bytes>(&self, range: R) -> &Self { - #[track_caller] - fn check_valid_boundary(bytes: &[u8], index: usize) { - if index == 0 || index == bytes.len() { - return; - } - - // Fast path - if bytes[index - 1].is_ascii() || bytes[index].is_ascii() { - return; - } - - let (before, after) = bytes.split_at(index); - - // UTF-8 takes at most 4 bytes per codepoint, so we don't - // need to check more than that. - let after = after.get(..4).unwrap_or(after); - match str_from_utf8(after) { - Ok(_) => return, - Err(err) if err.valid_up_to() != 0 => return, - Err(_) => (), - } - - for len in 2..=4.min(index) { - let before = &before[index - len..]; - if str_from_utf8(before).is_ok() { - return; - } - } - - panic!("byte index {index} is not an OsStr boundary"); - } - let encoded_bytes = self.as_encoded_bytes(); let Range { start, end } = slice::range(range, ..encoded_bytes.len()); - check_valid_boundary(encoded_bytes, start); - check_valid_boundary(encoded_bytes, end); + + // `check_public_boundary` should panic if the index does not lie on an + // `OsStr` boundary as described above. It's possible to do this in an + // encoding-agnostic way, but details of the internal encoding might + // permit a more efficient implementation. + self.inner.check_public_boundary(start); + self.inner.check_public_boundary(end); // SAFETY: `slice::range` ensures that `start` and `end` are valid let slice = unsafe { encoded_bytes.get_unchecked(start..end) }; diff --git a/library/std/src/ffi/os_str/tests.rs b/library/std/src/ffi/os_str/tests.rs index 60cde376d326a..b020e05eaab20 100644 --- a/library/std/src/ffi/os_str/tests.rs +++ b/library/std/src/ffi/os_str/tests.rs @@ -194,15 +194,65 @@ fn slice_encoded_bytes() { } #[test] -#[should_panic(expected = "byte index 2 is not an OsStr boundary")] +#[should_panic] +fn slice_out_of_bounds() { + let crab = OsStr::new("🦀"); + let _ = crab.slice_encoded_bytes(..5); +} + +#[test] +#[should_panic] fn slice_mid_char() { let crab = OsStr::new("🦀"); let _ = crab.slice_encoded_bytes(..2); } +#[cfg(unix)] +#[test] +#[should_panic(expected = "byte index 1 is not an OsStr boundary")] +fn slice_invalid_data() { + use crate::os::unix::ffi::OsStrExt; + + let os_string = OsStr::from_bytes(b"\xFF\xFF"); + let _ = os_string.slice_encoded_bytes(1..); +} + +#[cfg(unix)] +#[test] +#[should_panic(expected = "byte index 1 is not an OsStr boundary")] +fn slice_partial_utf8() { + use crate::os::unix::ffi::{OsStrExt, OsStringExt}; + + let part_crab = OsStr::from_bytes(&"🦀".as_bytes()[..3]); + let mut os_string = OsString::from_vec(vec![0xFF]); + os_string.push(part_crab); + let _ = os_string.slice_encoded_bytes(1..); +} + +#[cfg(unix)] +#[test] +fn slice_invalid_edge() { + use crate::os::unix::ffi::{OsStrExt, OsStringExt}; + + let os_string = OsStr::from_bytes(b"a\xFFa"); + assert_eq!(os_string.slice_encoded_bytes(..1), "a"); + assert_eq!(os_string.slice_encoded_bytes(1..), OsStr::from_bytes(b"\xFFa")); + assert_eq!(os_string.slice_encoded_bytes(..2), OsStr::from_bytes(b"a\xFF")); + assert_eq!(os_string.slice_encoded_bytes(2..), "a"); + + let os_string = OsStr::from_bytes(&"abc🦀".as_bytes()[..6]); + assert_eq!(os_string.slice_encoded_bytes(..3), "abc"); + assert_eq!(os_string.slice_encoded_bytes(3..), OsStr::from_bytes(b"\xF0\x9F\xA6")); + + let mut os_string = OsString::from_vec(vec![0xFF]); + os_string.push("🦀"); + assert_eq!(os_string.slice_encoded_bytes(..1), OsStr::from_bytes(b"\xFF")); + assert_eq!(os_string.slice_encoded_bytes(1..), "🦀"); +} + #[cfg(windows)] #[test] -#[should_panic(expected = "byte index 3 is not an OsStr boundary")] +#[should_panic(expected = "byte index 3 lies between surrogate codepoints")] fn slice_between_surrogates() { use crate::os::windows::ffi::OsStringExt; @@ -216,10 +266,14 @@ fn slice_between_surrogates() { fn slice_surrogate_edge() { use crate::os::windows::ffi::OsStringExt; - let os_string = OsString::from_wide(&[0xD800]); - let mut with_crab = os_string.clone(); - with_crab.push("🦀"); + let surrogate = OsString::from_wide(&[0xD800]); + let mut pre_crab = surrogate.clone(); + pre_crab.push("🦀"); + assert_eq!(pre_crab.slice_encoded_bytes(..3), surrogate); + assert_eq!(pre_crab.slice_encoded_bytes(3..), "🦀"); - assert_eq!(with_crab.slice_encoded_bytes(..3), os_string); - assert_eq!(with_crab.slice_encoded_bytes(3..), "🦀"); + let mut post_crab = OsString::from("🦀"); + post_crab.push(&surrogate); + assert_eq!(post_crab.slice_encoded_bytes(..4), "🦀"); + assert_eq!(post_crab.slice_encoded_bytes(4..), surrogate); } diff --git a/library/std/src/sys/os_str/bytes.rs b/library/std/src/sys/os_str/bytes.rs index 3a75ce9ebb781..4ca3f1cd1853a 100644 --- a/library/std/src/sys/os_str/bytes.rs +++ b/library/std/src/sys/os_str/bytes.rs @@ -211,6 +211,49 @@ impl Slice { unsafe { mem::transmute(s) } } + #[track_caller] + #[inline] + pub fn check_public_boundary(&self, index: usize) { + if index == 0 || index == self.inner.len() { + return; + } + if index < self.inner.len() + && (self.inner[index - 1].is_ascii() || self.inner[index].is_ascii()) + { + return; + } + + slow_path(&self.inner, index); + + /// We're betting that typical splits will involve an ASCII character. + /// + /// Putting the expensive checks in a separate function generates notably + /// better assembly. + #[track_caller] + #[inline(never)] + fn slow_path(bytes: &[u8], index: usize) { + let (before, after) = bytes.split_at(index); + + // UTF-8 takes at most 4 bytes per codepoint, so we don't + // need to check more than that. + let after = after.get(..4).unwrap_or(after); + match str::from_utf8(after) { + Ok(_) => return, + Err(err) if err.valid_up_to() != 0 => return, + Err(_) => (), + } + + for len in 2..=4.min(index) { + let before = &before[index - len..]; + if str::from_utf8(before).is_ok() { + return; + } + } + + panic!("byte index {index} is not an OsStr boundary"); + } + } + #[inline] pub fn from_str(s: &str) -> &Slice { unsafe { Slice::from_encoded_bytes_unchecked(s.as_bytes()) } diff --git a/library/std/src/sys/os_str/wtf8.rs b/library/std/src/sys/os_str/wtf8.rs index 237854fac4e2a..352bd7359033a 100644 --- a/library/std/src/sys/os_str/wtf8.rs +++ b/library/std/src/sys/os_str/wtf8.rs @@ -6,7 +6,7 @@ use crate::fmt; use crate::mem; use crate::rc::Rc; use crate::sync::Arc; -use crate::sys_common::wtf8::{Wtf8, Wtf8Buf}; +use crate::sys_common::wtf8::{check_utf8_boundary, Wtf8, Wtf8Buf}; use crate::sys_common::{AsInner, FromInner, IntoInner}; #[derive(Clone, Hash)] @@ -171,6 +171,11 @@ impl Slice { mem::transmute(Wtf8::from_bytes_unchecked(s)) } + #[track_caller] + pub fn check_public_boundary(&self, index: usize) { + check_utf8_boundary(&self.inner, index); + } + #[inline] pub fn from_str(s: &str) -> &Slice { unsafe { mem::transmute(Wtf8::from_str(s)) } diff --git a/library/std/src/sys_common/wtf8.rs b/library/std/src/sys_common/wtf8.rs index 67db5ebd89cfc..2dbd19d717199 100644 --- a/library/std/src/sys_common/wtf8.rs +++ b/library/std/src/sys_common/wtf8.rs @@ -885,15 +885,43 @@ fn decode_surrogate_pair(lead: u16, trail: u16) -> char { unsafe { char::from_u32_unchecked(code_point) } } -/// Copied from core::str::StrPrelude::is_char_boundary +/// Copied from str::is_char_boundary #[inline] pub fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool { - if index == slice.len() { + if index == 0 { return true; } match slice.bytes.get(index) { - None => false, - Some(&b) => b < 128 || b >= 192, + None => index == slice.len(), + Some(&b) => (b as i8) >= -0x40, + } +} + +/// Verify that `index` is at the edge of either a valid UTF-8 codepoint +/// (i.e. a codepoint that's not a surrogate) or of the whole string. +/// +/// These are the cases currently permitted by `OsStr::slice_encoded_bytes`. +/// Splitting between surrogates is valid as far as WTF-8 is concerned, but +/// we do not permit it in the public API because WTF-8 is considered an +/// implementation detail. +#[track_caller] +#[inline] +pub fn check_utf8_boundary(slice: &Wtf8, index: usize) { + if index == 0 { + return; + } + match slice.bytes.get(index) { + Some(0xED) => (), // Might be a surrogate + Some(&b) if (b as i8) >= -0x40 => return, + Some(_) => panic!("byte index {index} is not a codepoint boundary"), + None if index == slice.len() => return, + None => panic!("byte index {index} is out of bounds"), + } + if slice.bytes[index + 1] >= 0xA0 { + // There's a surrogate after index. Now check before index. + if index >= 3 && slice.bytes[index - 3] == 0xED && slice.bytes[index - 2] >= 0xA0 { + panic!("byte index {index} lies between surrogate codepoints"); + } } } diff --git a/library/std/src/sys_common/wtf8/tests.rs b/library/std/src/sys_common/wtf8/tests.rs index 28a426648e501..6a1cc41a8fb04 100644 --- a/library/std/src/sys_common/wtf8/tests.rs +++ b/library/std/src/sys_common/wtf8/tests.rs @@ -663,3 +663,65 @@ fn wtf8_to_owned() { assert_eq!(string.bytes, b"\xED\xA0\x80"); assert!(!string.is_known_utf8); } + +#[test] +fn wtf8_valid_utf8_boundaries() { + let mut string = Wtf8Buf::from_str("aé 💩"); + string.push(CodePoint::from_u32(0xD800).unwrap()); + string.push(CodePoint::from_u32(0xD800).unwrap()); + check_utf8_boundary(&string, 0); + check_utf8_boundary(&string, 1); + check_utf8_boundary(&string, 3); + check_utf8_boundary(&string, 4); + check_utf8_boundary(&string, 8); + check_utf8_boundary(&string, 14); + assert_eq!(string.len(), 14); + + string.push_char('a'); + check_utf8_boundary(&string, 14); + check_utf8_boundary(&string, 15); + + let mut string = Wtf8Buf::from_str("a"); + string.push(CodePoint::from_u32(0xD800).unwrap()); + check_utf8_boundary(&string, 1); + + let mut string = Wtf8Buf::from_str("\u{D7FF}"); + string.push(CodePoint::from_u32(0xD800).unwrap()); + check_utf8_boundary(&string, 3); + + let mut string = Wtf8Buf::new(); + string.push(CodePoint::from_u32(0xD800).unwrap()); + string.push_char('\u{D7FF}'); + check_utf8_boundary(&string, 3); +} + +#[test] +#[should_panic(expected = "byte index 4 is out of bounds")] +fn wtf8_utf8_boundary_out_of_bounds() { + let string = Wtf8::from_str("aé"); + check_utf8_boundary(&string, 4); +} + +#[test] +#[should_panic(expected = "byte index 1 is not a codepoint boundary")] +fn wtf8_utf8_boundary_inside_codepoint() { + let string = Wtf8::from_str("é"); + check_utf8_boundary(&string, 1); +} + +#[test] +#[should_panic(expected = "byte index 1 is not a codepoint boundary")] +fn wtf8_utf8_boundary_inside_surrogate() { + let mut string = Wtf8Buf::new(); + string.push(CodePoint::from_u32(0xD800).unwrap()); + check_utf8_boundary(&string, 1); +} + +#[test] +#[should_panic(expected = "byte index 3 lies between surrogate codepoints")] +fn wtf8_utf8_boundary_between_surrogates() { + let mut string = Wtf8Buf::new(); + string.push(CodePoint::from_u32(0xD800).unwrap()); + string.push(CodePoint::from_u32(0xD800).unwrap()); + check_utf8_boundary(&string, 3); +}