Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ my_bitflags! {
UnknownMariadbCapabilityFlags,
u32,

/// Mariadb client capability flags
/// MariaDB client capability flags
#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
pub struct MariadbCapabilities: u32 {
/// Permits feedback during long-running operations
Expand Down Expand Up @@ -431,6 +431,20 @@ my_bitflags! {
}
}

my_bitflags! {
StmtBulkExecuteParamsFlags,
#[error("Unknown flags in the raw value of StmtBulkExecuteParamsFlags (raw={0:b})")]
UnknownStmtBulkExecuteParamsFlags,
u16,

/// MySql stmt execute params flags.
#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)]
pub struct StmtBulkExecuteParamsFlags: u16 {
const SEND_UNIT_RESULTS = 64_u16;
const SEND_TYPES_TO_SERVER = 128_u16;
}
}

my_bitflags! {
ColumnFlags,
#[error("Unknown flags in the raw value of ColumnFlags (raw={0:b})")]
Expand Down Expand Up @@ -528,6 +542,22 @@ pub enum Command {
COM_BINLOG_DUMP_GTID,
COM_RESET_CONNECTION,
COM_END,
COM_STMT_BULK_EXECUTE = 0xfa_u8,
}

/// MariaDB bulk execute parameter value indicators
#[allow(non_camel_case_types)]
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
#[repr(u8)]
pub enum MariadbBulkIndicator {
/// No special indicator, normal value
BULK_INDICATOR_NONE = 0x00_u8,
/// NULL value
BULK_INDICATOR_NULL = 0x01_u8,
/// For INSERT/UPDATE, value is default. Not used
BULK_INDICATOR_DEFAULT = 0x02_u8,
/// Value is default for insert, Is ignored for update. Not used.
BULK_INDICATOR_IGNORE = 0x03_u8,
}

/// Type of state change information (part of MySql's Ok packet).
Expand Down
180 changes: 175 additions & 5 deletions src/packets/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ use std::{
};

use crate::collations::CollationId;
use crate::constants::StmtBulkExecuteParamsFlags;
use crate::scramble::create_response_for_ed25519;
use crate::{
constants::{
CapabilityFlags, ColumnFlags, ColumnType, Command, CursorType, MAX_PAYLOAD_LEN,
MariadbCapabilities, SessionStateType, StatusFlags, StmtExecuteParamFlags,
StmtExecuteParamsFlags,
MariadbBulkIndicator, MariadbCapabilities, SessionStateType, StatusFlags,
StmtExecuteParamFlags, StmtExecuteParamsFlags,
},
io::{BufMutExt, ParseBuf},
misc::{
Expand Down Expand Up @@ -2762,6 +2763,171 @@ impl MySerialize for ComStmtClose {
}
}

/// Sends array of parameters to the server for the bulk execution of a prepared statement with
/// COM_STMT_BULK_EXECUTE command.
#[derive(Debug, Clone, PartialEq)]
pub struct ComStmtBulkExecuteRequestBuilder {
pub stmt_id: u32,
pub with_types: bool,
pub paramset: Vec<Vec<Value>>,
pub payload_len: usize,
pub max_payload_len: usize, /* max_allowed_packet(if known) - 4 */
}

impl ComStmtBulkExecuteRequestBuilder {
pub fn new(stmt_id: u32, max_payload: usize) -> Self {
Self {
stmt_id,
with_types: true,
paramset: Vec::new(),
payload_len: 0,
max_payload_len: max_payload,
}
}

pub fn next(&mut self) -> () {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please document the meaning of all the public functions?

self.with_types = false;
self.paramset.clear();
self.payload_len = 0;
}
pub fn add_row(&mut self, params: &[Value]) -> bool {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be add_row(&mut self, params: Vec<Value>) I believe, to avoid unnecessary clone.

if self.with_types && self.payload_len == 0 {
self.payload_len = params.len() * 2;
}
let mut data_len = 0;
for p in params {
match p.bin_len() as usize {
0 => data_len += 1, // NULLs take 1 byte for the indicator
x => data_len += x + 1, // non-NULLs take their length + 1 byte for the indicator
}
}
// It should be really total packet len(+7 + 4)compared against max allowed packet size, not MAX_PAYLOAD_LEN
if 7 + self.payload_len + data_len > self.max_payload_len {
return true;
}
self.paramset.push(params.to_vec());
self.payload_len += data_len;
false
}

pub fn has_rows(&self) -> bool {
!self.paramset.is_empty()
}

pub fn build(&self) -> ComStmtBulkExecuteRequest<'_> {
ComStmtBulkExecuteRequest {
com_stmt_bulk_execute: ConstU8::new(),
stmt_id: RawInt::new(self.stmt_id),
bulk_flags: if self.with_types {
Const::new(StmtBulkExecuteParamsFlags::SEND_TYPES_TO_SERVER)
} else {
Const::new(StmtBulkExecuteParamsFlags::empty())
},
params: &self.paramset,
}
}
}

define_header!(
ComStmtBulkExecuteHeader,
COM_STMT_BULK_EXECUTE,
InvalidComStmtBulkExecuteHeader
);

#[derive(Debug, Clone, PartialEq)]
pub struct ComStmtBulkExecuteRequest<'a> {
com_stmt_bulk_execute: ComStmtBulkExecuteHeader,
stmt_id: RawInt<LeU32>,
bulk_flags: Const<StmtBulkExecuteParamsFlags, LeU16>,
// max params / bits per byte = 8192
params: &'a Vec<Vec<Value>>,
}

impl<'a> ComStmtBulkExecuteRequest<'a> {
pub fn stmt_id(&self) -> u32 {
self.stmt_id.0
}

pub fn bulk_flags(&self) -> StmtBulkExecuteParamsFlags {
self.bulk_flags.0
}

pub fn params(&self) -> &[Vec<Value>] {
self.params.as_ref()
}
}

impl MySerialize for ComStmtBulkExecuteRequest<'_> {
fn serialize(&self, buf: &mut Vec<u8>) {
self.com_stmt_bulk_execute.serialize(&mut *buf);
self.stmt_id.serialize(&mut *buf);
self.bulk_flags.serialize(&mut *buf);

if self
.bulk_flags
.0
.contains(StmtBulkExecuteParamsFlags::SEND_TYPES_TO_SERVER)
{
for param in &self.params[0] {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. This will panic if SEND_TYPES_TO_SERVER is set and the params is empty that seems reachable from ComStmtBulkExecuteRequestBuilder 🤔

let (column_type, flags) = match param {
Value::NULL => (ColumnType::MYSQL_TYPE_NULL, StmtExecuteParamFlags::empty()),
Value::Bytes(_) => (
ColumnType::MYSQL_TYPE_VAR_STRING,
StmtExecuteParamFlags::empty(),
),
Value::Int(_) => (
ColumnType::MYSQL_TYPE_LONGLONG,
StmtExecuteParamFlags::empty(),
),
Value::UInt(_) => (
ColumnType::MYSQL_TYPE_LONGLONG,
StmtExecuteParamFlags::UNSIGNED,
),
Value::Float(_) => {
(ColumnType::MYSQL_TYPE_FLOAT, StmtExecuteParamFlags::empty())
}
Value::Double(_) => (
ColumnType::MYSQL_TYPE_DOUBLE,
StmtExecuteParamFlags::empty(),
),
Value::Date(..) => (
ColumnType::MYSQL_TYPE_DATETIME,
StmtExecuteParamFlags::empty(),
),
Value::Time(..) => {
(ColumnType::MYSQL_TYPE_TIME, StmtExecuteParamFlags::empty())
}
};
buf.put_slice(&[column_type as u8, flags.bits()]);
}
}

for row in self.params {
for param in row {
match param {
Value::Int(_)
| Value::UInt(_)
| Value::Float(_)
| Value::Double(_)
| Value::Date(..)
| Value::Time(..) => {
buf.put_u8(MariadbBulkIndicator::BULK_INDICATOR_NONE as u8); // not NULL
param.serialize(buf);
}
Value::Bytes(_) => {
buf.put_u8(MariadbBulkIndicator::BULK_INDICATOR_NONE as u8); // not NULL
param.serialize(buf);
}
Value::NULL => {
buf.put_u8(MariadbBulkIndicator::BULK_INDICATOR_NULL as u8); // NULL indicator
}
}
}
}
}
}
// ------------------------------------------------------------------------------

define_header!(
ComRegisterSlaveHeader,
COM_REGISTER_SLAVE,
Expand Down Expand Up @@ -4129,7 +4295,7 @@ mod test {
fn should_parse_handshake_packet_with_mariadb_ext_capabilities() {
const HSP: &[u8] = b"\x0a5.5.5-11.4.7-MariaDB-log\x00\x0b\x00\
\x00\x00\x64\x76\x48\x40\x49\x2d\x43\x4a\x00\xff\xf7\x08\x02\x00\
\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x00\x00\x2a\x34\x64\
\x00\x00\x00\x00\x00\x00\x00\x00\x00\x14\x00\x00\x00\x2a\x34\x64\
\x7c\x63\x5a\x77\x6b\x34\x5e\x5d\x3a\x00";

let hsp = HandshakePacket::deserialize((), &mut ParseBuf(HSP)).unwrap();
Expand All @@ -4150,6 +4316,7 @@ mod test {
assert_eq!(
hsp.mariadb_ext_capabilities(),
MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA
| MariadbCapabilities::MARIADB_CLIENT_STMT_BULK_OPERATIONS
);
let mut output = Vec::new();
hsp.serialize(&mut output);
Expand All @@ -4169,7 +4336,10 @@ mod test {
None,
1_u32.to_be(),
)
.with_mariadb_ext_capabilities(MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA);
.with_mariadb_ext_capabilities(
MariadbCapabilities::MARIADB_CLIENT_CACHE_METADATA
| MariadbCapabilities::MARIADB_CLIENT_STMT_BULK_OPERATIONS,
);
let mut actual = Vec::new();
response.serialize(&mut actual);

Expand All @@ -4179,7 +4349,7 @@ mod test {
0x2d, // charset
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, // reserved
0x10, 0x00, 0x00, 0x00, // mariadb capabilities
0x14, 0x00, 0x00, 0x00, // mariadb capabilities
0x72, 0x6f, 0x6f, 0x74, 0x00, // username=root
0x00, // blank scramble
0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70,
Expand Down