diff --git a/Cargo.lock b/Cargo.lock index 7e4712b..08ab3e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1533,6 +1533,21 @@ dependencies = [ "slab", ] +[[package]] +name = "generator" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52f04ae4152da20c76fe800fa48659201d5cf627c5149ca0b707b69d7eef6cf9" +dependencies = [ + "cc", + "cfg-if", + "libc", + "log", + "rustversion", + "windows-link", + "windows-result", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -1856,7 +1871,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.5.10", + "socket2 0.6.3", "system-configuration", "tokio", "tower-service", @@ -2299,6 +2314,19 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + [[package]] name = "lru-slab" version = "0.1.2" @@ -2311,7 +2339,7 @@ version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "373f5eceeeab7925e0c1098212f2fbc4d416adec9d35051a6ab251e824c1854a" dependencies = [ - "twox-hash 2.1.2", + "twox-hash", ] [[package]] @@ -2382,43 +2410,36 @@ checksum = "1fafa6961cabd9c63bcd77a45d7e3b7f3b552b70417831fb0f56db717e72407e" [[package]] name = "musli" -version = "0.0.42" +version = "0.0.124" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c21124dd24833900879114414b877f2136f4b7b7a3b49756ecc5c36eca332bb" +checksum = "8b310b280353d9e1c92861820321f8742b02666acaf984a29cd8946965444384" dependencies = [ - "musli-macros", + "loom", + "musli-core", + "serde", + "simdutf8", ] [[package]] -name = "musli-common" -version = "0.0.42" +name = "musli-core" +version = "0.0.124" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "178446623aa62978aa0f894b2081bc11ea77c2119ccfe35be428ab9ddb495dfc" +checksum = "d00e227a374e92550ce2eb5002ae116e02a43926d7243c95997138406ae4e157" dependencies = [ - "musli", + "musli-macros", ] [[package]] name = "musli-macros" -version = "0.0.42" +version = "0.0.124" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f1ab0e4ac2721bc4fa3528a6a2640c1c30c36c820f8c85159252fbf6c2fac24" +checksum = "e7427c9aa85c882cd4dbe712d2fcdc511db05d595f7787e6747c90cd7d67efc4" dependencies = [ "proc-macro2", "quote", "syn 2.0.117", ] -[[package]] -name = "musli-storage" -version = "0.0.42" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2fc1f80b166f611c462e1344220e9b3a9ad37c885e43039d5d2e6887445937c" -dependencies = [ - "musli", - "musli-common", -] - [[package]] name = "nom" version = "7.1.3" @@ -2843,7 +2864,7 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls", - "socket2 0.5.10", + "socket2 0.6.3", "thiserror 2.0.18", "tokio", "tracing", @@ -2881,7 +2902,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.5.10", + "socket2 0.6.3", "tracing", "windows-sys 0.59.0", ] @@ -3252,33 +3273,35 @@ dependencies = [ [[package]] name = "rune" -version = "0.13.4" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d21925ac4f8974395d0d9e43f96a34c778e71ed86fe96d0313b2211102537234" +checksum = "b7b9174512d64882469ea9b12876305680154a541be448dcdc56f58acacbc3e0" dependencies = [ "anyhow", "codespan-reporting", "futures-core", "futures-util", "itoa", + "memchr", "musli", - "musli-storage", "num", "once_cell", "pin-project", "rune-alloc", "rune-core", "rune-macros", + "rune-tracing", "ryu", "serde", - "tracing", + "syntree", + "unicode-ident", ] [[package]] name = "rune-alloc" -version = "0.13.4" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e85c26e19f7efb91c6e19afc68b008f04685fdb2852e96ce8fbd3cf4a0b4e76c" +checksum = "f12484a608c8907b9a4590f11b669263e960d1fd40f3d3c992c6f15eec931ae9" dependencies = [ "ahash 0.8.12", "pin-project", @@ -3288,9 +3311,9 @@ dependencies = [ [[package]] name = "rune-alloc-macros" -version = "0.13.4" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "810588952a8710959d35ad17c933804d60f96c3792f216277cda68c1a9887120" +checksum = "382b14f6d8e65e9cfec789e85125f3e1d758b2756705739e39ccf06fd249a564" dependencies = [ "proc-macro2", "quote", @@ -3299,22 +3322,21 @@ dependencies = [ [[package]] name = "rune-core" -version = "0.13.4" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d30fa78b6cb15d1560bb4cb18f4b99a9097b08ade3a5fc23e5ae7311f97c537b" +checksum = "c424b28fde0f5012680361662145f238f04aeac8a320f352a6e2de863709e7b3" dependencies = [ - "byteorder", "musli", "rune-alloc", "serde", - "twox-hash 1.6.3", + "twox-hash", ] [[package]] name = "rune-macros" -version = "0.13.4" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1b91e53bae3804e4d72e2b04fa5d5108bd93e880ca597c0cae0fb0a662fe198" +checksum = "c86600b36281adeb101c2e4f0be325752fa4c07431e9234e05be5678ad9a97f7" dependencies = [ "proc-macro2", "quote", @@ -3322,6 +3344,26 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "rune-tracing" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0717a7c726015688a19ebfa4bea03a20d2acdd96eeacb5ef48f4ec780e11ac4b" +dependencies = [ + "rune-tracing-macros", +] + +[[package]] +name = "rune-tracing-macros" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12387f96a3e131ce5be8c5668e55f1581dbc6635555d77aa07ab509fd13562bb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "rust-embed" version = "8.11.0" @@ -3532,6 +3574,12 @@ dependencies = [ "serde_json", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" @@ -3862,12 +3910,6 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" -[[package]] -name = "static_assertions" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" - [[package]] name = "status-line" version = "0.2.0" @@ -3983,6 +4025,12 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "syntree" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00c99c9cda412afe293a6b962af651b4594161ba88c1affe7ef66459ea040a06" + [[package]] name = "system-configuration" version = "0.7.0" @@ -4457,16 +4505,6 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" -[[package]] -name = "twox-hash" -version = "1.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" -dependencies = [ - "cfg-if", - "static_assertions", -] - [[package]] name = "twox-hash" version = "2.1.2" diff --git a/Cargo.toml b/Cargo.toml index 5faea7a..31aeccb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ plotters = { version = "0.3", default-features = false, features = ["line_series rand = { version = "0.9.4", default-features = false, features = ["small_rng", "std", "thread_rng"] } rand_distr = "0.5" regex = "1.5" -rune = "0.13" +rune = "0.14" rust_decimal = "1.36" rust-embed = "8" scylla = { version = "1.6", features = ["openssl-010", "chrono-04"], optional = true } diff --git a/src/error.rs b/src/error.rs index 7cfb519..fff2b4b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,7 +2,7 @@ use crate::scripting::db_error::DbError; use hdrhistogram::serialization::interval_log::IntervalLogWriterError; use hdrhistogram::serialization::V2DeflateSerializeError; use rune::alloc; -use rune::runtime::{AccessError, VmError}; +use rune::runtime::{AccessError, RuntimeError, VmError}; use std::path::PathBuf; use thiserror::Error; @@ -55,6 +55,9 @@ pub enum LatteError { #[error("Rune AccessError: {0}")] RuneAccessError(#[from] AccessError), + + #[error("Rune runtime error: {0}")] + RuneRuntimeError(#[from] RuntimeError), } impl From for LatteError { diff --git a/src/exec/workload.rs b/src/exec/workload.rs index 19d8808..2e721ff 100644 --- a/src/exec/workload.rs +++ b/src/exec/workload.rs @@ -21,56 +21,30 @@ use rand::{Rng, SeedableRng}; use rune::alloc::clone::TryClone; use rune::compile::meta::Kind; use rune::compile::{CompileVisitor, MetaError, MetaRef}; -use rune::runtime::{AnyObj, Args, RuntimeContext, Shared, VmError, VmResult}; +use rune::runtime::{Args, RuntimeContext, RuntimeError, VmError}; use rune::termcolor::{ColorChoice, StandardStream}; -use rune::{vm_try, Any, Diagnostics, Source, Sources, ToValue, Unit, Value, Vm}; +use rune::{Diagnostics, Source, Sources, ToValue, Unit, Value, Vm}; use serde::{Deserialize, Serialize}; use try_lock::TryLock; -/// Wraps a reference to Session that can be converted to a Rune `Value` -/// and passed as one of `Args` arguments to a function. -struct SessionRef<'a> { - context: &'a Context, -} - -impl SessionRef<'_> { - pub fn new(context: &Context) -> SessionRef<'_> { - SessionRef { context } - } -} - -/// We need this to be able to pass a reference to `Session` as an argument -/// to Rune function. -/// -/// Caution! Be careful using this trait. Undefined Behaviour possible. -/// This is unsound - it is theoretically -/// possible that the underlying `Session` gets dropped before the `Value` produced by this trait -/// implementation and the compiler is not going to catch that. -/// The receiver of a `Value` must ensure that it is dropped before `Session`! -impl ToValue for SessionRef<'_> { - fn to_value(self) -> VmResult { - let obj = unsafe { AnyObj::from_ref(self.context) }; - VmResult::Ok(Value::from(vm_try!(Shared::new(obj)))) - } -} - -/// Wraps a mutable reference to Session that can be converted to a Rune `Value` and passed -/// as one of `Args` arguments to a function. -struct ContextRefMut<'a> { - context: &'a mut Context, +/// Wraps a shallow clone of Context that can be converted to a rune-owned `Value`. +/// The clone shares Arc-backed fields (stats, statements, presets) with the original, +/// so stats tracking and prepared statements remain shared across function calls. +struct SessionRef { + context: Context, } -impl ContextRefMut<'_> { - pub fn new(context: &mut Context) -> ContextRefMut<'_> { - ContextRefMut { context } +impl SessionRef { + pub fn new(context: &Context) -> SessionRef { + SessionRef { + context: context.shallow_clone(), + } } } -/// Caution! See `impl ToValue for SessionRef`. -impl ToValue for ContextRefMut<'_> { - fn to_value(self) -> VmResult { - let obj = unsafe { AnyObj::from_mut(self.context) }; - VmResult::Ok(Value::from(vm_try!(Shared::new(obj)))) +impl ToValue for SessionRef { + fn to_value(self) -> Result { + Value::new(self.context).map_err(Into::into) } } @@ -199,25 +173,21 @@ impl Program { /// fine, but the function could return an error value, and in this case we should not /// ignore it. fn convert_error(&self, function_name: &str, result: Value) -> Result { - match result { - Value::Result(result) => match result.take().unwrap() { + if result.borrow_ref::>().is_ok() { + match result.downcast::>().unwrap() { Ok(value) => Ok(value), - Err(Value::Any(e)) => { - if e.borrow_ref().unwrap().type_hash() == DbError::type_hash() { - let e = e.take_downcast::().unwrap(); - return Err(LatteError::Database(Box::new(e))); + Err(err_value) => { + if err_value.borrow_ref::().is_ok() { + let e = err_value.downcast::().unwrap(); + Err(LatteError::Database(Box::new(e))) + } else { + let msg = self.vm().with(|| format!("{err_value:?}")); + Err(LatteError::FunctionResult(function_name.to_string(), msg)) } - - let e = Value::Any(e); - let msg = self.vm().with(|| format!("{e:?}")); - Err(LatteError::FunctionResult(function_name.to_string(), msg)) } - Err(other) => Err(LatteError::FunctionResult( - function_name.to_string(), - format!("{other:?}"), - )), - }, - other => Ok(other), + } + } else { + Ok(result) } } @@ -267,24 +237,24 @@ impl Program { /// Calls the script's `init` function. /// Called once at the beginning of the benchmark. /// Typically used to prepare statements. - pub async fn prepare(&mut self, context: &mut Context) -> Result<(), LatteError> { - let context = ContextRefMut::new(context); + pub async fn prepare(&mut self, context: &Context) -> Result<(), LatteError> { + let context = SessionRef::new(context); self.async_call(&FnRef::new(PREPARE_FN), (context,)).await?; Ok(()) } /// Calls the script's `schema` function. /// Typically used to create database schema. - pub async fn schema(&mut self, context: &mut Context) -> Result<(), LatteError> { - let context = ContextRefMut::new(context); + pub async fn schema(&mut self, context: &Context) -> Result<(), LatteError> { + let context = SessionRef::new(context); self.async_call(&FnRef::new(SCHEMA_FN), (context,)).await?; Ok(()) } /// Calls the script's `erase` function. /// Typically used to remove the data from the database before running the benchmark. - pub async fn erase(&mut self, context: &mut Context) -> Result<(), LatteError> { - let context = ContextRefMut::new(context); + pub async fn erase(&mut self, context: &Context) -> Result<(), LatteError> { + let context = SessionRef::new(context); self.async_call(&FnRef::new(ERASE_FN), (context,)).await?; Ok(()) } diff --git a/src/main.rs b/src/main.rs index 18d2be6..427c448 100644 --- a/src/main.rs +++ b/src/main.rs @@ -137,13 +137,13 @@ async fn connect(conf: &ConnectionConf) -> Result<(Context, Option) /// Exits with error if the `schema` function is not present or fails. async fn schema(conf: SchemaCommand) -> Result<()> { let mut program = load_workload_script(&conf.workload, &conf.params)?; - let (mut session, _) = connect(&conf.connection).await?; + let (session, _) = connect(&conf.connection).await?; if !program.has_schema() { eprintln!("error: Function `schema` not found in the workload script."); exit(255); } eprintln!("info: Creating schema..."); - if let Err(e) = program.schema(&mut session).await { + if let Err(e) = program.schema(&session).await { eprintln!("error: Failed to create schema: {e}"); exit(255); } @@ -155,11 +155,11 @@ async fn schema(conf: SchemaCommand) -> Result<()> { /// Exits with error if the `load` function is not present or fails. async fn load(conf: LoadCommand) -> Result<()> { let mut program = load_workload_script(&conf.workload, &conf.params)?; - let (mut session, _) = connect(&conf.connection).await?; + let (session, _) = connect(&conf.connection).await?; if program.has_prepare() { eprintln!("info: Preparing..."); - if let Err(e) = program.prepare(&mut session).await { + if let Err(e) = program.prepare(&session).await { eprintln!("error: Failed to prepare: {e}"); exit(255); } @@ -173,7 +173,7 @@ async fn load(conf: LoadCommand) -> Result<()> { if program.has_erase() { eprintln!("info: Erasing data..."); - if let Err(e) = program.erase(&mut session).await { + if let Err(e) = program.erase(&session).await { eprintln!("error: Failed to erase: {e}"); exit(255); } @@ -236,18 +236,18 @@ async fn run(conf: RunCommand) -> Result<()> { functions.push((function, f.weight)) } - let (mut session, cluster_info) = connect(&conf.connection).await?; + let (session, cluster_info) = connect(&conf.connection).await?; // NOTE: Add info about the target rune functions to the context // for the more flexible tweaking of the 'prepare' rune function. - match &mut session.data { - rune::Value::Object(shared_obj) => { - let _ = shared_obj.borrow_mut()?.insert( + match session.data.borrow_mut::() { + Ok(mut obj) => { + let _ = obj.insert( rune::alloc::String::try_from("functions_to_invoke")?, rune::to_value(functions_to_invoke)?, ); } - _ => { + Err(_) => { eprintln!("error: session.data is not a Rune Object"); exit(255); } @@ -260,7 +260,7 @@ async fn run(conf: RunCommand) -> Result<()> { if program.has_prepare() { eprintln!("info: Preparing..."); - if let Err(e) = program.prepare(&mut session).await { + if let Err(e) = program.prepare(&session).await { eprintln!("error: Failed to prepare: {e}"); exit(255); } diff --git a/src/scripting/alternator/alternator_error.rs b/src/scripting/alternator/alternator_error.rs index 81ffb2e..38e6078 100644 --- a/src/scripting/alternator/alternator_error.rs +++ b/src/scripting/alternator/alternator_error.rs @@ -32,9 +32,9 @@ impl AlternatorError { ))) } - #[rune::function(protocol = STRING_DISPLAY)] + #[rune::function(protocol = DISPLAY_FMT)] pub fn string_display(&self, f: &mut rune::runtime::Formatter) -> VmResult<()> { - vm_write!(f, "{}", self.to_string()); + let _ = vm_write!(f, "{}", self.to_string()); VmResult::Ok(()) } } @@ -97,6 +97,12 @@ impl From for AlternatorError { } } +impl From for AlternatorError { + fn from(error: rune::runtime::RuntimeError) -> Self { + AlternatorError::new(AlternatorErrorKind::ConversionError(error.to_string())) + } +} + impl From for AlternatorError { fn from(error: rune::alloc::Error) -> Self { AlternatorError::new(AlternatorErrorKind::ConversionError(error.to_string())) diff --git a/src/scripting/alternator/context.rs b/src/scripting/alternator/context.rs index 4d4d78f..db37ae2 100644 --- a/src/scripting/alternator/context.rs +++ b/src/scripting/alternator/context.rs @@ -5,9 +5,10 @@ use crate::scripting::cluster_info::ClusterInfo; use crate::scripting::row_distribution::RowDistributionPreset; use crate::stats::session::SessionStats; use aws_sdk_dynamodb::Client; -use rune::runtime::{Object, Shared}; +use rune::runtime::Object; use rune::{Any, Value}; use std::collections::HashMap; +use std::sync::Arc; use std::time::Instant; use try_lock::TryLock; @@ -15,12 +16,12 @@ use try_lock::TryLock; pub struct Context { client: Option, page_size: u64, - pub stats: TryLock, + pub stats: Arc>, pub start_time: TryLock, pub retry_number: u64, pub retry_interval: RetryInterval, pub validation_strategy: ValidationStrategy, - pub partition_row_presets: HashMap, + pub partition_row_presets: Arc>>, #[rune(get, set, add_assign, copy)] pub load_cycle_count: u64, #[rune(get)] @@ -41,14 +42,14 @@ impl Context { Context { client, page_size, - stats: TryLock::new(SessionStats::new()), + stats: Arc::new(TryLock::new(SessionStats::new())), start_time: TryLock::new(Instant::now()), retry_number, retry_interval, validation_strategy, - partition_row_presets: HashMap::new(), + partition_row_presets: Arc::new(TryLock::new(HashMap::new())), load_cycle_count: 0, - data: Value::Object(Shared::new(Object::new()).unwrap()), + data: Value::new(Object::new()).unwrap(), } } @@ -58,17 +59,36 @@ impl Context { Ok(Context { client: self.client.clone(), page_size: self.page_size, - stats: TryLock::new(SessionStats::default()), + stats: Arc::new(TryLock::new(SessionStats::default())), start_time: TryLock::new(*self.start_time.try_lock().unwrap()), retry_number: self.retry_number, retry_interval: self.retry_interval, validation_strategy: self.validation_strategy, - partition_row_presets: self.partition_row_presets.clone(), + partition_row_presets: Arc::new(TryLock::new( + self.partition_row_presets.try_lock().unwrap().clone(), + )), load_cycle_count: self.load_cycle_count, data: deserialized, }) } + /// Creates a shallow clone that shares the Arc-backed fields (stats, presets) + /// with the original. Used to create a rune-owned `Value` for function call arguments. + pub fn shallow_clone(&self) -> Self { + Context { + client: self.client.clone(), + page_size: self.page_size, + stats: Arc::clone(&self.stats), + start_time: TryLock::new(*self.start_time.try_lock().unwrap()), + retry_number: self.retry_number, + retry_interval: self.retry_interval, + validation_strategy: self.validation_strategy, + partition_row_presets: Arc::clone(&self.partition_row_presets), + load_cycle_count: self.load_cycle_count, + data: self.data.clone(), + } + } + /// Returns cluster metadata. pub async fn cluster_info(&self) -> Result, AlternatorError> { let client = self.get_client()?; diff --git a/src/scripting/alternator/functions.rs b/src/scripting/alternator/functions.rs index 101fff3..96f0118 100644 --- a/src/scripting/alternator/functions.rs +++ b/src/scripting/alternator/functions.rs @@ -10,7 +10,7 @@ use aws_sdk_dynamodb::client::Waiters; use aws_sdk_dynamodb::types::{ AttributeDefinition, KeySchemaElement, KeyType, ScalarAttributeType, }; -use rune::runtime::{Object, Ref, Shared}; +use rune::runtime::{Object, Ref}; use rune::{ToValue, Value}; use std::cmp::min; use std::collections::HashMap; @@ -26,25 +26,33 @@ fn bad_input(msg: impl Into) -> Result { /// Gets the name and type of a primary or sort key from a object. fn extract_key_definition( - object: &Shared, + object: &Object, ) -> Result<(String, ScalarAttributeType), AlternatorError> { - let key_name = if let Some(Value::String(s)) = object.borrow_ref()?.get("name") { - s.borrow_ref()?.to_string() + let key_name = if let Some(v) = object.get("name") { + if let Ok(s) = v.borrow_ref::() { + s.as_str().to_string() + } else { + return bad_input("Key definition object must have a 'name' field"); + } } else { return bad_input("Key definition object must have a 'name' field"); }; - let key_type = if let Some(Value::String(t)) = object.borrow_ref()?.get("type") { - match t.borrow_ref()?.as_str() { - "N" => ScalarAttributeType::N, - "S" => ScalarAttributeType::S, - "B" => ScalarAttributeType::B, - other => { - return bad_input(format!( - "Invalid key type: {}, only N, S, and B are allowed.", - other - )) + let key_type = if let Some(v) = object.get("type") { + if let Ok(t) = v.borrow_ref::() { + match t.as_str() { + "N" => ScalarAttributeType::N, + "S" => ScalarAttributeType::S, + "B" => ScalarAttributeType::B, + other => { + return bad_input(format!( + "Invalid key type: {}, only N, S, and B are allowed.", + other + )) + } } + } else { + return bad_input("Key definition object must have a 'type' field"); } } else { return bad_input("Key definition object must have a 'type' field"); @@ -53,20 +61,15 @@ fn extract_key_definition( Ok((key_name, key_type)) } -fn extract_attribute_names( - object: &Shared, -) -> Result, AlternatorError> { +fn extract_attribute_names(object: &Object) -> Result, AlternatorError> { object - .borrow_ref()? .iter() .map(|(k, v)| { - Ok(( - k.to_string(), - match v { - Value::String(s) => s.borrow_ref()?.to_string(), - _ => return bad_input("Attribute names must be strings"), - }, - )) + if let Ok(s) = v.borrow_ref::() { + Ok((k.to_string(), s.as_str().to_string())) + } else { + bad_input("Attribute names must be strings") + } }) .collect::>() } @@ -195,25 +198,51 @@ pub async fn create_table( let client = ctx.get_client()?; // Extract primary key definition - let (pk_name, pk_type) = match ¶ms { - Value::String(s) => (s.borrow_ref()?.to_string(), ScalarAttributeType::S), - Value::Object(o) => match o.borrow_ref()?.get("primary_key") { - Some(Value::String(s)) => (s.borrow_ref()?.to_string(), ScalarAttributeType::S), - Some(Value::Object(pk_obj)) => extract_key_definition(pk_obj)?, + let (pk_name, pk_type) = if let Ok(s) = params.borrow_ref::() { + (s.as_str().to_string(), ScalarAttributeType::S) + } else if let Ok(o) = params.borrow_ref::() { + match o.get("primary_key") { + Some(v) if v.borrow_ref::().is_ok() => ( + v.borrow_ref::() + .unwrap() + .as_str() + .to_string(), + ScalarAttributeType::S, + ), + Some(v) => { + if let Ok(pk_obj) = v.borrow_ref::() { + extract_key_definition(&pk_obj)? + } else { + return bad_input("Invalid 'primary_key' object in params"); + } + } _ => return bad_input("Invalid 'primary_key' object in params"), - }, - _ => return bad_input("Params must be a string or an object"), + } + } else { + return bad_input("Params must be a string or an object"); }; // Extract sort key definition if present - let sk = match ¶ms { - Value::Object(o) => match o.borrow_ref()?.get("sort_key") { - Some(Value::String(s)) => Some((s.borrow_ref()?.to_string(), ScalarAttributeType::S)), - Some(Value::Object(sk_obj)) => Some(extract_key_definition(sk_obj)?), - Some(_) => return bad_input("Invalid 'sort_key' object in params"), + let sk = if let Ok(o) = params.borrow_ref::() { + match o.get("sort_key") { + Some(v) if v.borrow_ref::().is_ok() => Some(( + v.borrow_ref::() + .unwrap() + .as_str() + .to_string(), + ScalarAttributeType::S, + )), + Some(v) => { + if let Ok(sk_obj) = v.borrow_ref::() { + Some(extract_key_definition(&sk_obj)?) + } else { + return bad_input("Invalid 'sort_key' object in params"); + } + } None => None, - }, - _ => None, + } + } else { + None }; let mut builder = client @@ -292,7 +321,7 @@ pub async fn put( let builder = client .put_item() .table_name(table_name.deref()) - .set_item(Some(rune_object_to_alternator_map(item)?)); + .set_item(Some(rune_object_to_alternator_map(&item)?)); handle_request(&ctx, builder).await?; @@ -316,7 +345,7 @@ pub async fn delete( let builder = client .delete_item() .table_name(table_name.deref()) - .set_key(Some(rune_object_to_alternator_map(key)?)); + .set_key(Some(rune_object_to_alternator_map(&key)?)); handle_request(&ctx, builder).await?; @@ -349,25 +378,23 @@ pub async fn get( let mut builder = client .get_item() .table_name(table_name.deref()) - .set_key(Some(rune_object_to_alternator_map(key)?)); + .set_key(Some(rune_object_to_alternator_map(&key)?)); - if let Value::Object(opts) = &options { - if let Some(Value::Bool(consistent_read)) = opts.borrow_ref()?.get("consistent_read") { - builder = builder.consistent_read(*consistent_read); + if let Ok(opts) = options.borrow_ref::() { + if let Some(b) = opts.get("consistent_read").and_then(|v| v.as_bool().ok()) { + builder = builder.consistent_read(b); } } let result = handle_request(&ctx, builder).await?; - if let Value::Object(opts) = &options { - if let Some(Value::Bool(with_result)) = opts.borrow_ref()?.get("with_result") { - if *with_result { - return Ok(result.into_iter().next().to_value().into_result()?); - } + if let Ok(opts) = options.borrow_ref::() { + if opts.get("with_result").and_then(|v| v.as_bool().ok()) == Some(true) { + return Ok(result.into_iter().next().to_value()?); } } - Ok(Value::EmptyTuple) + Ok(Value::from(())) } /// Updates an item in the table. @@ -392,21 +419,25 @@ pub async fn update( let mut builder = client .update_item() .table_name(table_name.deref()) - .set_key(Some(rune_object_to_alternator_map(key)?)); + .set_key(Some(rune_object_to_alternator_map(&key)?)); - if let Some(Value::String(update_expression)) = params.get("update") { - builder = builder.update_expression(update_expression.borrow_ref()?.to_string()); + if let Some(v) = params.get("update") { + if let Ok(s) = v.borrow_ref::() { + builder = builder.update_expression(s.as_str().to_string()); + } } - if let Some(Value::Object(attr_names)) = params.get("attribute_names") { - builder = - builder.set_expression_attribute_names(Some(extract_attribute_names(attr_names)?)); + if let Some(v) = params.get("attribute_names") { + if let Ok(obj) = v.borrow_ref::() { + builder = builder.set_expression_attribute_names(Some(extract_attribute_names(&obj)?)); + } } - if let Some(Value::Object(attr_values)) = params.get("attribute_values") { - builder = builder.set_expression_attribute_values(Some(rune_object_to_alternator_map( - attr_values.clone().into_ref()?, - )?)); + if let Some(v) = params.get("attribute_values") { + if let Ok(obj) = v.borrow_ref::() { + builder = + builder.set_expression_attribute_values(Some(rune_object_to_alternator_map(&obj)?)); + } } handle_request(&ctx, builder).await?; @@ -444,58 +475,66 @@ pub async fn query( let mut builder = client.query().table_name(table_name.deref()); - if let Some(Value::String(key_condition_expression)) = params.get("query") { - builder = - builder.key_condition_expression(key_condition_expression.borrow_ref()?.to_string()); + if let Some(v) = params.get("query") { + if let Ok(s) = v.borrow_ref::() { + builder = builder.key_condition_expression(s.as_str().to_string()); + } } - if let Some(Value::String(filter_expression)) = params.get("filter") { - builder = builder.filter_expression(filter_expression.borrow_ref()?.to_string()); + if let Some(v) = params.get("filter") { + if let Ok(s) = v.borrow_ref::() { + builder = builder.filter_expression(s.as_str().to_string()); + } } - if let Some(Value::Object(attr_names)) = params.get("attribute_names") { - builder = - builder.set_expression_attribute_names(Some(extract_attribute_names(attr_names)?)); + if let Some(v) = params.get("attribute_names") { + if let Ok(obj) = v.borrow_ref::() { + builder = builder.set_expression_attribute_names(Some(extract_attribute_names(&obj)?)); + } } - if let Some(Value::Object(attr_values)) = params.get("attribute_values") { - builder = builder.set_expression_attribute_values(Some(rune_object_to_alternator_map( - attr_values.clone().into_ref()?, - )?)); + if let Some(v) = params.get("attribute_values") { + if let Ok(obj) = v.borrow_ref::() { + builder = + builder.set_expression_attribute_values(Some(rune_object_to_alternator_map(&obj)?)); + } } - if let Some(Value::Bool(consistent_read)) = params.get("consistent_read") { - builder = builder.consistent_read(*consistent_read); + if let Some(b) = params.get("consistent_read").and_then(|v| v.as_bool().ok()) { + builder = builder.consistent_read(b); } if let Some(limit_val) = params.get("limit") { - builder = builder.limit(match limit_val { - Value::Integer(i) => match i32::try_from(*i) { + if let Ok(i) = limit_val.as_signed() { + builder = builder.limit(match i32::try_from(i) { Ok(val) => val, Err(_) => return bad_input("limit is out of range"), - }, - _ => return bad_input("limit must be an integer"), - }); + }); + } else { + return bad_input("limit must be an integer"); + } } - let validation = if let Some(Value::Vec(validation)) = params.get("validation") { - Some( - extract_validation_args(validation.borrow_ref()?.to_vec()) - .map_err(|s| AlternatorError::new(AlternatorErrorKind::BadInput(s)))?, - ) + let validation = if let Some(v) = params.get("validation") { + if let Ok(vec) = v.borrow_ref::() { + Some( + extract_validation_args(vec.to_vec()) + .map_err(|s| AlternatorError::new(AlternatorErrorKind::BadInput(s)))?, + ) + } else { + None + } } else { None }; let result = handle_request_with_validation(&ctx, builder, validation, "Query").await?; - if let Some(Value::Bool(with_result)) = params.get("with_result") { - if *with_result { - return Ok(result.to_value().into_result()?); - } + if params.get("with_result").and_then(|v| v.as_bool().ok()) == Some(true) { + return Ok(result.to_value()?); } - Ok(Value::EmptyTuple) + Ok(Value::from(())) } /// Scans items from the table. @@ -523,48 +562,55 @@ pub async fn scan( let mut builder = client.scan().table_name(table_name.deref()); - if let Some(Value::String(filter_expression)) = params.get("filter") { - builder = builder.filter_expression(filter_expression.borrow_ref()?.to_string()); + if let Some(v) = params.get("filter") { + if let Ok(s) = v.borrow_ref::() { + builder = builder.filter_expression(s.as_str().to_string()); + } } - if let Some(Value::Object(attr_names)) = params.get("attribute_names") { - builder = - builder.set_expression_attribute_names(Some(extract_attribute_names(attr_names)?)); + if let Some(v) = params.get("attribute_names") { + if let Ok(obj) = v.borrow_ref::() { + builder = builder.set_expression_attribute_names(Some(extract_attribute_names(&obj)?)); + } } - if let Some(Value::Object(attr_values)) = params.get("attribute_values") { - builder = builder.set_expression_attribute_values(Some(rune_object_to_alternator_map( - attr_values.clone().into_ref()?, - )?)); + if let Some(v) = params.get("attribute_values") { + if let Ok(obj) = v.borrow_ref::() { + builder = + builder.set_expression_attribute_values(Some(rune_object_to_alternator_map(&obj)?)); + } } - if let Some(Value::Bool(consistent_read)) = params.get("consistent_read") { - builder = builder.consistent_read(*consistent_read); + if let Some(b) = params.get("consistent_read").and_then(|v| v.as_bool().ok()) { + builder = builder.consistent_read(b); } if let Some(limit_val) = params.get("limit") { - builder = builder.limit(match limit_val { - Value::Integer(i) => *i as i32, - _ => return bad_input("limit must be an integer"), - }); + if let Ok(i) = limit_val.as_signed() { + builder = builder.limit(i as i32); + } else { + return bad_input("limit must be an integer"); + } } - let validation = if let Some(Value::Vec(validation)) = params.get("validation") { - Some( - extract_validation_args(validation.borrow_ref()?.to_vec()) - .map_err(|s| AlternatorError::new(AlternatorErrorKind::BadInput(s)))?, - ) + let validation = if let Some(v) = params.get("validation") { + if let Ok(vec) = v.borrow_ref::() { + Some( + extract_validation_args(vec.to_vec()) + .map_err(|s| AlternatorError::new(AlternatorErrorKind::BadInput(s)))?, + ) + } else { + None + } } else { None }; let result = handle_request_with_validation(&ctx, builder, validation, "Scan").await?; - if let Some(Value::Bool(with_result)) = params.get("with_result") { - if *with_result { - return Ok(result.to_value().into_result()?); - } + if params.get("with_result").and_then(|v| v.as_bool().ok()) == Some(true) { + return Ok(result.to_value()?); } - Ok(Value::EmptyTuple) + Ok(Value::from(())) } diff --git a/src/scripting/alternator/types.rs b/src/scripting/alternator/types.rs index 1a2d5d1..6e8a74e 100644 --- a/src/scripting/alternator/types.rs +++ b/src/scripting/alternator/types.rs @@ -1,67 +1,66 @@ use super::alternator_error::{AlternatorError, AlternatorErrorKind}; use aws_sdk_dynamodb::types::AttributeValue; -use rune::runtime::{Bytes, Object, Ref}; +use rune::alloc::String as RuneString; +use rune::runtime::{Bytes, Object}; use rune::{ToValue, Value}; use std::collections::HashMap; pub fn rune_value_to_alternator_attribute(v: Value) -> Result { - match v { - Value::Bool(b) => Ok(AttributeValue::Bool(b)), - - // DynamoDB represents all numbers as strings - Value::Integer(i) => Ok(AttributeValue::N(i.to_string())), - // To distinguish floats from integers, we print them with the decimal point - Value::Float(f) => Ok(AttributeValue::N(format!("{:?}", f))), - - Value::String(s) => Ok(AttributeValue::S(s.into_ref()?.to_string())), - - Value::Bytes(b) => Ok(AttributeValue::B(b.into_ref()?.to_vec().into())), - - Value::Vec(v) => Ok(AttributeValue::L( - v.into_ref()? - .iter() - .map(|v| rune_value_to_alternator_attribute(v.clone())) - .collect::>()?, - )), - - Value::Object(o) => Ok(AttributeValue::M(rune_object_to_alternator_map( - o.into_ref()?, - )?)), - - Value::Option(o) => match o.into_ref()?.as_ref() { - Some(v) => rune_value_to_alternator_attribute(v.clone()), + if let Ok(b) = v.as_bool() { + return Ok(AttributeValue::Bool(b)); + } + if let Ok(i) = v.as_signed() { + return Ok(AttributeValue::N(i.to_string())); + } + if let Ok(f) = v.as_float() { + return Ok(AttributeValue::N(format!("{:?}", f))); + } + if let Ok(s) = v.borrow_ref::() { + return Ok(AttributeValue::S(s.as_str().to_string())); + } + if let Ok(b) = v.borrow_ref::() { + return Ok(AttributeValue::B(b.to_vec().into())); + } + if let Ok(vec) = v.borrow_ref::() { + let list = vec + .iter() + .map(|v| rune_value_to_alternator_attribute(v.clone())) + .collect::>()?; + return Ok(AttributeValue::L(list)); + } + if let Ok(obj) = v.borrow_ref::() { + let map = obj + .iter() + .map(|(k, v)| { + Ok(( + k.to_string(), + rune_value_to_alternator_attribute(v.clone())?, + )) + }) + .collect::, AlternatorError>>()?; + return Ok(AttributeValue::M(map)); + } + if let Ok(opt) = v.borrow_ref::>() { + return match opt.as_ref() { + Some(inner) => rune_value_to_alternator_attribute(inner.clone()), None => Ok(AttributeValue::Null(true)), - }, - - _ => Err(AlternatorError::new(AlternatorErrorKind::ConversionError( - format!("Unsupported Rune Value type for: {:?}", v), - ))), + }; } -} - -pub fn rune_object_to_alternator_map( - o: Ref, -) -> Result, AlternatorError> { - o.iter() - .map(|(k, v)| { - Ok(( - k.to_string(), - rune_value_to_alternator_attribute(v.clone())?, - )) - }) - .collect() + Err(AlternatorError::new(AlternatorErrorKind::ConversionError( + format!("Unsupported Rune Value type for: {:?}", v), + ))) } pub fn alternator_attribute_to_rune_value(attr: AttributeValue) -> Result { match attr { - AttributeValue::Bool(b) => Ok(Value::Bool(b)), + AttributeValue::Bool(b) => Ok(Value::from(b)), AttributeValue::N(n) => { // Try parsing as integer first, then as float if let Ok(i) = n.parse::() { - Ok(Value::Integer(i)) + Ok(Value::from(i)) } else if let Ok(f) = n.parse::() { - Ok(Value::Float(f)) + Ok(Value::from(f)) } else { Err(AlternatorError::new(AlternatorErrorKind::ConversionError( format!("Invalid number format: {}", n), @@ -69,20 +68,31 @@ pub fn alternator_attribute_to_rune_value(attr: AttributeValue) -> Result Ok(s.as_str().to_value().into_result()?), + AttributeValue::S(s) => Ok(Value::new(RuneString::try_from(s).map_err(|e| { + AlternatorError::new(AlternatorErrorKind::ConversionError(e.to_string())) + })?) + .map_err(|e| AlternatorError::new(AlternatorErrorKind::ConversionError(e.to_string())))?), - AttributeValue::B(b) => Ok(Bytes::try_from(b.into_inner())?.to_value().into_result()?), + AttributeValue::B(b) => Ok(Bytes::try_from(b.into_inner())?.to_value()?), - AttributeValue::L(l) => Ok(l - .into_iter() - .map(alternator_attribute_to_rune_value) - .collect::, _>>()? - .to_value() - .into_result()?), + AttributeValue::L(l) => { + let mut rune_vec = rune::runtime::Vec::new(); + for attr in l { + let val = alternator_attribute_to_rune_value(attr)?; + rune_vec.push(val).map_err(|e| { + AlternatorError::new(AlternatorErrorKind::ConversionError(e.to_string())) + })?; + } + Ok(Value::vec(rune_vec.into_inner()).map_err(|e| { + AlternatorError::new(AlternatorErrorKind::ConversionError(e.to_string())) + })?) + } AttributeValue::M(map) => Ok(alternator_map_to_rune_object(map)?), - AttributeValue::Null(_) => Ok(None::.to_value().into_result()?), + AttributeValue::Null(_) => Ok(Value::try_from(None::).map_err(|e| { + AlternatorError::new(AlternatorErrorKind::ConversionError(e.to_string())) + })?), _ => Err(AlternatorError::new(AlternatorErrorKind::ConversionError( format!("Unsupported Alternator AttributeValue type: {:?}", attr), @@ -90,13 +100,32 @@ pub fn alternator_attribute_to_rune_value(attr: AttributeValue) -> Result Result, AlternatorError> { + obj.iter() + .map(|(k, v)| { + Ok(( + k.to_string(), + rune_value_to_alternator_attribute(v.clone())?, + )) + }) + .collect() +} + pub fn alternator_map_to_rune_object( map: HashMap, ) -> Result { - Ok(map - .into_iter() - .map(|(k, v)| Ok((k, alternator_attribute_to_rune_value(v)?))) - .collect::, AlternatorError>>()? - .to_value() - .into_result()?) + let mut obj = Object::new(); + for (k, v) in map { + let rune_key = RuneString::try_from(k).map_err(|e| { + AlternatorError::new(AlternatorErrorKind::ConversionError(e.to_string())) + })?; + let rune_val = alternator_attribute_to_rune_value(v)?; + obj.insert(rune_key, rune_val).map_err(|e| { + AlternatorError::new(AlternatorErrorKind::ConversionError(e.to_string())) + })?; + } + Value::new(obj) + .map_err(|e| AlternatorError::new(AlternatorErrorKind::ConversionError(e.to_string()))) } diff --git a/src/scripting/cql/cass_error.rs b/src/scripting/cql/cass_error.rs index 2054ca0..3d2970f 100644 --- a/src/scripting/cql/cass_error.rs +++ b/src/scripting/cql/cass_error.rs @@ -116,9 +116,9 @@ pub struct QueryInfo { } impl CassError { - #[rune::function(protocol = STRING_DISPLAY)] + #[rune::function(protocol = DISPLAY_FMT)] pub fn string_display(&self, f: &mut rune::runtime::Formatter) -> VmResult<()> { - vm_write!(f, "{}", self.to_string()); + let _ = vm_write!(f, "{}", self.to_string()); VmResult::Ok(()) } @@ -280,26 +280,22 @@ impl std::error::Error for CassError {} /// Formats a rune Value into a list of parameter display strings for error messages fn rune_value_to_param_strings(value: Option<&Value>) -> Vec { - match value { - None => vec![], - Some(Value::Tuple(tuple)) => match tuple.borrow_ref() { - Ok(tuple) => tuple.iter().map(|v| format!("{v:?}")).collect(), - Err(_) => vec!["".to_string()], - }, - Some(Value::Vec(vec)) => match vec.borrow_ref() { - Ok(vec) => vec.iter().map(|v| format!("{v:?}")).collect(), - Err(_) => vec!["".to_string()], - }, - Some(Value::Object(obj)) => match obj.borrow_ref() { - Ok(obj) => obj.iter().map(|(k, v)| format!("{k}: {v:?}")).collect(), - Err(_) => vec!["".to_string()], - }, - Some(Value::Struct(obj)) => match obj.borrow_ref() { - Ok(obj) => vec![format!("{obj:?}")], - Err(_) => vec!["".to_string()], - }, - Some(other) => vec![format!("{other:?}")], + let Some(v) = value else { + return vec![]; + }; + if let Ok(tuple) = v.borrow_ref::() { + return tuple.iter().map(|v| format!("{v:?}")).collect(); } + if let Ok(vec) = v.borrow_ref::() { + return vec.iter().map(|v| format!("{v:?}")).collect(); + } + if let Ok(obj) = v.borrow_ref::() { + return obj.iter().map(|(k, v)| format!("{k}: {v:?}")).collect(); + } + if let Ok(rune::runtime::TypeValue::Struct(s)) = v.as_type_value() { + return vec![format!("{s:?}")]; + } + vec![format!("{v:?}")] } impl Display for QueryInfo { diff --git a/src/scripting/cql/context.rs b/src/scripting/cql/context.rs index 62301e2..39866ca 100644 --- a/src/scripting/cql/context.rs +++ b/src/scripting/cql/context.rs @@ -9,9 +9,8 @@ use crate::scripting::row_distribution::RowDistributionPreset; use crate::stats::session::SessionStats; use once_cell::sync::Lazy; -use rand::prelude::ThreadRng; use regex::Regex; -use rune::runtime::{Object, Shared, Vec as RuneVec}; +use rune::runtime::{Object, Vec as RuneVec}; use rune::{Any, Value}; use scylla::client::session::Session; use scylla::response::PagingState; @@ -38,12 +37,12 @@ pub struct Context { // which don't 'depend on'/'use' the 'session' object. session: Option>, page_size: u64, - statements: HashMap>, - pub stats: TryLock, + statements: Arc>>>, + pub stats: Arc>, pub retry_number: u64, pub retry_interval: RetryInterval, pub validation_strategy: ValidationStrategy, - pub partition_row_presets: HashMap, + pub partition_row_presets: Arc>>, #[rune(get, set, add_assign, copy)] pub load_cycle_count: u64, #[rune(get)] @@ -52,7 +51,6 @@ pub struct Context { pub preferred_rack: String, #[rune(get)] pub data: Value, - pub rng: ThreadRng, } // Needed, because Rune `Value` is !Send, as it may contain some internal pointers. @@ -74,21 +72,21 @@ impl Context { retry_interval: RetryInterval, validation_strategy: ValidationStrategy, ) -> Context { + let data = Value::new(Object::new()).unwrap(); Context { start_time: TryLock::new(Instant::now()), session: session.map(Arc::new), page_size, - statements: HashMap::new(), - stats: TryLock::new(SessionStats::new()), + statements: Arc::new(TryLock::new(HashMap::new())), + stats: Arc::new(TryLock::new(SessionStats::new())), retry_number, retry_interval, validation_strategy, - partition_row_presets: HashMap::new(), + partition_row_presets: Arc::new(TryLock::new(HashMap::new())), load_cycle_count: 0, preferred_datacenter, preferred_rack, - data: Value::Object(Shared::new(Object::new()).unwrap()), - rng: rand::rng(), + data, } } @@ -102,21 +100,43 @@ impl Context { Ok(Context { session: self.session.clone(), page_size: self.page_size, - statements: self.statements.clone(), - stats: TryLock::new(SessionStats::default()), + statements: Arc::new(TryLock::new(self.statements.try_lock().unwrap().clone())), + stats: Arc::new(TryLock::new(SessionStats::default())), retry_number: self.retry_number, retry_interval: self.retry_interval, validation_strategy: self.validation_strategy, - partition_row_presets: self.partition_row_presets.clone(), + partition_row_presets: Arc::new(TryLock::new( + self.partition_row_presets.try_lock().unwrap().clone(), + )), load_cycle_count: self.load_cycle_count, preferred_datacenter: self.preferred_datacenter.clone(), preferred_rack: self.preferred_rack.clone(), data: deserialized, start_time: TryLock::new(*self.start_time.try_lock().unwrap()), - rng: rand::rng(), }) } + /// Creates a shallow clone that shares the Arc-backed fields (stats, statements, presets) + /// with the original. Used to create a rune-owned `Value` for function call arguments + /// without losing stats tracking. + pub fn shallow_clone(&self) -> Self { + Context { + start_time: TryLock::new(*self.start_time.try_lock().unwrap()), + session: self.session.clone(), + page_size: self.page_size, + statements: Arc::clone(&self.statements), + stats: Arc::clone(&self.stats), + retry_number: self.retry_number, + retry_interval: self.retry_interval, + validation_strategy: self.validation_strategy, + partition_row_presets: Arc::clone(&self.partition_row_presets), + load_cycle_count: self.load_cycle_count, + preferred_datacenter: self.preferred_datacenter.clone(), + preferred_rack: self.preferred_rack.clone(), + data: self.data.clone(), + } + } + /// Returns cluster metadata such as cluster name and DB version. pub async fn cluster_info(&self) -> Result, CassError> { let session = match &self.session { @@ -199,14 +219,17 @@ impl Context { } /// Prepares a statement and stores it in an internal statement map for future use. - pub async fn prepare(&mut self, key: &str, cql: &str) -> Result<(), CassError> { + pub async fn prepare(&self, key: &str, cql: &str) -> Result<(), CassError> { match &self.session { Some(session) => { let statement = session .prepare(Statement::new(cql).with_page_size(self.page_size as i32)) .await .map_err(|e| CassError::prepare_error(cql, e))?; - self.statements.insert(key.to_string(), Arc::new(statement)); + self.statements + .try_lock() + .unwrap() + .insert(key.to_string(), Arc::new(statement)); Ok(()) } None => Err(CassError(CassErrorKind::Error( @@ -323,12 +346,17 @@ impl Context { ))); } let stmt = if let Some(key) = key { - self.statements.get(key).ok_or_else(|| { - CassError(CassErrorKind::PreparedStatementNotFound(key.to_string())) - })? + self.statements + .try_lock() + .unwrap() + .get(key) + .cloned() + .ok_or_else(|| { + CassError(CassErrorKind::PreparedStatementNotFound(key.to_string())) + })? } else { let cql = cql.expect("failed to unwrap the 'cql' parameter"); - &Arc::new( + Arc::new( session .prepare(Statement::new(cql).with_page_size(self.page_size as i32)) .await @@ -363,7 +391,7 @@ impl Context { while current_attempt_num <= self.retry_number { let start_time = self.stats.try_lock().unwrap().start_request(); let rs = session - .execute_single_page(stmt, &query_params, paging_state.clone()) + .execute_single_page(&stmt, &query_params, paging_state.clone()) .await; let current_duration = Instant::now() - start_time; let (page, paging_state_response) = match rs { @@ -384,11 +412,11 @@ impl Context { match row_result { Ok(RuneRow(row_obj)) => { rune_rows - .push(Value::Object(Shared::new(row_obj).map_err(|_| { + .push(Value::new(row_obj).map_err(|_| { CassError(CassErrorKind::Error( "Failed to create shared row object".to_string(), )) - })?)) + })?) .map_err(|_| { CassError(CassErrorKind::Error( "Failed to push row to result vector".to_string(), @@ -417,13 +445,13 @@ impl Context { .unwrap() .complete_request(all_pages_duration, rows_num); if process_and_return_data { - return Ok(Value::Vec(Shared::new(rune_rows).map_err(|_| { + return Value::vec(rune_rows.into_inner()).map_err(|_| { CassError(CassErrorKind::Error( "Failed to create shared result vector".to_string(), )) - })?)); + }); } else { - let empty_rune_vec = Value::Vec(Shared::new(RuneVec::new())?); + let empty_rune_vec = Value::vec(Default::default())?; let rows_min = match expected_rows_num_min { None => return Ok(empty_rune_vec), Some(rows_min) => rows_min, @@ -496,10 +524,16 @@ impl Context { let mut batch: Batch = Batch::new(BatchType::Logged); let mut batch_values: Vec> = Vec::with_capacity(keys_len); for (i, key) in keys.into_iter().enumerate() { - let statement = self.statements.get(key).ok_or_else(|| { - CassError(CassErrorKind::PreparedStatementNotFound(key.to_string())) - })?; - batch.append_statement((**statement).clone()); + let statement = self + .statements + .try_lock() + .unwrap() + .get(key) + .cloned() + .ok_or_else(|| { + CassError(CassErrorKind::PreparedStatementNotFound(key.to_string())) + })?; + batch.append_statement((*statement).clone()); batch_values.push(RuneQueryParams::new(params.get(i))); } match &self.session { diff --git a/src/scripting/cql/deserialize.rs b/src/scripting/cql/deserialize.rs index 328685e..d233574 100644 --- a/src/scripting/cql/deserialize.rs +++ b/src/scripting/cql/deserialize.rs @@ -5,7 +5,7 @@ use std::net::IpAddr; use super::cass_error::{CassError, CassErrorKind}; use rune::alloc::String as RuneString; -use rune::runtime::{Object, OwnedTuple, Shared, Vec as RuneVec}; +use rune::runtime::{Object, OwnedTuple, Vec as RuneVec}; use rune::Value; use scylla::cluster::metadata::{CollectionType, ColumnType, NativeType}; use scylla::deserialize::row::{ColumnIterator, DeserializeRow}; @@ -34,13 +34,9 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { v: Option>, ) -> Result { let Some(slice) = v else { - return Ok(RuneValue(Value::Option(Shared::new(None).map_err( - |_| { - DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared None".to_string(), - ))) - }, - )?))); + return Ok(RuneValue( + Value::try_from(None::).map_err(DeserializationError::new)?, + )); }; // This matches the old logic that used CqlValue @@ -52,13 +48,9 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { // can't be empty } _ => { - return Ok(RuneValue(Value::Option(Shared::new(None).map_err( - |_| { - DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared None".to_string(), - ))) - }, - )?))) + return Ok(RuneValue( + Value::try_from(None::).map_err(DeserializationError::new)?, + )); } } } @@ -67,58 +59,56 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { ColumnType::Native(NativeType::Ascii) | ColumnType::Native(NativeType::Text) => { let string = >::deserialize(typ, Some(slice))?; - Value::String( - Shared::new(RuneString::try_from(string).expect("Failed to create RuneString")) - .map_err(|_| { - DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared string".to_string(), - ))) - })?, - ) + Value::new(RuneString::try_from(string).expect("Failed to create RuneString")) + .map_err(|_| { + DeserializationError::new(CassError(CassErrorKind::Error( + "Failed to create string value".to_string(), + ))) + })? } ColumnType::Native(NativeType::Boolean) => { >::deserialize(typ, Some(slice)) - .map(Value::Bool)? + .map(Value::from)? } ColumnType::Native(NativeType::TinyInt) => { >::deserialize(typ, Some(slice)) - .map(|i| Value::Integer(i.into()))? + .map(|i| Value::from(i as i64))? } ColumnType::Native(NativeType::SmallInt) => { >::deserialize(typ, Some(slice)) - .map(|i| Value::Integer(i.into()))? + .map(|i| Value::from(i as i64))? } ColumnType::Native(NativeType::Int) => { >::deserialize(typ, Some(slice)) - .map(|i| Value::Integer(i.into()))? + .map(|i| Value::from(i as i64))? } ColumnType::Native(NativeType::BigInt) => { >::deserialize(typ, Some(slice)) - .map(Value::Integer)? + .map(Value::from)? } ColumnType::Native(NativeType::Float) => { >::deserialize(typ, Some(slice)) - .map(|i| Value::Float(i.into()))? + .map(|f| Value::from(f as f64))? } ColumnType::Native(NativeType::Double) => { >::deserialize(typ, Some(slice)) - .map(Value::Float)? + .map(Value::from)? } ColumnType::Native(NativeType::Counter) => { >::deserialize(typ, Some(slice)) - .map(|c| Value::Integer(c.0))? + .map(|c| Value::from(c.0))? } ColumnType::Native(NativeType::Timestamp) => { >::deserialize(typ, Some(slice)) - .map(|ts| Value::Integer(ts.0))? + .map(|ts| Value::from(ts.0))? } ColumnType::Native(NativeType::Date) => { >::deserialize(typ, Some(slice)) - .map(|date| Value::Integer(date.0.into()))? + .map(|date| Value::from(date.0 as i64))? } ColumnType::Native(NativeType::Time) => { >::deserialize(typ, Some(slice)) - .map(|time| Value::Integer(time.0))? + .map(|time| Value::from(time.0))? } ColumnType::Native(NativeType::Blob) => { // Note: is it intentional that blobs are representes as Vec of rune Bytes? @@ -130,64 +120,58 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { )?; let mut rune_vec = RuneVec::new(); for byte in bytes_slice { - rune_vec.push(Value::Byte(*byte)).map_err(|_| { + rune_vec.push(Value::from(*byte)).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( "Failed to push byte to Rune vector".to_string(), ))) })?; } - Value::Vec(Shared::new(rune_vec).map_err(|_| { + Value::vec(rune_vec.into_inner()).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared vector for blob".to_string(), + "Failed to create vector for blob".to_string(), ))) - })?) + })? } ColumnType::Native(NativeType::Uuid) => { let uuid = >::deserialize(typ, Some(slice))?; - Value::String( - Shared::new( - RuneString::try_from(uuid.to_string()) - .expect("Failed to create RuneString for UUID"), - ) - .map_err(|_| { - DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared string for UUID".to_string(), - ))) - })?, + Value::new( + RuneString::try_from(uuid.to_string()) + .expect("Failed to create RuneString for UUID"), ) + .map_err(|_| { + DeserializationError::new(CassError(CassErrorKind::Error( + "Failed to create string value for UUID".to_string(), + ))) + })? } ColumnType::Native(NativeType::Timeuuid) => { let timeuuid = >::deserialize( typ, Some(slice), )?; - Value::String( - Shared::new( - RuneString::try_from(timeuuid.to_string()) - .expect("Failed to create RuneString for TimeUuid"), - ) - .map_err(|_| { - DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared string for TimeUuid".to_string(), - ))) - })?, + Value::new( + RuneString::try_from(timeuuid.to_string()) + .expect("Failed to create RuneString for TimeUuid"), ) + .map_err(|_| { + DeserializationError::new(CassError(CassErrorKind::Error( + "Failed to create string value for TimeUuid".to_string(), + ))) + })? } ColumnType::Native(NativeType::Inet) => { let addr = >::deserialize(typ, Some(slice))?; - Value::String( - Shared::new( - RuneString::try_from(addr.to_string()) - .expect("Failed to create RuneString for IpAddr"), - ) - .map_err(|_| { - DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared string for IpAddr".to_string(), - ))) - })?, + Value::new( + RuneString::try_from(addr.to_string()) + .expect("Failed to create RuneString for IpAddr"), ) + .map_err(|_| { + DeserializationError::new(CassError(CassErrorKind::Error( + "Failed to create string value for IpAddr".to_string(), + ))) + })? } ColumnType::Native(NativeType::Varint) => { let varint = as DeserializeValue<'frame, 'metadata>>::deserialize(typ, Some(slice))?; @@ -205,7 +189,7 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { padded[8 - varint_bytes.len()..].copy_from_slice(varint_bytes); i64::from_be_bytes(padded) }; - Value::Integer(integer) + Value::from(integer) } ColumnType::Native(NativeType::Decimal) => { let decimal = as DeserializeValue< @@ -228,16 +212,14 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { u32::try_from(scale).map_err(DeserializationError::new)?, ) .unwrap(); - Value::String( - Shared::new( - RuneString::try_from(dec.to_string()).expect("Failed to create RuneString"), - ) - .map_err(|_| { - DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared string for Decimal".to_string(), - ))) - })?, + Value::new( + RuneString::try_from(dec.to_string()).expect("Failed to create RuneString"), ) + .map_err(|_| { + DeserializationError::new(CassError(CassErrorKind::Error( + "Failed to create string value for Decimal".to_string(), + ))) + })? } ColumnType::Native(NativeType::Duration) => { let duration = >::deserialize( @@ -249,7 +231,7 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { rune_obj .insert( RuneString::try_from("months").expect("Failed to create RuneString"), - Value::Integer(duration.months as i64), + Value::from(duration.months as i64), ) .map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( @@ -259,7 +241,7 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { rune_obj .insert( RuneString::try_from("days").expect("Failed to create RuneString"), - Value::Integer(duration.days as i64), + Value::from(duration.days as i64), ) .map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( @@ -269,18 +251,18 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { rune_obj .insert( RuneString::try_from("nanoseconds").expect("Failed to create RuneString"), - Value::Integer(duration.nanoseconds), + Value::from(duration.nanoseconds), ) .map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( "Failed to insert nanoseconds into duration object".to_string(), ))) })?; - Value::Object(Shared::new(rune_obj).map_err(|_| { + Value::new(rune_obj).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared object for Duration".to_string(), + "Failed to create object value for Duration".to_string(), ))) - })?) + })? } ColumnType::Vector { dimensions, .. } => { let mut rune_vec = @@ -298,11 +280,11 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { ))) })?; } - Value::Vec(Shared::new(rune_vec).map_err(|_| { + Value::vec(rune_vec.into_inner()).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared vector".to_string(), + "Failed to create vector value".to_string(), ))) - })?) + })? } ColumnType::Collection { typ: coll_type, .. } => match coll_type { CollectionType::List(_) | CollectionType::Set(_) => { @@ -321,11 +303,11 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { ))) })?; } - Value::Vec(Shared::new(rune_vec).map_err(|_| { + Value::vec(rune_vec.into_inner()).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared vector".to_string(), + "Failed to create vector value".to_string(), ))) - })?) + })? } CollectionType::Map(_, _) => { let cql_map_iterator = as DeserializeValue< @@ -342,22 +324,22 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { "Failed to create Rune OwnedTuple".to_string(), ))) })?; - let tuple = Value::Tuple(Shared::new(owned_tuple).map_err(|_| { + let tuple = Value::new(owned_tuple).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create Rune Shared".to_string(), + "Failed to create Rune tuple value".to_string(), ))) - })?); + })?; rune_vec.push(tuple).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( "Failed to push map key-value pair to the Rune vector".to_string(), ))) })?; } - Value::Vec(Shared::new(rune_vec).map_err(|_| { + Value::vec(rune_vec.into_inner()).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( "Failed to create shared Rune vector".to_string(), ))) - })?) + })? } _ => todo!(), // unexpected, should never be reached }, @@ -388,11 +370,11 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { ))) })?; } - Value::Object(Shared::new(rune_obj).map_err(|_| { + Value::new(rune_obj).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared object for UDT".to_string(), + "Failed to create object value for UDT".to_string(), ))) - })?) + })? } ColumnType::Tuple(tuple) => { let mut rune_vec = @@ -418,11 +400,11 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { })?; } - Value::Vec(Shared::new(rune_vec).map_err(|_| { + Value::vec(rune_vec.into_inner()).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared vector for tuple".to_string(), + "Failed to create vector value for tuple".to_string(), ))) - })?) + })? } _ => todo!(), // unexpected, should never be reached @@ -496,19 +478,17 @@ mod tests { } fn assert_is_none(v: Value) { - match v { - Value::Option(shared) => { - assert!(shared.borrow_ref().unwrap().is_none(), "expected None"); - } - other => panic!("expected Option(None), got {other:?}"), + match v.borrow_ref::>() { + Ok(opt) => assert!(opt.is_none(), "expected None"), + Err(_) => panic!("expected Option(None)"), } } fn str_from(v: &Value) -> String { - match v { - Value::String(s) => s.borrow_ref().unwrap().as_str().to_owned(), - other => panic!("expected String, got {other:?}"), - } + v.borrow_ref::() + .expect("expected String") + .as_str() + .to_owned() } fn col_spec<'a>(name: &'a str, typ: ColumnType<'a>) -> ColumnSpec<'a> { @@ -550,12 +530,10 @@ mod tests { let bytes = Bytes::new(); let slice = FrameSlice::new(&bytes); let result = RuneValue::deserialize(&typ, Some(slice)).unwrap().0; - match result { - Value::Vec(shared) => { - assert!(shared.borrow_ref().unwrap().is_empty()); - } - other => panic!("expected Vec, got {other:?}"), - } + let vec = result + .borrow_ref::() + .expect("expected Vec"); + assert!(vec.is_empty()); } #[test] @@ -573,14 +551,14 @@ mod tests { fn bool_true() { let typ = ColumnType::Native(NativeType::Boolean); let bytes = cql_to_raw(&typ, CqlValue::Boolean(true)); - assert!(matches!(deser(&typ, &bytes), Value::Bool(true))); + assert!(deser(&typ, &bytes).as_bool().unwrap()); } #[test] fn bool_false() { let typ = ColumnType::Native(NativeType::Boolean); let bytes = cql_to_raw(&typ, CqlValue::Boolean(false)); - assert!(matches!(deser(&typ, &bytes), Value::Bool(false))); + assert!(!deser(&typ, &bytes).as_bool().unwrap()); } // ── integer types ────────────────────────────────────────────────────────── @@ -589,35 +567,35 @@ mod tests { fn tiny_int_roundtrip() { let typ = ColumnType::Native(NativeType::TinyInt); let bytes = cql_to_raw(&typ, CqlValue::TinyInt(-42)); - assert!(matches!(deser(&typ, &bytes), Value::Integer(-42))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), -42_i64); } #[test] fn small_int_roundtrip() { let typ = ColumnType::Native(NativeType::SmallInt); let bytes = cql_to_raw(&typ, CqlValue::SmallInt(1000)); - assert!(matches!(deser(&typ, &bytes), Value::Integer(1000))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), 1000_i64); } #[test] fn int_roundtrip() { let typ = ColumnType::Native(NativeType::Int); let bytes = cql_to_raw(&typ, CqlValue::Int(100_000)); - assert!(matches!(deser(&typ, &bytes), Value::Integer(100_000))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), 100_000_i64); } #[test] fn big_int_roundtrip() { let typ = ColumnType::Native(NativeType::BigInt); let bytes = cql_to_raw(&typ, CqlValue::BigInt(i64::MAX)); - assert!(matches!(deser(&typ, &bytes), Value::Integer(i64::MAX))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), i64::MAX); } #[test] fn counter_roundtrip() { let typ = ColumnType::Native(NativeType::Counter); let bytes = cql_to_raw(&typ, CqlValue::Counter(Counter(42))); - assert!(matches!(deser(&typ, &bytes), Value::Integer(42))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), 42_i64); } // ── floating-point ───────────────────────────────────────────────────────── @@ -626,20 +604,16 @@ mod tests { fn float_roundtrip() { let typ = ColumnType::Native(NativeType::Float); let bytes = cql_to_raw(&typ, CqlValue::Float(1.5)); - match deser(&typ, &bytes) { - Value::Float(f) => assert!((f - 1.5_f32 as f64).abs() < 1e-7), - other => panic!("expected Float, got {other:?}"), - } + let f = deser(&typ, &bytes).as_float().expect("expected Float"); + assert!((f - 1.5_f32 as f64).abs() < 1e-7); } #[test] fn double_roundtrip() { let typ = ColumnType::Native(NativeType::Double); let bytes = cql_to_raw(&typ, CqlValue::Double(1.23456789)); - match deser(&typ, &bytes) { - Value::Float(f) => assert!((f - 1.23456789).abs() < 1e-10), - other => panic!("expected Float, got {other:?}"), - } + let f = deser(&typ, &bytes).as_float().expect("expected Float"); + assert!((f - 1.23456789).abs() < 1e-10); } // ── text / ascii ─────────────────────────────────────────────────────────── @@ -664,17 +638,15 @@ mod tests { fn blob_becomes_vec_of_bytes() { let typ = ColumnType::Native(NativeType::Blob); let bytes = cql_to_raw(&typ, CqlValue::Blob(vec![0xDE, 0xAD, 0xBE, 0xEF])); - match deser(&typ, &bytes) { - Value::Vec(shared) => { - let vec = shared.borrow_ref().unwrap(); - assert_eq!(vec.len(), 4); - assert!(matches!(vec[0], Value::Byte(0xDE))); - assert!(matches!(vec[1], Value::Byte(0xAD))); - assert!(matches!(vec[2], Value::Byte(0xBE))); - assert!(matches!(vec[3], Value::Byte(0xEF))); - } - other => panic!("expected Vec, got {other:?}"), - } + let binding = deser(&typ, &bytes); + let vec = binding + .borrow_ref::() + .expect("expected Vec"); + assert_eq!(vec.len(), 4); + assert_eq!(vec[0].as_integer::().unwrap(), 0xDE); + assert_eq!(vec[1].as_integer::().unwrap(), 0xAD); + assert_eq!(vec[2].as_integer::().unwrap(), 0xBE); + assert_eq!(vec[3].as_integer::().unwrap(), 0xEF); } // ── UUID / Timeuuid ──────────────────────────────────────────────────────── @@ -726,10 +698,10 @@ mod tests { let typ = ColumnType::Native(NativeType::Timestamp); let ts = CqlTimestamp(1_609_459_200_000); // 2021-01-01 00:00:00 UTC in ms let bytes = cql_to_raw(&typ, CqlValue::Timestamp(ts)); - assert!(matches!( - deser(&typ, &bytes), - Value::Integer(1_609_459_200_000) - )); + assert_eq!( + deser(&typ, &bytes).as_signed().unwrap(), + 1_609_459_200_000_i64 + ); } #[test] @@ -737,7 +709,7 @@ mod tests { let typ = ColumnType::Native(NativeType::Date); let date = CqlDate(2_147_483_648u32); // 2^31 let bytes = cql_to_raw(&typ, CqlValue::Date(date)); - assert!(matches!(deser(&typ, &bytes), Value::Integer(2_147_483_648))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), 2_147_483_648_i64); } #[test] @@ -745,10 +717,10 @@ mod tests { let typ = ColumnType::Native(NativeType::Time); let time = CqlTime(3_600_000_000_000i64); // 1 hour in nanoseconds let bytes = cql_to_raw(&typ, CqlValue::Time(time)); - assert!(matches!( - deser(&typ, &bytes), - Value::Integer(3_600_000_000_000) - )); + assert_eq!( + deser(&typ, &bytes).as_signed().unwrap(), + 3_600_000_000_000_i64 + ); } // ── varint ──────────────────────────────────────────────────────────────── @@ -760,7 +732,7 @@ mod tests { &typ, CqlValue::Varint(CqlVarint::from_signed_bytes_be_slice(&[42])), ); - assert!(matches!(deser(&typ, &bytes), Value::Integer(42))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), 42_i64); } #[test] @@ -770,7 +742,7 @@ mod tests { &typ, CqlValue::Varint(CqlVarint::from_signed_bytes_be_slice(&[0xFF])), ); - assert!(matches!(deser(&typ, &bytes), Value::Integer(-1))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), -1_i64); } #[test] @@ -782,7 +754,7 @@ mod tests { &i64::MAX.to_be_bytes(), )), ); - assert!(matches!(deser(&typ, &bytes), Value::Integer(i64::MAX))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), i64::MAX); } #[test] @@ -817,15 +789,13 @@ mod tests { nanoseconds: 3, }; let bytes = cql_to_raw(&typ, CqlValue::Duration(dur)); - match deser(&typ, &bytes) { - Value::Object(shared) => { - let obj = shared.borrow_ref().unwrap(); - assert!(matches!(obj.get("months").unwrap(), Value::Integer(1))); - assert!(matches!(obj.get("days").unwrap(), Value::Integer(2))); - assert!(matches!(obj.get("nanoseconds").unwrap(), Value::Integer(3))); - } - other => panic!("expected Object, got {other:?}"), - } + let binding = deser(&typ, &bytes); + let obj = binding + .borrow_ref::() + .expect("expected Object"); + assert_eq!(obj.get("months").unwrap().as_signed().unwrap(), 1_i64); + assert_eq!(obj.get("days").unwrap().as_signed().unwrap(), 2_i64); + assert_eq!(obj.get("nanoseconds").unwrap().as_signed().unwrap(), 3_i64); } // ── list ────────────────────────────────────────────────────────────────── @@ -841,16 +811,14 @@ mod tests { &typ, CqlValue::List(vec![CqlValue::Int(1), CqlValue::Int(2), CqlValue::Int(3)]), ); - match deser(&typ, &bytes) { - Value::Vec(shared) => { - let vec = shared.borrow_ref().unwrap(); - assert_eq!(vec.len(), 3); - assert!(matches!(vec[0], Value::Integer(1))); - assert!(matches!(vec[1], Value::Integer(2))); - assert!(matches!(vec[2], Value::Integer(3))); - } - other => panic!("expected Vec, got {other:?}"), - } + let binding = deser(&typ, &bytes); + let vec = binding + .borrow_ref::() + .expect("expected Vec"); + assert_eq!(vec.len(), 3); + assert_eq!(vec[0].as_signed().unwrap(), 1_i64); + assert_eq!(vec[1].as_signed().unwrap(), 2_i64); + assert_eq!(vec[2].as_signed().unwrap(), 3_i64); } // ── set ─────────────────────────────────────────────────────────────────── @@ -866,10 +834,11 @@ mod tests { &typ, CqlValue::Set(vec![CqlValue::Text("a".into()), CqlValue::Text("b".into())]), ); - match deser(&typ, &bytes) { - Value::Vec(shared) => assert_eq!(shared.borrow_ref().unwrap().len(), 2), - other => panic!("expected Vec, got {other:?}"), - } + let binding = deser(&typ, &bytes); + let vec = binding + .borrow_ref::() + .expect("expected Vec"); + assert_eq!(vec.len(), 2); } // ── map ─────────────────────────────────────────────────────────────────── @@ -886,21 +855,16 @@ mod tests { &typ, CqlValue::Map(vec![(CqlValue::Text("key".into()), CqlValue::Int(99))]), ); - match deser(&typ, &bytes) { - Value::Vec(shared) => { - let vec = shared.borrow_ref().unwrap(); - assert_eq!(vec.len(), 1); - match &vec[0] { - Value::Tuple(tuple_shared) => { - let tuple = tuple_shared.borrow_ref().unwrap(); - assert_eq!(str_from(&tuple[0]), "key"); - assert!(matches!(tuple[1], Value::Integer(99))); - } - other => panic!("expected Tuple inside map Vec, got {other:?}"), - } - } - other => panic!("expected Vec, got {other:?}"), - } + let binding = deser(&typ, &bytes); + let vec = binding + .borrow_ref::() + .expect("expected Vec"); + assert_eq!(vec.len(), 1); + let tuple = vec[0] + .borrow_ref::() + .expect("expected Tuple inside map Vec"); + assert_eq!(str_from(&tuple[0]), "key"); + assert_eq!(tuple[1].as_signed().unwrap(), 99_i64); } // ── tuple ───────────────────────────────────────────────────────────────── @@ -915,15 +879,13 @@ mod tests { &typ, CqlValue::Tuple(vec![Some(CqlValue::Int(7)), Some(CqlValue::Boolean(true))]), ); - match deser(&typ, &bytes) { - Value::Vec(shared) => { - let vec = shared.borrow_ref().unwrap(); - assert_eq!(vec.len(), 2); - assert!(matches!(vec[0], Value::Integer(7))); - assert!(matches!(vec[1], Value::Bool(true))); - } - other => panic!("expected Vec, got {other:?}"), - } + let binding = deser(&typ, &bytes); + let vec = binding + .borrow_ref::() + .expect("expected Vec"); + assert_eq!(vec.len(), 2); + assert_eq!(vec[0].as_signed().unwrap(), 7_i64); + assert!(vec[1].as_bool().unwrap()); } #[test] @@ -933,15 +895,13 @@ mod tests { ColumnType::Native(NativeType::Int), ]); let bytes = cql_to_raw(&typ, CqlValue::Tuple(vec![Some(CqlValue::Int(1)), None])); - match deser(&typ, &bytes) { - Value::Vec(shared) => { - let vec = shared.borrow_ref().unwrap(); - assert_eq!(vec.len(), 2); - assert!(matches!(vec[0], Value::Integer(1))); - assert_is_none(vec[1].clone()); - } - other => panic!("expected Vec, got {other:?}"), - } + let binding = deser(&typ, &bytes); + let vec = binding + .borrow_ref::() + .expect("expected Vec"); + assert_eq!(vec.len(), 2); + assert_eq!(vec[0].as_signed().unwrap(), 1_i64); + assert_is_none(vec[1].clone()); } // ── vector (CQL vector type) ─────────────────────────────────────────────── @@ -961,15 +921,13 @@ mod tests { CqlValue::Float(3.0), ]), ); - match deser(&typ, &bytes) { - Value::Vec(shared) => { - let vec = shared.borrow_ref().unwrap(); - assert_eq!(vec.len(), 3); - for v in vec.iter() { - assert!(matches!(v, Value::Float(_))); - } - } - other => panic!("expected Vec, got {other:?}"), + let binding = deser(&typ, &bytes); + let vec = binding + .borrow_ref::() + .expect("expected Vec"); + assert_eq!(vec.len(), 3); + for v in vec.iter() { + assert!(v.as_float().is_ok()); } } @@ -999,14 +957,12 @@ mod tests { ], }, ); - match deser(&udt_typ, &bytes) { - Value::Object(shared) => { - let obj = shared.borrow_ref().unwrap(); - assert!(matches!(obj.get("x").unwrap(), Value::Integer(42))); - assert!(matches!(obj.get("y").unwrap(), Value::Bool(false))); - } - other => panic!("expected Object, got {other:?}"), - } + let binding = deser(&udt_typ, &bytes); + let obj = binding + .borrow_ref::() + .expect("expected Object"); + assert_eq!(obj.get("x").unwrap().as_signed().unwrap(), 42_i64); + assert!(!obj.get("y").unwrap().as_bool().unwrap()); } #[test] @@ -1027,13 +983,11 @@ mod tests { fields: vec![("v".into(), None)], }, ); - match deser(&udt_typ, &bytes) { - Value::Object(shared) => { - let obj = shared.borrow_ref().unwrap(); - assert_is_none(obj.get("v").unwrap().clone()); - } - other => panic!("expected Object, got {other:?}"), - } + let binding = deser(&udt_typ, &bytes); + let obj = binding + .borrow_ref::() + .expect("expected Object"); + assert_is_none(obj.get("v").unwrap().clone()); } // ── RuneRow ─────────────────────────────────────────────────────────────── @@ -1060,7 +1014,7 @@ mod tests { let row = RuneRow::deserialize(col_iter).unwrap(); let obj = row.0; - assert!(matches!(obj.get("id").unwrap(), Value::Integer(42))); + assert_eq!(obj.get("id").unwrap().as_signed().unwrap(), 42_i64); assert_eq!(str_from(obj.get("name").unwrap()), "alice"); } diff --git a/src/scripting/cql/functions.rs b/src/scripting/cql/functions.rs index 1a9ba4b..81cb90c 100644 --- a/src/scripting/cql/functions.rs +++ b/src/scripting/cql/functions.rs @@ -2,12 +2,12 @@ use crate::scripting::functions_common::extract_validation_args; use super::cass_error::{CassError, CassErrorKind}; use super::context::Context; -use rune::runtime::{Mut, Ref}; +use rune::runtime::Ref; use rune::Value; use std::ops::Deref; #[rune::function(instance)] -pub async fn prepare(mut ctx: Mut, key: Ref, cql: Ref) -> Result<(), CassError> { +pub async fn prepare(ctx: Ref, key: Ref, cql: Ref) -> Result<(), CassError> { ctx.prepare(&key, &cql).await } @@ -22,46 +22,18 @@ pub async fn execute_with_validation( cql: Ref, validation_args: Vec, ) -> Result { - match validation_args.as_slice() { - // (int): expected_rows - [Value::Integer(expected_rows)] => { - ctx.execute_with_validation( - cql.deref(), - *expected_rows as u64, - *expected_rows as u64, - "", - ) - .await - } - // (int, int): expected_rows_num_min, expected_rows_num_max - [Value::Integer(min), Value::Integer(max)] => { - ctx.execute_with_validation(cql.deref(), *min as u64, *max as u64, "") - .await - } - // (int, str): expected_rows, custom_err_msg - [Value::Integer(expected_rows), Value::String(custom_err_msg)] => { - ctx.execute_with_validation( - cql.deref(), - *expected_rows as u64, - *expected_rows as u64, - &custom_err_msg.borrow_ref().unwrap(), - ) - .await - } - // (int, int, str): expected_rows_num_min, expected_rows_num_max, custom_err_msg - [Value::Integer(min), Value::Integer(max), Value::String(custom_err_msg)] => { - ctx.execute_with_validation( - cql.deref(), - *min as u64, - *max as u64, - &custom_err_msg.borrow_ref().unwrap(), - ) - .await - } - _ => Err(CassError(CassErrorKind::Error( - "Invalid arguments for execute_with_validation".to_string(), - ))), - } + let args = extract_validation_args(validation_args).map_err(|e| { + CassError(CassErrorKind::Error(format!( + "execute_with_validation: {e}" + ))) + })?; + ctx.execute_with_validation( + cql.deref(), + args.expected_min, + args.expected_max, + &args.custom_err_msg, + ) + .await } #[rune::function(instance)] diff --git a/src/scripting/cql/serialize.rs b/src/scripting/cql/serialize.rs index efa732b..cc1e0b3 100644 --- a/src/scripting/cql/serialize.rs +++ b/src/scripting/cql/serialize.rs @@ -5,7 +5,8 @@ use crate::scripting::rune_uuid::Uuid; use chrono::{NaiveDate, NaiveTime}; use once_cell::sync::Lazy; use regex::Regex; -use rune::{Any, ToValue, Value}; +use rune::runtime::{Object, OwnedTuple, Vec as RuneVec}; +use rune::{ToValue, Value}; use scylla::_macro_internal::ColumnType; use scylla::frame::response::result::{CollectionType, ColumnSpec, NativeType}; use scylla::serialize::row::{RowSerializationContext, SerializeRow}; @@ -62,68 +63,55 @@ fn serialize_rune_params( columns: &[ColumnSpec<'_>], writer: &mut RowWriter<'_>, ) -> Result<(), SerializationError> { - match value { - Value::Tuple(tuple) => { - let tuple = tuple.borrow_ref().map_err(|e| { - SerializationError::new(CassError(CassErrorKind::Error(e.to_string()))) - })?; - if tuple.len() != columns.len() { - return Err(SerializationError::new(CassError( - CassErrorKind::InvalidNumberOfQueryParams, - ))); - } - for (v, col) in tuple.iter().zip(columns) { - serialize_rune_cell(v, col.typ(), writer)?; - } - Ok(()) + if let Ok(tuple) = value.borrow_ref::() { + if tuple.len() != columns.len() { + return Err(SerializationError::new(CassError( + CassErrorKind::InvalidNumberOfQueryParams, + ))); } - Value::Vec(vec) => { - let vec = vec.borrow_ref().map_err(|e| { - SerializationError::new(CassError(CassErrorKind::Error(e.to_string()))) - })?; - for (v, col) in vec.iter().zip(columns) { - serialize_rune_cell(v, col.typ(), writer)?; - } - Ok(()) + for (v, col) in tuple.iter().zip(columns) { + serialize_rune_cell(v, col.typ(), writer)?; } - Value::Object(obj) => { - let obj = obj.borrow_ref().map_err(|e| { - SerializationError::new(CassError(CassErrorKind::Error(e.to_string()))) - })?; - for col in columns { - let cql_val = match obj.get(col.name()) { - Some(v) => { - to_scylla_value(v, col.typ()).map_err(|e| SerializationError::new(*e))? - } - None => Some(CqlValue::Empty), - }; - cql_val - .serialize(col.typ(), writer.make_cell_writer()) - .map_err(SerializationError::new)?; - } - Ok(()) + return Ok(()); + } + if let Ok(vec) = value.borrow_ref::() { + for (v, col) in vec.iter().zip(columns) { + serialize_rune_cell(v, col.typ(), writer)?; } - Value::Struct(obj) => { - let obj = obj.borrow_ref().map_err(|e| { - SerializationError::new(CassError(CassErrorKind::Error(e.to_string()))) - })?; - for col in columns { - let cql_val = match obj.get(col.name()) { - Some(v) => { - to_scylla_value(v, col.typ()).map_err(|e| SerializationError::new(*e))? - } - None => Some(CqlValue::Empty), - }; - cql_val - .serialize(col.typ(), writer.make_cell_writer()) - .map_err(SerializationError::new)?; - } - Ok(()) + return Ok(()); + } + if let Ok(obj) = value.borrow_ref::() { + for col in columns { + let cql_val = match obj.get(col.name()) { + Some(v) => { + to_scylla_value(v, col.typ()).map_err(|e| SerializationError::new(*e))? + } + None => Some(CqlValue::Empty), + }; + cql_val + .serialize(col.typ(), writer.make_cell_writer()) + .map_err(SerializationError::new)?; + } + return Ok(()); + } + // Handle struct types (rune typed structs) + if let Ok(rune::runtime::TypeValue::Struct(s)) = value.as_type_value() { + for col in columns { + let cql_val = match s.get(col.name()) { + Some(v) => { + to_scylla_value(v, col.typ()).map_err(|e| SerializationError::new(*e))? + } + None => Some(CqlValue::Empty), + }; + cql_val + .serialize(col.typ(), writer.make_cell_writer()) + .map_err(SerializationError::new)?; } - other => Err(SerializationError::new(CassError( - CassErrorKind::InvalidQueryParamsObject(other.type_info().unwrap()), - ))), + return Ok(()); } + Err(SerializationError::new(CassError( + CassErrorKind::InvalidQueryParamsObject(value.type_info()), + ))) } /// Serializes a single rune value as a CQL cell. @@ -157,468 +145,460 @@ static DURATION_REGEX: Lazy = Lazy::new(|| { }); fn to_scylla_value(v: &Value, typ: &ColumnType) -> Result, Box> { - match (v, typ) { - (Value::Bool(v), ColumnType::Native(NativeType::Boolean)) => { - Ok(Some(CqlValue::Boolean(*v))) - } + // Option (must be checked before inline types) + if let Ok(opt) = v.borrow_ref::>() { + return match opt.as_ref() { + Some(inner) => to_scylla_value(inner, typ), + None => Ok(None), + }; + } - (Value::Byte(v), ColumnType::Native(NativeType::TinyInt)) => { - Ok(Some(CqlValue::TinyInt(*v as i8))) - } - (Value::Byte(v), ColumnType::Native(NativeType::SmallInt)) => { - Ok(Some(CqlValue::SmallInt(*v as i16))) - } - (Value::Byte(v), ColumnType::Native(NativeType::Int)) => Ok(Some(CqlValue::Int(*v as i32))), - (Value::Byte(v), ColumnType::Native(NativeType::BigInt)) => { - Ok(Some(CqlValue::BigInt(*v as i64))) - } + // Bool + if let Ok(b) = v.as_bool() { + return match typ { + ColumnType::Native(NativeType::Boolean) => Ok(Some(CqlValue::Boolean(b))), + _ => type_mismatch(v, typ), + }; + } - (Value::Integer(v), ColumnType::Native(NativeType::TinyInt)) => { - convert_int(*v, NativeType::TinyInt, CqlValue::TinyInt) - } - (Value::Integer(v), ColumnType::Native(NativeType::SmallInt)) => { - convert_int(*v, NativeType::SmallInt, CqlValue::SmallInt) - } - (Value::Integer(v), ColumnType::Native(NativeType::Int)) => { - convert_int(*v, NativeType::Int, CqlValue::Int) - } - (Value::Integer(v), ColumnType::Native(NativeType::BigInt)) => { - Ok(Some(CqlValue::BigInt(*v))) - } - (Value::Integer(v), ColumnType::Native(NativeType::Counter)) => { - Ok(Some(CqlValue::Counter(scylla::value::Counter(*v)))) - } - (Value::Integer(v), ColumnType::Native(NativeType::Timestamp)) => { - Ok(Some(CqlValue::Timestamp(scylla::value::CqlTimestamp(*v)))) - } - (Value::Integer(v), ColumnType::Native(NativeType::Date)) => match (*v).try_into() { - Ok(date) => Ok(Some(CqlValue::Date(CqlDate(date)))), - Err(_) => Err(Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "NativeType::Date".to_string(), - Some("Invalid date value".to_string()), - )))), - }, - (Value::Integer(v), ColumnType::Native(NativeType::Time)) => { - Ok(Some(CqlValue::Time(CqlTime(*v)))) - } - (Value::Integer(v), ColumnType::Native(NativeType::Varint)) => Ok(Some(CqlValue::Varint( - CqlVarint::from_signed_bytes_be((*v).to_be_bytes().to_vec()), - ))), - (Value::Integer(v), ColumnType::Native(NativeType::Decimal)) => { - Ok(Some(CqlValue::Decimal( + // Unsigned integer (u64) — byte literals and other unsigned values + if let Ok(u) = v.as_unsigned() { + return match typ { + ColumnType::Native(NativeType::TinyInt) => Ok(Some(CqlValue::TinyInt(u as i8))), + ColumnType::Native(NativeType::SmallInt) => Ok(Some(CqlValue::SmallInt(u as i16))), + ColumnType::Native(NativeType::Int) => Ok(Some(CqlValue::Int(u as i32))), + ColumnType::Native(NativeType::BigInt) => Ok(Some(CqlValue::BigInt(u as i64))), + _ => type_mismatch(v, typ), + }; + } + + // Signed integer (i64) + if let Ok(i) = v.as_signed() { + return match typ { + ColumnType::Native(NativeType::TinyInt) => { + convert_int(i, NativeType::TinyInt, CqlValue::TinyInt) + } + ColumnType::Native(NativeType::SmallInt) => { + convert_int(i, NativeType::SmallInt, CqlValue::SmallInt) + } + ColumnType::Native(NativeType::Int) => convert_int(i, NativeType::Int, CqlValue::Int), + ColumnType::Native(NativeType::BigInt) => Ok(Some(CqlValue::BigInt(i))), + ColumnType::Native(NativeType::Counter) => { + Ok(Some(CqlValue::Counter(scylla::value::Counter(i)))) + } + ColumnType::Native(NativeType::Timestamp) => { + Ok(Some(CqlValue::Timestamp(scylla::value::CqlTimestamp(i)))) + } + ColumnType::Native(NativeType::Date) => match i.try_into() { + Ok(date) => Ok(Some(CqlValue::Date(CqlDate(date)))), + Err(_) => Err(Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "NativeType::Date".to_string(), + Some("Invalid date value".to_string()), + )))), + }, + ColumnType::Native(NativeType::Time) => Ok(Some(CqlValue::Time(CqlTime(i)))), + ColumnType::Native(NativeType::Varint) => Ok(Some(CqlValue::Varint( + CqlVarint::from_signed_bytes_be(i.to_be_bytes().to_vec()), + ))), + ColumnType::Native(NativeType::Decimal) => Ok(Some(CqlValue::Decimal( scylla::value::CqlDecimal::from_signed_be_bytes_and_exponent( - (*v).to_be_bytes().to_vec(), + i.to_be_bytes().to_vec(), 0, ), - ))) - } + ))), + _ => type_mismatch(v, typ), + }; + } - (Value::Float(v), ColumnType::Native(NativeType::Float)) => { - Ok(Some(CqlValue::Float(*v as f32))) - } - (Value::Float(v), ColumnType::Native(NativeType::Double)) => Ok(Some(CqlValue::Double(*v))), - (Value::Float(v), ColumnType::Native(NativeType::Decimal)) => { - let decimal = rust_decimal::Decimal::from_f64_retain(*v).unwrap(); - Ok(Some(CqlValue::Decimal( - scylla::value::CqlDecimal::from_signed_be_bytes_and_exponent( - decimal.mantissa().to_be_bytes().to_vec(), - decimal.scale().try_into().unwrap(), - ), - ))) - } + // Float (f64) + if let Ok(f) = v.as_float() { + return match typ { + ColumnType::Native(NativeType::Float) => Ok(Some(CqlValue::Float(f as f32))), + ColumnType::Native(NativeType::Double) => Ok(Some(CqlValue::Double(f))), + ColumnType::Native(NativeType::Decimal) => { + let decimal = rust_decimal::Decimal::from_f64_retain(f).unwrap(); + Ok(Some(CqlValue::Decimal( + scylla::value::CqlDecimal::from_signed_be_bytes_and_exponent( + decimal.mantissa().to_be_bytes().to_vec(), + decimal.scale().try_into().unwrap(), + ), + ))) + } + _ => type_mismatch(v, typ), + }; + } - (Value::String(s), ColumnType::Native(NativeType::Date)) => { - let date_str = s.borrow_ref().unwrap(); - let naive_date = NaiveDate::parse_from_str(&date_str, "%Y-%m-%d").map_err(|e| { - CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "NativeType::Date".to_string(), - Some(format!("{e}")), - )) - })?; - let cql_date = CqlDate::from(naive_date); - Ok(Some(CqlValue::Date(cql_date))) - } - (Value::String(s), ColumnType::Native(NativeType::Time)) => { - let time_str = s.borrow_ref().unwrap(); - let mut time_format = "%H:%M:%S".to_string(); - if time_str.contains('.') { - time_format = format!("{time_format}.%f"); + // String + if let Ok(s) = v.borrow_ref::() { + return match typ { + ColumnType::Native(NativeType::Date) => { + let naive_date = + NaiveDate::parse_from_str(s.as_str(), "%Y-%m-%d").map_err(|e| { + CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "NativeType::Date".to_string(), + Some(format!("{e}")), + )) + })?; + Ok(Some(CqlValue::Date(CqlDate::from(naive_date)))) } - let naive_time = NaiveTime::parse_from_str(&time_str, &time_format).map_err(|e| { - Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "NativeType::Time".to_string(), - Some(format!("{e}")), - ))) - })?; - let cql_time = CqlTime::try_from(naive_time)?; - Ok(Some(CqlValue::Time(cql_time))) - } - (Value::String(s), ColumnType::Native(NativeType::Duration)) => { - // TODO: add support for the following 'ISO 8601' format variants: - // - ISO 8601 format: P[n]Y[n]M[n]DT[n]H[n]M[n]S or P[n]W - // - ISO 8601 alternative format: P[YYYY]-[MM]-[DD]T[hh]:[mm]:[ss] - // See: https://opensource.docs.scylladb.com/stable/cql/types.html#working-with-durations - let duration_str = s.borrow_ref().unwrap(); - if duration_str.is_empty() { - return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "NativeType::Duration".to_string(), - Some("Duration cannot be empty".to_string()), - )))); + ColumnType::Native(NativeType::Time) => { + let mut time_format = "%H:%M:%S".to_string(); + if s.as_str().contains('.') { + time_format = format!("{time_format}.%f"); + } + let naive_time = + NaiveTime::parse_from_str(s.as_str(), &time_format).map_err(|e| { + Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "NativeType::Time".to_string(), + Some(format!("{e}")), + ))) + })?; + Ok(Some(CqlValue::Time(CqlTime::try_from(naive_time)?))) } - // NOTE: we parse the duration explicitly because of the 'CqlDuration' type specifics. - // It stores only months, days and nanoseconds. - // So, we do not translate days to months and hours to days because those are ambiguous - let (mut months, mut days, mut nanoseconds) = (0, 0, 0); - let mut matches_counter = HashMap::from([ - ("y", 0), - ("mo", 0), - ("w", 0), - ("d", 0), - ("h", 0), - ("m", 0), - ("s", 0), - ("ms", 0), - ("us", 0), - ("ns", 0), - ]); - for cap in DURATION_REGEX.captures_iter(&duration_str) { - if let Some(m) = cap.name("years") { - months += m.as_str().parse::().unwrap() * 12; - *matches_counter.entry("y").or_insert(1) += 1; - } else if let Some(m) = cap.name("months") { - months += m.as_str().parse::().unwrap(); - *matches_counter.entry("mo").or_insert(1) += 1; - } else if let Some(m) = cap.name("weeks") { - days += m.as_str().parse::().unwrap() * 7; - *matches_counter.entry("w").or_insert(1) += 1; - } else if let Some(m) = cap.name("days") { - days += m.as_str().parse::().unwrap(); - *matches_counter.entry("d").or_insert(1) += 1; - } else if let Some(m) = cap.name("hours") { - nanoseconds += m.as_str().parse::().unwrap() * 3_600_000_000_000; - *matches_counter.entry("h").or_insert(1) += 1; - } else if let Some(m) = cap.name("minutes") { - nanoseconds += m.as_str().parse::().unwrap() * 60_000_000_000; - *matches_counter.entry("m").or_insert(1) += 1; - } else if let Some(m) = cap.name("seconds") { - nanoseconds += m.as_str().parse::().unwrap() * 1_000_000_000; - *matches_counter.entry("s").or_insert(1) += 1; - } else if let Some(m) = cap.name("millis") { - nanoseconds += m.as_str().parse::().unwrap() * 1_000_000; - *matches_counter.entry("ms").or_insert(1) += 1; - } else if let Some(m) = cap.name("micros") { - nanoseconds += m.as_str().parse::().unwrap() * 1_000; - *matches_counter.entry("us").or_insert(1) += 1; - } else if let Some(m) = cap.name("nanoseconds") { - nanoseconds += m.as_str().parse::().unwrap(); - *matches_counter.entry("ns").or_insert(1) += 1; - } else if cap.name("invalid").is_some() { + ColumnType::Native(NativeType::Duration) => { + // TODO: add support for the following 'ISO 8601' format variants: + // - ISO 8601 format: P[n]Y[n]M[n]DT[n]H[n]M[n]S or P[n]W + // - ISO 8601 alternative format: P[YYYY]-[MM]-[DD]T[hh]:[mm]:[ss] + // See: https://opensource.docs.scylladb.com/stable/cql/types.html#working-with-durations + let duration_str = s.as_str(); + if duration_str.is_empty() { return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( format!("{v:?}"), "NativeType::Duration".to_string(), - Some("Got invalid duration value".to_string()), + Some("Duration cannot be empty".to_string()), )))); } + // NOTE: we parse the duration explicitly because of the 'CqlDuration' type specifics. + // It stores only months, days and nanoseconds. + // So, we do not translate days to months and hours to days because those are ambiguous + let (mut months, mut days, mut nanoseconds) = (0, 0, 0); + let mut matches_counter = HashMap::from([ + ("y", 0), + ("mo", 0), + ("w", 0), + ("d", 0), + ("h", 0), + ("m", 0), + ("s", 0), + ("ms", 0), + ("us", 0), + ("ns", 0), + ]); + for cap in DURATION_REGEX.captures_iter(duration_str) { + if let Some(m) = cap.name("years") { + months += m.as_str().parse::().unwrap() * 12; + *matches_counter.entry("y").or_insert(1) += 1; + } else if let Some(m) = cap.name("months") { + months += m.as_str().parse::().unwrap(); + *matches_counter.entry("mo").or_insert(1) += 1; + } else if let Some(m) = cap.name("weeks") { + days += m.as_str().parse::().unwrap() * 7; + *matches_counter.entry("w").or_insert(1) += 1; + } else if let Some(m) = cap.name("days") { + days += m.as_str().parse::().unwrap(); + *matches_counter.entry("d").or_insert(1) += 1; + } else if let Some(m) = cap.name("hours") { + nanoseconds += m.as_str().parse::().unwrap() * 3_600_000_000_000; + *matches_counter.entry("h").or_insert(1) += 1; + } else if let Some(m) = cap.name("minutes") { + nanoseconds += m.as_str().parse::().unwrap() * 60_000_000_000; + *matches_counter.entry("m").or_insert(1) += 1; + } else if let Some(m) = cap.name("seconds") { + nanoseconds += m.as_str().parse::().unwrap() * 1_000_000_000; + *matches_counter.entry("s").or_insert(1) += 1; + } else if let Some(m) = cap.name("millis") { + nanoseconds += m.as_str().parse::().unwrap() * 1_000_000; + *matches_counter.entry("ms").or_insert(1) += 1; + } else if let Some(m) = cap.name("micros") { + nanoseconds += m.as_str().parse::().unwrap() * 1_000; + *matches_counter.entry("us").or_insert(1) += 1; + } else if let Some(m) = cap.name("nanoseconds") { + nanoseconds += m.as_str().parse::().unwrap(); + *matches_counter.entry("ns").or_insert(1) += 1; + } else if cap.name("invalid").is_some() { + return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "NativeType::Duration".to_string(), + Some("Got invalid duration value".to_string()), + )))); + } + } + if matches_counter.values().all(|&v| v == 0) { + return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "NativeType::Duration".to_string(), + Some("None time units were found".to_string()), + )))); + } + let duplicated_units: Vec<&str> = matches_counter + .iter() + .filter(|&(_, &count)| count > 1) + .map(|(&unit, _)| unit) + .collect(); + if !duplicated_units.is_empty() { + return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "NativeType::Duration".to_string(), + Some(format!( + "Got multiple matches for time unit(s): {}", + duplicated_units.join(", ") + )), + )))); + } + Ok(Some(CqlValue::Duration(CqlDuration { + months, + days, + nanoseconds, + }))) } - if matches_counter.values().all(|&v| v == 0) { - return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "NativeType::Duration".to_string(), - Some("None time units were found".to_string()), - )))); - } - let duplicated_units: Vec<&str> = matches_counter - .iter() - .filter(|&(_, &count)| count > 1) - .map(|(&unit, _)| unit) - .collect(); - if !duplicated_units.is_empty() { - return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "NativeType::Duration".to_string(), - Some(format!( - "Got multiple matches for time unit(s): {}", - duplicated_units.join(", ") - )), - )))); - } - let cql_duration = CqlDuration { - months, - days, - nanoseconds, - }; - Ok(Some(CqlValue::Duration(cql_duration))) - } - - (Value::String(s), ColumnType::Native(NativeType::Varint)) => { - let varint_str = s.borrow_ref().unwrap(); - if !varint_str.chars().all(|c| c.is_ascii_digit()) { - return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "NativeType::Varint".to_string(), - Some("Input contains non-digit characters".to_string()), - )))); + ColumnType::Native(NativeType::Varint) => { + if !s.as_str().chars().all(|c| c.is_ascii_digit()) { + return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "NativeType::Varint".to_string(), + Some("Input contains non-digit characters".to_string()), + )))); + } + let byte_vector: Vec = s + .as_str() + .chars() + .map(|c| c.to_digit(10).expect("Invalid digit") as u8) + .collect(); + Ok(Some(CqlValue::Varint( + scylla::value::CqlVarint::from_signed_bytes_be(byte_vector), + ))) } - let byte_vector: Vec = varint_str - .chars() - .map(|c| c.to_digit(10).expect("Invalid digit") as u8) - .collect(); - Ok(Some(CqlValue::Varint( - scylla::value::CqlVarint::from_signed_bytes_be(byte_vector), - ))) - } - (Value::String(s), ColumnType::Native(NativeType::Timeuuid)) => { - let timeuuid_str = s.borrow_ref().unwrap(); - let timeuuid = CqlTimeuuid::from_str(timeuuid_str.as_str()); - match timeuuid { + ColumnType::Native(NativeType::Timeuuid) => match CqlTimeuuid::from_str(s.as_str()) { Ok(timeuuid) => Ok(Some(CqlValue::Timeuuid(timeuuid))), Err(e) => Err(Box::new(CassError(CassErrorKind::QueryParamConversion( format!("{v:?}"), "NativeType::Timeuuid".to_string(), Some(format!("{e}")), )))), + }, + ColumnType::Native(NativeType::Text) | ColumnType::Native(NativeType::Ascii) => { + Ok(Some(CqlValue::Text(s.as_str().to_string()))) } - } - ( - Value::String(v), - ColumnType::Native(NativeType::Text) | ColumnType::Native(NativeType::Ascii), - ) => Ok(Some(CqlValue::Text( - v.borrow_ref().unwrap().as_str().to_string(), - ))), - (Value::String(s), ColumnType::Native(NativeType::Inet)) => { - let ipaddr_str = s.borrow_ref().unwrap(); - let ipaddr = IpAddr::from_str(ipaddr_str.as_str()); - match ipaddr { + ColumnType::Native(NativeType::Inet) => match IpAddr::from_str(s.as_str()) { Ok(ipaddr) => Ok(Some(CqlValue::Inet(ipaddr))), Err(e) => Err(Box::new(CassError(CassErrorKind::QueryParamConversion( format!("{v:?}"), "NativeType::Inet".to_string(), Some(format!("{e}")), )))), + }, + ColumnType::Native(NativeType::Decimal) => { + let decimal = rust_decimal::Decimal::from_str_exact(s.as_str()).unwrap(); + Ok(Some(CqlValue::Decimal( + scylla::value::CqlDecimal::from_signed_be_bytes_and_exponent( + decimal.mantissa().to_be_bytes().to_vec(), + decimal.scale().try_into().unwrap(), + ), + ))) } - } - (Value::String(s), ColumnType::Native(NativeType::Decimal)) => { - let dec_str = s.borrow_ref().unwrap(); - let decimal = rust_decimal::Decimal::from_str_exact(&dec_str).unwrap(); - Ok(Some(CqlValue::Decimal( - scylla::value::CqlDecimal::from_signed_be_bytes_and_exponent( - decimal.mantissa().to_be_bytes().to_vec(), - decimal.scale().try_into().unwrap(), - ), - ))) - } - (Value::Bytes(v), ColumnType::Native(NativeType::Blob)) => { - Ok(Some(CqlValue::Blob(v.borrow_ref().unwrap().to_vec()))) - } - (Value::Vec(v), ColumnType::Native(NativeType::Blob)) => { - let v: Vec = v.borrow_ref().unwrap().to_vec(); - let byte_vec: Vec = v - .into_iter() - .map(|value| value.as_byte().unwrap()) - .collect(); - Ok(Some(CqlValue::Blob(byte_vec))) - } - (Value::Option(v), typ) => match v.borrow_ref().unwrap().as_ref() { - Some(v) => to_scylla_value(v, typ), - None => Ok(None), - }, - (Value::Tuple(v), ColumnType::Tuple(tuple)) => { - let v = v.borrow_ref().unwrap(); - let mut elements = Vec::with_capacity(v.len()); - for (i, current_element) in v.iter().enumerate() { - let element = to_scylla_value(current_element, &tuple[i])?; - elements.push(element); + _ => type_mismatch(v, typ), + }; + } + + // Bytes (rune::runtime::Bytes) + if let Ok(b) = v.borrow_ref::() { + return match typ { + ColumnType::Native(NativeType::Blob) => Ok(Some(CqlValue::Blob(b.to_vec()))), + _ => type_mismatch(v, typ), + }; + } + + // Vec (rune::runtime::Vec) + if let Ok(vec) = v.borrow_ref::() { + return match typ { + ColumnType::Native(NativeType::Blob) => { + let byte_vec: Vec = + vec.iter().map(|v| v.as_unsigned().unwrap() as u8).collect(); + Ok(Some(CqlValue::Blob(byte_vec))) } - Ok(Some(CqlValue::Tuple(elements))) - } - (Value::Vec(v), ColumnType::Tuple(tuple)) => { - let v = v.borrow_ref().unwrap(); - let mut elements = Vec::with_capacity(v.len()); - for (i, current_element) in v.iter().enumerate() { - let element = to_scylla_value(current_element, &tuple[i])?; - elements.push(element); + ColumnType::Tuple(tuple) => { + let mut elements = Vec::with_capacity(vec.len()); + for (i, current_element) in vec.iter().enumerate() { + elements.push(to_scylla_value(current_element, &tuple[i])?); + } + Ok(Some(CqlValue::Tuple(elements))) } - Ok(Some(CqlValue::Tuple(elements))) - } - (Value::Vec(v), ColumnType::Vector { typ, .. }) => { - let v = v.borrow_ref().unwrap(); - let elements = v - .as_ref() - .iter() - .map(|v| { - to_scylla_value(v, typ).and_then(|opt| { - opt.ok_or_else(|| { - Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "ColumnType::Vector".to_string(), - None, - ))) + ColumnType::Vector { typ: elem_typ, .. } => { + let elements = vec + .iter() + .map(|v| { + to_scylla_value(v, elem_typ).and_then(|opt| { + opt.ok_or_else(|| { + Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "ColumnType::Vector".to_string(), + None, + ))) + }) }) }) - }) - .try_collect()?; - Ok(Some(CqlValue::Vector(elements))) - } - ( - Value::Vec(v), + .try_collect()?; + Ok(Some(CqlValue::Vector(elements))) + } ColumnType::Collection { - frozen: _, typ: CollectionType::List(elt), - }, - ) => { - let v = v.borrow_ref().unwrap(); - let elements = v - .as_ref() - .iter() - .map(|v| { - to_scylla_value(v, elt).and_then(|opt| { - opt.ok_or_else(|| { - Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "CollectionType::List".to_string(), - None, - ))) + .. + } => { + let elements = vec + .iter() + .map(|v| { + to_scylla_value(v, elt).and_then(|opt| { + opt.ok_or_else(|| { + Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "CollectionType::List".to_string(), + None, + ))) + }) }) }) - }) - .try_collect()?; - Ok(Some(CqlValue::List(elements))) - } - ( - Value::Vec(v), + .try_collect()?; + Ok(Some(CqlValue::List(elements))) + } ColumnType::Collection { - frozen: _, typ: CollectionType::Set(elt), - }, - ) => { - let v = v.borrow_ref().unwrap(); - let elements = v - .as_ref() - .iter() - .map(|v| { - to_scylla_value(v, elt).and_then(|opt| { - opt.ok_or_else(|| { - Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "CollectionType::Set".to_string(), - None, - ))) + .. + } => { + let elements = vec + .iter() + .map(|v| { + to_scylla_value(v, elt).and_then(|opt| { + opt.ok_or_else(|| { + Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "CollectionType::Set".to_string(), + None, + ))) + }) }) }) - }) - .try_collect()?; - Ok(Some(CqlValue::Set(elements))) - } - ( - Value::Vec(v), + .try_collect()?; + Ok(Some(CqlValue::Set(elements))) + } ColumnType::Collection { - frozen: _, typ: CollectionType::Map(key_elt, value_elt), - }, - ) => { - let v = v.borrow_ref().unwrap(); - let mut map_vec = Vec::with_capacity(v.len()); - for tuple in v.iter() { - match tuple { - Value::Tuple(tuple) if tuple.borrow_ref().unwrap().len() == 2 => { - let tuple = tuple.borrow_ref().unwrap(); - let key = to_scylla_value(tuple.first().unwrap(), key_elt)?.unwrap(); - let value = to_scylla_value(tuple.get(1).unwrap(), value_elt)?.unwrap(); - map_vec.push((key, value)); - } - _ => { + .. + } => { + let mut map_vec = Vec::with_capacity(vec.len()); + for item in vec.iter() { + if let Ok(tuple) = item.borrow_ref::() { + if tuple.len() == 2 { + let key = to_scylla_value(tuple.first().unwrap(), key_elt)?.unwrap(); + let value = to_scylla_value(tuple.get(1).unwrap(), value_elt)?.unwrap(); + map_vec.push((key, value)); + } else { + return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{item:?}"), + "CollectionType::Map".to_string(), + None, + )))); + } + } else { return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{tuple:?}"), + format!("{item:?}"), "CollectionType::Map".to_string(), None, )))); } } + Ok(Some(CqlValue::Map(map_vec))) } - Ok(Some(CqlValue::Map(map_vec))) - } - ( - Value::Object(obj), + _ => type_mismatch(v, typ), + }; + } + + // OwnedTuple + if let Ok(tuple) = v.borrow_ref::() { + return match typ { + ColumnType::Tuple(types) => { + let mut elements = Vec::with_capacity(tuple.len()); + for (i, current_element) in tuple.iter().enumerate() { + elements.push(to_scylla_value(current_element, &types[i])?); + } + Ok(Some(CqlValue::Tuple(elements))) + } + _ => type_mismatch(v, typ), + }; + } + + // Object + if let Ok(obj) = v.borrow_ref::() { + return match typ { ColumnType::Collection { - frozen: _, typ: CollectionType::Map(key_elt, value_elt), - }, - ) => { - let obj = obj.borrow_ref().unwrap(); - let mut map_vec = Vec::with_capacity(obj.keys().len()); - for (k, v) in obj.iter() { - let key = String::from(k.as_str()); - let key = to_scylla_value(&(key.to_value().unwrap()), key_elt)?.unwrap(); - let value = to_scylla_value(v, value_elt)?.unwrap(); - map_vec.push((key, value)); + .. + } => { + let mut map_vec = Vec::with_capacity(obj.keys().len()); + for (k, val) in obj.iter() { + let key = String::from(k.as_str()); + let key = to_scylla_value(&key.to_value().unwrap(), key_elt)?.unwrap(); + let value = to_scylla_value(val, value_elt)?.unwrap(); + map_vec.push((key, value)); + } + Ok(Some(CqlValue::Map(map_vec))) } - Ok(Some(CqlValue::Map(map_vec))) - } - ( - Value::Object(v), - ColumnType::UserDefinedType { - frozen: _, - definition, - }, - ) => { - let obj = v.borrow_ref().unwrap(); - let field_types: Vec<(String, ColumnType)> = definition - .field_types - .iter() - .map(|(name, typ)| (name.to_string(), typ.clone())) - .collect(); - let fields = read_fields(|s| obj.get(s), &field_types)?; - Ok(Some(CqlValue::UserDefinedType { - name: definition.name.to_string(), - keyspace: definition.keyspace.to_string(), - fields, - })) - } - ( - Value::Struct(v), - ColumnType::UserDefinedType { - frozen: _, - definition, - }, - ) => { - let obj = v.borrow_ref().unwrap(); - let field_types: Vec<(String, ColumnType)> = definition - .field_types - .iter() - .map(|(name, typ)| (name.to_string(), typ.clone())) - .collect(); - let fields = read_fields(|s| obj.get(s), &field_types)?; - Ok(Some(CqlValue::UserDefinedType { - name: definition.name.to_string(), - keyspace: definition.keyspace.to_string(), - fields, - })) - } + ColumnType::UserDefinedType { definition, .. } => { + let field_types: Vec<(String, ColumnType)> = definition + .field_types + .iter() + .map(|(name, typ)| (name.to_string(), typ.clone())) + .collect(); + let fields = read_fields(|s| obj.get(s), &field_types)?; + Ok(Some(CqlValue::UserDefinedType { + name: definition.name.to_string(), + keyspace: definition.keyspace.to_string(), + fields, + })) + } + _ => type_mismatch(v, typ), + }; + } - (Value::Any(obj), ColumnType::Native(NativeType::Uuid)) => { - let obj = obj.borrow_ref().unwrap(); - let h = obj.type_hash(); - if h == Uuid::type_hash() { - let uuid: &Uuid = obj.downcast_borrow_ref().unwrap(); - Ok(Some(CqlValue::Uuid(uuid.0))) - } else { - Err(Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "NativeType::Uuid".to_string(), - None, - )))) + // Struct (rune typed struct) + if let Ok(rune::runtime::TypeValue::Struct(s)) = v.as_type_value() { + return match typ { + ColumnType::UserDefinedType { definition, .. } => { + let field_types: Vec<(String, ColumnType)> = definition + .field_types + .iter() + .map(|(name, typ)| (name.to_string(), typ.clone())) + .collect(); + let fields = read_fields(|name| s.get(name), &field_types)?; + Ok(Some(CqlValue::UserDefinedType { + name: definition.name.to_string(), + keyspace: definition.keyspace.to_string(), + fields, + })) } - } - (value, typ) => Err(Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{value:?}"), - format!("{typ:?}").to_string(), - None, - )))), + _ => type_mismatch(v, typ), + }; + } + + // UUID (custom Rune Any type) + if let Ok(uuid) = v.borrow_ref::() { + return match typ { + ColumnType::Native(NativeType::Uuid) => Ok(Some(CqlValue::Uuid(uuid.0))), + _ => type_mismatch(v, typ), + }; } + + type_mismatch(v, typ) +} + +fn type_mismatch(v: &Value, typ: &ColumnType) -> Result, Box> { + Err(Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + format!("{typ:?}").to_string(), + None, + )))) } fn convert_int, R>( @@ -655,7 +635,7 @@ mod tests { use rstest::rstest; use rune::alloc::String as RuneString; - use rune::runtime::{Object, Shared, Vec as RuneVec}; + use rune::runtime::{Object, Vec as RuneVec}; use scylla::frame::response::result::TableSpec; use scylla::serialize::row::RowSerializationContext; use scylla::serialize::writers::RowWriter; @@ -665,19 +645,19 @@ mod tests { // ── helpers ────────────────────────────────────────────────────── fn rune_string(s: &str) -> Value { - Value::String(Shared::new(RuneString::try_from(s).unwrap()).unwrap()) + Value::new(RuneString::try_from(s).unwrap()).unwrap() } fn rune_int(i: i64) -> Value { - Value::Integer(i) + Value::from(i) } fn rune_float(f: f64) -> Value { - Value::Float(f) + Value::from(f) } fn rune_bool(b: bool) -> Value { - Value::Bool(b) + Value::from(b) } fn rune_vec(items: Vec) -> Value { @@ -685,7 +665,7 @@ mod tests { for item in items { v.push(item).unwrap(); } - Value::Vec(Shared::new(v).unwrap()) + Value::vec(v.into_inner()).unwrap() } fn rune_tuple(items: Vec) -> Value { @@ -693,7 +673,7 @@ mod tests { for item in items { v.try_push(item).unwrap(); } - Value::Tuple(Shared::new(rune::runtime::OwnedTuple::try_from(v).unwrap()).unwrap()) + Value::new(rune::runtime::OwnedTuple::try_from(v).unwrap()).unwrap() } fn rune_object(pairs: Vec<(&str, Value)>) -> Value { @@ -701,7 +681,7 @@ mod tests { for (k, v) in pairs { obj.insert(RuneString::try_from(k).unwrap(), v).unwrap(); } - Value::Object(Shared::new(obj).unwrap()) + Value::new(obj).unwrap() } fn col_spec<'a>(name: &'a str, typ: ColumnType<'a>) -> ColumnSpec<'a> { @@ -894,14 +874,14 @@ mod tests { #[test] fn test_to_scylla_value_option_none() { - let val = Value::Option(Shared::new(None).unwrap()); + let val = Value::try_from(None::).unwrap(); let result = to_scylla_value(&val, &ColumnType::Native(NativeType::Int)); assert_eq!(result.unwrap(), None); } #[test] fn test_to_scylla_value_option_some() { - let val = Value::Option(Shared::new(Some(rune_int(42))).unwrap()); + let val = Value::try_from(Some(rune_int(42))).unwrap(); let result = to_scylla_value(&val, &ColumnType::Native(NativeType::Int)); assert_eq!(result.unwrap(), Some(CqlValue::Int(42))); } @@ -912,6 +892,15 @@ mod tests { assert!(result.is_err()); } + // ── to_scylla_value tests (blob) ──────────────────────────────── + + #[test] + fn test_to_scylla_value_blob_from_rune_bytes() { + let bytes_val = rune::to_value(vec![1u8, 2u8, 3u8]).unwrap(); + let result = to_scylla_value(&bytes_val, &ColumnType::Native(NativeType::Blob)).unwrap(); + assert_eq!(result, Some(CqlValue::Blob(vec![1u8, 2u8, 3u8]))); + } + // ── to_scylla_value tests (collections) ───────────────────────── #[test] @@ -1259,10 +1248,7 @@ mod tests { #[case] ns: i64, ) { let expected = format!("{mo:?}mo{d:?}d{ns:?}ns"); - let duration_rune_str = Value::String( - Shared::new(RuneString::try_from(input).expect("Failed to create RuneString")) - .expect("Failed to create Shared RuneString"), - ); + let duration_rune_str = rune_string(input.as_str()); let actual = to_scylla_value( &duration_rune_str, &ColumnType::Native(NativeType::Duration), @@ -1281,10 +1267,7 @@ mod tests { #[case("fake")] #[case("1d2h3m4h")] fn test_to_scylla_value_duration_neg(#[case] input: String) { - let duration_rune_str = Value::String( - Shared::new(RuneString::try_from(input.clone()).expect("Failed to create RuneString")) - .expect("Failed to create Shared RuneString"), - ); + let duration_rune_str = rune_string(input.as_str()); let actual = to_scylla_value( &duration_rune_str, &ColumnType::Native(NativeType::Duration), diff --git a/src/scripting/functions_common.rs b/src/scripting/functions_common.rs index aa74cf8..40c2034 100644 --- a/src/scripting/functions_common.rs +++ b/src/scripting/functions_common.rs @@ -55,31 +55,43 @@ pub struct ValidationArgs { /// * [Integer, String] -> Exact number of expected rows and custom error message. /// * [Integer, Integer, String] -> Range of expected rows and custom error message. pub fn extract_validation_args(validation_args: Vec) -> Result { + let as_int = |v: &Value| v.as_signed().ok().map(|i| i as u64); + let as_str = |v: &Value| { + v.borrow_ref::() + .ok() + .map(|s| s.as_str().to_string()) + }; match validation_args.as_slice() { // (int): expected_rows - [Value::Integer(expected_rows)] => Ok(ValidationArgs { - expected_min: *expected_rows as u64, - expected_max: *expected_rows as u64, - custom_err_msg: String::new(), - }), + [a] if as_int(a).is_some() => { + let n = as_int(a).unwrap(); + Ok(ValidationArgs { + expected_min: n, + expected_max: n, + custom_err_msg: String::new(), + }) + } // (int, int): expected_rows_num_min, expected_rows_num_max - [Value::Integer(min), Value::Integer(max)] => Ok(ValidationArgs { - expected_min: *min as u64, - expected_max: *max as u64, + [a, b] if as_int(a).is_some() && as_int(b).is_some() => Ok(ValidationArgs { + expected_min: as_int(a).unwrap(), + expected_max: as_int(b).unwrap(), custom_err_msg: String::new(), }), // (int, str): expected_rows, custom_err_msg - [Value::Integer(expected_rows), Value::String(custom_err_msg)] => Ok(ValidationArgs { - expected_min: *expected_rows as u64, - expected_max: *expected_rows as u64, - custom_err_msg: custom_err_msg.borrow_ref().unwrap().to_string(), - }), + [a, b] if as_int(a).is_some() && as_str(b).is_some() => { + let n = as_int(a).unwrap(); + Ok(ValidationArgs { + expected_min: n, + expected_max: n, + custom_err_msg: as_str(b).unwrap(), + }) + } // (int, int, str): expected_rows_num_min, expected_rows_num_max, custom_err_msg - [Value::Integer(min), Value::Integer(max), Value::String(custom_err_msg)] => { + [a, b, c] if as_int(a).is_some() && as_int(b).is_some() && as_str(c).is_some() => { Ok(ValidationArgs { - expected_min: *min as u64, - expected_max: *max as u64, - custom_err_msg: custom_err_msg.borrow_ref().unwrap().to_string(), + expected_min: as_int(a).unwrap(), + expected_max: as_int(b).unwrap(), + custom_err_msg: as_str(c).unwrap(), }) } _ => Err("Invalid validation arguments".to_string()), @@ -211,7 +223,7 @@ pub fn join(collection: &[Value], separator: &str) -> VmResult { if !first { result.push_str(separator); } - result.push_str(vm_try!(v.borrow_ref()).as_str()); + result.push_str(v.as_str()); first = false; } VmResult::Ok(result) @@ -225,10 +237,8 @@ pub fn is_none(input: Value) -> bool { // With this function it is possible to check for None the following way: // let result = if is_none(row.some_col) { "None" } else { row.some_col }; // println!("DEBUG: value for some_col is '{result}'", result=result); - if let Value::Option(option) = input { - if let Ok(borrowed) = option.borrow_ref() { - return borrowed.is_none(); - } + if let Ok(opt) = input.borrow_ref::>() { + return opt.is_none(); } false } diff --git a/src/scripting/row_distribution.rs b/src/scripting/row_distribution.rs index 80f77f0..c97bb01 100644 --- a/src/scripting/row_distribution.rs +++ b/src/scripting/row_distribution.rs @@ -1,4 +1,4 @@ -use rune::runtime::{Mut, Ref}; +use rune::runtime::Ref; use rune::Any; use std::collections::HashMap; @@ -169,14 +169,14 @@ impl RowDistributionPreset { #[rune::function(instance)] pub async fn init_partition_row_distribution_preset( - mut ctx: Mut, + ctx: Ref, preset_name: Ref, row_count: u64, rows_per_partitions_base: u64, rows_per_partitions_groups: Ref, ) -> Result<(), DbError> { _init_partition_row_distribution_preset( - &mut ctx, + &ctx, &preset_name, row_count, rows_per_partitions_base, @@ -214,7 +214,7 @@ pub async fn get_partition_idx(ctx: Ref, preset_name: Ref, idx: u6 /// Creates a preset for uneven row distribution among partitions #[allow(clippy::comparison_chain)] async fn _init_partition_row_distribution_preset( - ctx: &mut Context, + ctx: &Context, preset_name: &str, row_count: u64, rows_per_partitions_base: u64, @@ -412,6 +412,8 @@ async fn _init_partition_row_distribution_preset( // NOTE: generate row distributions only after the partition groups are finished with changes row_distribution_preset.generate_row_distributions(); ctx.partition_row_presets + .try_lock() + .unwrap() .insert(preset_name.to_string(), row_distribution_preset); Ok(()) @@ -423,11 +425,17 @@ async fn _get_partition_info( preset_name: &str, idx: u64, ) -> Result<(u64, u64), DbError> { - let preset = ctx.partition_row_presets.get(preset_name).ok_or_else(|| { - DbError::new(DbErrorKind::PartitionRowPresetNotFound( - preset_name.to_string(), - )) - })?; + let preset = ctx + .partition_row_presets + .try_lock() + .unwrap() + .get(preset_name) + .cloned() + .ok_or_else(|| { + DbError::new(DbErrorKind::PartitionRowPresetNotFound( + preset_name.to_string(), + )) + })?; Ok(preset.get_partition_info(idx).await) } @@ -526,20 +534,23 @@ mod tests { expected_idx_partition_idx_mapping: Vec<(u64, u64)>, ) { for (rows_per_partitions_base, rows_per_partitions_groups) in rows_per_partitions_base_and_groups_mapping { - let mut ctxt: Context = create_test_context(); + let ctxt: Context = create_test_context(); let preset_name = "foo_name"; - assert!(ctxt.partition_row_presets.is_empty(), "The 'partition_row_presets' HashMap should not be empty"); + assert!(ctxt.partition_row_presets.try_lock().unwrap().is_empty(), "The 'partition_row_presets' HashMap should not be empty"); tokio::runtime::Runtime::new().unwrap().block_on(async { - let _ = _init_partition_row_distribution_preset(&mut ctxt, + let _ = _init_partition_row_distribution_preset(&ctxt, preset_name, row_count, rows_per_partitions_base, &rows_per_partitions_groups).await; }); - assert!(!ctxt.partition_row_presets.is_empty(), "The 'partition_row_presets' HashMap should not be empty"); - let actual_preset = ctxt.partition_row_presets.get(preset_name) - .unwrap_or_else(|| panic!("Preset with name '{preset_name}' was not found")); - assert_eq!(expected_partition_groups, actual_preset.partition_groups); + assert!(!ctxt.partition_row_presets.try_lock().unwrap().is_empty(), "The 'partition_row_presets' HashMap should not be empty"); + { + let binding = ctxt.partition_row_presets.try_lock().unwrap(); + let actual_preset = binding.get(preset_name) + .unwrap_or_else(|| panic!("Preset with name '{preset_name}' was not found")); + assert_eq!(expected_partition_groups, actual_preset.partition_groups); + } for (idx, expected_partition_idx) in expected_idx_partition_idx_mapping.clone() { let (p_idx, _p_size) = tokio::runtime::Runtime::new().unwrap().block_on(async { @@ -703,28 +714,28 @@ mod tests { fn test_partition_row_distribution_preset_05_pos_multiple_presets() { let name_foo: String = "foo".to_string(); let name_bar: String = "bar".to_string(); - let mut ctxt: Context = create_test_context(); + let ctxt: Context = create_test_context(); - assert!(ctxt.partition_row_presets.is_empty(), "The 'partition_row_presets' HashMap should be empty"); - let foo_value = ctxt.partition_row_presets.get(&name_foo); + assert!(ctxt.partition_row_presets.try_lock().unwrap().is_empty(), "The 'partition_row_presets' HashMap should be empty"); + let foo_value = ctxt.partition_row_presets.try_lock().unwrap().get(&name_foo).cloned(); assert_eq!(None, foo_value); tokio::runtime::Runtime::new().unwrap().block_on(async { - _init_partition_row_distribution_preset(&mut ctxt, + _init_partition_row_distribution_preset(&ctxt, &name_foo, 1000, 10, "100:1").await }).unwrap_or_else(|_| panic!("The '{name_foo}' preset must have been created successfully")); - assert!(!ctxt.partition_row_presets.is_empty(), "The 'partition_row_presets' HashMap should not be empty"); - ctxt.partition_row_presets.get(&name_foo) + assert!(!ctxt.partition_row_presets.try_lock().unwrap().is_empty(), "The 'partition_row_presets' HashMap should not be empty"); + ctxt.partition_row_presets.try_lock().unwrap().get(&name_foo) .unwrap_or_else(|| panic!("Preset with name '{name_foo}' was not found")); - let absent_bar = ctxt.partition_row_presets.get(&name_bar); + let absent_bar = ctxt.partition_row_presets.try_lock().unwrap().get(&name_bar).cloned(); assert_eq!(None, absent_bar, "{}", format_args!("The '{}' preset was expected to be absent", name_bar)); tokio::runtime::Runtime::new().unwrap().block_on(async { - _init_partition_row_distribution_preset(&mut ctxt, + _init_partition_row_distribution_preset(&ctxt, &name_bar, 1000, 10, "90:1,10:2").await }).unwrap_or_else(|_| panic!("The '{name_bar}' preset must have been created successfully")); - ctxt.partition_row_presets.get(&name_bar) + ctxt.partition_row_presets.try_lock().unwrap().get(&name_bar) .unwrap_or_else(|| panic!("Preset with name '{name_bar}' was not found")); } @@ -734,9 +745,9 @@ mod tests { rows_per_partitions_base: u64, rows_per_partitions_groups: String, ) { - let mut ctxt: Context = create_test_context(); + let ctxt: Context = create_test_context(); let result = tokio::runtime::Runtime::new().unwrap().block_on(async { - _init_partition_row_distribution_preset(&mut ctxt, + _init_partition_row_distribution_preset(&ctxt, &preset_name, row_count, rows_per_partitions_base, &rows_per_partitions_groups).await }); diff --git a/src/scripting/rune_uuid.rs b/src/scripting/rune_uuid.rs index 039c7dd..3e3d92e 100644 --- a/src/scripting/rune_uuid.rs +++ b/src/scripting/rune_uuid.rs @@ -20,9 +20,9 @@ impl Uuid { Uuid(builder.into_uuid()) } - #[rune::function(protocol = STRING_DISPLAY)] + #[rune::function(protocol = DISPLAY_FMT)] pub fn string_display(&self, f: &mut rune::runtime::Formatter) -> VmResult<()> { - vm_write!(f, "{}", self.0); + let _ = vm_write!(f, "{}", self.0); VmResult::Ok(()) } } diff --git a/src/scripting/split_lines_iter.rs b/src/scripting/split_lines_iter.rs index 66ff85a..9be2794 100644 --- a/src/scripting/split_lines_iter.rs +++ b/src/scripting/split_lines_iter.rs @@ -92,43 +92,53 @@ pub fn read_split_lines_iter( let mut maxsplit = -1; let mut do_trim = true; let mut skip_empty = true; + let as_str = |v: &Value| -> Option { + v.borrow_ref::() + .ok() + .map(|s| s.as_str().to_string()) + }; + let as_int = |v: &Value| v.as_signed().ok(); + let as_bool = |v: &Value| v.as_bool().ok(); match params.as_slice() { // (str): delimiter - [Value::String(custom_delimiter)] => { - delimiter = custom_delimiter.borrow_ref().unwrap().to_string(); + [a] if as_str(a).is_some() => { + delimiter = as_str(a).unwrap(); } // (int): maxsplit - [Value::Integer(custom_maxsplit)] => { - maxsplit = *custom_maxsplit; + [a] if as_int(a).is_some() => { + maxsplit = as_int(a).unwrap(); } // (bool): do_trim - [Value::Bool(custom_do_trim)] => { - do_trim = *custom_do_trim; + [a] if as_bool(a).is_some() => { + do_trim = as_bool(a).unwrap(); } // (bool, bool): do_trim, skip_empty - [Value::Bool(custom_do_trim), Value::Bool(custom_skip_empty)] => { - do_trim = *custom_do_trim; - skip_empty = *custom_skip_empty; + [a, b] if as_bool(a).is_some() && as_bool(b).is_some() => { + do_trim = as_bool(a).unwrap(); + skip_empty = as_bool(b).unwrap(); } // (str, int): delimiter, maxsplit - [Value::String(custom_delimiter), Value::Integer(custom_maxsplit)] => { - delimiter = custom_delimiter.borrow_ref().unwrap().to_string(); - maxsplit = *custom_maxsplit; + [a, b] if as_str(a).is_some() && as_int(b).is_some() => { + delimiter = as_str(a).unwrap(); + maxsplit = as_int(b).unwrap(); } // (str, int, bool): delimiter, maxsplit, do_trim - [Value::String(custom_delimiter), Value::Integer(custom_maxsplit), Value::Bool(custom_do_trim)] => - { - delimiter = custom_delimiter.borrow_ref().unwrap().to_string(); - maxsplit = *custom_maxsplit; - do_trim = *custom_do_trim; + [a, b, c] if as_str(a).is_some() && as_int(b).is_some() && as_bool(c).is_some() => { + delimiter = as_str(a).unwrap(); + maxsplit = as_int(b).unwrap(); + do_trim = as_bool(c).unwrap(); } // (str, int, bool, bool): delimiter, maxsplit, do_trim, skip_empty - [Value::String(custom_delimiter), Value::Integer(custom_maxsplit), Value::Bool(custom_do_trim), Value::Bool(custom_skip_empty)] => + [a, b, c, d] + if as_str(a).is_some() + && as_int(b).is_some() + && as_bool(c).is_some() + && as_bool(d).is_some() => { - delimiter = custom_delimiter.borrow_ref().unwrap().to_string(); - maxsplit = *custom_maxsplit; - do_trim = *custom_do_trim; - skip_empty = *custom_skip_empty; + delimiter = as_str(a).unwrap(); + maxsplit = as_int(b).unwrap(); + do_trim = as_bool(c).unwrap(); + skip_empty = as_bool(d).unwrap(); } _ => panic!("Invalid arguments for read_split_lines_iter"), }