diff --git a/src/config.rs b/src/config.rs index 8f4ece4..7b40a8b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -8,12 +8,12 @@ use std::time::Duration; use anyhow::anyhow; use chrono::Utc; -#[cfg(feature = "cql")] -use clap::builder::PossibleValue; use clap::{Parser, ValueEnum}; use itertools::Itertools; use serde::{Deserialize, Serialize}; +use crate::scripting::db_config; + /// Limit of retry errors to be kept and then printed in scope of a sampling interval pub const PRINT_RETRY_ERROR_LIMIT: u64 = 5; @@ -158,71 +158,10 @@ impl FromStr for RetryInterval { #[derive(Parser, Debug, Serialize, Deserialize)] pub struct ConnectionConf { - /// Number of connections per Cassandra node / Scylla shard. - #[clap( - short('c'), - long("connections"), - default_value = "1", - value_name = "COUNT" - )] - pub count: NonZeroUsize, - - /// List of Cassandra addresses to connect to. + /// List of addresses to connect to. #[clap(name = "addresses", default_value = "localhost")] pub addresses: Vec, - /// Cassandra user name - #[clap(long, env("CASSANDRA_USER"), default_value = "")] - pub user: String, - - /// Password to use if password authentication is required by the server - #[clap(long, env("CASSANDRA_PASSWORD"), default_value = "")] - pub password: String, - - /// Enable SSL - #[clap(long("ssl"))] - pub ssl: bool, - - /// Path to the CA certificate file in PEM format - #[clap(long("ssl-ca"), value_name = "PATH")] - pub ssl_ca_cert_file: Option, - - /// Path to the client SSL certificate file in PEM format - #[clap(long("ssl-cert"), value_name = "PATH")] - pub ssl_cert_file: Option, - - /// Path to the client SSL private key file in PEM format - #[clap(long("ssl-key"), value_name = "PATH")] - pub ssl_key_file: Option, - - /// Verify if the peer's certificate is trusted - #[clap(long("ssl-peer-verification"))] - pub ssl_peer_verification: bool, - - /// Datacenter name - #[clap(long("datacenter"), required = false)] - pub datacenter: Option, - - /// Rack name - #[clap(long("rack"), required = false)] - pub rack: Option, - - /// CQL query consistency level. - /// 'SERIAL' and 'LOCAL_SERIAL' values are compatible only with SELECT statements - /// and make Scylla use Paxos consensus algorithm - #[cfg(feature = "cql")] - #[clap(long("consistency"), required = false, default_value = "LOCAL_QUORUM")] - pub consistency: Consistency, - - /// Serial consistency level for conditional (LWT) queries - #[cfg(feature = "cql")] - #[clap( - long("serial-consistency"), - required = false, - default_value = "LOCAL_SERIAL" - )] - pub serial_consistency: SerialConsistency, - #[clap( long("request-timeout"), default_value = "5s", @@ -263,6 +202,10 @@ pub struct ConnectionConf { default_value = "fail-fast" )] pub validation_strategy: ValidationStrategy, + + #[clap(flatten)] + #[serde(flatten)] + pub db: db_config::DbConnectionConf, } #[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Serialize, Deserialize, ValueEnum)] @@ -273,136 +216,6 @@ pub enum ValidationStrategy { Ignore, // Ignore validation errors - face, print, go on. } -#[cfg(feature = "cql")] -#[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Serialize, Deserialize)] -pub enum Consistency { - Any, - One, - Two, - Three, - Quorum, - All, - LocalOne, - #[default] - LocalQuorum, - EachQuorum, - // NOTE: 'Serial' and 'LocalSerial' values may be used in SELECT statements - // to make them use Paxos consensus algorithm. - Serial, - LocalSerial, -} - -#[cfg(feature = "cql")] -impl Consistency { - pub fn consistency(&self) -> scylla::frame::types::Consistency { - match self { - Self::Any => scylla::frame::types::Consistency::Any, - Self::One => scylla::frame::types::Consistency::One, - Self::Two => scylla::frame::types::Consistency::Two, - Self::Three => scylla::frame::types::Consistency::Three, - Self::Quorum => scylla::frame::types::Consistency::Quorum, - Self::All => scylla::frame::types::Consistency::All, - Self::LocalOne => scylla::frame::types::Consistency::LocalOne, - Self::LocalQuorum => scylla::frame::types::Consistency::LocalQuorum, - Self::EachQuorum => scylla::frame::types::Consistency::EachQuorum, - Self::Serial => scylla::frame::types::Consistency::Serial, - Self::LocalSerial => scylla::frame::types::Consistency::LocalSerial, - } - } -} - -#[cfg(feature = "cql")] -impl ValueEnum for Consistency { - fn value_variants<'a>() -> &'a [Self] { - &[ - Self::Any, - Self::One, - Self::Two, - Self::Three, - Self::Quorum, - Self::All, - Self::LocalOne, - Self::LocalQuorum, - Self::EachQuorum, - Self::Serial, - Self::LocalSerial, - ] - } - - fn from_str(s: &str, _ignore_case: bool) -> Result { - match s.to_lowercase().as_str() { - "any" => Ok(Self::Any), - "one" | "1" => Ok(Self::One), - "two" | "2" => Ok(Self::Two), - "three" | "3" => Ok(Self::Three), - "quorum" | "q" => Ok(Self::Quorum), - "all" => Ok(Self::All), - "local_one" | "localone" | "l1" => Ok(Self::LocalOne), - "local_quorum" | "localquorum" | "lq" => Ok(Self::LocalQuorum), - "each_quorum" | "eachquorum" | "eq" => Ok(Self::EachQuorum), - "serial" | "s" => Ok(Self::Serial), - "local_serial" | "localserial" | "ls" => Ok(Self::LocalSerial), - s => Err(format!("Unknown consistency level {s}")), - } - } - - fn to_possible_value(&self) -> Option { - match self { - Self::Any => Some(PossibleValue::new("ANY")), - Self::One => Some(PossibleValue::new("ONE")), - Self::Two => Some(PossibleValue::new("TWO")), - Self::Three => Some(PossibleValue::new("THREE")), - Self::Quorum => Some(PossibleValue::new("QUORUM")), - Self::All => Some(PossibleValue::new("ALL")), - Self::LocalOne => Some(PossibleValue::new("LOCAL_ONE")), - Self::LocalQuorum => Some(PossibleValue::new("LOCAL_QUORUM")), - Self::EachQuorum => Some(PossibleValue::new("EACH_QUORUM")), - Self::Serial => Some(PossibleValue::new("SERIAL")), - Self::LocalSerial => Some(PossibleValue::new("LOCAL_SERIAL")), - } - } -} - -#[cfg(feature = "cql")] -#[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Serialize, Deserialize)] -pub enum SerialConsistency { - Serial, - #[default] - LocalSerial, -} - -#[cfg(feature = "cql")] -impl SerialConsistency { - pub fn serial_consistency(&self) -> scylla::frame::types::SerialConsistency { - match self { - Self::Serial => scylla::frame::types::SerialConsistency::Serial, - Self::LocalSerial => scylla::frame::types::SerialConsistency::LocalSerial, - } - } -} - -#[cfg(feature = "cql")] -impl ValueEnum for SerialConsistency { - fn value_variants<'a>() -> &'a [Self] { - &[Self::Serial, Self::LocalSerial] - } - - fn from_str(s: &str, _ignore_case: bool) -> Result { - match s.to_lowercase().as_str() { - "serial" | "s" => Ok(Self::Serial), - "local_serial" | "localserial" | "ls" => Ok(Self::LocalSerial), - s => Err(format!("Unknown serial consistency level {s}")), - } - } - - fn to_possible_value(&self) -> Option { - match self { - Self::Serial => Some(PossibleValue::new("SERIAL")), - Self::LocalSerial => Some(PossibleValue::new("LOCAL_SERIAL")), - } - } -} - #[derive(Clone, Debug, Serialize, Deserialize)] pub struct WeightedFunction { pub name: String, @@ -489,7 +302,7 @@ pub struct SchemaCommand { #[clap(name = "workload", required = true, value_name = "PATH")] pub workload: PathBuf, - // Cassandra connection settings. + // Connection settings. #[clap(flatten)] pub connection: ConnectionConf, } @@ -520,7 +333,7 @@ pub struct LoadCommand { #[clap(name = "workload", required = true, value_name = "PATH")] pub workload: PathBuf, - // Cassandra connection settings. + // Connection settings. #[clap(flatten)] pub connection: ConnectionConf, } @@ -642,7 +455,7 @@ pub struct RunCommand { #[clap(short, long)] pub quiet: bool, - // Cassandra connection settings. + // Connection settings. #[clap(flatten)] pub connection: ConnectionConf, diff --git a/src/report/mod.rs b/src/report/mod.rs index f62a5fc..4d0ee14 100644 --- a/src/report/mod.rs +++ b/src/report/mod.rs @@ -484,11 +484,13 @@ impl Display for RunConfigCmp<'_> { self.line("Cluster", "", |conf| { OptionDisplay(conf.cluster_name.clone()) }), + #[cfg(feature = "cql")] self.line("Datacenter", "", |conf| { - conf.connection.datacenter.clone().unwrap_or_default() + conf.connection.db.datacenter.clone().unwrap_or_default() }), + #[cfg(feature = "cql")] self.line("Rack", "", |conf| { - conf.connection.rack.clone().unwrap_or_default() + conf.connection.db.rack.clone().unwrap_or_default() }), self.line("DB version", "", |conf| { OptionDisplay(conf.db_version.clone()) @@ -507,11 +509,12 @@ impl Display for RunConfigCmp<'_> { }), #[cfg(feature = "cql")] self.line("Consistency", "", |conf| { - conf.connection.consistency.consistency().to_string() + conf.connection.db.consistency.consistency().to_string() }), #[cfg(feature = "cql")] self.line("Serial consistency", "", |conf| { conf.connection + .db .serial_consistency .serial_consistency() .to_string() @@ -550,8 +553,9 @@ impl Display for RunConfigCmp<'_> { let lines: Vec> = vec![ self.line("Threads", "", |conf| Quantity::from(conf.threads)), + #[cfg(feature = "cql")] self.line("Connections", "", |conf| { - Quantity::from(conf.connection.count) + Quantity::from(conf.connection.db.count) }), self.line("Concurrency", "req", |conf| { Quantity::from(conf.concurrency) diff --git a/src/scripting/alternator/config.rs b/src/scripting/alternator/config.rs new file mode 100644 index 0000000..b50835c --- /dev/null +++ b/src/scripting/alternator/config.rs @@ -0,0 +1,23 @@ +use clap::Parser; +use serde::{Deserialize, Serialize}; + +#[derive(Parser, Debug, Default, Serialize, Deserialize)] +pub struct DbConnectionConf { + /// Use AWS credentials and region from the environment. + /// If this flag is set, access key ID, secret access key, and region args will be ignored. + #[clap(long("aws-credentials"))] + pub aws_credentials: bool, + + /// Access key ID. + #[clap(long("access-key-id"), default_value = "")] + pub access_key_id: String, + + /// Secret access key. + #[serde(skip_serializing)] // Don't save the secret to generated reports. + #[clap(long("secret-access-key"), default_value = "")] + pub secret_access_key: String, + + /// Region. + #[clap(long("region"), default_value = "us-east-1")] + pub region: String, +} diff --git a/src/scripting/alternator/connect.rs b/src/scripting/alternator/connect.rs index f313af0..63dc8af 100644 --- a/src/scripting/alternator/connect.rs +++ b/src/scripting/alternator/connect.rs @@ -10,14 +10,32 @@ use aws_sdk_dynamodb::Client; pub async fn connect(conf: &ConnectionConf) -> Result { let address = conf.addresses.first().cloned().unwrap_or_default(); - // TODO: use latte parameters for setting the configuration - let config = aws_config::defaults(BehaviorVersion::latest()) + let mut config_loader = aws_config::defaults(BehaviorVersion::latest()) .endpoint_url(&address) - .region(Region::new("us-east-1")) - .credentials_provider(Credentials::new("", "", None, None, "")) .retry_config(RetryConfig::standard().with_max_attempts(1)) - .load() - .await; + .timeout_config( + aws_config::timeout::TimeoutConfig::builder() + .operation_timeout(conf.request_timeout) + .build(), + ); + + // We only specify custom credentials if aws_credentials flag is not set. + // If aws_credentials flag is set, the SDK will automatically use credentials from the environment. + if !conf.db.aws_credentials { + let creds = Credentials::new( + &conf.db.access_key_id, + &conf.db.secret_access_key, + None, + None, + "", + ); + + config_loader = config_loader + .credentials_provider(creds) + .region(Region::new(conf.db.region.clone())); + } + + let config = config_loader.load().await; let client = Client::new(&config); diff --git a/src/scripting/alternator/functions.rs b/src/scripting/alternator/functions.rs index 101fff3..7ab6bf4 100644 --- a/src/scripting/alternator/functions.rs +++ b/src/scripting/alternator/functions.rs @@ -1,16 +1,20 @@ use crate::config::ValidationStrategy; -use crate::scripting::alternator::traits::{AlternatorRequest, IntoAlternatorOutput}; +use crate::scripting::alternator::traits::{ + AlternatorRequest, IntoAlternatorOutput, PaginationToken, +}; use crate::scripting::functions_common::{extract_validation_args, ValidationArgs}; use crate::scripting::retry_error::handle_retry_error; use super::alternator_error::{AlternatorError, AlternatorErrorKind}; use super::context::Context; -use super::types::rune_object_to_alternator_map; +use super::types::{alternator_map_to_rune_object, rune_object_to_alternator_map}; +use super::types::{BSET_KEY, NSET_KEY, SSET_KEY}; use aws_sdk_dynamodb::client::Waiters; use aws_sdk_dynamodb::types::{ - AttributeDefinition, KeySchemaElement, KeyType, ScalarAttributeType, + AttributeDefinition, DeleteRequest, KeySchemaElement, KeyType, KeysAndAttributes, PutRequest, + ScalarAttributeType, WriteRequest, }; -use rune::runtime::{Object, Ref, Shared}; +use rune::runtime::{Object, Ref, Shared, VmResult}; use rune::{ToValue, Value}; use std::cmp::min; use std::collections::HashMap; @@ -71,11 +75,12 @@ fn extract_attribute_names( .collect::>() } -async fn handle_request( +async fn handle_request_with_pagination( ctx: &Context, builder: impl AlternatorRequest, -) -> Result, AlternatorError> { - let mut token: Option> = None; + auto_paginate: bool, +) -> Result<(Vec, Option), AlternatorError> { + let mut token: Option = None; let mut current_attempt_num = 0; let mut all_pages_duration = Duration::ZERO; let mut all_items = Vec::new(); @@ -109,20 +114,28 @@ async fn handle_request( .try_lock() .unwrap() .complete_request(all_pages_duration, total_item_count); - return Ok(all_items); + return Ok((all_items, token)); } } if token.is_some() && builder.has_pagination() { - current_attempt_num = 0; // reset retries for next page - continue; + if auto_paginate { + current_attempt_num = 0; // reset retries for next page + continue; + } else { + ctx.stats + .try_lock() + .unwrap() + .complete_request(all_pages_duration, total_item_count); + return Ok((all_items, token)); + } } ctx.stats .try_lock() .unwrap() .complete_request(all_pages_duration, total_item_count); - return Ok(all_items); + return Ok((all_items, token)); } Err(e) => { let current_error = e; @@ -135,6 +148,13 @@ async fn handle_request( Err(AlternatorError::query_retries_exceeded(ctx.retry_number)) } +async fn handle_request( + ctx: &Context, + builder: impl AlternatorRequest, +) -> Result, AlternatorError> { + Ok(handle_request_with_pagination(ctx, builder, true).await?.0) +} + async fn handle_request_with_validation( ctx: &Context, builder: impl AlternatorRequest, @@ -143,7 +163,7 @@ async fn handle_request_with_validation( ) -> Result, AlternatorError> { let mut current_attempt_num: u64 = 0; loop { - let result = handle_request(ctx, builder.clone()).await?; + let (result, _) = handle_request_with_pagination(ctx, builder.clone(), true).await?; let validation = match validation { None => return Ok(result), @@ -179,6 +199,86 @@ async fn handle_request_with_validation( } } +fn format_batch_result( + items: Vec, + token: Option, + auto_paginate: bool, + with_result: bool, + table_name: &str, +) -> Result { + if !with_result { + return Ok(Value::EmptyTuple); + } + + if auto_paginate { + return Ok(items.to_value().into_result()?); + } + + let mut res_obj = rune::runtime::Object::new(); + + res_obj.insert( + rune::alloc::String::try_from("items")?, + items.to_value().into_result()?, + )?; + + match token { + Some(PaginationToken::UnprocessedKeys(mut u_keys)) => { + let keys: Vec = u_keys + .remove(table_name) + .map(|k| k.keys) + .unwrap_or_default() + .into_iter() + .map(alternator_map_to_rune_object) + .collect::>()?; + + res_obj.insert( + rune::alloc::String::try_from("unprocessed_keys")?, + keys.to_value().into_result()?, + )?; + } + Some(PaginationToken::UnprocessedItems(mut u_items)) => { + let requests: Vec = u_items + .remove(table_name) + .unwrap_or_default() + .into_iter() + .map(|req| { + let mut o = rune::runtime::Object::new(); + + if let Some(put) = req.put_request { + o.insert( + rune::alloc::String::try_from("type")?, + "put".to_value().into_result()?, + )?; + o.insert( + rune::alloc::String::try_from("item")?, + alternator_map_to_rune_object(put.item)?, + )?; + } else if let Some(del) = req.delete_request { + o.insert( + rune::alloc::String::try_from("type")?, + "delete".to_value().into_result()?, + )?; + o.insert( + rune::alloc::String::try_from("key")?, + alternator_map_to_rune_object(del.key)?, + )?; + } + + Ok(Value::Object(Shared::new(o)?)) + }) + .collect::>()?; + + res_obj.insert( + rune::alloc::String::try_from("unprocessed_items")?, + requests.to_value().into_result()?, + )?; + } + _ => {} + } + + Ok(Value::Object(Shared::new(res_obj)?)) +} + /// Creates a new table. /// /// # Arguments @@ -414,6 +514,159 @@ pub async fn update( Ok(()) } +/// Batch retrieves items from the table. +/// +/// If `with_result` is set to true, the retrieved items are returned as a `Vec`. +/// Otherwise, the unit value is returned. +/// +/// # Arguments +/// * `table_name` - The name of the table. +/// * `keys` - A list of items, where each item is an object representing a primary key. +/// * `options` - Optional parameters. An object containing: +/// - `consistent_read`: Boolean to enable consistent read for all keys (default: false). +/// - `with_result`: If true, the retrieved items are returned (default: false). +/// - `get_unprocessed`: If true, disables auto-pagination. When `with_result: true` returns an object with `items` and `unprocessed_keys`. +#[rune::function(instance)] +pub async fn batch_get_item( + ctx: Ref, + table_name: Ref, + keys: Ref, + options: Value, +) -> Result { + let client = ctx.get_client()?; + + // Convert keys vector to DynamoDB keys + let keys_list = keys + .iter() + .map(|key_val| match key_val { + Value::Object(key_obj) => rune_object_to_alternator_map(key_obj.clone().into_ref()?), + _ => bad_input("Each key in the keys list must be an object"), + }) + .collect::>()?; + + let mut with_result = false; + let mut get_unprocessed = false; + + // BatchGetItem requires the keys to be wrapped in a KeysAndAttributes struct + let mut keys_request_builder = KeysAndAttributes::builder().set_keys(Some(keys_list)); + if let Value::Object(opts) = &options { + let opts_ref = opts.borrow_ref()?; + if let Some(Value::Bool(consistent_read)) = opts_ref.get("consistent_read") { + keys_request_builder = keys_request_builder.consistent_read(*consistent_read); + } + if let Some(Value::Bool(w)) = opts_ref.get("with_result") { + with_result = *w; + } + if let Some(Value::Bool(u)) = opts_ref.get("get_unprocessed") { + get_unprocessed = *u; + } + } + + let builder = client + .batch_get_item() + .request_items(table_name.deref(), keys_request_builder.build()?); + + let (result_items, token) = + handle_request_with_pagination(&ctx, builder, !get_unprocessed).await?; + + format_batch_result( + result_items, + token, + !get_unprocessed, + with_result, + table_name.deref(), + ) +} + +/// Batch writes items to the table. +/// +/// # Arguments +/// * `table_name` - The name of the table. +/// * `write_requests` - A list of write requests. Each request is an object containing: +/// - `type`: Either "put" or "delete". +/// - `item`: For put requests, the item object to insert. +/// - `key`: For delete requests, the key object to delete. +/// * `options` - Optional parameters. An object containing: +/// - `get_unprocessed`: If true, disables auto-pagination. Returns an object with `unprocessed_items`. +#[rune::function(instance)] +pub async fn batch_write_item( + ctx: Ref, + table_name: Ref, + write_requests: Ref, + options: Value, +) -> Result { + let client = ctx.get_client()?; + + let writes = write_requests + .iter() + .map(|req_val| { + let Value::Object(req_obj) = req_val else { + return bad_input("Each write request must be an object"); + }; + + let req_ref = req_obj.borrow_ref()?; + + let req_type = match req_ref.get("type") { + Some(Value::String(t)) => t.borrow_ref()?.to_string(), + _ => return bad_input("Write request must have a 'type' field (put or delete)"), + }; + + match req_type.as_str() { + "put" => { + let Some(Value::Object(item)) = req_ref.get("item") else { + return bad_input("Put request must have an 'item' field"); + }; + + let item_map = rune_object_to_alternator_map(item.clone().into_ref()?)?; + + Ok(WriteRequest::builder() + .put_request(PutRequest::builder().set_item(Some(item_map)).build()?) + .build()) + } + "delete" => { + let Some(Value::Object(key)) = req_ref.get("key") else { + return bad_input("Delete request must have a 'key' field"); + }; + + let key_map = rune_object_to_alternator_map(key.clone().into_ref()?)?; + + Ok(WriteRequest::builder() + .delete_request(DeleteRequest::builder().set_key(Some(key_map)).build()?) + .build()) + } + _ => bad_input(format!( + "Invalid request type: {}, must be 'put' or 'delete'", + req_type + )), + } + }) + .collect::>()?; + + let mut get_unprocessed = false; + + if let Value::Object(opts) = &options { + let opts_ref = opts.borrow_ref()?; + if let Some(Value::Bool(x)) = opts_ref.get("get_unprocessed") { + get_unprocessed = *x; + } + } + + let builder = client + .batch_write_item() + .request_items(table_name.deref(), writes); + + let (result_items, token) = + handle_request_with_pagination(&ctx, builder, !get_unprocessed).await?; + + format_batch_result( + result_items, + token, + !get_unprocessed, + get_unprocessed, + table_name.deref(), + ) +} + /// Queries items from the table. /// /// Unlike `get`, which retrieves a single item by its exact primary key, @@ -568,3 +821,30 @@ pub async fn scan( Ok(Value::EmptyTuple) } + +/// Marks a list of items as an Alternator string set. +#[rune::function] +pub fn sset(items: Vec) -> VmResult { + let mut map = HashMap::new(); + let items_val = rune::vm_try!(items.to_value()); + map.insert(SSET_KEY.to_string(), items_val); + map.to_value() +} + +/// Marks a list of items as an Alternator number set. +#[rune::function] +pub fn nset(items: Vec) -> VmResult { + let mut map = HashMap::new(); + let items_val = rune::vm_try!(items.to_value()); + map.insert(NSET_KEY.to_string(), items_val); + map.to_value() +} + +/// Marks a list of items as an Alternator binary set. +#[rune::function] +pub fn bset(items: Vec) -> VmResult { + let mut map = HashMap::new(); + let items_val = rune::vm_try!(items.to_value()); + map.insert(BSET_KEY.to_string(), items_val); + map.to_value() +} diff --git a/src/scripting/alternator/mod.rs b/src/scripting/alternator/mod.rs index 210f778..ae5163e 100644 --- a/src/scripting/alternator/mod.rs +++ b/src/scripting/alternator/mod.rs @@ -1,4 +1,5 @@ pub mod alternator_error; +pub mod config; pub mod connect; pub mod context; pub mod functions; diff --git a/src/scripting/alternator/traits.rs b/src/scripting/alternator/traits.rs index 4be5106..0c18c7c 100644 --- a/src/scripting/alternator/traits.rs +++ b/src/scripting/alternator/traits.rs @@ -2,17 +2,25 @@ use super::alternator_error::AlternatorError; use super::types::alternator_map_to_rune_object; use aws_sdk_dynamodb::error::{ProvideErrorMetadata, SdkError}; use aws_sdk_dynamodb::operation::{ + batch_get_item::BatchGetItemOutput, batch_write_item::BatchWriteItemOutput, create_table::CreateTableOutput, delete_item::DeleteItemOutput, delete_table::DeleteTableOutput, get_item::GetItemOutput, put_item::PutItemOutput, query::QueryOutput, scan::ScanOutput, update_item::UpdateItemOutput, }; -use aws_sdk_dynamodb::types::AttributeValue; +use aws_sdk_dynamodb::types::{AttributeValue, KeysAndAttributes, WriteRequest}; use rune::Value; use std::collections::HashMap; use std::future::Future; +#[derive(Clone)] +pub(super) enum PaginationToken { + LastEvaluatedKey(HashMap), + UnprocessedKeys(HashMap), + UnprocessedItems(HashMap>), +} + pub(super) type AlternatorOutputResult = - Result<(Vec, u64, Option>), AlternatorError>; + Result<(Vec, u64, Option), AlternatorError>; pub(super) trait IntoAlternatorOutput { fn into_output(self) -> AlternatorOutputResult; @@ -36,7 +44,12 @@ impl IntoAlternatorOutput for QueryOutput { result.push(alternator_map_to_rune_object(item)?); } let len = result.len() as u64; - Ok((result, len, self.last_evaluated_key)) + Ok(( + result, + len, + self.last_evaluated_key + .map(PaginationToken::LastEvaluatedKey), + )) } } @@ -48,7 +61,44 @@ impl IntoAlternatorOutput for ScanOutput { result.push(alternator_map_to_rune_object(item)?); } let len = result.len() as u64; - Ok((result, len, self.last_evaluated_key)) + Ok(( + result, + len, + self.last_evaluated_key + .map(PaginationToken::LastEvaluatedKey), + )) + } +} + +impl IntoAlternatorOutput for BatchGetItemOutput { + fn into_output(self) -> AlternatorOutputResult { + let responses = self.responses.unwrap_or_default(); + + let result = responses + .into_values() + .flatten() + .map(alternator_map_to_rune_object) + .collect::, _>>()?; + + let len = result.len() as u64; + + let token = self + .unprocessed_keys + .filter(|keys| !keys.is_empty()) + .map(PaginationToken::UnprocessedKeys); + + Ok((result, len, token)) + } +} + +impl IntoAlternatorOutput for BatchWriteItemOutput { + fn into_output(self) -> AlternatorOutputResult { + let token = self + .unprocessed_items + .filter(|keys| !keys.is_empty()) + .map(PaginationToken::UnprocessedItems); + + Ok((vec![], 0, token)) } } @@ -94,11 +144,7 @@ pub(super) trait SendRequest { } pub(super) trait AlternatorRequest: SendRequest + Clone { - fn set_pagination( - self, - token: Option>, - limit: Option, - ) -> Self; + fn set_pagination(self, token: Option, limit: Option) -> Self; fn has_pagination(&self) -> bool; fn get_limit_val(&self) -> Option; } @@ -124,7 +170,7 @@ macro_rules! impl_alternator_request_no_pagination { $( impl_send_request!($t); impl AlternatorRequest for $t { - fn set_pagination(self, _: Option>, _: Option) -> Self { self } + fn set_pagination(self, _: Option, _: Option) -> Self { self } fn has_pagination(&self) -> bool { false } fn get_limit_val(&self) -> Option { None } } @@ -141,16 +187,19 @@ impl_alternator_request_no_pagination!( aws_sdk_dynamodb::operation::update_item::builders::UpdateItemFluentBuilder ); -impl_send_request!(aws_sdk_dynamodb::operation::query::builders::QueryFluentBuilder); -impl_send_request!(aws_sdk_dynamodb::operation::scan::builders::ScanFluentBuilder); +impl_send_request!( + aws_sdk_dynamodb::operation::query::builders::QueryFluentBuilder, + aws_sdk_dynamodb::operation::scan::builders::ScanFluentBuilder, + aws_sdk_dynamodb::operation::batch_get_item::builders::BatchGetItemFluentBuilder, + aws_sdk_dynamodb::operation::batch_write_item::builders::BatchWriteItemFluentBuilder +); impl AlternatorRequest for aws_sdk_dynamodb::operation::query::builders::QueryFluentBuilder { - fn set_pagination( - self, - token: Option>, - limit: Option, - ) -> Self { - let mut b = self.set_exclusive_start_key(token); + fn set_pagination(self, token: Option, limit: Option) -> Self { + let mut b = self.set_exclusive_start_key(match token { + Some(PaginationToken::LastEvaluatedKey(key)) => Some(key), + _ => None, + }); if let Some(limit) = limit { b = b.limit(limit); } @@ -165,12 +214,11 @@ impl AlternatorRequest for aws_sdk_dynamodb::operation::query::builders::QueryFl } impl AlternatorRequest for aws_sdk_dynamodb::operation::scan::builders::ScanFluentBuilder { - fn set_pagination( - self, - token: Option>, - limit: Option, - ) -> Self { - let mut b = self.set_exclusive_start_key(token); + fn set_pagination(self, token: Option, limit: Option) -> Self { + let mut b = self.set_exclusive_start_key(match token { + Some(PaginationToken::LastEvaluatedKey(key)) => Some(key), + _ => None, + }); if let Some(limit) = limit { b = b.limit(limit); } @@ -183,3 +231,39 @@ impl AlternatorRequest for aws_sdk_dynamodb::operation::scan::builders::ScanFlue *self.get_limit() } } + +impl AlternatorRequest + for aws_sdk_dynamodb::operation::batch_get_item::builders::BatchGetItemFluentBuilder +{ + fn set_pagination(self, token: Option, _limit: Option) -> Self { + if let Some(PaginationToken::UnprocessedKeys(keys)) = token { + self.set_request_items(Some(keys)) + } else { + self + } + } + fn has_pagination(&self) -> bool { + true + } + fn get_limit_val(&self) -> Option { + None + } +} + +impl AlternatorRequest + for aws_sdk_dynamodb::operation::batch_write_item::builders::BatchWriteItemFluentBuilder +{ + fn set_pagination(self, token: Option, _limit: Option) -> Self { + if let Some(PaginationToken::UnprocessedItems(items)) = token { + self.set_request_items(Some(items)) + } else { + self + } + } + fn has_pagination(&self) -> bool { + true + } + fn get_limit_val(&self) -> Option { + None + } +} diff --git a/src/scripting/alternator/types.rs b/src/scripting/alternator/types.rs index 1a2d5d1..85c169b 100644 --- a/src/scripting/alternator/types.rs +++ b/src/scripting/alternator/types.rs @@ -4,6 +4,56 @@ use rune::runtime::{Bytes, Object, Ref}; use rune::{ToValue, Value}; use std::collections::HashMap; +pub const SSET_KEY: &str = "__sset"; +pub const NSET_KEY: &str = "__nset"; +pub const BSET_KEY: &str = "__bset"; + +fn alternator_set_to_rune(key: &str, iter: I, wrapper: F) -> Result +where + I: IntoIterator, + F: Fn(T) -> AttributeValue, +{ + let items = iter + .into_iter() + .map(|item| alternator_attribute_to_rune_value(wrapper(item))) + .collect::, AlternatorError>>()?; + + let mut map = HashMap::new(); + map.insert(key.to_string(), items.to_value().into_result()?); + Ok(map.to_value().into_result()?) +} + +fn rune_set_to_alternator( + v: Value, + key: &str, + attribute_constructor: W, + rune_unwrapper: U, +) -> Result +where + W: Fn(Vec) -> AttributeValue, + U: Fn(AttributeValue) -> Option, +{ + if let Value::Vec(vec) = v { + let items = vec + .into_ref()? + .iter() + .map(|item| { + rune_unwrapper(rune_value_to_alternator_attribute(item.clone())?).ok_or_else(|| { + AlternatorError::new(AlternatorErrorKind::ConversionError(format!( + "Invalid element type found in set {}: {:?}", + key, item + ))) + }) + }) + .collect::, _>>()?; + Ok(attribute_constructor(items)) + } else { + Err(AlternatorError::new(AlternatorErrorKind::ConversionError( + format!("Expected a vector of elements for {}", key), + ))) + } +} + pub fn rune_value_to_alternator_attribute(v: Value) -> Result { match v { Value::Bool(b) => Ok(AttributeValue::Bool(b)), @@ -24,9 +74,50 @@ pub fn rune_value_to_alternator_attribute(v: Value) -> Result>()?, )), - Value::Object(o) => Ok(AttributeValue::M(rune_object_to_alternator_map( - o.into_ref()?, - )?)), + Value::Object(o) => { + let obj = o.into_ref()?; + + // Check for special Set representations. + // They have to be objects with exactly one key with special name, and the value has to be a vector of appropriate types. + if obj.len() == 1 { + let mut iter = obj.iter(); + let (k, val) = iter.next().unwrap(); + + match k.as_str() { + SSET_KEY => { + return rune_set_to_alternator(val.clone(), k, AttributeValue::Ss, |a| { + if let AttributeValue::S(s) = a { + Some(s) + } else { + None + } + }); + } + NSET_KEY => { + return rune_set_to_alternator(val.clone(), k, AttributeValue::Ns, |a| { + if let AttributeValue::N(n) = a { + Some(n) + } else { + None + } + }); + } + BSET_KEY => { + return rune_set_to_alternator(val.clone(), k, AttributeValue::Bs, |a| { + if let AttributeValue::B(b) = a { + Some(b) + } else { + None + } + }); + } + // Does not match any of the special set keys, so we treat it as a regular object. + _ => {} + } + } + + Ok(AttributeValue::M(rune_object_to_alternator_map(obj)?)) + } Value::Option(o) => match o.into_ref()?.as_ref() { Some(v) => rune_value_to_alternator_attribute(v.clone()), @@ -84,6 +175,10 @@ pub fn alternator_attribute_to_rune_value(attr: AttributeValue) -> Result Ok(None::.to_value().into_result()?), + AttributeValue::Ss(ss) => alternator_set_to_rune(SSET_KEY, ss, AttributeValue::S), + AttributeValue::Ns(ns) => alternator_set_to_rune(NSET_KEY, ns, AttributeValue::N), + AttributeValue::Bs(bs) => alternator_set_to_rune(BSET_KEY, bs, AttributeValue::B), + _ => Err(AlternatorError::new(AlternatorErrorKind::ConversionError( format!("Unsupported Alternator AttributeValue type: {:?}", attr), ))), diff --git a/src/scripting/cql/config.rs b/src/scripting/cql/config.rs new file mode 100644 index 0000000..64858fc --- /dev/null +++ b/src/scripting/cql/config.rs @@ -0,0 +1,192 @@ +use clap::builder::PossibleValue; +use clap::{Parser, ValueEnum}; +use serde::{Deserialize, Serialize}; +use std::num::NonZeroUsize; +use std::path::PathBuf; + +#[derive(Parser, Debug, Serialize, Deserialize)] +pub struct DbConnectionConf { + /// Number of connections per Cassandra node / Scylla shard. + #[clap( + short('c'), + long("connections"), + default_value = "1", + value_name = "COUNT" + )] + pub count: NonZeroUsize, + + /// Cassandra user name + #[clap(long, env("CASSANDRA_USER"), default_value = "")] + pub user: String, + + /// Password to use if password authentication is required by the server + #[serde(skip_serializing)] // Don't save the password to generated reports. + #[clap(long, env("CASSANDRA_PASSWORD"), default_value = "")] + pub password: String, + + /// Enable SSL + #[clap(long("ssl"))] + pub ssl: bool, + + /// Path to the CA certificate file in PEM format + #[clap(long("ssl-ca"), value_name = "PATH")] + pub ssl_ca_cert_file: Option, + + /// Path to the client SSL certificate file in PEM format + #[clap(long("ssl-cert"), value_name = "PATH")] + pub ssl_cert_file: Option, + + /// Path to the client SSL private key file in PEM format + #[clap(long("ssl-key"), value_name = "PATH")] + pub ssl_key_file: Option, + + /// Verify if the peer's certificate is trusted + #[clap(long("ssl-peer-verification"))] + pub ssl_peer_verification: bool, + + /// Datacenter name + #[clap(long("datacenter"), required = false)] + pub datacenter: Option, + + /// Rack name + #[clap(long("rack"), required = false)] + pub rack: Option, + + /// CQL query consistency level. + /// 'SERIAL' and 'LOCAL_SERIAL' values are compatible only with SELECT statements + /// and make Scylla use Paxos consensus algorithm + #[clap(long("consistency"), required = false, default_value = "LOCAL_QUORUM")] + pub consistency: Consistency, + + /// Serial consistency level for conditional (LWT) queries + #[clap( + long("serial-consistency"), + required = false, + default_value = "LOCAL_SERIAL" + )] + pub serial_consistency: SerialConsistency, +} + +#[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub enum Consistency { + Any, + One, + Two, + Three, + Quorum, + All, + LocalOne, + #[default] + LocalQuorum, + EachQuorum, + // NOTE: 'Serial' and 'LocalSerial' values may be used in SELECT statements + // to make them use Paxos consensus algorithm. + Serial, + LocalSerial, +} + +impl Consistency { + pub fn consistency(&self) -> scylla::frame::types::Consistency { + match self { + Self::Any => scylla::frame::types::Consistency::Any, + Self::One => scylla::frame::types::Consistency::One, + Self::Two => scylla::frame::types::Consistency::Two, + Self::Three => scylla::frame::types::Consistency::Three, + Self::Quorum => scylla::frame::types::Consistency::Quorum, + Self::All => scylla::frame::types::Consistency::All, + Self::LocalOne => scylla::frame::types::Consistency::LocalOne, + Self::LocalQuorum => scylla::frame::types::Consistency::LocalQuorum, + Self::EachQuorum => scylla::frame::types::Consistency::EachQuorum, + Self::Serial => scylla::frame::types::Consistency::Serial, + Self::LocalSerial => scylla::frame::types::Consistency::LocalSerial, + } + } +} + +impl ValueEnum for Consistency { + fn value_variants<'a>() -> &'a [Self] { + &[ + Self::Any, + Self::One, + Self::Two, + Self::Three, + Self::Quorum, + Self::All, + Self::LocalOne, + Self::LocalQuorum, + Self::EachQuorum, + Self::Serial, + Self::LocalSerial, + ] + } + + fn from_str(s: &str, _ignore_case: bool) -> Result { + match s.to_lowercase().as_str() { + "any" => Ok(Self::Any), + "one" | "1" => Ok(Self::One), + "two" | "2" => Ok(Self::Two), + "three" | "3" => Ok(Self::Three), + "quorum" | "q" => Ok(Self::Quorum), + "all" => Ok(Self::All), + "local_one" | "localone" | "l1" => Ok(Self::LocalOne), + "local_quorum" | "localquorum" | "lq" => Ok(Self::LocalQuorum), + "each_quorum" | "eachquorum" | "eq" => Ok(Self::EachQuorum), + "serial" | "s" => Ok(Self::Serial), + "local_serial" | "localserial" | "ls" => Ok(Self::LocalSerial), + s => Err(format!("Unknown consistency level {s}")), + } + } + + fn to_possible_value(&self) -> Option { + match self { + Self::Any => Some(PossibleValue::new("ANY")), + Self::One => Some(PossibleValue::new("ONE")), + Self::Two => Some(PossibleValue::new("TWO")), + Self::Three => Some(PossibleValue::new("THREE")), + Self::Quorum => Some(PossibleValue::new("QUORUM")), + Self::All => Some(PossibleValue::new("ALL")), + Self::LocalOne => Some(PossibleValue::new("LOCAL_ONE")), + Self::LocalQuorum => Some(PossibleValue::new("LOCAL_QUORUM")), + Self::EachQuorum => Some(PossibleValue::new("EACH_QUORUM")), + Self::Serial => Some(PossibleValue::new("SERIAL")), + Self::LocalSerial => Some(PossibleValue::new("LOCAL_SERIAL")), + } + } +} + +#[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub enum SerialConsistency { + Serial, + #[default] + LocalSerial, +} + +impl SerialConsistency { + pub fn serial_consistency(&self) -> scylla::frame::types::SerialConsistency { + match self { + Self::Serial => scylla::frame::types::SerialConsistency::Serial, + Self::LocalSerial => scylla::frame::types::SerialConsistency::LocalSerial, + } + } +} + +impl ValueEnum for SerialConsistency { + fn value_variants<'a>() -> &'a [Self] { + &[Self::Serial, Self::LocalSerial] + } + + fn from_str(s: &str, _ignore_case: bool) -> Result { + match s.to_lowercase().as_str() { + "serial" | "s" => Ok(Self::Serial), + "local_serial" | "localserial" | "ls" => Ok(Self::LocalSerial), + s => Err(format!("Unknown serial consistency level {s}")), + } + } + + fn to_possible_value(&self) -> Option { + match self { + Self::Serial => Some(PossibleValue::new("SERIAL")), + Self::LocalSerial => Some(PossibleValue::new("LOCAL_SERIAL")), + } + } +} diff --git a/src/scripting/cql/connect.rs b/src/scripting/cql/connect.rs index a0abaa0..7358c6a 100644 --- a/src/scripting/cql/connect.rs +++ b/src/scripting/cql/connect.rs @@ -10,18 +10,18 @@ use scylla::client::execution_profile::ExecutionProfile; use scylla::client::session_builder::SessionBuilder; fn tls_context(conf: &&ConnectionConf) -> Result, Box> { - if conf.ssl { + if conf.db.ssl { let mut ssl = SslContextBuilder::new(SslMethod::tls())?; - if let Some(path) = &conf.ssl_ca_cert_file { + if let Some(path) = &conf.db.ssl_ca_cert_file { ssl.set_ca_file(path)?; } - if let Some(path) = &conf.ssl_cert_file { + if let Some(path) = &conf.db.ssl_cert_file { ssl.set_certificate_file(path, SslFiletype::PEM)?; } - if let Some(path) = &conf.ssl_key_file { + if let Some(path) = &conf.db.ssl_key_file { ssl.set_private_key_file(path, SslFiletype::PEM)?; } - if conf.ssl_peer_verification { + if conf.db.ssl_peer_verification { ssl.set_verify(SslVerifyMode::PEER); } Ok(Some(TlsContext::from(ssl.build()))) @@ -35,8 +35,8 @@ pub async fn connect(conf: &ConnectionConf) -> Result { let mut policy_builder = DefaultPolicy::builder().token_aware(true); let mut datacenter: String = "".to_string(); let mut rack: String = "".to_string(); - if let Some(dc) = &conf.datacenter { - if let Some(current_rack) = &conf.rack { + if let Some(dc) = &conf.db.datacenter { + if let Some(current_rack) = &conf.db.rack { policy_builder = policy_builder .prefer_datacenter_and_rack(dc.to_owned(), current_rack.to_owned()) .permit_dc_failover(true); @@ -47,20 +47,20 @@ pub async fn connect(conf: &ConnectionConf) -> Result { .permit_dc_failover(true); } datacenter = dc.clone(); - } else if let Some(_rack) = &conf.rack { + } else if let Some(_rack) = &conf.db.rack { panic!("Datacenter must also be defined when rack is defined"); } let profile = ExecutionProfile::builder() - .consistency(conf.consistency.consistency()) - .serial_consistency(Some(conf.serial_consistency.serial_consistency())) + .consistency(conf.db.consistency.consistency()) + .serial_consistency(Some(conf.db.serial_consistency.serial_consistency())) .load_balancing_policy(policy_builder.build()) .request_timeout(Some(conf.request_timeout)) .build(); let scylla_session = SessionBuilder::new() .known_nodes(&conf.addresses) - .pool_size(PoolSize::PerShard(conf.count)) - .user(&conf.user, &conf.password) + .pool_size(PoolSize::PerShard(conf.db.count)) + .user(&conf.db.user, &conf.db.password) .tls_context(tls_context(&conf)?) .default_execution_profile_handle(profile.into_handle()) .build() diff --git a/src/scripting/cql/mod.rs b/src/scripting/cql/mod.rs index 02494f9..f820b19 100644 --- a/src/scripting/cql/mod.rs +++ b/src/scripting/cql/mod.rs @@ -1,5 +1,6 @@ mod bind; pub mod cass_error; +pub mod config; pub mod connect; pub mod context; pub mod cql_types; diff --git a/src/scripting/mod.rs b/src/scripting/mod.rs index b4e6289..913fe5d 100644 --- a/src/scripting/mod.rs +++ b/src/scripting/mod.rs @@ -17,6 +17,8 @@ mod cql; #[cfg(feature = "cql")] pub use cql::cass_error as db_error; #[cfg(feature = "cql")] +pub use cql::config as db_config; +#[cfg(feature = "cql")] pub use cql::connect; #[cfg(feature = "cql")] pub use cql::context; @@ -24,6 +26,8 @@ pub use cql::context; #[cfg(feature = "alternator")] pub use alternator::alternator_error as db_error; #[cfg(feature = "alternator")] +pub use alternator::config as db_config; +#[cfg(feature = "alternator")] pub use alternator::connect; #[cfg(feature = "alternator")] pub use alternator::context; @@ -101,15 +105,21 @@ fn try_install( context_module.function_meta(functions::get)?; context_module.function_meta(functions::delete)?; context_module.function_meta(functions::update)?; + context_module.function_meta(functions::batch_get_item)?; + context_module.function_meta(functions::batch_write_item)?; context_module.function_meta(functions::query)?; context_module.function_meta(functions::scan)?; let err_module = init_error_module()?; let uuid_module = init_uuid_module()?; - let latte_module = init_latte_module(params)?; + let mut latte_module = init_latte_module(params)?; let mut fs_module = init_fs_module()?; let iter_module = init_iter_module(&mut fs_module)?; + latte_module.function_meta(functions::sset)?; + latte_module.function_meta(functions::nset)?; + latte_module.function_meta(functions::bset)?; + rune_ctx.install(&context_module)?; rune_ctx.install(&err_module)?; rune_ctx.install(&uuid_module)?; diff --git a/workloads/alternator/batch_operations.rn b/workloads/alternator/batch_operations.rn new file mode 100644 index 0000000..7e1cfea --- /dev/null +++ b/workloads/alternator/batch_operations.rn @@ -0,0 +1,68 @@ +use latte::*; + +// Usage: +// latte schema workloads/alternator/batch_operations.rn http://172.17.0.2:8000 +// latte run -d 1 workloads/alternator/batch_operations.rn http://172.17.0.2:8000 + +const TABLE = "batch_ops_table"; + +pub async fn schema(db) { + db.delete_table(TABLE).await; + db.create_table(TABLE, "pk").await?; +} + +pub async fn run(db, i) { + let batch_size = 5; + let base_id = "user_" + i.to_string() + "_"; + + let write_requests = []; + for j in 0..batch_size { + write_requests.push(#{ + type: "put", + item: #{ + pk: base_id + j.to_string(), + data: "batch_item_" + j.to_string() + } + }); + } + + db.batch_write_item(TABLE, write_requests, ()).await?; + + let keys = []; + for j in 0..batch_size { + keys.push(#{ + pk: base_id + j.to_string(), + }); + } + + db.batch_get_item(TABLE, keys, None).await?; + + // Batch get with result retrieval + let result = db.batch_get_item(TABLE, keys, #{ + consistent_read: true, + with_result: true + }).await?; + + // Check that we got the expected number of items back + assert!(result.len() == batch_size); + + let delete_requests = []; + for j in 0..batch_size { + delete_requests.push(#{ + type: "delete", + key: #{ + pk: base_id + j.to_string(), + } + }); + } + + db.batch_write_item(TABLE, delete_requests, ()).await?; + + let result_after_delete = db.batch_get_item(TABLE, keys, #{ + consistent_read: true, + with_result: true + }).await?; + + // Make sure all items were deleted + assert!(result_after_delete.len() == 0); +} diff --git a/workloads/alternator/manual_batch_operations.rn b/workloads/alternator/manual_batch_operations.rn new file mode 100644 index 0000000..8cc72bc --- /dev/null +++ b/workloads/alternator/manual_batch_operations.rn @@ -0,0 +1,39 @@ +use latte::*; + +// Usage: +// latte schema workloads/alternator/manual_batch_operations.rn http://172.17.0.2:8000 +// latte run -d 1 workloads/alternator/manual_batch_operations.rn http://172.17.0.2:8000 + +const TABLE = "batch_ops_table"; + +pub async fn schema(db) { + db.delete_table(TABLE).await; + db.create_table(TABLE, "pk").await?; +} + +pub async fn run(db, i) { + let batch_size = 5; + let base_id = "user_" + i.to_string() + "_"; + + let write_requests = []; + for j in 0..batch_size { + write_requests.push(#{ + type: "put", + item: #{ + pk: base_id + j.to_string(), + data: "batch_item_" + j.to_string() + } + }); + } + + // Manual pagination loop for batch writes + let current_writes = write_requests; + while current_writes.len() > 0 { + let res = db.batch_write_item(TABLE, current_writes, #{ get_unprocessed: true }).await?; + if let Some(unprocessed) = res.get("unprocessed_items") { + current_writes = unprocessed; + } else { + current_writes = []; // All items processed + } + } +} diff --git a/workloads/alternator/type_validation.rn b/workloads/alternator/type_validation.rn index 986747f..edce2e6 100644 --- a/workloads/alternator/type_validation.rn +++ b/workloads/alternator/type_validation.rn @@ -32,6 +32,14 @@ pub async fn run(db, i) { // Binary "AttributeValue::B" test_bytes: b"binary data", + // Sets (AttributeValue::Ss, AttributeValue::Ns, AttributeValue::Bs) + test_string_set: sset(["a", "b", "c"]), + test_number_set: nset([1, 2, 3]), + test_bytes_set: bset([b"bar", b"foo"]), + manual_string_set: #{ + __sset: ["x", "y", "z"] + }, + // List "AttributeValue::L" test_list: [ 1, @@ -83,6 +91,18 @@ pub async fn run(db, i) { assert!(result["test_bytes"] == b"binary data"); + // Validate set contents ignoring order + result["test_string_set"]["__sset"].sort(); + result["test_number_set"]["__nset"].sort(); + result["test_bytes_set"]["__bset"].sort(); + result["manual_string_set"]["__sset"].sort(); + assert!(result["test_string_set"] == sset(["a", "b", "c"])); + assert!(result["test_number_set"] == nset([1, 2, 3])); + assert!(result["test_bytes_set"] == bset([b"bar", b"foo"])); + assert!(result["manual_string_set"] == #{ + __sset: ["x", "y", "z"] + }); + assert!(result["test_list"] == [ 1, "two",