Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite UTF-8 validation in shift-based DFA for 70%~135% performance increase on non-ASCII strings #136693

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
35 changes: 29 additions & 6 deletions library/core/src/str/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,24 @@ use crate::fmt;
/// }
/// }
/// ```
#[derive(Copy, Eq, PartialEq, Clone, Debug)]
#[derive(Copy, Eq, PartialEq, Clone)]
#[stable(feature = "rust1", since = "1.0.0")]
pub struct Utf8Error {
pub(super) valid_up_to: usize,
pub(super) error_len: Option<u8>,
// Use a single value instead of tagged enum `Option<u8>` to make `Result<(), Utf8Error>` fits
// in two machine words, so `run_utf8_validation` does not need to returns values on stack on
// x86(_64). Register spill is very expensive on `run_utf8_validation` and can give up to 200%
// latency penalty on the error path.
pub(super) error_len: Utf8ErrorLen,
}

#[derive(Copy, Eq, PartialEq, Clone)]
#[repr(u8)]
pub(super) enum Utf8ErrorLen {
Eof = 0,
One,
Two,
Three,
}

impl Utf8Error {
Expand Down Expand Up @@ -100,18 +113,28 @@ impl Utf8Error {
#[must_use]
#[inline]
pub const fn error_len(&self) -> Option<usize> {
// FIXME(const-hack): This should become `map` again, once it's `const`
match self.error_len {
Some(len) => Some(len as usize),
None => None,
Utf8ErrorLen::Eof => None,
// FIXME(136972): Direct `match` gives suboptimal codegen involving two table lookups.
len => Some(len as usize),
}
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl fmt::Debug for Utf8Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Utf8Error")
.field("valid_up_to", &self.valid_up_to)
.field("error_len", &self.error_len())
.finish()
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl fmt::Display for Utf8Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(error_len) = self.error_len {
if let Some(error_len) = self.error_len() {
write!(
f,
"invalid utf-8 sequence of {} bytes from index {}",
Expand Down
108 changes: 22 additions & 86 deletions library/core/src/str/lossy.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::from_utf8_unchecked;
use super::validations::utf8_char_width;
use super::validations::run_utf8_validation;
use crate::fmt;
use crate::fmt::{Formatter, Write};
use crate::iter::FusedIterator;
Expand Down Expand Up @@ -197,93 +197,29 @@ impl<'a> Iterator for Utf8Chunks<'a> {
return None;
}

const TAG_CONT_U8: u8 = 128;
fn safe_get(xs: &[u8], i: usize) -> u8 {
*xs.get(i).unwrap_or(&0)
}

let mut i = 0;
let mut valid_up_to = 0;
while i < self.source.len() {
// SAFETY: `i < self.source.len()` per previous line.
// For some reason the following are both significantly slower:
// while let Some(&byte) = self.source.get(i) {
// while let Some(byte) = self.source.get(i).copied() {
let byte = unsafe { *self.source.get_unchecked(i) };
i += 1;

if byte < 128 {
// This could be a `1 => ...` case in the match below, but for
// the common case of all-ASCII inputs, we bypass loading the
// sizeable UTF8_CHAR_WIDTH table into cache.
} else {
let w = utf8_char_width(byte);

match w {
2 => {
if safe_get(self.source, i) & 192 != TAG_CONT_U8 {
break;
}
i += 1;
}
3 => {
match (byte, safe_get(self.source, i)) {
(0xE0, 0xA0..=0xBF) => (),
(0xE1..=0xEC, 0x80..=0xBF) => (),
(0xED, 0x80..=0x9F) => (),
(0xEE..=0xEF, 0x80..=0xBF) => (),
_ => break,
}
i += 1;
if safe_get(self.source, i) & 192 != TAG_CONT_U8 {
break;
}
i += 1;
}
4 => {
match (byte, safe_get(self.source, i)) {
(0xF0, 0x90..=0xBF) => (),
(0xF1..=0xF3, 0x80..=0xBF) => (),
(0xF4, 0x80..=0x8F) => (),
_ => break,
}
i += 1;
if safe_get(self.source, i) & 192 != TAG_CONT_U8 {
break;
}
i += 1;
if safe_get(self.source, i) & 192 != TAG_CONT_U8 {
break;
}
i += 1;
}
_ => break,
}
match run_utf8_validation(self.source) {
Ok(()) => {
// SAFETY: The whole `source` is valid in UTF-8.
let valid = unsafe { from_utf8_unchecked(&self.source) };
// Truncate the slice, no need to touch the pointer.
self.source = &self.source[..0];
Some(Utf8Chunk { valid, invalid: &[] })
}
Err(err) => {
let valid_up_to = err.valid_up_to();
let error_len = err.error_len().unwrap_or(self.source.len() - valid_up_to);
// SAFETY: `valid_up_to` is the valid UTF-8 string length, so is in bound.
let (valid, remaining) = unsafe { self.source.split_at_unchecked(valid_up_to) };
// SAFETY: `error_len` is the errornous byte sequence length, so is in bound.
let (invalid, after_invalid) = unsafe { remaining.split_at_unchecked(error_len) };
self.source = after_invalid;
Some(Utf8Chunk {
// SAFETY: All bytes up to `valid_up_to` are valid UTF-8.
valid: unsafe { from_utf8_unchecked(valid) },
invalid,
})
}

valid_up_to = i;
}

// SAFETY: `i <= self.source.len()` because it is only ever incremented
// via `i += 1` and in between every single one of those increments, `i`
// is compared against `self.source.len()`. That happens either
// literally by `i < self.source.len()` in the while-loop's condition,
// or indirectly by `safe_get(self.source, i) & 192 != TAG_CONT_U8`. The
// loop is terminated as soon as the latest `i += 1` has made `i` no
// longer less than `self.source.len()`, which means it'll be at most
// equal to `self.source.len()`.
let (inspected, remaining) = unsafe { self.source.split_at_unchecked(i) };
self.source = remaining;

// SAFETY: `valid_up_to <= i` because it is only ever assigned via
// `valid_up_to = i` and `i` only increases.
let (valid, invalid) = unsafe { inspected.split_at_unchecked(valid_up_to) };

Some(Utf8Chunk {
// SAFETY: All bytes up to `valid_up_to` are valid UTF-8.
valid: unsafe { from_utf8_unchecked(valid) },
invalid,
})
}
}

Expand Down
78 changes: 78 additions & 0 deletions library/core/src/str/solve_dfa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python3
# Use z3 to solve UTF-8 validation DFA for offset and transition table,
# in order to encode transition table into u32.
# We minimize the output variables in the solution to make it deterministic.
# Ref: <https://gist.github.com/dougallj/166e326de6ad4cf2c94be97a204c025f>
# See more detail explanation in `./validations.rs`.
#
# It is expected to find a solution in <30s on a modern machine, and the
# solution is appended to the end of this file.
from z3 import *

STATE_CNT = 9

# The transition table.
# A value X on column Y means state Y should transition to state X on some
# input bytes. We assign state 0 as ERROR and state 1 as ACCEPT (initial).
# Eg. first line: for input byte 00..=7F, transition S1 -> S1, others -> S0.
TRANSITIONS = [
# 0 1 2 3 4 5 6 7 8
# First bytes
((0, 1, 0, 0, 0, 0, 0, 0, 0), "00-7F"),
((0, 2, 0, 0, 0, 0, 0, 0, 0), "C2-DF"),
((0, 3, 0, 0, 0, 0, 0, 0, 0), "E0"),
((0, 4, 0, 0, 0, 0, 0, 0, 0), "E1-EC, EE-EF"),
((0, 5, 0, 0, 0, 0, 0, 0, 0), "ED"),
((0, 6, 0, 0, 0, 0, 0, 0, 0), "F0"),
((0, 7, 0, 0, 0, 0, 0, 0, 0), "F1-F3"),
((0, 8, 0, 0, 0, 0, 0, 0, 0), "F4"),
# Continuation bytes
((0, 0, 1, 0, 2, 2, 0, 4, 4), "80-8F"),
((0, 0, 1, 0, 2, 2, 4, 4, 0), "90-9F"),
((0, 0, 1, 2, 2, 0, 4, 4, 0), "A0-BF"),
# Illegal
((0, 0, 0, 0, 0, 0, 0, 0, 0), "C0-C1, F5-FF"),
]

o = Optimize()
offsets = [BitVec(f"o{i}", 32) for i in range(STATE_CNT)]
trans_table = [BitVec(f"t{i}", 32) for i in range(len(TRANSITIONS))]

# Add some guiding constraints to make solving faster.
o.add(offsets[0] == 0)
o.add(trans_table[-1] == 0)

for i in range(len(offsets)):
# Do not over-shift. It's not necessary but makes solving faster.
o.add(offsets[i] < 32 - 5)
for j in range(i):
o.add(offsets[i] != offsets[j])
for trans, (targets, _) in zip(trans_table, TRANSITIONS):
for src, tgt in enumerate(targets):
o.add((LShR(trans, offsets[src]) & 31) == offsets[tgt])

# Minimize ordered outputs to get a unique solution.
goal = Concat(*offsets, *trans_table)
o.minimize(goal)
print(o.check())
print("Offset[]= ", [o.model()[i].as_long() for i in offsets])
print("Transitions:")
for (_, label), v in zip(TRANSITIONS, [o.model()[i].as_long() for i in trans_table]):
print(f"{label:14} => {v:#10x}, // {v:032b}")

# Output should be deterministic:
# sat
# Offset[]= [0, 6, 16, 19, 1, 25, 11, 18, 24]
# Transitions:
# 00-7F => 0x180, // 00000000000000000000000110000000
# C2-DF => 0x400, // 00000000000000000000010000000000
# E0 => 0x4c0, // 00000000000000000000010011000000
# E1-EC, EE-EF => 0x40, // 00000000000000000000000001000000
# ED => 0x640, // 00000000000000000000011001000000
# F0 => 0x2c0, // 00000000000000000000001011000000
# F1-F3 => 0x480, // 00000000000000000000010010000000
# F4 => 0x600, // 00000000000000000000011000000000
# 80-8F => 0x21060020, // 00100001000001100000000000100000
# 90-9F => 0x20060820, // 00100000000001100000100000100000
# A0-BF => 0x860820, // 00000000100001100000100000100000
# C0-C1, F5-FF => 0x0, // 00000000000000000000000000000000
Loading
Loading