From aee62296cc6b507c04045317c95dee7c34a9c587 Mon Sep 17 00:00:00 2001 From: Lawrin Novitsky Date: Mon, 3 Nov 2025 03:07:25 +0100 Subject: [PATCH 1/2] Constants and classes MariaDB COM_STMT_BULK_EXECUTE command Added StmtBulkExecuteParamsFlags bitflags for packet's bulk flags and MariadbBulkIndicator enum to define possible paramater value indicators. Added ComStmtBulkExecuteRequestBuilder class to build COM_STMT_BULK_EXECUTE packet representation in ComStmtBulkExecuteRequest class implementing packet data serialization. --- src/constants.rs | 32 +++++++- src/packets/mod.rs | 180 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 206 insertions(+), 6 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index 45e54bb..78cc9c1 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -366,7 +366,7 @@ my_bitflags! { UnknownMariadbCapabilityFlags, u32, - /// Mariadb client capability flags + /// MariaDB client capability flags #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] pub struct MariadbCapabilities: u32 { /// Permits feedback during long-running operations @@ -431,6 +431,20 @@ my_bitflags! { } } +my_bitflags! { + StmtBulkExecuteParamsFlags, + #[error("Unknown flags in the raw value of StmtBulkExecuteParamsFlags (raw={0:b})")] + UnknownStmtBulkExecuteParamsFlags, + u16, + + /// MySql stmt execute params flags. + #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] + pub struct StmtBulkExecuteParamsFlags: u16 { + const SEND_UNIT_RESULTS = 64_u16; + const SEND_TYPES_TO_SERVER = 128_u16; + } +} + my_bitflags! { ColumnFlags, #[error("Unknown flags in the raw value of ColumnFlags (raw={0:b})")] @@ -528,6 +542,22 @@ pub enum Command { COM_BINLOG_DUMP_GTID, COM_RESET_CONNECTION, COM_END, + COM_STMT_BULK_EXECUTE = 0xfa_u8, +} + +/// MariaDB bulk execute parameter value indicators +#[allow(non_camel_case_types)] +#[derive(Clone, Copy, Eq, PartialEq, Debug)] +#[repr(u8)] +pub enum MariadbBulkIndicator { + /// No special indicator, normal value + BULK_INDICATOR_NONE = 0x00_u8, + /// NULL value + BULK_INDICATOR_NULL = 0x01_u8, + /// For INSERT/UPDATE, value is default. Not used + BULK_INDICATOR_DEFAULT = 0x02_u8, + /// Value is default for insert, Is ignored for update. Not used. + BULK_INDICATOR_IGNORE = 0x03_u8, } /// Type of state change information (part of MySql's Ok packet). diff --git a/src/packets/mod.rs b/src/packets/mod.rs index 9d94c99..9ae5829 100644 --- a/src/packets/mod.rs +++ b/src/packets/mod.rs @@ -18,12 +18,13 @@ use std::{ }; use crate::collations::CollationId; +use crate::constants::StmtBulkExecuteParamsFlags; use crate::scramble::create_response_for_ed25519; use crate::{ constants::{ CapabilityFlags, ColumnFlags, ColumnType, Command, CursorType, MAX_PAYLOAD_LEN, - MariadbCapabilities, SessionStateType, StatusFlags, StmtExecuteParamFlags, - StmtExecuteParamsFlags, + MariadbBulkIndicator, MariadbCapabilities, SessionStateType, StatusFlags, + StmtExecuteParamFlags, StmtExecuteParamsFlags, }, io::{BufMutExt, ParseBuf}, misc::{ @@ -2762,6 +2763,171 @@ impl MySerialize for ComStmtClose { } } +/// Sends array of parameters to the server for the bulk execution of a prepared statement with +/// COM_STMT_BULK_EXECUTE command. +#[derive(Debug, Clone, PartialEq)] +pub struct ComStmtBulkExecuteRequestBuilder { + pub stmt_id: u32, + pub with_types: bool, + pub paramset: Vec>, + pub payload_len: usize, + pub max_payload_len: usize, /* max_allowed_packet(if known) - 4 */ +} + +impl ComStmtBulkExecuteRequestBuilder { + pub fn new(stmt_id: u32, max_payload: usize) -> Self { + Self { + stmt_id, + with_types: true, + paramset: Vec::new(), + payload_len: 0, + max_payload_len: max_payload, + } + } + + pub fn next(&mut self) -> () { + self.with_types = false; + self.paramset.clear(); + self.payload_len = 0; + } + pub fn add_row(&mut self, params: &[Value]) -> bool { + if self.with_types && self.payload_len == 0 { + self.payload_len = params.len() * 2; + } + let mut data_len = 0; + for p in params { + match p.bin_len() as usize { + 0 => data_len += 1, // NULLs take 1 byte for the indicator + x => data_len += x + 1, // non-NULLs take their length + 1 byte for the indicator + } + } + // It should be really total packet len(+7 + 4)compared against max allowed packet size, not MAX_PAYLOAD_LEN + if 7 + self.payload_len + data_len > self.max_payload_len { + return true; + } + self.paramset.push(params.to_vec()); + self.payload_len += data_len; + false + } + + pub fn has_rows(&self) -> bool { + !self.paramset.is_empty() + } + + pub fn build(&self) -> ComStmtBulkExecuteRequest { + ComStmtBulkExecuteRequest { + com_stmt_bulk_execute: ConstU8::new(), + stmt_id: RawInt::new(self.stmt_id), + bulk_flags: if self.with_types { + Const::new(StmtBulkExecuteParamsFlags::SEND_TYPES_TO_SERVER) + } else { + Const::new(StmtBulkExecuteParamsFlags::empty()) + }, + params: &self.paramset, + } + } +} + +define_header!( + ComStmtBulkExecuteHeader, + COM_STMT_BULK_EXECUTE, + InvalidComStmtBulkExecuteHeader +); + +#[derive(Debug, Clone, PartialEq)] +pub struct ComStmtBulkExecuteRequest<'a> { + com_stmt_bulk_execute: ComStmtBulkExecuteHeader, + stmt_id: RawInt, + bulk_flags: Const, + // max params / bits per byte = 8192 + params: &'a Vec>, +} + +impl<'a> ComStmtBulkExecuteRequest<'a> { + pub fn stmt_id(&self) -> u32 { + self.stmt_id.0 + } + + pub fn bulk_flags(&self) -> StmtBulkExecuteParamsFlags { + self.bulk_flags.0 + } + + pub fn params(&self) -> &[Vec] { + self.params.as_ref() + } +} + +impl MySerialize for ComStmtBulkExecuteRequest<'_> { + fn serialize(&self, buf: &mut Vec) { + self.com_stmt_bulk_execute.serialize(&mut *buf); + self.stmt_id.serialize(&mut *buf); + self.bulk_flags.serialize(&mut *buf); + + if self + .bulk_flags + .0 + .contains(StmtBulkExecuteParamsFlags::SEND_TYPES_TO_SERVER) + { + for param in &self.params[0] { + let (column_type, flags) = match param { + Value::NULL => (ColumnType::MYSQL_TYPE_NULL, StmtExecuteParamFlags::empty()), + Value::Bytes(_) => ( + ColumnType::MYSQL_TYPE_VAR_STRING, + StmtExecuteParamFlags::empty(), + ), + Value::Int(_) => ( + ColumnType::MYSQL_TYPE_LONGLONG, + StmtExecuteParamFlags::empty(), + ), + Value::UInt(_) => ( + ColumnType::MYSQL_TYPE_LONGLONG, + StmtExecuteParamFlags::UNSIGNED, + ), + Value::Float(_) => { + (ColumnType::MYSQL_TYPE_FLOAT, StmtExecuteParamFlags::empty()) + } + Value::Double(_) => ( + ColumnType::MYSQL_TYPE_DOUBLE, + StmtExecuteParamFlags::empty(), + ), + Value::Date(..) => ( + ColumnType::MYSQL_TYPE_DATETIME, + StmtExecuteParamFlags::empty(), + ), + Value::Time(..) => { + (ColumnType::MYSQL_TYPE_TIME, StmtExecuteParamFlags::empty()) + } + }; + buf.put_slice(&[column_type as u8, flags.bits()]); + } + } + + for row in self.params { + for param in row { + match param { + Value::Int(_) + | Value::UInt(_) + | Value::Float(_) + | Value::Double(_) + | Value::Date(..) + | Value::Time(..) => { + buf.put_u8(MariadbBulkIndicator::BULK_INDICATOR_NONE as u8); // not NULL + param.serialize(buf); + } + Value::Bytes(_) => { + buf.put_u8(MariadbBulkIndicator::BULK_INDICATOR_NONE as u8); // not NULL + param.serialize(buf); + } + Value::NULL => { + buf.put_u8(MariadbBulkIndicator::BULK_INDICATOR_NULL as u8); // NULL indicator + } + } + } + } + } +} +// ------------------------------------------------------------------------------ + define_header!( ComRegisterSlaveHeader, COM_REGISTER_SLAVE, @@ -4129,7 +4295,7 @@ mod test { fn should_parse_handshake_packet_with_mariadb_ext_capabilities() { const HSP: &[u8] = b"\x0a5.5.5-11.4.7-MariaDB-log\x00\x0b\x00\ \x00\x00\x64\x76\x48\x40\x49\x2d\x43\x4a\x00\xff\xf7\x08\x02\x00\ - \x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x2a\x34\x64\ + \x00\x00\x00\x00\x00\x00\x00\x00\x00\x14\x00\x00\x00\x2a\x34\x64\ \x7c\x63\x5a\x77\x6b\x34\x5e\x5d\x3a\x00"; let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP)).unwrap(); @@ -4150,6 +4316,7 @@ mod test { assert_eq!( hsp.mariadb_ext_capabilities(), MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA + | MariadbCapabilities::MARIADB_CLIENT_STMT_BULK_OPERATIONS ); let mut output = Vec::new(); hsp.serialize(&mut output); @@ -4169,7 +4336,10 @@ mod test { None, 1_u32.to_be(), ) - .with_mariadb_ext_capabilities(MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA); + .with_mariadb_ext_capabilities( + MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA + | MariadbCapabilities::MARIADB_CLIENT_STMT_BULK_OPERATIONS, + ); let mut actual = Vec::new(); response.serialize(&mut actual); @@ -4179,7 +4349,7 @@ mod test { 0x2d, // charset 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // reserved - 0x10, 0x00, 0x00, 0x00, // mariadb capabilities + 0x14, 0x00, 0x00, 0x00, // mariadb capabilities 0x72, 0x6f, 0x6f, 0x74, 0x00, // username=root 0x00, // blank scramble 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70, From 99331f105ce76084e1d8c2a0316136707e71072a Mon Sep 17 00:00:00 2001 From: Lawrin Novitsky Date: Mon, 3 Nov 2025 09:52:23 +0100 Subject: [PATCH 2/2] Correction of the elided lifetime --- src/packets/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/packets/mod.rs b/src/packets/mod.rs index 9ae5829..45304f4 100644 --- a/src/packets/mod.rs +++ b/src/packets/mod.rs @@ -2814,7 +2814,7 @@ impl ComStmtBulkExecuteRequestBuilder { !self.paramset.is_empty() } - pub fn build(&self) -> ComStmtBulkExecuteRequest { + pub fn build(&self) -> ComStmtBulkExecuteRequest<'_> { ComStmtBulkExecuteRequest { com_stmt_bulk_execute: ConstU8::new(), stmt_id: RawInt::new(self.stmt_id),