From 48dff26f76ad29d53857f99c01b57ce3f030b323 Mon Sep 17 00:00:00 2001 From: Valerii Ponomarov Date: Tue, 5 May 2026 19:41:43 +0300 Subject: [PATCH 1/2] Remove unused ThreadRng field from Context ThreadRng contains Rc> pointing to thread-local storage, making it !Send. While the field was never read or written by any code, its construction (rand::rng()) and destruction involved Rc refcount operations on thread-local data. When tokio migrated tasks between worker threads, the Rc drop could race with Rc operations on the originating thread, causing data corruption and segfaults. This was the root cause of intermittent crashes observed after switching to rune 0.14, where Context became VM-owned (created and dropped per call) rather than borrowed via AnyObj::from_ref as in rune 0.13. --- src/scripting/cql/context.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/scripting/cql/context.rs b/src/scripting/cql/context.rs index 62301e2..dec0cc8 100644 --- a/src/scripting/cql/context.rs +++ b/src/scripting/cql/context.rs @@ -9,7 +9,6 @@ use crate::scripting::row_distribution::RowDistributionPreset; use crate::stats::session::SessionStats; use once_cell::sync::Lazy; -use rand::prelude::ThreadRng; use regex::Regex; use rune::runtime::{Object, Shared, Vec as RuneVec}; use rune::{Any, Value}; @@ -52,7 +51,6 @@ pub struct Context { pub preferred_rack: String, #[rune(get)] pub data: Value, - pub rng: ThreadRng, } // Needed, because Rune `Value` is !Send, as it may contain some internal pointers. @@ -88,7 +86,6 @@ impl Context { preferred_datacenter, preferred_rack, data: Value::Object(Shared::new(Object::new()).unwrap()), - rng: rand::rng(), } } @@ -113,7 +110,6 @@ impl Context { preferred_rack: self.preferred_rack.clone(), data: deserialized, start_time: TryLock::new(*self.start_time.try_lock().unwrap()), - rng: rand::rng(), }) } From eced210a48896fcd63e91850b1a9651312147281 Mon Sep 17 00:00:00 2001 From: Valerii Ponomarov Date: Fri, 1 May 2026 15:54:27 +0300 Subject: [PATCH 2/2] Upgrade rune from 0.13 to 0.14 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adapt to rune 0.14 API changes: - AnyObj::from_ref/from_mut became pub(crate), so Context is now passed to the VM as an owned shallow clone via Value::new(context) instead of a borrowed reference. - Update Value/VmResult/VmError APIs to match 0.14 signatures (Value::Result → downcast, Shared removal, VmResult → Result, etc). - Retain unsafe impl Send/Sync for Context. This is sound because the only !Send field is `data: Value` (non-atomic Cell refcount), and within latte's architecture each Context instance is confined to sequential polling within a single tokio task — no concurrent access to the same Cell can occur. The previously-unsound ThreadRng field was removed in the prior commit. --- Cargo.lock | 144 +-- Cargo.toml | 2 +- src/error.rs | 5 +- src/exec/workload.rs | 96 +- src/main.rs | 22 +- src/scripting/alternator/alternator_error.rs | 10 +- src/scripting/alternator/context.rs | 36 +- src/scripting/alternator/functions.rs | 276 +++--- src/scripting/alternator/types.rs | 153 +-- src/scripting/cql/cass_error.rs | 38 +- src/scripting/cql/context.rs | 92 +- src/scripting/cql/deserialize.rs | 424 ++++----- src/scripting/cql/functions.rs | 56 +- src/scripting/cql/serialize.rs | 945 +++++++++---------- src/scripting/functions_common.rs | 54 +- src/scripting/row_distribution.rs | 65 +- src/scripting/rune_uuid.rs | 4 +- src/scripting/split_lines_iter.rs | 54 +- 18 files changed, 1281 insertions(+), 1195 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7e4712b..08ab3e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1533,6 +1533,21 @@ dependencies = [ "slab", ] +[[package]] +name = "generator" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52f04ae4152da20c76fe800fa48659201d5cf627c5149ca0b707b69d7eef6cf9" +dependencies = [ + "cc", + "cfg-if", + "libc", + "log", + "rustversion", + "windows-link", + "windows-result", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -1856,7 +1871,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.5.10", + "socket2 0.6.3", "system-configuration", "tokio", "tower-service", @@ -2299,6 +2314,19 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + [[package]] name = "lru-slab" version = "0.1.2" @@ -2311,7 +2339,7 @@ version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "373f5eceeeab7925e0c1098212f2fbc4d416adec9d35051a6ab251e824c1854a" dependencies = [ - "twox-hash 2.1.2", + "twox-hash", ] [[package]] @@ -2382,43 +2410,36 @@ checksum = "1fafa6961cabd9c63bcd77a45d7e3b7f3b552b70417831fb0f56db717e72407e" [[package]] name = "musli" -version = "0.0.42" +version = "0.0.124" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c21124dd24833900879114414b877f2136f4b7b7a3b49756ecc5c36eca332bb" +checksum = "8b310b280353d9e1c92861820321f8742b02666acaf984a29cd8946965444384" dependencies = [ - "musli-macros", + "loom", + "musli-core", + "serde", + "simdutf8", ] [[package]] -name = "musli-common" -version = "0.0.42" +name = "musli-core" +version = "0.0.124" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "178446623aa62978aa0f894b2081bc11ea77c2119ccfe35be428ab9ddb495dfc" +checksum = "d00e227a374e92550ce2eb5002ae116e02a43926d7243c95997138406ae4e157" dependencies = [ - "musli", + "musli-macros", ] [[package]] name = "musli-macros" -version = "0.0.42" +version = "0.0.124" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f1ab0e4ac2721bc4fa3528a6a2640c1c30c36c820f8c85159252fbf6c2fac24" +checksum = "e7427c9aa85c882cd4dbe712d2fcdc511db05d595f7787e6747c90cd7d67efc4" dependencies = [ "proc-macro2", "quote", "syn 2.0.117", ] -[[package]] -name = "musli-storage" -version = "0.0.42" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2fc1f80b166f611c462e1344220e9b3a9ad37c885e43039d5d2e6887445937c" -dependencies = [ - "musli", - "musli-common", -] - [[package]] name = "nom" version = "7.1.3" @@ -2843,7 +2864,7 @@ dependencies = [ "quinn-udp", "rustc-hash", "rustls", - "socket2 0.5.10", + "socket2 0.6.3", "thiserror 2.0.18", "tokio", "tracing", @@ -2881,7 +2902,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.5.10", + "socket2 0.6.3", "tracing", "windows-sys 0.59.0", ] @@ -3252,33 +3273,35 @@ dependencies = [ [[package]] name = "rune" -version = "0.13.4" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d21925ac4f8974395d0d9e43f96a34c778e71ed86fe96d0313b2211102537234" +checksum = "b7b9174512d64882469ea9b12876305680154a541be448dcdc56f58acacbc3e0" dependencies = [ "anyhow", "codespan-reporting", "futures-core", "futures-util", "itoa", + "memchr", "musli", - "musli-storage", "num", "once_cell", "pin-project", "rune-alloc", "rune-core", "rune-macros", + "rune-tracing", "ryu", "serde", - "tracing", + "syntree", + "unicode-ident", ] [[package]] name = "rune-alloc" -version = "0.13.4" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e85c26e19f7efb91c6e19afc68b008f04685fdb2852e96ce8fbd3cf4a0b4e76c" +checksum = "f12484a608c8907b9a4590f11b669263e960d1fd40f3d3c992c6f15eec931ae9" dependencies = [ "ahash 0.8.12", "pin-project", @@ -3288,9 +3311,9 @@ dependencies = [ [[package]] name = "rune-alloc-macros" -version = "0.13.4" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "810588952a8710959d35ad17c933804d60f96c3792f216277cda68c1a9887120" +checksum = "382b14f6d8e65e9cfec789e85125f3e1d758b2756705739e39ccf06fd249a564" dependencies = [ "proc-macro2", "quote", @@ -3299,22 +3322,21 @@ dependencies = [ [[package]] name = "rune-core" -version = "0.13.4" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d30fa78b6cb15d1560bb4cb18f4b99a9097b08ade3a5fc23e5ae7311f97c537b" +checksum = "c424b28fde0f5012680361662145f238f04aeac8a320f352a6e2de863709e7b3" dependencies = [ - "byteorder", "musli", "rune-alloc", "serde", - "twox-hash 1.6.3", + "twox-hash", ] [[package]] name = "rune-macros" -version = "0.13.4" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1b91e53bae3804e4d72e2b04fa5d5108bd93e880ca597c0cae0fb0a662fe198" +checksum = "c86600b36281adeb101c2e4f0be325752fa4c07431e9234e05be5678ad9a97f7" dependencies = [ "proc-macro2", "quote", @@ -3322,6 +3344,26 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "rune-tracing" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0717a7c726015688a19ebfa4bea03a20d2acdd96eeacb5ef48f4ec780e11ac4b" +dependencies = [ + "rune-tracing-macros", +] + +[[package]] +name = "rune-tracing-macros" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12387f96a3e131ce5be8c5668e55f1581dbc6635555d77aa07ab509fd13562bb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "rust-embed" version = "8.11.0" @@ -3532,6 +3574,12 @@ dependencies = [ "serde_json", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" @@ -3862,12 +3910,6 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" -[[package]] -name = "static_assertions" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" - [[package]] name = "status-line" version = "0.2.0" @@ -3983,6 +4025,12 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "syntree" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00c99c9cda412afe293a6b962af651b4594161ba88c1affe7ef66459ea040a06" + [[package]] name = "system-configuration" version = "0.7.0" @@ -4457,16 +4505,6 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" -[[package]] -name = "twox-hash" -version = "1.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" -dependencies = [ - "cfg-if", - "static_assertions", -] - [[package]] name = "twox-hash" version = "2.1.2" diff --git a/Cargo.toml b/Cargo.toml index 5faea7a..31aeccb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ plotters = { version = "0.3", default-features = false, features = ["line_series rand = { version = "0.9.4", default-features = false, features = ["small_rng", "std", "thread_rng"] } rand_distr = "0.5" regex = "1.5" -rune = "0.13" +rune = "0.14" rust_decimal = "1.36" rust-embed = "8" scylla = { version = "1.6", features = ["openssl-010", "chrono-04"], optional = true } diff --git a/src/error.rs b/src/error.rs index 7cfb519..fff2b4b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,7 +2,7 @@ use crate::scripting::db_error::DbError; use hdrhistogram::serialization::interval_log::IntervalLogWriterError; use hdrhistogram::serialization::V2DeflateSerializeError; use rune::alloc; -use rune::runtime::{AccessError, VmError}; +use rune::runtime::{AccessError, RuntimeError, VmError}; use std::path::PathBuf; use thiserror::Error; @@ -55,6 +55,9 @@ pub enum LatteError { #[error("Rune AccessError: {0}")] RuneAccessError(#[from] AccessError), + + #[error("Rune runtime error: {0}")] + RuneRuntimeError(#[from] RuntimeError), } impl From for LatteError { diff --git a/src/exec/workload.rs b/src/exec/workload.rs index 19d8808..2e721ff 100644 --- a/src/exec/workload.rs +++ b/src/exec/workload.rs @@ -21,56 +21,30 @@ use rand::{Rng, SeedableRng}; use rune::alloc::clone::TryClone; use rune::compile::meta::Kind; use rune::compile::{CompileVisitor, MetaError, MetaRef}; -use rune::runtime::{AnyObj, Args, RuntimeContext, Shared, VmError, VmResult}; +use rune::runtime::{Args, RuntimeContext, RuntimeError, VmError}; use rune::termcolor::{ColorChoice, StandardStream}; -use rune::{vm_try, Any, Diagnostics, Source, Sources, ToValue, Unit, Value, Vm}; +use rune::{Diagnostics, Source, Sources, ToValue, Unit, Value, Vm}; use serde::{Deserialize, Serialize}; use try_lock::TryLock; -/// Wraps a reference to Session that can be converted to a Rune `Value` -/// and passed as one of `Args` arguments to a function. -struct SessionRef<'a> { - context: &'a Context, -} - -impl SessionRef<'_> { - pub fn new(context: &Context) -> SessionRef<'_> { - SessionRef { context } - } -} - -/// We need this to be able to pass a reference to `Session` as an argument -/// to Rune function. -/// -/// Caution! Be careful using this trait. Undefined Behaviour possible. -/// This is unsound - it is theoretically -/// possible that the underlying `Session` gets dropped before the `Value` produced by this trait -/// implementation and the compiler is not going to catch that. -/// The receiver of a `Value` must ensure that it is dropped before `Session`! -impl ToValue for SessionRef<'_> { - fn to_value(self) -> VmResult { - let obj = unsafe { AnyObj::from_ref(self.context) }; - VmResult::Ok(Value::from(vm_try!(Shared::new(obj)))) - } -} - -/// Wraps a mutable reference to Session that can be converted to a Rune `Value` and passed -/// as one of `Args` arguments to a function. -struct ContextRefMut<'a> { - context: &'a mut Context, +/// Wraps a shallow clone of Context that can be converted to a rune-owned `Value`. +/// The clone shares Arc-backed fields (stats, statements, presets) with the original, +/// so stats tracking and prepared statements remain shared across function calls. +struct SessionRef { + context: Context, } -impl ContextRefMut<'_> { - pub fn new(context: &mut Context) -> ContextRefMut<'_> { - ContextRefMut { context } +impl SessionRef { + pub fn new(context: &Context) -> SessionRef { + SessionRef { + context: context.shallow_clone(), + } } } -/// Caution! See `impl ToValue for SessionRef`. -impl ToValue for ContextRefMut<'_> { - fn to_value(self) -> VmResult { - let obj = unsafe { AnyObj::from_mut(self.context) }; - VmResult::Ok(Value::from(vm_try!(Shared::new(obj)))) +impl ToValue for SessionRef { + fn to_value(self) -> Result { + Value::new(self.context).map_err(Into::into) } } @@ -199,25 +173,21 @@ impl Program { /// fine, but the function could return an error value, and in this case we should not /// ignore it. fn convert_error(&self, function_name: &str, result: Value) -> Result { - match result { - Value::Result(result) => match result.take().unwrap() { + if result.borrow_ref::>().is_ok() { + match result.downcast::>().unwrap() { Ok(value) => Ok(value), - Err(Value::Any(e)) => { - if e.borrow_ref().unwrap().type_hash() == DbError::type_hash() { - let e = e.take_downcast::().unwrap(); - return Err(LatteError::Database(Box::new(e))); + Err(err_value) => { + if err_value.borrow_ref::().is_ok() { + let e = err_value.downcast::().unwrap(); + Err(LatteError::Database(Box::new(e))) + } else { + let msg = self.vm().with(|| format!("{err_value:?}")); + Err(LatteError::FunctionResult(function_name.to_string(), msg)) } - - let e = Value::Any(e); - let msg = self.vm().with(|| format!("{e:?}")); - Err(LatteError::FunctionResult(function_name.to_string(), msg)) } - Err(other) => Err(LatteError::FunctionResult( - function_name.to_string(), - format!("{other:?}"), - )), - }, - other => Ok(other), + } + } else { + Ok(result) } } @@ -267,24 +237,24 @@ impl Program { /// Calls the script's `init` function. /// Called once at the beginning of the benchmark. /// Typically used to prepare statements. - pub async fn prepare(&mut self, context: &mut Context) -> Result<(), LatteError> { - let context = ContextRefMut::new(context); + pub async fn prepare(&mut self, context: &Context) -> Result<(), LatteError> { + let context = SessionRef::new(context); self.async_call(&FnRef::new(PREPARE_FN), (context,)).await?; Ok(()) } /// Calls the script's `schema` function. /// Typically used to create database schema. - pub async fn schema(&mut self, context: &mut Context) -> Result<(), LatteError> { - let context = ContextRefMut::new(context); + pub async fn schema(&mut self, context: &Context) -> Result<(), LatteError> { + let context = SessionRef::new(context); self.async_call(&FnRef::new(SCHEMA_FN), (context,)).await?; Ok(()) } /// Calls the script's `erase` function. /// Typically used to remove the data from the database before running the benchmark. - pub async fn erase(&mut self, context: &mut Context) -> Result<(), LatteError> { - let context = ContextRefMut::new(context); + pub async fn erase(&mut self, context: &Context) -> Result<(), LatteError> { + let context = SessionRef::new(context); self.async_call(&FnRef::new(ERASE_FN), (context,)).await?; Ok(()) } diff --git a/src/main.rs b/src/main.rs index 18d2be6..427c448 100644 --- a/src/main.rs +++ b/src/main.rs @@ -137,13 +137,13 @@ async fn connect(conf: &ConnectionConf) -> Result<(Context, Option) /// Exits with error if the `schema` function is not present or fails. async fn schema(conf: SchemaCommand) -> Result<()> { let mut program = load_workload_script(&conf.workload, &conf.params)?; - let (mut session, _) = connect(&conf.connection).await?; + let (session, _) = connect(&conf.connection).await?; if !program.has_schema() { eprintln!("error: Function `schema` not found in the workload script."); exit(255); } eprintln!("info: Creating schema..."); - if let Err(e) = program.schema(&mut session).await { + if let Err(e) = program.schema(&session).await { eprintln!("error: Failed to create schema: {e}"); exit(255); } @@ -155,11 +155,11 @@ async fn schema(conf: SchemaCommand) -> Result<()> { /// Exits with error if the `load` function is not present or fails. async fn load(conf: LoadCommand) -> Result<()> { let mut program = load_workload_script(&conf.workload, &conf.params)?; - let (mut session, _) = connect(&conf.connection).await?; + let (session, _) = connect(&conf.connection).await?; if program.has_prepare() { eprintln!("info: Preparing..."); - if let Err(e) = program.prepare(&mut session).await { + if let Err(e) = program.prepare(&session).await { eprintln!("error: Failed to prepare: {e}"); exit(255); } @@ -173,7 +173,7 @@ async fn load(conf: LoadCommand) -> Result<()> { if program.has_erase() { eprintln!("info: Erasing data..."); - if let Err(e) = program.erase(&mut session).await { + if let Err(e) = program.erase(&session).await { eprintln!("error: Failed to erase: {e}"); exit(255); } @@ -236,18 +236,18 @@ async fn run(conf: RunCommand) -> Result<()> { functions.push((function, f.weight)) } - let (mut session, cluster_info) = connect(&conf.connection).await?; + let (session, cluster_info) = connect(&conf.connection).await?; // NOTE: Add info about the target rune functions to the context // for the more flexible tweaking of the 'prepare' rune function. - match &mut session.data { - rune::Value::Object(shared_obj) => { - let _ = shared_obj.borrow_mut()?.insert( + match session.data.borrow_mut::() { + Ok(mut obj) => { + let _ = obj.insert( rune::alloc::String::try_from("functions_to_invoke")?, rune::to_value(functions_to_invoke)?, ); } - _ => { + Err(_) => { eprintln!("error: session.data is not a Rune Object"); exit(255); } @@ -260,7 +260,7 @@ async fn run(conf: RunCommand) -> Result<()> { if program.has_prepare() { eprintln!("info: Preparing..."); - if let Err(e) = program.prepare(&mut session).await { + if let Err(e) = program.prepare(&session).await { eprintln!("error: Failed to prepare: {e}"); exit(255); } diff --git a/src/scripting/alternator/alternator_error.rs b/src/scripting/alternator/alternator_error.rs index 81ffb2e..38e6078 100644 --- a/src/scripting/alternator/alternator_error.rs +++ b/src/scripting/alternator/alternator_error.rs @@ -32,9 +32,9 @@ impl AlternatorError { ))) } - #[rune::function(protocol = STRING_DISPLAY)] + #[rune::function(protocol = DISPLAY_FMT)] pub fn string_display(&self, f: &mut rune::runtime::Formatter) -> VmResult<()> { - vm_write!(f, "{}", self.to_string()); + let _ = vm_write!(f, "{}", self.to_string()); VmResult::Ok(()) } } @@ -97,6 +97,12 @@ impl From for AlternatorError { } } +impl From for AlternatorError { + fn from(error: rune::runtime::RuntimeError) -> Self { + AlternatorError::new(AlternatorErrorKind::ConversionError(error.to_string())) + } +} + impl From for AlternatorError { fn from(error: rune::alloc::Error) -> Self { AlternatorError::new(AlternatorErrorKind::ConversionError(error.to_string())) diff --git a/src/scripting/alternator/context.rs b/src/scripting/alternator/context.rs index 4d4d78f..db37ae2 100644 --- a/src/scripting/alternator/context.rs +++ b/src/scripting/alternator/context.rs @@ -5,9 +5,10 @@ use crate::scripting::cluster_info::ClusterInfo; use crate::scripting::row_distribution::RowDistributionPreset; use crate::stats::session::SessionStats; use aws_sdk_dynamodb::Client; -use rune::runtime::{Object, Shared}; +use rune::runtime::Object; use rune::{Any, Value}; use std::collections::HashMap; +use std::sync::Arc; use std::time::Instant; use try_lock::TryLock; @@ -15,12 +16,12 @@ use try_lock::TryLock; pub struct Context { client: Option, page_size: u64, - pub stats: TryLock, + pub stats: Arc>, pub start_time: TryLock, pub retry_number: u64, pub retry_interval: RetryInterval, pub validation_strategy: ValidationStrategy, - pub partition_row_presets: HashMap, + pub partition_row_presets: Arc>>, #[rune(get, set, add_assign, copy)] pub load_cycle_count: u64, #[rune(get)] @@ -41,14 +42,14 @@ impl Context { Context { client, page_size, - stats: TryLock::new(SessionStats::new()), + stats: Arc::new(TryLock::new(SessionStats::new())), start_time: TryLock::new(Instant::now()), retry_number, retry_interval, validation_strategy, - partition_row_presets: HashMap::new(), + partition_row_presets: Arc::new(TryLock::new(HashMap::new())), load_cycle_count: 0, - data: Value::Object(Shared::new(Object::new()).unwrap()), + data: Value::new(Object::new()).unwrap(), } } @@ -58,17 +59,36 @@ impl Context { Ok(Context { client: self.client.clone(), page_size: self.page_size, - stats: TryLock::new(SessionStats::default()), + stats: Arc::new(TryLock::new(SessionStats::default())), start_time: TryLock::new(*self.start_time.try_lock().unwrap()), retry_number: self.retry_number, retry_interval: self.retry_interval, validation_strategy: self.validation_strategy, - partition_row_presets: self.partition_row_presets.clone(), + partition_row_presets: Arc::new(TryLock::new( + self.partition_row_presets.try_lock().unwrap().clone(), + )), load_cycle_count: self.load_cycle_count, data: deserialized, }) } + /// Creates a shallow clone that shares the Arc-backed fields (stats, presets) + /// with the original. Used to create a rune-owned `Value` for function call arguments. + pub fn shallow_clone(&self) -> Self { + Context { + client: self.client.clone(), + page_size: self.page_size, + stats: Arc::clone(&self.stats), + start_time: TryLock::new(*self.start_time.try_lock().unwrap()), + retry_number: self.retry_number, + retry_interval: self.retry_interval, + validation_strategy: self.validation_strategy, + partition_row_presets: Arc::clone(&self.partition_row_presets), + load_cycle_count: self.load_cycle_count, + data: self.data.clone(), + } + } + /// Returns cluster metadata. pub async fn cluster_info(&self) -> Result, AlternatorError> { let client = self.get_client()?; diff --git a/src/scripting/alternator/functions.rs b/src/scripting/alternator/functions.rs index 101fff3..96f0118 100644 --- a/src/scripting/alternator/functions.rs +++ b/src/scripting/alternator/functions.rs @@ -10,7 +10,7 @@ use aws_sdk_dynamodb::client::Waiters; use aws_sdk_dynamodb::types::{ AttributeDefinition, KeySchemaElement, KeyType, ScalarAttributeType, }; -use rune::runtime::{Object, Ref, Shared}; +use rune::runtime::{Object, Ref}; use rune::{ToValue, Value}; use std::cmp::min; use std::collections::HashMap; @@ -26,25 +26,33 @@ fn bad_input(msg: impl Into) -> Result { /// Gets the name and type of a primary or sort key from a object. fn extract_key_definition( - object: &Shared, + object: &Object, ) -> Result<(String, ScalarAttributeType), AlternatorError> { - let key_name = if let Some(Value::String(s)) = object.borrow_ref()?.get("name") { - s.borrow_ref()?.to_string() + let key_name = if let Some(v) = object.get("name") { + if let Ok(s) = v.borrow_ref::() { + s.as_str().to_string() + } else { + return bad_input("Key definition object must have a 'name' field"); + } } else { return bad_input("Key definition object must have a 'name' field"); }; - let key_type = if let Some(Value::String(t)) = object.borrow_ref()?.get("type") { - match t.borrow_ref()?.as_str() { - "N" => ScalarAttributeType::N, - "S" => ScalarAttributeType::S, - "B" => ScalarAttributeType::B, - other => { - return bad_input(format!( - "Invalid key type: {}, only N, S, and B are allowed.", - other - )) + let key_type = if let Some(v) = object.get("type") { + if let Ok(t) = v.borrow_ref::() { + match t.as_str() { + "N" => ScalarAttributeType::N, + "S" => ScalarAttributeType::S, + "B" => ScalarAttributeType::B, + other => { + return bad_input(format!( + "Invalid key type: {}, only N, S, and B are allowed.", + other + )) + } } + } else { + return bad_input("Key definition object must have a 'type' field"); } } else { return bad_input("Key definition object must have a 'type' field"); @@ -53,20 +61,15 @@ fn extract_key_definition( Ok((key_name, key_type)) } -fn extract_attribute_names( - object: &Shared, -) -> Result, AlternatorError> { +fn extract_attribute_names(object: &Object) -> Result, AlternatorError> { object - .borrow_ref()? .iter() .map(|(k, v)| { - Ok(( - k.to_string(), - match v { - Value::String(s) => s.borrow_ref()?.to_string(), - _ => return bad_input("Attribute names must be strings"), - }, - )) + if let Ok(s) = v.borrow_ref::() { + Ok((k.to_string(), s.as_str().to_string())) + } else { + bad_input("Attribute names must be strings") + } }) .collect::>() } @@ -195,25 +198,51 @@ pub async fn create_table( let client = ctx.get_client()?; // Extract primary key definition - let (pk_name, pk_type) = match ¶ms { - Value::String(s) => (s.borrow_ref()?.to_string(), ScalarAttributeType::S), - Value::Object(o) => match o.borrow_ref()?.get("primary_key") { - Some(Value::String(s)) => (s.borrow_ref()?.to_string(), ScalarAttributeType::S), - Some(Value::Object(pk_obj)) => extract_key_definition(pk_obj)?, + let (pk_name, pk_type) = if let Ok(s) = params.borrow_ref::() { + (s.as_str().to_string(), ScalarAttributeType::S) + } else if let Ok(o) = params.borrow_ref::() { + match o.get("primary_key") { + Some(v) if v.borrow_ref::().is_ok() => ( + v.borrow_ref::() + .unwrap() + .as_str() + .to_string(), + ScalarAttributeType::S, + ), + Some(v) => { + if let Ok(pk_obj) = v.borrow_ref::() { + extract_key_definition(&pk_obj)? + } else { + return bad_input("Invalid 'primary_key' object in params"); + } + } _ => return bad_input("Invalid 'primary_key' object in params"), - }, - _ => return bad_input("Params must be a string or an object"), + } + } else { + return bad_input("Params must be a string or an object"); }; // Extract sort key definition if present - let sk = match ¶ms { - Value::Object(o) => match o.borrow_ref()?.get("sort_key") { - Some(Value::String(s)) => Some((s.borrow_ref()?.to_string(), ScalarAttributeType::S)), - Some(Value::Object(sk_obj)) => Some(extract_key_definition(sk_obj)?), - Some(_) => return bad_input("Invalid 'sort_key' object in params"), + let sk = if let Ok(o) = params.borrow_ref::() { + match o.get("sort_key") { + Some(v) if v.borrow_ref::().is_ok() => Some(( + v.borrow_ref::() + .unwrap() + .as_str() + .to_string(), + ScalarAttributeType::S, + )), + Some(v) => { + if let Ok(sk_obj) = v.borrow_ref::() { + Some(extract_key_definition(&sk_obj)?) + } else { + return bad_input("Invalid 'sort_key' object in params"); + } + } None => None, - }, - _ => None, + } + } else { + None }; let mut builder = client @@ -292,7 +321,7 @@ pub async fn put( let builder = client .put_item() .table_name(table_name.deref()) - .set_item(Some(rune_object_to_alternator_map(item)?)); + .set_item(Some(rune_object_to_alternator_map(&item)?)); handle_request(&ctx, builder).await?; @@ -316,7 +345,7 @@ pub async fn delete( let builder = client .delete_item() .table_name(table_name.deref()) - .set_key(Some(rune_object_to_alternator_map(key)?)); + .set_key(Some(rune_object_to_alternator_map(&key)?)); handle_request(&ctx, builder).await?; @@ -349,25 +378,23 @@ pub async fn get( let mut builder = client .get_item() .table_name(table_name.deref()) - .set_key(Some(rune_object_to_alternator_map(key)?)); + .set_key(Some(rune_object_to_alternator_map(&key)?)); - if let Value::Object(opts) = &options { - if let Some(Value::Bool(consistent_read)) = opts.borrow_ref()?.get("consistent_read") { - builder = builder.consistent_read(*consistent_read); + if let Ok(opts) = options.borrow_ref::() { + if let Some(b) = opts.get("consistent_read").and_then(|v| v.as_bool().ok()) { + builder = builder.consistent_read(b); } } let result = handle_request(&ctx, builder).await?; - if let Value::Object(opts) = &options { - if let Some(Value::Bool(with_result)) = opts.borrow_ref()?.get("with_result") { - if *with_result { - return Ok(result.into_iter().next().to_value().into_result()?); - } + if let Ok(opts) = options.borrow_ref::() { + if opts.get("with_result").and_then(|v| v.as_bool().ok()) == Some(true) { + return Ok(result.into_iter().next().to_value()?); } } - Ok(Value::EmptyTuple) + Ok(Value::from(())) } /// Updates an item in the table. @@ -392,21 +419,25 @@ pub async fn update( let mut builder = client .update_item() .table_name(table_name.deref()) - .set_key(Some(rune_object_to_alternator_map(key)?)); + .set_key(Some(rune_object_to_alternator_map(&key)?)); - if let Some(Value::String(update_expression)) = params.get("update") { - builder = builder.update_expression(update_expression.borrow_ref()?.to_string()); + if let Some(v) = params.get("update") { + if let Ok(s) = v.borrow_ref::() { + builder = builder.update_expression(s.as_str().to_string()); + } } - if let Some(Value::Object(attr_names)) = params.get("attribute_names") { - builder = - builder.set_expression_attribute_names(Some(extract_attribute_names(attr_names)?)); + if let Some(v) = params.get("attribute_names") { + if let Ok(obj) = v.borrow_ref::() { + builder = builder.set_expression_attribute_names(Some(extract_attribute_names(&obj)?)); + } } - if let Some(Value::Object(attr_values)) = params.get("attribute_values") { - builder = builder.set_expression_attribute_values(Some(rune_object_to_alternator_map( - attr_values.clone().into_ref()?, - )?)); + if let Some(v) = params.get("attribute_values") { + if let Ok(obj) = v.borrow_ref::() { + builder = + builder.set_expression_attribute_values(Some(rune_object_to_alternator_map(&obj)?)); + } } handle_request(&ctx, builder).await?; @@ -444,58 +475,66 @@ pub async fn query( let mut builder = client.query().table_name(table_name.deref()); - if let Some(Value::String(key_condition_expression)) = params.get("query") { - builder = - builder.key_condition_expression(key_condition_expression.borrow_ref()?.to_string()); + if let Some(v) = params.get("query") { + if let Ok(s) = v.borrow_ref::() { + builder = builder.key_condition_expression(s.as_str().to_string()); + } } - if let Some(Value::String(filter_expression)) = params.get("filter") { - builder = builder.filter_expression(filter_expression.borrow_ref()?.to_string()); + if let Some(v) = params.get("filter") { + if let Ok(s) = v.borrow_ref::() { + builder = builder.filter_expression(s.as_str().to_string()); + } } - if let Some(Value::Object(attr_names)) = params.get("attribute_names") { - builder = - builder.set_expression_attribute_names(Some(extract_attribute_names(attr_names)?)); + if let Some(v) = params.get("attribute_names") { + if let Ok(obj) = v.borrow_ref::() { + builder = builder.set_expression_attribute_names(Some(extract_attribute_names(&obj)?)); + } } - if let Some(Value::Object(attr_values)) = params.get("attribute_values") { - builder = builder.set_expression_attribute_values(Some(rune_object_to_alternator_map( - attr_values.clone().into_ref()?, - )?)); + if let Some(v) = params.get("attribute_values") { + if let Ok(obj) = v.borrow_ref::() { + builder = + builder.set_expression_attribute_values(Some(rune_object_to_alternator_map(&obj)?)); + } } - if let Some(Value::Bool(consistent_read)) = params.get("consistent_read") { - builder = builder.consistent_read(*consistent_read); + if let Some(b) = params.get("consistent_read").and_then(|v| v.as_bool().ok()) { + builder = builder.consistent_read(b); } if let Some(limit_val) = params.get("limit") { - builder = builder.limit(match limit_val { - Value::Integer(i) => match i32::try_from(*i) { + if let Ok(i) = limit_val.as_signed() { + builder = builder.limit(match i32::try_from(i) { Ok(val) => val, Err(_) => return bad_input("limit is out of range"), - }, - _ => return bad_input("limit must be an integer"), - }); + }); + } else { + return bad_input("limit must be an integer"); + } } - let validation = if let Some(Value::Vec(validation)) = params.get("validation") { - Some( - extract_validation_args(validation.borrow_ref()?.to_vec()) - .map_err(|s| AlternatorError::new(AlternatorErrorKind::BadInput(s)))?, - ) + let validation = if let Some(v) = params.get("validation") { + if let Ok(vec) = v.borrow_ref::() { + Some( + extract_validation_args(vec.to_vec()) + .map_err(|s| AlternatorError::new(AlternatorErrorKind::BadInput(s)))?, + ) + } else { + None + } } else { None }; let result = handle_request_with_validation(&ctx, builder, validation, "Query").await?; - if let Some(Value::Bool(with_result)) = params.get("with_result") { - if *with_result { - return Ok(result.to_value().into_result()?); - } + if params.get("with_result").and_then(|v| v.as_bool().ok()) == Some(true) { + return Ok(result.to_value()?); } - Ok(Value::EmptyTuple) + Ok(Value::from(())) } /// Scans items from the table. @@ -523,48 +562,55 @@ pub async fn scan( let mut builder = client.scan().table_name(table_name.deref()); - if let Some(Value::String(filter_expression)) = params.get("filter") { - builder = builder.filter_expression(filter_expression.borrow_ref()?.to_string()); + if let Some(v) = params.get("filter") { + if let Ok(s) = v.borrow_ref::() { + builder = builder.filter_expression(s.as_str().to_string()); + } } - if let Some(Value::Object(attr_names)) = params.get("attribute_names") { - builder = - builder.set_expression_attribute_names(Some(extract_attribute_names(attr_names)?)); + if let Some(v) = params.get("attribute_names") { + if let Ok(obj) = v.borrow_ref::() { + builder = builder.set_expression_attribute_names(Some(extract_attribute_names(&obj)?)); + } } - if let Some(Value::Object(attr_values)) = params.get("attribute_values") { - builder = builder.set_expression_attribute_values(Some(rune_object_to_alternator_map( - attr_values.clone().into_ref()?, - )?)); + if let Some(v) = params.get("attribute_values") { + if let Ok(obj) = v.borrow_ref::() { + builder = + builder.set_expression_attribute_values(Some(rune_object_to_alternator_map(&obj)?)); + } } - if let Some(Value::Bool(consistent_read)) = params.get("consistent_read") { - builder = builder.consistent_read(*consistent_read); + if let Some(b) = params.get("consistent_read").and_then(|v| v.as_bool().ok()) { + builder = builder.consistent_read(b); } if let Some(limit_val) = params.get("limit") { - builder = builder.limit(match limit_val { - Value::Integer(i) => *i as i32, - _ => return bad_input("limit must be an integer"), - }); + if let Ok(i) = limit_val.as_signed() { + builder = builder.limit(i as i32); + } else { + return bad_input("limit must be an integer"); + } } - let validation = if let Some(Value::Vec(validation)) = params.get("validation") { - Some( - extract_validation_args(validation.borrow_ref()?.to_vec()) - .map_err(|s| AlternatorError::new(AlternatorErrorKind::BadInput(s)))?, - ) + let validation = if let Some(v) = params.get("validation") { + if let Ok(vec) = v.borrow_ref::() { + Some( + extract_validation_args(vec.to_vec()) + .map_err(|s| AlternatorError::new(AlternatorErrorKind::BadInput(s)))?, + ) + } else { + None + } } else { None }; let result = handle_request_with_validation(&ctx, builder, validation, "Scan").await?; - if let Some(Value::Bool(with_result)) = params.get("with_result") { - if *with_result { - return Ok(result.to_value().into_result()?); - } + if params.get("with_result").and_then(|v| v.as_bool().ok()) == Some(true) { + return Ok(result.to_value()?); } - Ok(Value::EmptyTuple) + Ok(Value::from(())) } diff --git a/src/scripting/alternator/types.rs b/src/scripting/alternator/types.rs index 1a2d5d1..6e8a74e 100644 --- a/src/scripting/alternator/types.rs +++ b/src/scripting/alternator/types.rs @@ -1,67 +1,66 @@ use super::alternator_error::{AlternatorError, AlternatorErrorKind}; use aws_sdk_dynamodb::types::AttributeValue; -use rune::runtime::{Bytes, Object, Ref}; +use rune::alloc::String as RuneString; +use rune::runtime::{Bytes, Object}; use rune::{ToValue, Value}; use std::collections::HashMap; pub fn rune_value_to_alternator_attribute(v: Value) -> Result { - match v { - Value::Bool(b) => Ok(AttributeValue::Bool(b)), - - // DynamoDB represents all numbers as strings - Value::Integer(i) => Ok(AttributeValue::N(i.to_string())), - // To distinguish floats from integers, we print them with the decimal point - Value::Float(f) => Ok(AttributeValue::N(format!("{:?}", f))), - - Value::String(s) => Ok(AttributeValue::S(s.into_ref()?.to_string())), - - Value::Bytes(b) => Ok(AttributeValue::B(b.into_ref()?.to_vec().into())), - - Value::Vec(v) => Ok(AttributeValue::L( - v.into_ref()? - .iter() - .map(|v| rune_value_to_alternator_attribute(v.clone())) - .collect::>()?, - )), - - Value::Object(o) => Ok(AttributeValue::M(rune_object_to_alternator_map( - o.into_ref()?, - )?)), - - Value::Option(o) => match o.into_ref()?.as_ref() { - Some(v) => rune_value_to_alternator_attribute(v.clone()), + if let Ok(b) = v.as_bool() { + return Ok(AttributeValue::Bool(b)); + } + if let Ok(i) = v.as_signed() { + return Ok(AttributeValue::N(i.to_string())); + } + if let Ok(f) = v.as_float() { + return Ok(AttributeValue::N(format!("{:?}", f))); + } + if let Ok(s) = v.borrow_ref::() { + return Ok(AttributeValue::S(s.as_str().to_string())); + } + if let Ok(b) = v.borrow_ref::() { + return Ok(AttributeValue::B(b.to_vec().into())); + } + if let Ok(vec) = v.borrow_ref::() { + let list = vec + .iter() + .map(|v| rune_value_to_alternator_attribute(v.clone())) + .collect::>()?; + return Ok(AttributeValue::L(list)); + } + if let Ok(obj) = v.borrow_ref::() { + let map = obj + .iter() + .map(|(k, v)| { + Ok(( + k.to_string(), + rune_value_to_alternator_attribute(v.clone())?, + )) + }) + .collect::, AlternatorError>>()?; + return Ok(AttributeValue::M(map)); + } + if let Ok(opt) = v.borrow_ref::>() { + return match opt.as_ref() { + Some(inner) => rune_value_to_alternator_attribute(inner.clone()), None => Ok(AttributeValue::Null(true)), - }, - - _ => Err(AlternatorError::new(AlternatorErrorKind::ConversionError( - format!("Unsupported Rune Value type for: {:?}", v), - ))), + }; } -} - -pub fn rune_object_to_alternator_map( - o: Ref, -) -> Result, AlternatorError> { - o.iter() - .map(|(k, v)| { - Ok(( - k.to_string(), - rune_value_to_alternator_attribute(v.clone())?, - )) - }) - .collect() + Err(AlternatorError::new(AlternatorErrorKind::ConversionError( + format!("Unsupported Rune Value type for: {:?}", v), + ))) } pub fn alternator_attribute_to_rune_value(attr: AttributeValue) -> Result { match attr { - AttributeValue::Bool(b) => Ok(Value::Bool(b)), + AttributeValue::Bool(b) => Ok(Value::from(b)), AttributeValue::N(n) => { // Try parsing as integer first, then as float if let Ok(i) = n.parse::() { - Ok(Value::Integer(i)) + Ok(Value::from(i)) } else if let Ok(f) = n.parse::() { - Ok(Value::Float(f)) + Ok(Value::from(f)) } else { Err(AlternatorError::new(AlternatorErrorKind::ConversionError( format!("Invalid number format: {}", n), @@ -69,20 +68,31 @@ pub fn alternator_attribute_to_rune_value(attr: AttributeValue) -> Result Ok(s.as_str().to_value().into_result()?), + AttributeValue::S(s) => Ok(Value::new(RuneString::try_from(s).map_err(|e| { + AlternatorError::new(AlternatorErrorKind::ConversionError(e.to_string())) + })?) + .map_err(|e| AlternatorError::new(AlternatorErrorKind::ConversionError(e.to_string())))?), - AttributeValue::B(b) => Ok(Bytes::try_from(b.into_inner())?.to_value().into_result()?), + AttributeValue::B(b) => Ok(Bytes::try_from(b.into_inner())?.to_value()?), - AttributeValue::L(l) => Ok(l - .into_iter() - .map(alternator_attribute_to_rune_value) - .collect::, _>>()? - .to_value() - .into_result()?), + AttributeValue::L(l) => { + let mut rune_vec = rune::runtime::Vec::new(); + for attr in l { + let val = alternator_attribute_to_rune_value(attr)?; + rune_vec.push(val).map_err(|e| { + AlternatorError::new(AlternatorErrorKind::ConversionError(e.to_string())) + })?; + } + Ok(Value::vec(rune_vec.into_inner()).map_err(|e| { + AlternatorError::new(AlternatorErrorKind::ConversionError(e.to_string())) + })?) + } AttributeValue::M(map) => Ok(alternator_map_to_rune_object(map)?), - AttributeValue::Null(_) => Ok(None::.to_value().into_result()?), + AttributeValue::Null(_) => Ok(Value::try_from(None::).map_err(|e| { + AlternatorError::new(AlternatorErrorKind::ConversionError(e.to_string())) + })?), _ => Err(AlternatorError::new(AlternatorErrorKind::ConversionError( format!("Unsupported Alternator AttributeValue type: {:?}", attr), @@ -90,13 +100,32 @@ pub fn alternator_attribute_to_rune_value(attr: AttributeValue) -> Result Result, AlternatorError> { + obj.iter() + .map(|(k, v)| { + Ok(( + k.to_string(), + rune_value_to_alternator_attribute(v.clone())?, + )) + }) + .collect() +} + pub fn alternator_map_to_rune_object( map: HashMap, ) -> Result { - Ok(map - .into_iter() - .map(|(k, v)| Ok((k, alternator_attribute_to_rune_value(v)?))) - .collect::, AlternatorError>>()? - .to_value() - .into_result()?) + let mut obj = Object::new(); + for (k, v) in map { + let rune_key = RuneString::try_from(k).map_err(|e| { + AlternatorError::new(AlternatorErrorKind::ConversionError(e.to_string())) + })?; + let rune_val = alternator_attribute_to_rune_value(v)?; + obj.insert(rune_key, rune_val).map_err(|e| { + AlternatorError::new(AlternatorErrorKind::ConversionError(e.to_string())) + })?; + } + Value::new(obj) + .map_err(|e| AlternatorError::new(AlternatorErrorKind::ConversionError(e.to_string()))) } diff --git a/src/scripting/cql/cass_error.rs b/src/scripting/cql/cass_error.rs index 2054ca0..3d2970f 100644 --- a/src/scripting/cql/cass_error.rs +++ b/src/scripting/cql/cass_error.rs @@ -116,9 +116,9 @@ pub struct QueryInfo { } impl CassError { - #[rune::function(protocol = STRING_DISPLAY)] + #[rune::function(protocol = DISPLAY_FMT)] pub fn string_display(&self, f: &mut rune::runtime::Formatter) -> VmResult<()> { - vm_write!(f, "{}", self.to_string()); + let _ = vm_write!(f, "{}", self.to_string()); VmResult::Ok(()) } @@ -280,26 +280,22 @@ impl std::error::Error for CassError {} /// Formats a rune Value into a list of parameter display strings for error messages fn rune_value_to_param_strings(value: Option<&Value>) -> Vec { - match value { - None => vec![], - Some(Value::Tuple(tuple)) => match tuple.borrow_ref() { - Ok(tuple) => tuple.iter().map(|v| format!("{v:?}")).collect(), - Err(_) => vec!["".to_string()], - }, - Some(Value::Vec(vec)) => match vec.borrow_ref() { - Ok(vec) => vec.iter().map(|v| format!("{v:?}")).collect(), - Err(_) => vec!["".to_string()], - }, - Some(Value::Object(obj)) => match obj.borrow_ref() { - Ok(obj) => obj.iter().map(|(k, v)| format!("{k}: {v:?}")).collect(), - Err(_) => vec!["".to_string()], - }, - Some(Value::Struct(obj)) => match obj.borrow_ref() { - Ok(obj) => vec![format!("{obj:?}")], - Err(_) => vec!["".to_string()], - }, - Some(other) => vec![format!("{other:?}")], + let Some(v) = value else { + return vec![]; + }; + if let Ok(tuple) = v.borrow_ref::() { + return tuple.iter().map(|v| format!("{v:?}")).collect(); } + if let Ok(vec) = v.borrow_ref::() { + return vec.iter().map(|v| format!("{v:?}")).collect(); + } + if let Ok(obj) = v.borrow_ref::() { + return obj.iter().map(|(k, v)| format!("{k}: {v:?}")).collect(); + } + if let Ok(rune::runtime::TypeValue::Struct(s)) = v.as_type_value() { + return vec![format!("{s:?}")]; + } + vec![format!("{v:?}")] } impl Display for QueryInfo { diff --git a/src/scripting/cql/context.rs b/src/scripting/cql/context.rs index dec0cc8..39866ca 100644 --- a/src/scripting/cql/context.rs +++ b/src/scripting/cql/context.rs @@ -10,7 +10,7 @@ use crate::stats::session::SessionStats; use once_cell::sync::Lazy; use regex::Regex; -use rune::runtime::{Object, Shared, Vec as RuneVec}; +use rune::runtime::{Object, Vec as RuneVec}; use rune::{Any, Value}; use scylla::client::session::Session; use scylla::response::PagingState; @@ -37,12 +37,12 @@ pub struct Context { // which don't 'depend on'/'use' the 'session' object. session: Option>, page_size: u64, - statements: HashMap>, - pub stats: TryLock, + statements: Arc>>>, + pub stats: Arc>, pub retry_number: u64, pub retry_interval: RetryInterval, pub validation_strategy: ValidationStrategy, - pub partition_row_presets: HashMap, + pub partition_row_presets: Arc>>, #[rune(get, set, add_assign, copy)] pub load_cycle_count: u64, #[rune(get)] @@ -72,20 +72,21 @@ impl Context { retry_interval: RetryInterval, validation_strategy: ValidationStrategy, ) -> Context { + let data = Value::new(Object::new()).unwrap(); Context { start_time: TryLock::new(Instant::now()), session: session.map(Arc::new), page_size, - statements: HashMap::new(), - stats: TryLock::new(SessionStats::new()), + statements: Arc::new(TryLock::new(HashMap::new())), + stats: Arc::new(TryLock::new(SessionStats::new())), retry_number, retry_interval, validation_strategy, - partition_row_presets: HashMap::new(), + partition_row_presets: Arc::new(TryLock::new(HashMap::new())), load_cycle_count: 0, preferred_datacenter, preferred_rack, - data: Value::Object(Shared::new(Object::new()).unwrap()), + data, } } @@ -99,12 +100,14 @@ impl Context { Ok(Context { session: self.session.clone(), page_size: self.page_size, - statements: self.statements.clone(), - stats: TryLock::new(SessionStats::default()), + statements: Arc::new(TryLock::new(self.statements.try_lock().unwrap().clone())), + stats: Arc::new(TryLock::new(SessionStats::default())), retry_number: self.retry_number, retry_interval: self.retry_interval, validation_strategy: self.validation_strategy, - partition_row_presets: self.partition_row_presets.clone(), + partition_row_presets: Arc::new(TryLock::new( + self.partition_row_presets.try_lock().unwrap().clone(), + )), load_cycle_count: self.load_cycle_count, preferred_datacenter: self.preferred_datacenter.clone(), preferred_rack: self.preferred_rack.clone(), @@ -113,6 +116,27 @@ impl Context { }) } + /// Creates a shallow clone that shares the Arc-backed fields (stats, statements, presets) + /// with the original. Used to create a rune-owned `Value` for function call arguments + /// without losing stats tracking. + pub fn shallow_clone(&self) -> Self { + Context { + start_time: TryLock::new(*self.start_time.try_lock().unwrap()), + session: self.session.clone(), + page_size: self.page_size, + statements: Arc::clone(&self.statements), + stats: Arc::clone(&self.stats), + retry_number: self.retry_number, + retry_interval: self.retry_interval, + validation_strategy: self.validation_strategy, + partition_row_presets: Arc::clone(&self.partition_row_presets), + load_cycle_count: self.load_cycle_count, + preferred_datacenter: self.preferred_datacenter.clone(), + preferred_rack: self.preferred_rack.clone(), + data: self.data.clone(), + } + } + /// Returns cluster metadata such as cluster name and DB version. pub async fn cluster_info(&self) -> Result, CassError> { let session = match &self.session { @@ -195,14 +219,17 @@ impl Context { } /// Prepares a statement and stores it in an internal statement map for future use. - pub async fn prepare(&mut self, key: &str, cql: &str) -> Result<(), CassError> { + pub async fn prepare(&self, key: &str, cql: &str) -> Result<(), CassError> { match &self.session { Some(session) => { let statement = session .prepare(Statement::new(cql).with_page_size(self.page_size as i32)) .await .map_err(|e| CassError::prepare_error(cql, e))?; - self.statements.insert(key.to_string(), Arc::new(statement)); + self.statements + .try_lock() + .unwrap() + .insert(key.to_string(), Arc::new(statement)); Ok(()) } None => Err(CassError(CassErrorKind::Error( @@ -319,12 +346,17 @@ impl Context { ))); } let stmt = if let Some(key) = key { - self.statements.get(key).ok_or_else(|| { - CassError(CassErrorKind::PreparedStatementNotFound(key.to_string())) - })? + self.statements + .try_lock() + .unwrap() + .get(key) + .cloned() + .ok_or_else(|| { + CassError(CassErrorKind::PreparedStatementNotFound(key.to_string())) + })? } else { let cql = cql.expect("failed to unwrap the 'cql' parameter"); - &Arc::new( + Arc::new( session .prepare(Statement::new(cql).with_page_size(self.page_size as i32)) .await @@ -359,7 +391,7 @@ impl Context { while current_attempt_num <= self.retry_number { let start_time = self.stats.try_lock().unwrap().start_request(); let rs = session - .execute_single_page(stmt, &query_params, paging_state.clone()) + .execute_single_page(&stmt, &query_params, paging_state.clone()) .await; let current_duration = Instant::now() - start_time; let (page, paging_state_response) = match rs { @@ -380,11 +412,11 @@ impl Context { match row_result { Ok(RuneRow(row_obj)) => { rune_rows - .push(Value::Object(Shared::new(row_obj).map_err(|_| { + .push(Value::new(row_obj).map_err(|_| { CassError(CassErrorKind::Error( "Failed to create shared row object".to_string(), )) - })?)) + })?) .map_err(|_| { CassError(CassErrorKind::Error( "Failed to push row to result vector".to_string(), @@ -413,13 +445,13 @@ impl Context { .unwrap() .complete_request(all_pages_duration, rows_num); if process_and_return_data { - return Ok(Value::Vec(Shared::new(rune_rows).map_err(|_| { + return Value::vec(rune_rows.into_inner()).map_err(|_| { CassError(CassErrorKind::Error( "Failed to create shared result vector".to_string(), )) - })?)); + }); } else { - let empty_rune_vec = Value::Vec(Shared::new(RuneVec::new())?); + let empty_rune_vec = Value::vec(Default::default())?; let rows_min = match expected_rows_num_min { None => return Ok(empty_rune_vec), Some(rows_min) => rows_min, @@ -492,10 +524,16 @@ impl Context { let mut batch: Batch = Batch::new(BatchType::Logged); let mut batch_values: Vec> = Vec::with_capacity(keys_len); for (i, key) in keys.into_iter().enumerate() { - let statement = self.statements.get(key).ok_or_else(|| { - CassError(CassErrorKind::PreparedStatementNotFound(key.to_string())) - })?; - batch.append_statement((**statement).clone()); + let statement = self + .statements + .try_lock() + .unwrap() + .get(key) + .cloned() + .ok_or_else(|| { + CassError(CassErrorKind::PreparedStatementNotFound(key.to_string())) + })?; + batch.append_statement((*statement).clone()); batch_values.push(RuneQueryParams::new(params.get(i))); } match &self.session { diff --git a/src/scripting/cql/deserialize.rs b/src/scripting/cql/deserialize.rs index 328685e..d233574 100644 --- a/src/scripting/cql/deserialize.rs +++ b/src/scripting/cql/deserialize.rs @@ -5,7 +5,7 @@ use std::net::IpAddr; use super::cass_error::{CassError, CassErrorKind}; use rune::alloc::String as RuneString; -use rune::runtime::{Object, OwnedTuple, Shared, Vec as RuneVec}; +use rune::runtime::{Object, OwnedTuple, Vec as RuneVec}; use rune::Value; use scylla::cluster::metadata::{CollectionType, ColumnType, NativeType}; use scylla::deserialize::row::{ColumnIterator, DeserializeRow}; @@ -34,13 +34,9 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { v: Option>, ) -> Result { let Some(slice) = v else { - return Ok(RuneValue(Value::Option(Shared::new(None).map_err( - |_| { - DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared None".to_string(), - ))) - }, - )?))); + return Ok(RuneValue( + Value::try_from(None::).map_err(DeserializationError::new)?, + )); }; // This matches the old logic that used CqlValue @@ -52,13 +48,9 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { // can't be empty } _ => { - return Ok(RuneValue(Value::Option(Shared::new(None).map_err( - |_| { - DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared None".to_string(), - ))) - }, - )?))) + return Ok(RuneValue( + Value::try_from(None::).map_err(DeserializationError::new)?, + )); } } } @@ -67,58 +59,56 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { ColumnType::Native(NativeType::Ascii) | ColumnType::Native(NativeType::Text) => { let string = >::deserialize(typ, Some(slice))?; - Value::String( - Shared::new(RuneString::try_from(string).expect("Failed to create RuneString")) - .map_err(|_| { - DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared string".to_string(), - ))) - })?, - ) + Value::new(RuneString::try_from(string).expect("Failed to create RuneString")) + .map_err(|_| { + DeserializationError::new(CassError(CassErrorKind::Error( + "Failed to create string value".to_string(), + ))) + })? } ColumnType::Native(NativeType::Boolean) => { >::deserialize(typ, Some(slice)) - .map(Value::Bool)? + .map(Value::from)? } ColumnType::Native(NativeType::TinyInt) => { >::deserialize(typ, Some(slice)) - .map(|i| Value::Integer(i.into()))? + .map(|i| Value::from(i as i64))? } ColumnType::Native(NativeType::SmallInt) => { >::deserialize(typ, Some(slice)) - .map(|i| Value::Integer(i.into()))? + .map(|i| Value::from(i as i64))? } ColumnType::Native(NativeType::Int) => { >::deserialize(typ, Some(slice)) - .map(|i| Value::Integer(i.into()))? + .map(|i| Value::from(i as i64))? } ColumnType::Native(NativeType::BigInt) => { >::deserialize(typ, Some(slice)) - .map(Value::Integer)? + .map(Value::from)? } ColumnType::Native(NativeType::Float) => { >::deserialize(typ, Some(slice)) - .map(|i| Value::Float(i.into()))? + .map(|f| Value::from(f as f64))? } ColumnType::Native(NativeType::Double) => { >::deserialize(typ, Some(slice)) - .map(Value::Float)? + .map(Value::from)? } ColumnType::Native(NativeType::Counter) => { >::deserialize(typ, Some(slice)) - .map(|c| Value::Integer(c.0))? + .map(|c| Value::from(c.0))? } ColumnType::Native(NativeType::Timestamp) => { >::deserialize(typ, Some(slice)) - .map(|ts| Value::Integer(ts.0))? + .map(|ts| Value::from(ts.0))? } ColumnType::Native(NativeType::Date) => { >::deserialize(typ, Some(slice)) - .map(|date| Value::Integer(date.0.into()))? + .map(|date| Value::from(date.0 as i64))? } ColumnType::Native(NativeType::Time) => { >::deserialize(typ, Some(slice)) - .map(|time| Value::Integer(time.0))? + .map(|time| Value::from(time.0))? } ColumnType::Native(NativeType::Blob) => { // Note: is it intentional that blobs are representes as Vec of rune Bytes? @@ -130,64 +120,58 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { )?; let mut rune_vec = RuneVec::new(); for byte in bytes_slice { - rune_vec.push(Value::Byte(*byte)).map_err(|_| { + rune_vec.push(Value::from(*byte)).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( "Failed to push byte to Rune vector".to_string(), ))) })?; } - Value::Vec(Shared::new(rune_vec).map_err(|_| { + Value::vec(rune_vec.into_inner()).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared vector for blob".to_string(), + "Failed to create vector for blob".to_string(), ))) - })?) + })? } ColumnType::Native(NativeType::Uuid) => { let uuid = >::deserialize(typ, Some(slice))?; - Value::String( - Shared::new( - RuneString::try_from(uuid.to_string()) - .expect("Failed to create RuneString for UUID"), - ) - .map_err(|_| { - DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared string for UUID".to_string(), - ))) - })?, + Value::new( + RuneString::try_from(uuid.to_string()) + .expect("Failed to create RuneString for UUID"), ) + .map_err(|_| { + DeserializationError::new(CassError(CassErrorKind::Error( + "Failed to create string value for UUID".to_string(), + ))) + })? } ColumnType::Native(NativeType::Timeuuid) => { let timeuuid = >::deserialize( typ, Some(slice), )?; - Value::String( - Shared::new( - RuneString::try_from(timeuuid.to_string()) - .expect("Failed to create RuneString for TimeUuid"), - ) - .map_err(|_| { - DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared string for TimeUuid".to_string(), - ))) - })?, + Value::new( + RuneString::try_from(timeuuid.to_string()) + .expect("Failed to create RuneString for TimeUuid"), ) + .map_err(|_| { + DeserializationError::new(CassError(CassErrorKind::Error( + "Failed to create string value for TimeUuid".to_string(), + ))) + })? } ColumnType::Native(NativeType::Inet) => { let addr = >::deserialize(typ, Some(slice))?; - Value::String( - Shared::new( - RuneString::try_from(addr.to_string()) - .expect("Failed to create RuneString for IpAddr"), - ) - .map_err(|_| { - DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared string for IpAddr".to_string(), - ))) - })?, + Value::new( + RuneString::try_from(addr.to_string()) + .expect("Failed to create RuneString for IpAddr"), ) + .map_err(|_| { + DeserializationError::new(CassError(CassErrorKind::Error( + "Failed to create string value for IpAddr".to_string(), + ))) + })? } ColumnType::Native(NativeType::Varint) => { let varint = as DeserializeValue<'frame, 'metadata>>::deserialize(typ, Some(slice))?; @@ -205,7 +189,7 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { padded[8 - varint_bytes.len()..].copy_from_slice(varint_bytes); i64::from_be_bytes(padded) }; - Value::Integer(integer) + Value::from(integer) } ColumnType::Native(NativeType::Decimal) => { let decimal = as DeserializeValue< @@ -228,16 +212,14 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { u32::try_from(scale).map_err(DeserializationError::new)?, ) .unwrap(); - Value::String( - Shared::new( - RuneString::try_from(dec.to_string()).expect("Failed to create RuneString"), - ) - .map_err(|_| { - DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared string for Decimal".to_string(), - ))) - })?, + Value::new( + RuneString::try_from(dec.to_string()).expect("Failed to create RuneString"), ) + .map_err(|_| { + DeserializationError::new(CassError(CassErrorKind::Error( + "Failed to create string value for Decimal".to_string(), + ))) + })? } ColumnType::Native(NativeType::Duration) => { let duration = >::deserialize( @@ -249,7 +231,7 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { rune_obj .insert( RuneString::try_from("months").expect("Failed to create RuneString"), - Value::Integer(duration.months as i64), + Value::from(duration.months as i64), ) .map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( @@ -259,7 +241,7 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { rune_obj .insert( RuneString::try_from("days").expect("Failed to create RuneString"), - Value::Integer(duration.days as i64), + Value::from(duration.days as i64), ) .map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( @@ -269,18 +251,18 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { rune_obj .insert( RuneString::try_from("nanoseconds").expect("Failed to create RuneString"), - Value::Integer(duration.nanoseconds), + Value::from(duration.nanoseconds), ) .map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( "Failed to insert nanoseconds into duration object".to_string(), ))) })?; - Value::Object(Shared::new(rune_obj).map_err(|_| { + Value::new(rune_obj).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared object for Duration".to_string(), + "Failed to create object value for Duration".to_string(), ))) - })?) + })? } ColumnType::Vector { dimensions, .. } => { let mut rune_vec = @@ -298,11 +280,11 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { ))) })?; } - Value::Vec(Shared::new(rune_vec).map_err(|_| { + Value::vec(rune_vec.into_inner()).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared vector".to_string(), + "Failed to create vector value".to_string(), ))) - })?) + })? } ColumnType::Collection { typ: coll_type, .. } => match coll_type { CollectionType::List(_) | CollectionType::Set(_) => { @@ -321,11 +303,11 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { ))) })?; } - Value::Vec(Shared::new(rune_vec).map_err(|_| { + Value::vec(rune_vec.into_inner()).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared vector".to_string(), + "Failed to create vector value".to_string(), ))) - })?) + })? } CollectionType::Map(_, _) => { let cql_map_iterator = as DeserializeValue< @@ -342,22 +324,22 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { "Failed to create Rune OwnedTuple".to_string(), ))) })?; - let tuple = Value::Tuple(Shared::new(owned_tuple).map_err(|_| { + let tuple = Value::new(owned_tuple).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create Rune Shared".to_string(), + "Failed to create Rune tuple value".to_string(), ))) - })?); + })?; rune_vec.push(tuple).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( "Failed to push map key-value pair to the Rune vector".to_string(), ))) })?; } - Value::Vec(Shared::new(rune_vec).map_err(|_| { + Value::vec(rune_vec.into_inner()).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( "Failed to create shared Rune vector".to_string(), ))) - })?) + })? } _ => todo!(), // unexpected, should never be reached }, @@ -388,11 +370,11 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { ))) })?; } - Value::Object(Shared::new(rune_obj).map_err(|_| { + Value::new(rune_obj).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared object for UDT".to_string(), + "Failed to create object value for UDT".to_string(), ))) - })?) + })? } ColumnType::Tuple(tuple) => { let mut rune_vec = @@ -418,11 +400,11 @@ impl<'frame, 'metadata> DeserializeValue<'frame, 'metadata> for RuneValue { })?; } - Value::Vec(Shared::new(rune_vec).map_err(|_| { + Value::vec(rune_vec.into_inner()).map_err(|_| { DeserializationError::new(CassError(CassErrorKind::Error( - "Failed to create shared vector for tuple".to_string(), + "Failed to create vector value for tuple".to_string(), ))) - })?) + })? } _ => todo!(), // unexpected, should never be reached @@ -496,19 +478,17 @@ mod tests { } fn assert_is_none(v: Value) { - match v { - Value::Option(shared) => { - assert!(shared.borrow_ref().unwrap().is_none(), "expected None"); - } - other => panic!("expected Option(None), got {other:?}"), + match v.borrow_ref::>() { + Ok(opt) => assert!(opt.is_none(), "expected None"), + Err(_) => panic!("expected Option(None)"), } } fn str_from(v: &Value) -> String { - match v { - Value::String(s) => s.borrow_ref().unwrap().as_str().to_owned(), - other => panic!("expected String, got {other:?}"), - } + v.borrow_ref::() + .expect("expected String") + .as_str() + .to_owned() } fn col_spec<'a>(name: &'a str, typ: ColumnType<'a>) -> ColumnSpec<'a> { @@ -550,12 +530,10 @@ mod tests { let bytes = Bytes::new(); let slice = FrameSlice::new(&bytes); let result = RuneValue::deserialize(&typ, Some(slice)).unwrap().0; - match result { - Value::Vec(shared) => { - assert!(shared.borrow_ref().unwrap().is_empty()); - } - other => panic!("expected Vec, got {other:?}"), - } + let vec = result + .borrow_ref::() + .expect("expected Vec"); + assert!(vec.is_empty()); } #[test] @@ -573,14 +551,14 @@ mod tests { fn bool_true() { let typ = ColumnType::Native(NativeType::Boolean); let bytes = cql_to_raw(&typ, CqlValue::Boolean(true)); - assert!(matches!(deser(&typ, &bytes), Value::Bool(true))); + assert!(deser(&typ, &bytes).as_bool().unwrap()); } #[test] fn bool_false() { let typ = ColumnType::Native(NativeType::Boolean); let bytes = cql_to_raw(&typ, CqlValue::Boolean(false)); - assert!(matches!(deser(&typ, &bytes), Value::Bool(false))); + assert!(!deser(&typ, &bytes).as_bool().unwrap()); } // ── integer types ────────────────────────────────────────────────────────── @@ -589,35 +567,35 @@ mod tests { fn tiny_int_roundtrip() { let typ = ColumnType::Native(NativeType::TinyInt); let bytes = cql_to_raw(&typ, CqlValue::TinyInt(-42)); - assert!(matches!(deser(&typ, &bytes), Value::Integer(-42))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), -42_i64); } #[test] fn small_int_roundtrip() { let typ = ColumnType::Native(NativeType::SmallInt); let bytes = cql_to_raw(&typ, CqlValue::SmallInt(1000)); - assert!(matches!(deser(&typ, &bytes), Value::Integer(1000))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), 1000_i64); } #[test] fn int_roundtrip() { let typ = ColumnType::Native(NativeType::Int); let bytes = cql_to_raw(&typ, CqlValue::Int(100_000)); - assert!(matches!(deser(&typ, &bytes), Value::Integer(100_000))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), 100_000_i64); } #[test] fn big_int_roundtrip() { let typ = ColumnType::Native(NativeType::BigInt); let bytes = cql_to_raw(&typ, CqlValue::BigInt(i64::MAX)); - assert!(matches!(deser(&typ, &bytes), Value::Integer(i64::MAX))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), i64::MAX); } #[test] fn counter_roundtrip() { let typ = ColumnType::Native(NativeType::Counter); let bytes = cql_to_raw(&typ, CqlValue::Counter(Counter(42))); - assert!(matches!(deser(&typ, &bytes), Value::Integer(42))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), 42_i64); } // ── floating-point ───────────────────────────────────────────────────────── @@ -626,20 +604,16 @@ mod tests { fn float_roundtrip() { let typ = ColumnType::Native(NativeType::Float); let bytes = cql_to_raw(&typ, CqlValue::Float(1.5)); - match deser(&typ, &bytes) { - Value::Float(f) => assert!((f - 1.5_f32 as f64).abs() < 1e-7), - other => panic!("expected Float, got {other:?}"), - } + let f = deser(&typ, &bytes).as_float().expect("expected Float"); + assert!((f - 1.5_f32 as f64).abs() < 1e-7); } #[test] fn double_roundtrip() { let typ = ColumnType::Native(NativeType::Double); let bytes = cql_to_raw(&typ, CqlValue::Double(1.23456789)); - match deser(&typ, &bytes) { - Value::Float(f) => assert!((f - 1.23456789).abs() < 1e-10), - other => panic!("expected Float, got {other:?}"), - } + let f = deser(&typ, &bytes).as_float().expect("expected Float"); + assert!((f - 1.23456789).abs() < 1e-10); } // ── text / ascii ─────────────────────────────────────────────────────────── @@ -664,17 +638,15 @@ mod tests { fn blob_becomes_vec_of_bytes() { let typ = ColumnType::Native(NativeType::Blob); let bytes = cql_to_raw(&typ, CqlValue::Blob(vec![0xDE, 0xAD, 0xBE, 0xEF])); - match deser(&typ, &bytes) { - Value::Vec(shared) => { - let vec = shared.borrow_ref().unwrap(); - assert_eq!(vec.len(), 4); - assert!(matches!(vec[0], Value::Byte(0xDE))); - assert!(matches!(vec[1], Value::Byte(0xAD))); - assert!(matches!(vec[2], Value::Byte(0xBE))); - assert!(matches!(vec[3], Value::Byte(0xEF))); - } - other => panic!("expected Vec, got {other:?}"), - } + let binding = deser(&typ, &bytes); + let vec = binding + .borrow_ref::() + .expect("expected Vec"); + assert_eq!(vec.len(), 4); + assert_eq!(vec[0].as_integer::().unwrap(), 0xDE); + assert_eq!(vec[1].as_integer::().unwrap(), 0xAD); + assert_eq!(vec[2].as_integer::().unwrap(), 0xBE); + assert_eq!(vec[3].as_integer::().unwrap(), 0xEF); } // ── UUID / Timeuuid ──────────────────────────────────────────────────────── @@ -726,10 +698,10 @@ mod tests { let typ = ColumnType::Native(NativeType::Timestamp); let ts = CqlTimestamp(1_609_459_200_000); // 2021-01-01 00:00:00 UTC in ms let bytes = cql_to_raw(&typ, CqlValue::Timestamp(ts)); - assert!(matches!( - deser(&typ, &bytes), - Value::Integer(1_609_459_200_000) - )); + assert_eq!( + deser(&typ, &bytes).as_signed().unwrap(), + 1_609_459_200_000_i64 + ); } #[test] @@ -737,7 +709,7 @@ mod tests { let typ = ColumnType::Native(NativeType::Date); let date = CqlDate(2_147_483_648u32); // 2^31 let bytes = cql_to_raw(&typ, CqlValue::Date(date)); - assert!(matches!(deser(&typ, &bytes), Value::Integer(2_147_483_648))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), 2_147_483_648_i64); } #[test] @@ -745,10 +717,10 @@ mod tests { let typ = ColumnType::Native(NativeType::Time); let time = CqlTime(3_600_000_000_000i64); // 1 hour in nanoseconds let bytes = cql_to_raw(&typ, CqlValue::Time(time)); - assert!(matches!( - deser(&typ, &bytes), - Value::Integer(3_600_000_000_000) - )); + assert_eq!( + deser(&typ, &bytes).as_signed().unwrap(), + 3_600_000_000_000_i64 + ); } // ── varint ──────────────────────────────────────────────────────────────── @@ -760,7 +732,7 @@ mod tests { &typ, CqlValue::Varint(CqlVarint::from_signed_bytes_be_slice(&[42])), ); - assert!(matches!(deser(&typ, &bytes), Value::Integer(42))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), 42_i64); } #[test] @@ -770,7 +742,7 @@ mod tests { &typ, CqlValue::Varint(CqlVarint::from_signed_bytes_be_slice(&[0xFF])), ); - assert!(matches!(deser(&typ, &bytes), Value::Integer(-1))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), -1_i64); } #[test] @@ -782,7 +754,7 @@ mod tests { &i64::MAX.to_be_bytes(), )), ); - assert!(matches!(deser(&typ, &bytes), Value::Integer(i64::MAX))); + assert_eq!(deser(&typ, &bytes).as_signed().unwrap(), i64::MAX); } #[test] @@ -817,15 +789,13 @@ mod tests { nanoseconds: 3, }; let bytes = cql_to_raw(&typ, CqlValue::Duration(dur)); - match deser(&typ, &bytes) { - Value::Object(shared) => { - let obj = shared.borrow_ref().unwrap(); - assert!(matches!(obj.get("months").unwrap(), Value::Integer(1))); - assert!(matches!(obj.get("days").unwrap(), Value::Integer(2))); - assert!(matches!(obj.get("nanoseconds").unwrap(), Value::Integer(3))); - } - other => panic!("expected Object, got {other:?}"), - } + let binding = deser(&typ, &bytes); + let obj = binding + .borrow_ref::() + .expect("expected Object"); + assert_eq!(obj.get("months").unwrap().as_signed().unwrap(), 1_i64); + assert_eq!(obj.get("days").unwrap().as_signed().unwrap(), 2_i64); + assert_eq!(obj.get("nanoseconds").unwrap().as_signed().unwrap(), 3_i64); } // ── list ────────────────────────────────────────────────────────────────── @@ -841,16 +811,14 @@ mod tests { &typ, CqlValue::List(vec![CqlValue::Int(1), CqlValue::Int(2), CqlValue::Int(3)]), ); - match deser(&typ, &bytes) { - Value::Vec(shared) => { - let vec = shared.borrow_ref().unwrap(); - assert_eq!(vec.len(), 3); - assert!(matches!(vec[0], Value::Integer(1))); - assert!(matches!(vec[1], Value::Integer(2))); - assert!(matches!(vec[2], Value::Integer(3))); - } - other => panic!("expected Vec, got {other:?}"), - } + let binding = deser(&typ, &bytes); + let vec = binding + .borrow_ref::() + .expect("expected Vec"); + assert_eq!(vec.len(), 3); + assert_eq!(vec[0].as_signed().unwrap(), 1_i64); + assert_eq!(vec[1].as_signed().unwrap(), 2_i64); + assert_eq!(vec[2].as_signed().unwrap(), 3_i64); } // ── set ─────────────────────────────────────────────────────────────────── @@ -866,10 +834,11 @@ mod tests { &typ, CqlValue::Set(vec![CqlValue::Text("a".into()), CqlValue::Text("b".into())]), ); - match deser(&typ, &bytes) { - Value::Vec(shared) => assert_eq!(shared.borrow_ref().unwrap().len(), 2), - other => panic!("expected Vec, got {other:?}"), - } + let binding = deser(&typ, &bytes); + let vec = binding + .borrow_ref::() + .expect("expected Vec"); + assert_eq!(vec.len(), 2); } // ── map ─────────────────────────────────────────────────────────────────── @@ -886,21 +855,16 @@ mod tests { &typ, CqlValue::Map(vec![(CqlValue::Text("key".into()), CqlValue::Int(99))]), ); - match deser(&typ, &bytes) { - Value::Vec(shared) => { - let vec = shared.borrow_ref().unwrap(); - assert_eq!(vec.len(), 1); - match &vec[0] { - Value::Tuple(tuple_shared) => { - let tuple = tuple_shared.borrow_ref().unwrap(); - assert_eq!(str_from(&tuple[0]), "key"); - assert!(matches!(tuple[1], Value::Integer(99))); - } - other => panic!("expected Tuple inside map Vec, got {other:?}"), - } - } - other => panic!("expected Vec, got {other:?}"), - } + let binding = deser(&typ, &bytes); + let vec = binding + .borrow_ref::() + .expect("expected Vec"); + assert_eq!(vec.len(), 1); + let tuple = vec[0] + .borrow_ref::() + .expect("expected Tuple inside map Vec"); + assert_eq!(str_from(&tuple[0]), "key"); + assert_eq!(tuple[1].as_signed().unwrap(), 99_i64); } // ── tuple ───────────────────────────────────────────────────────────────── @@ -915,15 +879,13 @@ mod tests { &typ, CqlValue::Tuple(vec![Some(CqlValue::Int(7)), Some(CqlValue::Boolean(true))]), ); - match deser(&typ, &bytes) { - Value::Vec(shared) => { - let vec = shared.borrow_ref().unwrap(); - assert_eq!(vec.len(), 2); - assert!(matches!(vec[0], Value::Integer(7))); - assert!(matches!(vec[1], Value::Bool(true))); - } - other => panic!("expected Vec, got {other:?}"), - } + let binding = deser(&typ, &bytes); + let vec = binding + .borrow_ref::() + .expect("expected Vec"); + assert_eq!(vec.len(), 2); + assert_eq!(vec[0].as_signed().unwrap(), 7_i64); + assert!(vec[1].as_bool().unwrap()); } #[test] @@ -933,15 +895,13 @@ mod tests { ColumnType::Native(NativeType::Int), ]); let bytes = cql_to_raw(&typ, CqlValue::Tuple(vec![Some(CqlValue::Int(1)), None])); - match deser(&typ, &bytes) { - Value::Vec(shared) => { - let vec = shared.borrow_ref().unwrap(); - assert_eq!(vec.len(), 2); - assert!(matches!(vec[0], Value::Integer(1))); - assert_is_none(vec[1].clone()); - } - other => panic!("expected Vec, got {other:?}"), - } + let binding = deser(&typ, &bytes); + let vec = binding + .borrow_ref::() + .expect("expected Vec"); + assert_eq!(vec.len(), 2); + assert_eq!(vec[0].as_signed().unwrap(), 1_i64); + assert_is_none(vec[1].clone()); } // ── vector (CQL vector type) ─────────────────────────────────────────────── @@ -961,15 +921,13 @@ mod tests { CqlValue::Float(3.0), ]), ); - match deser(&typ, &bytes) { - Value::Vec(shared) => { - let vec = shared.borrow_ref().unwrap(); - assert_eq!(vec.len(), 3); - for v in vec.iter() { - assert!(matches!(v, Value::Float(_))); - } - } - other => panic!("expected Vec, got {other:?}"), + let binding = deser(&typ, &bytes); + let vec = binding + .borrow_ref::() + .expect("expected Vec"); + assert_eq!(vec.len(), 3); + for v in vec.iter() { + assert!(v.as_float().is_ok()); } } @@ -999,14 +957,12 @@ mod tests { ], }, ); - match deser(&udt_typ, &bytes) { - Value::Object(shared) => { - let obj = shared.borrow_ref().unwrap(); - assert!(matches!(obj.get("x").unwrap(), Value::Integer(42))); - assert!(matches!(obj.get("y").unwrap(), Value::Bool(false))); - } - other => panic!("expected Object, got {other:?}"), - } + let binding = deser(&udt_typ, &bytes); + let obj = binding + .borrow_ref::() + .expect("expected Object"); + assert_eq!(obj.get("x").unwrap().as_signed().unwrap(), 42_i64); + assert!(!obj.get("y").unwrap().as_bool().unwrap()); } #[test] @@ -1027,13 +983,11 @@ mod tests { fields: vec![("v".into(), None)], }, ); - match deser(&udt_typ, &bytes) { - Value::Object(shared) => { - let obj = shared.borrow_ref().unwrap(); - assert_is_none(obj.get("v").unwrap().clone()); - } - other => panic!("expected Object, got {other:?}"), - } + let binding = deser(&udt_typ, &bytes); + let obj = binding + .borrow_ref::() + .expect("expected Object"); + assert_is_none(obj.get("v").unwrap().clone()); } // ── RuneRow ─────────────────────────────────────────────────────────────── @@ -1060,7 +1014,7 @@ mod tests { let row = RuneRow::deserialize(col_iter).unwrap(); let obj = row.0; - assert!(matches!(obj.get("id").unwrap(), Value::Integer(42))); + assert_eq!(obj.get("id").unwrap().as_signed().unwrap(), 42_i64); assert_eq!(str_from(obj.get("name").unwrap()), "alice"); } diff --git a/src/scripting/cql/functions.rs b/src/scripting/cql/functions.rs index 1a9ba4b..81cb90c 100644 --- a/src/scripting/cql/functions.rs +++ b/src/scripting/cql/functions.rs @@ -2,12 +2,12 @@ use crate::scripting::functions_common::extract_validation_args; use super::cass_error::{CassError, CassErrorKind}; use super::context::Context; -use rune::runtime::{Mut, Ref}; +use rune::runtime::Ref; use rune::Value; use std::ops::Deref; #[rune::function(instance)] -pub async fn prepare(mut ctx: Mut, key: Ref, cql: Ref) -> Result<(), CassError> { +pub async fn prepare(ctx: Ref, key: Ref, cql: Ref) -> Result<(), CassError> { ctx.prepare(&key, &cql).await } @@ -22,46 +22,18 @@ pub async fn execute_with_validation( cql: Ref, validation_args: Vec, ) -> Result { - match validation_args.as_slice() { - // (int): expected_rows - [Value::Integer(expected_rows)] => { - ctx.execute_with_validation( - cql.deref(), - *expected_rows as u64, - *expected_rows as u64, - "", - ) - .await - } - // (int, int): expected_rows_num_min, expected_rows_num_max - [Value::Integer(min), Value::Integer(max)] => { - ctx.execute_with_validation(cql.deref(), *min as u64, *max as u64, "") - .await - } - // (int, str): expected_rows, custom_err_msg - [Value::Integer(expected_rows), Value::String(custom_err_msg)] => { - ctx.execute_with_validation( - cql.deref(), - *expected_rows as u64, - *expected_rows as u64, - &custom_err_msg.borrow_ref().unwrap(), - ) - .await - } - // (int, int, str): expected_rows_num_min, expected_rows_num_max, custom_err_msg - [Value::Integer(min), Value::Integer(max), Value::String(custom_err_msg)] => { - ctx.execute_with_validation( - cql.deref(), - *min as u64, - *max as u64, - &custom_err_msg.borrow_ref().unwrap(), - ) - .await - } - _ => Err(CassError(CassErrorKind::Error( - "Invalid arguments for execute_with_validation".to_string(), - ))), - } + let args = extract_validation_args(validation_args).map_err(|e| { + CassError(CassErrorKind::Error(format!( + "execute_with_validation: {e}" + ))) + })?; + ctx.execute_with_validation( + cql.deref(), + args.expected_min, + args.expected_max, + &args.custom_err_msg, + ) + .await } #[rune::function(instance)] diff --git a/src/scripting/cql/serialize.rs b/src/scripting/cql/serialize.rs index efa732b..cc1e0b3 100644 --- a/src/scripting/cql/serialize.rs +++ b/src/scripting/cql/serialize.rs @@ -5,7 +5,8 @@ use crate::scripting::rune_uuid::Uuid; use chrono::{NaiveDate, NaiveTime}; use once_cell::sync::Lazy; use regex::Regex; -use rune::{Any, ToValue, Value}; +use rune::runtime::{Object, OwnedTuple, Vec as RuneVec}; +use rune::{ToValue, Value}; use scylla::_macro_internal::ColumnType; use scylla::frame::response::result::{CollectionType, ColumnSpec, NativeType}; use scylla::serialize::row::{RowSerializationContext, SerializeRow}; @@ -62,68 +63,55 @@ fn serialize_rune_params( columns: &[ColumnSpec<'_>], writer: &mut RowWriter<'_>, ) -> Result<(), SerializationError> { - match value { - Value::Tuple(tuple) => { - let tuple = tuple.borrow_ref().map_err(|e| { - SerializationError::new(CassError(CassErrorKind::Error(e.to_string()))) - })?; - if tuple.len() != columns.len() { - return Err(SerializationError::new(CassError( - CassErrorKind::InvalidNumberOfQueryParams, - ))); - } - for (v, col) in tuple.iter().zip(columns) { - serialize_rune_cell(v, col.typ(), writer)?; - } - Ok(()) + if let Ok(tuple) = value.borrow_ref::() { + if tuple.len() != columns.len() { + return Err(SerializationError::new(CassError( + CassErrorKind::InvalidNumberOfQueryParams, + ))); } - Value::Vec(vec) => { - let vec = vec.borrow_ref().map_err(|e| { - SerializationError::new(CassError(CassErrorKind::Error(e.to_string()))) - })?; - for (v, col) in vec.iter().zip(columns) { - serialize_rune_cell(v, col.typ(), writer)?; - } - Ok(()) + for (v, col) in tuple.iter().zip(columns) { + serialize_rune_cell(v, col.typ(), writer)?; } - Value::Object(obj) => { - let obj = obj.borrow_ref().map_err(|e| { - SerializationError::new(CassError(CassErrorKind::Error(e.to_string()))) - })?; - for col in columns { - let cql_val = match obj.get(col.name()) { - Some(v) => { - to_scylla_value(v, col.typ()).map_err(|e| SerializationError::new(*e))? - } - None => Some(CqlValue::Empty), - }; - cql_val - .serialize(col.typ(), writer.make_cell_writer()) - .map_err(SerializationError::new)?; - } - Ok(()) + return Ok(()); + } + if let Ok(vec) = value.borrow_ref::() { + for (v, col) in vec.iter().zip(columns) { + serialize_rune_cell(v, col.typ(), writer)?; } - Value::Struct(obj) => { - let obj = obj.borrow_ref().map_err(|e| { - SerializationError::new(CassError(CassErrorKind::Error(e.to_string()))) - })?; - for col in columns { - let cql_val = match obj.get(col.name()) { - Some(v) => { - to_scylla_value(v, col.typ()).map_err(|e| SerializationError::new(*e))? - } - None => Some(CqlValue::Empty), - }; - cql_val - .serialize(col.typ(), writer.make_cell_writer()) - .map_err(SerializationError::new)?; - } - Ok(()) + return Ok(()); + } + if let Ok(obj) = value.borrow_ref::() { + for col in columns { + let cql_val = match obj.get(col.name()) { + Some(v) => { + to_scylla_value(v, col.typ()).map_err(|e| SerializationError::new(*e))? + } + None => Some(CqlValue::Empty), + }; + cql_val + .serialize(col.typ(), writer.make_cell_writer()) + .map_err(SerializationError::new)?; + } + return Ok(()); + } + // Handle struct types (rune typed structs) + if let Ok(rune::runtime::TypeValue::Struct(s)) = value.as_type_value() { + for col in columns { + let cql_val = match s.get(col.name()) { + Some(v) => { + to_scylla_value(v, col.typ()).map_err(|e| SerializationError::new(*e))? + } + None => Some(CqlValue::Empty), + }; + cql_val + .serialize(col.typ(), writer.make_cell_writer()) + .map_err(SerializationError::new)?; } - other => Err(SerializationError::new(CassError( - CassErrorKind::InvalidQueryParamsObject(other.type_info().unwrap()), - ))), + return Ok(()); } + Err(SerializationError::new(CassError( + CassErrorKind::InvalidQueryParamsObject(value.type_info()), + ))) } /// Serializes a single rune value as a CQL cell. @@ -157,468 +145,460 @@ static DURATION_REGEX: Lazy = Lazy::new(|| { }); fn to_scylla_value(v: &Value, typ: &ColumnType) -> Result, Box> { - match (v, typ) { - (Value::Bool(v), ColumnType::Native(NativeType::Boolean)) => { - Ok(Some(CqlValue::Boolean(*v))) - } + // Option (must be checked before inline types) + if let Ok(opt) = v.borrow_ref::>() { + return match opt.as_ref() { + Some(inner) => to_scylla_value(inner, typ), + None => Ok(None), + }; + } - (Value::Byte(v), ColumnType::Native(NativeType::TinyInt)) => { - Ok(Some(CqlValue::TinyInt(*v as i8))) - } - (Value::Byte(v), ColumnType::Native(NativeType::SmallInt)) => { - Ok(Some(CqlValue::SmallInt(*v as i16))) - } - (Value::Byte(v), ColumnType::Native(NativeType::Int)) => Ok(Some(CqlValue::Int(*v as i32))), - (Value::Byte(v), ColumnType::Native(NativeType::BigInt)) => { - Ok(Some(CqlValue::BigInt(*v as i64))) - } + // Bool + if let Ok(b) = v.as_bool() { + return match typ { + ColumnType::Native(NativeType::Boolean) => Ok(Some(CqlValue::Boolean(b))), + _ => type_mismatch(v, typ), + }; + } - (Value::Integer(v), ColumnType::Native(NativeType::TinyInt)) => { - convert_int(*v, NativeType::TinyInt, CqlValue::TinyInt) - } - (Value::Integer(v), ColumnType::Native(NativeType::SmallInt)) => { - convert_int(*v, NativeType::SmallInt, CqlValue::SmallInt) - } - (Value::Integer(v), ColumnType::Native(NativeType::Int)) => { - convert_int(*v, NativeType::Int, CqlValue::Int) - } - (Value::Integer(v), ColumnType::Native(NativeType::BigInt)) => { - Ok(Some(CqlValue::BigInt(*v))) - } - (Value::Integer(v), ColumnType::Native(NativeType::Counter)) => { - Ok(Some(CqlValue::Counter(scylla::value::Counter(*v)))) - } - (Value::Integer(v), ColumnType::Native(NativeType::Timestamp)) => { - Ok(Some(CqlValue::Timestamp(scylla::value::CqlTimestamp(*v)))) - } - (Value::Integer(v), ColumnType::Native(NativeType::Date)) => match (*v).try_into() { - Ok(date) => Ok(Some(CqlValue::Date(CqlDate(date)))), - Err(_) => Err(Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "NativeType::Date".to_string(), - Some("Invalid date value".to_string()), - )))), - }, - (Value::Integer(v), ColumnType::Native(NativeType::Time)) => { - Ok(Some(CqlValue::Time(CqlTime(*v)))) - } - (Value::Integer(v), ColumnType::Native(NativeType::Varint)) => Ok(Some(CqlValue::Varint( - CqlVarint::from_signed_bytes_be((*v).to_be_bytes().to_vec()), - ))), - (Value::Integer(v), ColumnType::Native(NativeType::Decimal)) => { - Ok(Some(CqlValue::Decimal( + // Unsigned integer (u64) — byte literals and other unsigned values + if let Ok(u) = v.as_unsigned() { + return match typ { + ColumnType::Native(NativeType::TinyInt) => Ok(Some(CqlValue::TinyInt(u as i8))), + ColumnType::Native(NativeType::SmallInt) => Ok(Some(CqlValue::SmallInt(u as i16))), + ColumnType::Native(NativeType::Int) => Ok(Some(CqlValue::Int(u as i32))), + ColumnType::Native(NativeType::BigInt) => Ok(Some(CqlValue::BigInt(u as i64))), + _ => type_mismatch(v, typ), + }; + } + + // Signed integer (i64) + if let Ok(i) = v.as_signed() { + return match typ { + ColumnType::Native(NativeType::TinyInt) => { + convert_int(i, NativeType::TinyInt, CqlValue::TinyInt) + } + ColumnType::Native(NativeType::SmallInt) => { + convert_int(i, NativeType::SmallInt, CqlValue::SmallInt) + } + ColumnType::Native(NativeType::Int) => convert_int(i, NativeType::Int, CqlValue::Int), + ColumnType::Native(NativeType::BigInt) => Ok(Some(CqlValue::BigInt(i))), + ColumnType::Native(NativeType::Counter) => { + Ok(Some(CqlValue::Counter(scylla::value::Counter(i)))) + } + ColumnType::Native(NativeType::Timestamp) => { + Ok(Some(CqlValue::Timestamp(scylla::value::CqlTimestamp(i)))) + } + ColumnType::Native(NativeType::Date) => match i.try_into() { + Ok(date) => Ok(Some(CqlValue::Date(CqlDate(date)))), + Err(_) => Err(Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "NativeType::Date".to_string(), + Some("Invalid date value".to_string()), + )))), + }, + ColumnType::Native(NativeType::Time) => Ok(Some(CqlValue::Time(CqlTime(i)))), + ColumnType::Native(NativeType::Varint) => Ok(Some(CqlValue::Varint( + CqlVarint::from_signed_bytes_be(i.to_be_bytes().to_vec()), + ))), + ColumnType::Native(NativeType::Decimal) => Ok(Some(CqlValue::Decimal( scylla::value::CqlDecimal::from_signed_be_bytes_and_exponent( - (*v).to_be_bytes().to_vec(), + i.to_be_bytes().to_vec(), 0, ), - ))) - } + ))), + _ => type_mismatch(v, typ), + }; + } - (Value::Float(v), ColumnType::Native(NativeType::Float)) => { - Ok(Some(CqlValue::Float(*v as f32))) - } - (Value::Float(v), ColumnType::Native(NativeType::Double)) => Ok(Some(CqlValue::Double(*v))), - (Value::Float(v), ColumnType::Native(NativeType::Decimal)) => { - let decimal = rust_decimal::Decimal::from_f64_retain(*v).unwrap(); - Ok(Some(CqlValue::Decimal( - scylla::value::CqlDecimal::from_signed_be_bytes_and_exponent( - decimal.mantissa().to_be_bytes().to_vec(), - decimal.scale().try_into().unwrap(), - ), - ))) - } + // Float (f64) + if let Ok(f) = v.as_float() { + return match typ { + ColumnType::Native(NativeType::Float) => Ok(Some(CqlValue::Float(f as f32))), + ColumnType::Native(NativeType::Double) => Ok(Some(CqlValue::Double(f))), + ColumnType::Native(NativeType::Decimal) => { + let decimal = rust_decimal::Decimal::from_f64_retain(f).unwrap(); + Ok(Some(CqlValue::Decimal( + scylla::value::CqlDecimal::from_signed_be_bytes_and_exponent( + decimal.mantissa().to_be_bytes().to_vec(), + decimal.scale().try_into().unwrap(), + ), + ))) + } + _ => type_mismatch(v, typ), + }; + } - (Value::String(s), ColumnType::Native(NativeType::Date)) => { - let date_str = s.borrow_ref().unwrap(); - let naive_date = NaiveDate::parse_from_str(&date_str, "%Y-%m-%d").map_err(|e| { - CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "NativeType::Date".to_string(), - Some(format!("{e}")), - )) - })?; - let cql_date = CqlDate::from(naive_date); - Ok(Some(CqlValue::Date(cql_date))) - } - (Value::String(s), ColumnType::Native(NativeType::Time)) => { - let time_str = s.borrow_ref().unwrap(); - let mut time_format = "%H:%M:%S".to_string(); - if time_str.contains('.') { - time_format = format!("{time_format}.%f"); + // String + if let Ok(s) = v.borrow_ref::() { + return match typ { + ColumnType::Native(NativeType::Date) => { + let naive_date = + NaiveDate::parse_from_str(s.as_str(), "%Y-%m-%d").map_err(|e| { + CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "NativeType::Date".to_string(), + Some(format!("{e}")), + )) + })?; + Ok(Some(CqlValue::Date(CqlDate::from(naive_date)))) } - let naive_time = NaiveTime::parse_from_str(&time_str, &time_format).map_err(|e| { - Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "NativeType::Time".to_string(), - Some(format!("{e}")), - ))) - })?; - let cql_time = CqlTime::try_from(naive_time)?; - Ok(Some(CqlValue::Time(cql_time))) - } - (Value::String(s), ColumnType::Native(NativeType::Duration)) => { - // TODO: add support for the following 'ISO 8601' format variants: - // - ISO 8601 format: P[n]Y[n]M[n]DT[n]H[n]M[n]S or P[n]W - // - ISO 8601 alternative format: P[YYYY]-[MM]-[DD]T[hh]:[mm]:[ss] - // See: https://opensource.docs.scylladb.com/stable/cql/types.html#working-with-durations - let duration_str = s.borrow_ref().unwrap(); - if duration_str.is_empty() { - return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "NativeType::Duration".to_string(), - Some("Duration cannot be empty".to_string()), - )))); + ColumnType::Native(NativeType::Time) => { + let mut time_format = "%H:%M:%S".to_string(); + if s.as_str().contains('.') { + time_format = format!("{time_format}.%f"); + } + let naive_time = + NaiveTime::parse_from_str(s.as_str(), &time_format).map_err(|e| { + Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "NativeType::Time".to_string(), + Some(format!("{e}")), + ))) + })?; + Ok(Some(CqlValue::Time(CqlTime::try_from(naive_time)?))) } - // NOTE: we parse the duration explicitly because of the 'CqlDuration' type specifics. - // It stores only months, days and nanoseconds. - // So, we do not translate days to months and hours to days because those are ambiguous - let (mut months, mut days, mut nanoseconds) = (0, 0, 0); - let mut matches_counter = HashMap::from([ - ("y", 0), - ("mo", 0), - ("w", 0), - ("d", 0), - ("h", 0), - ("m", 0), - ("s", 0), - ("ms", 0), - ("us", 0), - ("ns", 0), - ]); - for cap in DURATION_REGEX.captures_iter(&duration_str) { - if let Some(m) = cap.name("years") { - months += m.as_str().parse::().unwrap() * 12; - *matches_counter.entry("y").or_insert(1) += 1; - } else if let Some(m) = cap.name("months") { - months += m.as_str().parse::().unwrap(); - *matches_counter.entry("mo").or_insert(1) += 1; - } else if let Some(m) = cap.name("weeks") { - days += m.as_str().parse::().unwrap() * 7; - *matches_counter.entry("w").or_insert(1) += 1; - } else if let Some(m) = cap.name("days") { - days += m.as_str().parse::().unwrap(); - *matches_counter.entry("d").or_insert(1) += 1; - } else if let Some(m) = cap.name("hours") { - nanoseconds += m.as_str().parse::().unwrap() * 3_600_000_000_000; - *matches_counter.entry("h").or_insert(1) += 1; - } else if let Some(m) = cap.name("minutes") { - nanoseconds += m.as_str().parse::().unwrap() * 60_000_000_000; - *matches_counter.entry("m").or_insert(1) += 1; - } else if let Some(m) = cap.name("seconds") { - nanoseconds += m.as_str().parse::().unwrap() * 1_000_000_000; - *matches_counter.entry("s").or_insert(1) += 1; - } else if let Some(m) = cap.name("millis") { - nanoseconds += m.as_str().parse::().unwrap() * 1_000_000; - *matches_counter.entry("ms").or_insert(1) += 1; - } else if let Some(m) = cap.name("micros") { - nanoseconds += m.as_str().parse::().unwrap() * 1_000; - *matches_counter.entry("us").or_insert(1) += 1; - } else if let Some(m) = cap.name("nanoseconds") { - nanoseconds += m.as_str().parse::().unwrap(); - *matches_counter.entry("ns").or_insert(1) += 1; - } else if cap.name("invalid").is_some() { + ColumnType::Native(NativeType::Duration) => { + // TODO: add support for the following 'ISO 8601' format variants: + // - ISO 8601 format: P[n]Y[n]M[n]DT[n]H[n]M[n]S or P[n]W + // - ISO 8601 alternative format: P[YYYY]-[MM]-[DD]T[hh]:[mm]:[ss] + // See: https://opensource.docs.scylladb.com/stable/cql/types.html#working-with-durations + let duration_str = s.as_str(); + if duration_str.is_empty() { return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( format!("{v:?}"), "NativeType::Duration".to_string(), - Some("Got invalid duration value".to_string()), + Some("Duration cannot be empty".to_string()), )))); } + // NOTE: we parse the duration explicitly because of the 'CqlDuration' type specifics. + // It stores only months, days and nanoseconds. + // So, we do not translate days to months and hours to days because those are ambiguous + let (mut months, mut days, mut nanoseconds) = (0, 0, 0); + let mut matches_counter = HashMap::from([ + ("y", 0), + ("mo", 0), + ("w", 0), + ("d", 0), + ("h", 0), + ("m", 0), + ("s", 0), + ("ms", 0), + ("us", 0), + ("ns", 0), + ]); + for cap in DURATION_REGEX.captures_iter(duration_str) { + if let Some(m) = cap.name("years") { + months += m.as_str().parse::().unwrap() * 12; + *matches_counter.entry("y").or_insert(1) += 1; + } else if let Some(m) = cap.name("months") { + months += m.as_str().parse::().unwrap(); + *matches_counter.entry("mo").or_insert(1) += 1; + } else if let Some(m) = cap.name("weeks") { + days += m.as_str().parse::().unwrap() * 7; + *matches_counter.entry("w").or_insert(1) += 1; + } else if let Some(m) = cap.name("days") { + days += m.as_str().parse::().unwrap(); + *matches_counter.entry("d").or_insert(1) += 1; + } else if let Some(m) = cap.name("hours") { + nanoseconds += m.as_str().parse::().unwrap() * 3_600_000_000_000; + *matches_counter.entry("h").or_insert(1) += 1; + } else if let Some(m) = cap.name("minutes") { + nanoseconds += m.as_str().parse::().unwrap() * 60_000_000_000; + *matches_counter.entry("m").or_insert(1) += 1; + } else if let Some(m) = cap.name("seconds") { + nanoseconds += m.as_str().parse::().unwrap() * 1_000_000_000; + *matches_counter.entry("s").or_insert(1) += 1; + } else if let Some(m) = cap.name("millis") { + nanoseconds += m.as_str().parse::().unwrap() * 1_000_000; + *matches_counter.entry("ms").or_insert(1) += 1; + } else if let Some(m) = cap.name("micros") { + nanoseconds += m.as_str().parse::().unwrap() * 1_000; + *matches_counter.entry("us").or_insert(1) += 1; + } else if let Some(m) = cap.name("nanoseconds") { + nanoseconds += m.as_str().parse::().unwrap(); + *matches_counter.entry("ns").or_insert(1) += 1; + } else if cap.name("invalid").is_some() { + return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "NativeType::Duration".to_string(), + Some("Got invalid duration value".to_string()), + )))); + } + } + if matches_counter.values().all(|&v| v == 0) { + return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "NativeType::Duration".to_string(), + Some("None time units were found".to_string()), + )))); + } + let duplicated_units: Vec<&str> = matches_counter + .iter() + .filter(|&(_, &count)| count > 1) + .map(|(&unit, _)| unit) + .collect(); + if !duplicated_units.is_empty() { + return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "NativeType::Duration".to_string(), + Some(format!( + "Got multiple matches for time unit(s): {}", + duplicated_units.join(", ") + )), + )))); + } + Ok(Some(CqlValue::Duration(CqlDuration { + months, + days, + nanoseconds, + }))) } - if matches_counter.values().all(|&v| v == 0) { - return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "NativeType::Duration".to_string(), - Some("None time units were found".to_string()), - )))); - } - let duplicated_units: Vec<&str> = matches_counter - .iter() - .filter(|&(_, &count)| count > 1) - .map(|(&unit, _)| unit) - .collect(); - if !duplicated_units.is_empty() { - return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "NativeType::Duration".to_string(), - Some(format!( - "Got multiple matches for time unit(s): {}", - duplicated_units.join(", ") - )), - )))); - } - let cql_duration = CqlDuration { - months, - days, - nanoseconds, - }; - Ok(Some(CqlValue::Duration(cql_duration))) - } - - (Value::String(s), ColumnType::Native(NativeType::Varint)) => { - let varint_str = s.borrow_ref().unwrap(); - if !varint_str.chars().all(|c| c.is_ascii_digit()) { - return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "NativeType::Varint".to_string(), - Some("Input contains non-digit characters".to_string()), - )))); + ColumnType::Native(NativeType::Varint) => { + if !s.as_str().chars().all(|c| c.is_ascii_digit()) { + return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "NativeType::Varint".to_string(), + Some("Input contains non-digit characters".to_string()), + )))); + } + let byte_vector: Vec = s + .as_str() + .chars() + .map(|c| c.to_digit(10).expect("Invalid digit") as u8) + .collect(); + Ok(Some(CqlValue::Varint( + scylla::value::CqlVarint::from_signed_bytes_be(byte_vector), + ))) } - let byte_vector: Vec = varint_str - .chars() - .map(|c| c.to_digit(10).expect("Invalid digit") as u8) - .collect(); - Ok(Some(CqlValue::Varint( - scylla::value::CqlVarint::from_signed_bytes_be(byte_vector), - ))) - } - (Value::String(s), ColumnType::Native(NativeType::Timeuuid)) => { - let timeuuid_str = s.borrow_ref().unwrap(); - let timeuuid = CqlTimeuuid::from_str(timeuuid_str.as_str()); - match timeuuid { + ColumnType::Native(NativeType::Timeuuid) => match CqlTimeuuid::from_str(s.as_str()) { Ok(timeuuid) => Ok(Some(CqlValue::Timeuuid(timeuuid))), Err(e) => Err(Box::new(CassError(CassErrorKind::QueryParamConversion( format!("{v:?}"), "NativeType::Timeuuid".to_string(), Some(format!("{e}")), )))), + }, + ColumnType::Native(NativeType::Text) | ColumnType::Native(NativeType::Ascii) => { + Ok(Some(CqlValue::Text(s.as_str().to_string()))) } - } - ( - Value::String(v), - ColumnType::Native(NativeType::Text) | ColumnType::Native(NativeType::Ascii), - ) => Ok(Some(CqlValue::Text( - v.borrow_ref().unwrap().as_str().to_string(), - ))), - (Value::String(s), ColumnType::Native(NativeType::Inet)) => { - let ipaddr_str = s.borrow_ref().unwrap(); - let ipaddr = IpAddr::from_str(ipaddr_str.as_str()); - match ipaddr { + ColumnType::Native(NativeType::Inet) => match IpAddr::from_str(s.as_str()) { Ok(ipaddr) => Ok(Some(CqlValue::Inet(ipaddr))), Err(e) => Err(Box::new(CassError(CassErrorKind::QueryParamConversion( format!("{v:?}"), "NativeType::Inet".to_string(), Some(format!("{e}")), )))), + }, + ColumnType::Native(NativeType::Decimal) => { + let decimal = rust_decimal::Decimal::from_str_exact(s.as_str()).unwrap(); + Ok(Some(CqlValue::Decimal( + scylla::value::CqlDecimal::from_signed_be_bytes_and_exponent( + decimal.mantissa().to_be_bytes().to_vec(), + decimal.scale().try_into().unwrap(), + ), + ))) } - } - (Value::String(s), ColumnType::Native(NativeType::Decimal)) => { - let dec_str = s.borrow_ref().unwrap(); - let decimal = rust_decimal::Decimal::from_str_exact(&dec_str).unwrap(); - Ok(Some(CqlValue::Decimal( - scylla::value::CqlDecimal::from_signed_be_bytes_and_exponent( - decimal.mantissa().to_be_bytes().to_vec(), - decimal.scale().try_into().unwrap(), - ), - ))) - } - (Value::Bytes(v), ColumnType::Native(NativeType::Blob)) => { - Ok(Some(CqlValue::Blob(v.borrow_ref().unwrap().to_vec()))) - } - (Value::Vec(v), ColumnType::Native(NativeType::Blob)) => { - let v: Vec = v.borrow_ref().unwrap().to_vec(); - let byte_vec: Vec = v - .into_iter() - .map(|value| value.as_byte().unwrap()) - .collect(); - Ok(Some(CqlValue::Blob(byte_vec))) - } - (Value::Option(v), typ) => match v.borrow_ref().unwrap().as_ref() { - Some(v) => to_scylla_value(v, typ), - None => Ok(None), - }, - (Value::Tuple(v), ColumnType::Tuple(tuple)) => { - let v = v.borrow_ref().unwrap(); - let mut elements = Vec::with_capacity(v.len()); - for (i, current_element) in v.iter().enumerate() { - let element = to_scylla_value(current_element, &tuple[i])?; - elements.push(element); + _ => type_mismatch(v, typ), + }; + } + + // Bytes (rune::runtime::Bytes) + if let Ok(b) = v.borrow_ref::() { + return match typ { + ColumnType::Native(NativeType::Blob) => Ok(Some(CqlValue::Blob(b.to_vec()))), + _ => type_mismatch(v, typ), + }; + } + + // Vec (rune::runtime::Vec) + if let Ok(vec) = v.borrow_ref::() { + return match typ { + ColumnType::Native(NativeType::Blob) => { + let byte_vec: Vec = + vec.iter().map(|v| v.as_unsigned().unwrap() as u8).collect(); + Ok(Some(CqlValue::Blob(byte_vec))) } - Ok(Some(CqlValue::Tuple(elements))) - } - (Value::Vec(v), ColumnType::Tuple(tuple)) => { - let v = v.borrow_ref().unwrap(); - let mut elements = Vec::with_capacity(v.len()); - for (i, current_element) in v.iter().enumerate() { - let element = to_scylla_value(current_element, &tuple[i])?; - elements.push(element); + ColumnType::Tuple(tuple) => { + let mut elements = Vec::with_capacity(vec.len()); + for (i, current_element) in vec.iter().enumerate() { + elements.push(to_scylla_value(current_element, &tuple[i])?); + } + Ok(Some(CqlValue::Tuple(elements))) } - Ok(Some(CqlValue::Tuple(elements))) - } - (Value::Vec(v), ColumnType::Vector { typ, .. }) => { - let v = v.borrow_ref().unwrap(); - let elements = v - .as_ref() - .iter() - .map(|v| { - to_scylla_value(v, typ).and_then(|opt| { - opt.ok_or_else(|| { - Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "ColumnType::Vector".to_string(), - None, - ))) + ColumnType::Vector { typ: elem_typ, .. } => { + let elements = vec + .iter() + .map(|v| { + to_scylla_value(v, elem_typ).and_then(|opt| { + opt.ok_or_else(|| { + Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "ColumnType::Vector".to_string(), + None, + ))) + }) }) }) - }) - .try_collect()?; - Ok(Some(CqlValue::Vector(elements))) - } - ( - Value::Vec(v), + .try_collect()?; + Ok(Some(CqlValue::Vector(elements))) + } ColumnType::Collection { - frozen: _, typ: CollectionType::List(elt), - }, - ) => { - let v = v.borrow_ref().unwrap(); - let elements = v - .as_ref() - .iter() - .map(|v| { - to_scylla_value(v, elt).and_then(|opt| { - opt.ok_or_else(|| { - Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "CollectionType::List".to_string(), - None, - ))) + .. + } => { + let elements = vec + .iter() + .map(|v| { + to_scylla_value(v, elt).and_then(|opt| { + opt.ok_or_else(|| { + Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "CollectionType::List".to_string(), + None, + ))) + }) }) }) - }) - .try_collect()?; - Ok(Some(CqlValue::List(elements))) - } - ( - Value::Vec(v), + .try_collect()?; + Ok(Some(CqlValue::List(elements))) + } ColumnType::Collection { - frozen: _, typ: CollectionType::Set(elt), - }, - ) => { - let v = v.borrow_ref().unwrap(); - let elements = v - .as_ref() - .iter() - .map(|v| { - to_scylla_value(v, elt).and_then(|opt| { - opt.ok_or_else(|| { - Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "CollectionType::Set".to_string(), - None, - ))) + .. + } => { + let elements = vec + .iter() + .map(|v| { + to_scylla_value(v, elt).and_then(|opt| { + opt.ok_or_else(|| { + Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + "CollectionType::Set".to_string(), + None, + ))) + }) }) }) - }) - .try_collect()?; - Ok(Some(CqlValue::Set(elements))) - } - ( - Value::Vec(v), + .try_collect()?; + Ok(Some(CqlValue::Set(elements))) + } ColumnType::Collection { - frozen: _, typ: CollectionType::Map(key_elt, value_elt), - }, - ) => { - let v = v.borrow_ref().unwrap(); - let mut map_vec = Vec::with_capacity(v.len()); - for tuple in v.iter() { - match tuple { - Value::Tuple(tuple) if tuple.borrow_ref().unwrap().len() == 2 => { - let tuple = tuple.borrow_ref().unwrap(); - let key = to_scylla_value(tuple.first().unwrap(), key_elt)?.unwrap(); - let value = to_scylla_value(tuple.get(1).unwrap(), value_elt)?.unwrap(); - map_vec.push((key, value)); - } - _ => { + .. + } => { + let mut map_vec = Vec::with_capacity(vec.len()); + for item in vec.iter() { + if let Ok(tuple) = item.borrow_ref::() { + if tuple.len() == 2 { + let key = to_scylla_value(tuple.first().unwrap(), key_elt)?.unwrap(); + let value = to_scylla_value(tuple.get(1).unwrap(), value_elt)?.unwrap(); + map_vec.push((key, value)); + } else { + return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{item:?}"), + "CollectionType::Map".to_string(), + None, + )))); + } + } else { return Err(Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{tuple:?}"), + format!("{item:?}"), "CollectionType::Map".to_string(), None, )))); } } + Ok(Some(CqlValue::Map(map_vec))) } - Ok(Some(CqlValue::Map(map_vec))) - } - ( - Value::Object(obj), + _ => type_mismatch(v, typ), + }; + } + + // OwnedTuple + if let Ok(tuple) = v.borrow_ref::() { + return match typ { + ColumnType::Tuple(types) => { + let mut elements = Vec::with_capacity(tuple.len()); + for (i, current_element) in tuple.iter().enumerate() { + elements.push(to_scylla_value(current_element, &types[i])?); + } + Ok(Some(CqlValue::Tuple(elements))) + } + _ => type_mismatch(v, typ), + }; + } + + // Object + if let Ok(obj) = v.borrow_ref::() { + return match typ { ColumnType::Collection { - frozen: _, typ: CollectionType::Map(key_elt, value_elt), - }, - ) => { - let obj = obj.borrow_ref().unwrap(); - let mut map_vec = Vec::with_capacity(obj.keys().len()); - for (k, v) in obj.iter() { - let key = String::from(k.as_str()); - let key = to_scylla_value(&(key.to_value().unwrap()), key_elt)?.unwrap(); - let value = to_scylla_value(v, value_elt)?.unwrap(); - map_vec.push((key, value)); + .. + } => { + let mut map_vec = Vec::with_capacity(obj.keys().len()); + for (k, val) in obj.iter() { + let key = String::from(k.as_str()); + let key = to_scylla_value(&key.to_value().unwrap(), key_elt)?.unwrap(); + let value = to_scylla_value(val, value_elt)?.unwrap(); + map_vec.push((key, value)); + } + Ok(Some(CqlValue::Map(map_vec))) } - Ok(Some(CqlValue::Map(map_vec))) - } - ( - Value::Object(v), - ColumnType::UserDefinedType { - frozen: _, - definition, - }, - ) => { - let obj = v.borrow_ref().unwrap(); - let field_types: Vec<(String, ColumnType)> = definition - .field_types - .iter() - .map(|(name, typ)| (name.to_string(), typ.clone())) - .collect(); - let fields = read_fields(|s| obj.get(s), &field_types)?; - Ok(Some(CqlValue::UserDefinedType { - name: definition.name.to_string(), - keyspace: definition.keyspace.to_string(), - fields, - })) - } - ( - Value::Struct(v), - ColumnType::UserDefinedType { - frozen: _, - definition, - }, - ) => { - let obj = v.borrow_ref().unwrap(); - let field_types: Vec<(String, ColumnType)> = definition - .field_types - .iter() - .map(|(name, typ)| (name.to_string(), typ.clone())) - .collect(); - let fields = read_fields(|s| obj.get(s), &field_types)?; - Ok(Some(CqlValue::UserDefinedType { - name: definition.name.to_string(), - keyspace: definition.keyspace.to_string(), - fields, - })) - } + ColumnType::UserDefinedType { definition, .. } => { + let field_types: Vec<(String, ColumnType)> = definition + .field_types + .iter() + .map(|(name, typ)| (name.to_string(), typ.clone())) + .collect(); + let fields = read_fields(|s| obj.get(s), &field_types)?; + Ok(Some(CqlValue::UserDefinedType { + name: definition.name.to_string(), + keyspace: definition.keyspace.to_string(), + fields, + })) + } + _ => type_mismatch(v, typ), + }; + } - (Value::Any(obj), ColumnType::Native(NativeType::Uuid)) => { - let obj = obj.borrow_ref().unwrap(); - let h = obj.type_hash(); - if h == Uuid::type_hash() { - let uuid: &Uuid = obj.downcast_borrow_ref().unwrap(); - Ok(Some(CqlValue::Uuid(uuid.0))) - } else { - Err(Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{v:?}"), - "NativeType::Uuid".to_string(), - None, - )))) + // Struct (rune typed struct) + if let Ok(rune::runtime::TypeValue::Struct(s)) = v.as_type_value() { + return match typ { + ColumnType::UserDefinedType { definition, .. } => { + let field_types: Vec<(String, ColumnType)> = definition + .field_types + .iter() + .map(|(name, typ)| (name.to_string(), typ.clone())) + .collect(); + let fields = read_fields(|name| s.get(name), &field_types)?; + Ok(Some(CqlValue::UserDefinedType { + name: definition.name.to_string(), + keyspace: definition.keyspace.to_string(), + fields, + })) } - } - (value, typ) => Err(Box::new(CassError(CassErrorKind::QueryParamConversion( - format!("{value:?}"), - format!("{typ:?}").to_string(), - None, - )))), + _ => type_mismatch(v, typ), + }; + } + + // UUID (custom Rune Any type) + if let Ok(uuid) = v.borrow_ref::() { + return match typ { + ColumnType::Native(NativeType::Uuid) => Ok(Some(CqlValue::Uuid(uuid.0))), + _ => type_mismatch(v, typ), + }; } + + type_mismatch(v, typ) +} + +fn type_mismatch(v: &Value, typ: &ColumnType) -> Result, Box> { + Err(Box::new(CassError(CassErrorKind::QueryParamConversion( + format!("{v:?}"), + format!("{typ:?}").to_string(), + None, + )))) } fn convert_int, R>( @@ -655,7 +635,7 @@ mod tests { use rstest::rstest; use rune::alloc::String as RuneString; - use rune::runtime::{Object, Shared, Vec as RuneVec}; + use rune::runtime::{Object, Vec as RuneVec}; use scylla::frame::response::result::TableSpec; use scylla::serialize::row::RowSerializationContext; use scylla::serialize::writers::RowWriter; @@ -665,19 +645,19 @@ mod tests { // ── helpers ────────────────────────────────────────────────────── fn rune_string(s: &str) -> Value { - Value::String(Shared::new(RuneString::try_from(s).unwrap()).unwrap()) + Value::new(RuneString::try_from(s).unwrap()).unwrap() } fn rune_int(i: i64) -> Value { - Value::Integer(i) + Value::from(i) } fn rune_float(f: f64) -> Value { - Value::Float(f) + Value::from(f) } fn rune_bool(b: bool) -> Value { - Value::Bool(b) + Value::from(b) } fn rune_vec(items: Vec) -> Value { @@ -685,7 +665,7 @@ mod tests { for item in items { v.push(item).unwrap(); } - Value::Vec(Shared::new(v).unwrap()) + Value::vec(v.into_inner()).unwrap() } fn rune_tuple(items: Vec) -> Value { @@ -693,7 +673,7 @@ mod tests { for item in items { v.try_push(item).unwrap(); } - Value::Tuple(Shared::new(rune::runtime::OwnedTuple::try_from(v).unwrap()).unwrap()) + Value::new(rune::runtime::OwnedTuple::try_from(v).unwrap()).unwrap() } fn rune_object(pairs: Vec<(&str, Value)>) -> Value { @@ -701,7 +681,7 @@ mod tests { for (k, v) in pairs { obj.insert(RuneString::try_from(k).unwrap(), v).unwrap(); } - Value::Object(Shared::new(obj).unwrap()) + Value::new(obj).unwrap() } fn col_spec<'a>(name: &'a str, typ: ColumnType<'a>) -> ColumnSpec<'a> { @@ -894,14 +874,14 @@ mod tests { #[test] fn test_to_scylla_value_option_none() { - let val = Value::Option(Shared::new(None).unwrap()); + let val = Value::try_from(None::).unwrap(); let result = to_scylla_value(&val, &ColumnType::Native(NativeType::Int)); assert_eq!(result.unwrap(), None); } #[test] fn test_to_scylla_value_option_some() { - let val = Value::Option(Shared::new(Some(rune_int(42))).unwrap()); + let val = Value::try_from(Some(rune_int(42))).unwrap(); let result = to_scylla_value(&val, &ColumnType::Native(NativeType::Int)); assert_eq!(result.unwrap(), Some(CqlValue::Int(42))); } @@ -912,6 +892,15 @@ mod tests { assert!(result.is_err()); } + // ── to_scylla_value tests (blob) ──────────────────────────────── + + #[test] + fn test_to_scylla_value_blob_from_rune_bytes() { + let bytes_val = rune::to_value(vec![1u8, 2u8, 3u8]).unwrap(); + let result = to_scylla_value(&bytes_val, &ColumnType::Native(NativeType::Blob)).unwrap(); + assert_eq!(result, Some(CqlValue::Blob(vec![1u8, 2u8, 3u8]))); + } + // ── to_scylla_value tests (collections) ───────────────────────── #[test] @@ -1259,10 +1248,7 @@ mod tests { #[case] ns: i64, ) { let expected = format!("{mo:?}mo{d:?}d{ns:?}ns"); - let duration_rune_str = Value::String( - Shared::new(RuneString::try_from(input).expect("Failed to create RuneString")) - .expect("Failed to create Shared RuneString"), - ); + let duration_rune_str = rune_string(input.as_str()); let actual = to_scylla_value( &duration_rune_str, &ColumnType::Native(NativeType::Duration), @@ -1281,10 +1267,7 @@ mod tests { #[case("fake")] #[case("1d2h3m4h")] fn test_to_scylla_value_duration_neg(#[case] input: String) { - let duration_rune_str = Value::String( - Shared::new(RuneString::try_from(input.clone()).expect("Failed to create RuneString")) - .expect("Failed to create Shared RuneString"), - ); + let duration_rune_str = rune_string(input.as_str()); let actual = to_scylla_value( &duration_rune_str, &ColumnType::Native(NativeType::Duration), diff --git a/src/scripting/functions_common.rs b/src/scripting/functions_common.rs index aa74cf8..40c2034 100644 --- a/src/scripting/functions_common.rs +++ b/src/scripting/functions_common.rs @@ -55,31 +55,43 @@ pub struct ValidationArgs { /// * [Integer, String] -> Exact number of expected rows and custom error message. /// * [Integer, Integer, String] -> Range of expected rows and custom error message. pub fn extract_validation_args(validation_args: Vec) -> Result { + let as_int = |v: &Value| v.as_signed().ok().map(|i| i as u64); + let as_str = |v: &Value| { + v.borrow_ref::() + .ok() + .map(|s| s.as_str().to_string()) + }; match validation_args.as_slice() { // (int): expected_rows - [Value::Integer(expected_rows)] => Ok(ValidationArgs { - expected_min: *expected_rows as u64, - expected_max: *expected_rows as u64, - custom_err_msg: String::new(), - }), + [a] if as_int(a).is_some() => { + let n = as_int(a).unwrap(); + Ok(ValidationArgs { + expected_min: n, + expected_max: n, + custom_err_msg: String::new(), + }) + } // (int, int): expected_rows_num_min, expected_rows_num_max - [Value::Integer(min), Value::Integer(max)] => Ok(ValidationArgs { - expected_min: *min as u64, - expected_max: *max as u64, + [a, b] if as_int(a).is_some() && as_int(b).is_some() => Ok(ValidationArgs { + expected_min: as_int(a).unwrap(), + expected_max: as_int(b).unwrap(), custom_err_msg: String::new(), }), // (int, str): expected_rows, custom_err_msg - [Value::Integer(expected_rows), Value::String(custom_err_msg)] => Ok(ValidationArgs { - expected_min: *expected_rows as u64, - expected_max: *expected_rows as u64, - custom_err_msg: custom_err_msg.borrow_ref().unwrap().to_string(), - }), + [a, b] if as_int(a).is_some() && as_str(b).is_some() => { + let n = as_int(a).unwrap(); + Ok(ValidationArgs { + expected_min: n, + expected_max: n, + custom_err_msg: as_str(b).unwrap(), + }) + } // (int, int, str): expected_rows_num_min, expected_rows_num_max, custom_err_msg - [Value::Integer(min), Value::Integer(max), Value::String(custom_err_msg)] => { + [a, b, c] if as_int(a).is_some() && as_int(b).is_some() && as_str(c).is_some() => { Ok(ValidationArgs { - expected_min: *min as u64, - expected_max: *max as u64, - custom_err_msg: custom_err_msg.borrow_ref().unwrap().to_string(), + expected_min: as_int(a).unwrap(), + expected_max: as_int(b).unwrap(), + custom_err_msg: as_str(c).unwrap(), }) } _ => Err("Invalid validation arguments".to_string()), @@ -211,7 +223,7 @@ pub fn join(collection: &[Value], separator: &str) -> VmResult { if !first { result.push_str(separator); } - result.push_str(vm_try!(v.borrow_ref()).as_str()); + result.push_str(v.as_str()); first = false; } VmResult::Ok(result) @@ -225,10 +237,8 @@ pub fn is_none(input: Value) -> bool { // With this function it is possible to check for None the following way: // let result = if is_none(row.some_col) { "None" } else { row.some_col }; // println!("DEBUG: value for some_col is '{result}'", result=result); - if let Value::Option(option) = input { - if let Ok(borrowed) = option.borrow_ref() { - return borrowed.is_none(); - } + if let Ok(opt) = input.borrow_ref::>() { + return opt.is_none(); } false } diff --git a/src/scripting/row_distribution.rs b/src/scripting/row_distribution.rs index 80f77f0..c97bb01 100644 --- a/src/scripting/row_distribution.rs +++ b/src/scripting/row_distribution.rs @@ -1,4 +1,4 @@ -use rune::runtime::{Mut, Ref}; +use rune::runtime::Ref; use rune::Any; use std::collections::HashMap; @@ -169,14 +169,14 @@ impl RowDistributionPreset { #[rune::function(instance)] pub async fn init_partition_row_distribution_preset( - mut ctx: Mut, + ctx: Ref, preset_name: Ref, row_count: u64, rows_per_partitions_base: u64, rows_per_partitions_groups: Ref, ) -> Result<(), DbError> { _init_partition_row_distribution_preset( - &mut ctx, + &ctx, &preset_name, row_count, rows_per_partitions_base, @@ -214,7 +214,7 @@ pub async fn get_partition_idx(ctx: Ref, preset_name: Ref, idx: u6 /// Creates a preset for uneven row distribution among partitions #[allow(clippy::comparison_chain)] async fn _init_partition_row_distribution_preset( - ctx: &mut Context, + ctx: &Context, preset_name: &str, row_count: u64, rows_per_partitions_base: u64, @@ -412,6 +412,8 @@ async fn _init_partition_row_distribution_preset( // NOTE: generate row distributions only after the partition groups are finished with changes row_distribution_preset.generate_row_distributions(); ctx.partition_row_presets + .try_lock() + .unwrap() .insert(preset_name.to_string(), row_distribution_preset); Ok(()) @@ -423,11 +425,17 @@ async fn _get_partition_info( preset_name: &str, idx: u64, ) -> Result<(u64, u64), DbError> { - let preset = ctx.partition_row_presets.get(preset_name).ok_or_else(|| { - DbError::new(DbErrorKind::PartitionRowPresetNotFound( - preset_name.to_string(), - )) - })?; + let preset = ctx + .partition_row_presets + .try_lock() + .unwrap() + .get(preset_name) + .cloned() + .ok_or_else(|| { + DbError::new(DbErrorKind::PartitionRowPresetNotFound( + preset_name.to_string(), + )) + })?; Ok(preset.get_partition_info(idx).await) } @@ -526,20 +534,23 @@ mod tests { expected_idx_partition_idx_mapping: Vec<(u64, u64)>, ) { for (rows_per_partitions_base, rows_per_partitions_groups) in rows_per_partitions_base_and_groups_mapping { - let mut ctxt: Context = create_test_context(); + let ctxt: Context = create_test_context(); let preset_name = "foo_name"; - assert!(ctxt.partition_row_presets.is_empty(), "The 'partition_row_presets' HashMap should not be empty"); + assert!(ctxt.partition_row_presets.try_lock().unwrap().is_empty(), "The 'partition_row_presets' HashMap should not be empty"); tokio::runtime::Runtime::new().unwrap().block_on(async { - let _ = _init_partition_row_distribution_preset(&mut ctxt, + let _ = _init_partition_row_distribution_preset(&ctxt, preset_name, row_count, rows_per_partitions_base, &rows_per_partitions_groups).await; }); - assert!(!ctxt.partition_row_presets.is_empty(), "The 'partition_row_presets' HashMap should not be empty"); - let actual_preset = ctxt.partition_row_presets.get(preset_name) - .unwrap_or_else(|| panic!("Preset with name '{preset_name}' was not found")); - assert_eq!(expected_partition_groups, actual_preset.partition_groups); + assert!(!ctxt.partition_row_presets.try_lock().unwrap().is_empty(), "The 'partition_row_presets' HashMap should not be empty"); + { + let binding = ctxt.partition_row_presets.try_lock().unwrap(); + let actual_preset = binding.get(preset_name) + .unwrap_or_else(|| panic!("Preset with name '{preset_name}' was not found")); + assert_eq!(expected_partition_groups, actual_preset.partition_groups); + } for (idx, expected_partition_idx) in expected_idx_partition_idx_mapping.clone() { let (p_idx, _p_size) = tokio::runtime::Runtime::new().unwrap().block_on(async { @@ -703,28 +714,28 @@ mod tests { fn test_partition_row_distribution_preset_05_pos_multiple_presets() { let name_foo: String = "foo".to_string(); let name_bar: String = "bar".to_string(); - let mut ctxt: Context = create_test_context(); + let ctxt: Context = create_test_context(); - assert!(ctxt.partition_row_presets.is_empty(), "The 'partition_row_presets' HashMap should be empty"); - let foo_value = ctxt.partition_row_presets.get(&name_foo); + assert!(ctxt.partition_row_presets.try_lock().unwrap().is_empty(), "The 'partition_row_presets' HashMap should be empty"); + let foo_value = ctxt.partition_row_presets.try_lock().unwrap().get(&name_foo).cloned(); assert_eq!(None, foo_value); tokio::runtime::Runtime::new().unwrap().block_on(async { - _init_partition_row_distribution_preset(&mut ctxt, + _init_partition_row_distribution_preset(&ctxt, &name_foo, 1000, 10, "100:1").await }).unwrap_or_else(|_| panic!("The '{name_foo}' preset must have been created successfully")); - assert!(!ctxt.partition_row_presets.is_empty(), "The 'partition_row_presets' HashMap should not be empty"); - ctxt.partition_row_presets.get(&name_foo) + assert!(!ctxt.partition_row_presets.try_lock().unwrap().is_empty(), "The 'partition_row_presets' HashMap should not be empty"); + ctxt.partition_row_presets.try_lock().unwrap().get(&name_foo) .unwrap_or_else(|| panic!("Preset with name '{name_foo}' was not found")); - let absent_bar = ctxt.partition_row_presets.get(&name_bar); + let absent_bar = ctxt.partition_row_presets.try_lock().unwrap().get(&name_bar).cloned(); assert_eq!(None, absent_bar, "{}", format_args!("The '{}' preset was expected to be absent", name_bar)); tokio::runtime::Runtime::new().unwrap().block_on(async { - _init_partition_row_distribution_preset(&mut ctxt, + _init_partition_row_distribution_preset(&ctxt, &name_bar, 1000, 10, "90:1,10:2").await }).unwrap_or_else(|_| panic!("The '{name_bar}' preset must have been created successfully")); - ctxt.partition_row_presets.get(&name_bar) + ctxt.partition_row_presets.try_lock().unwrap().get(&name_bar) .unwrap_or_else(|| panic!("Preset with name '{name_bar}' was not found")); } @@ -734,9 +745,9 @@ mod tests { rows_per_partitions_base: u64, rows_per_partitions_groups: String, ) { - let mut ctxt: Context = create_test_context(); + let ctxt: Context = create_test_context(); let result = tokio::runtime::Runtime::new().unwrap().block_on(async { - _init_partition_row_distribution_preset(&mut ctxt, + _init_partition_row_distribution_preset(&ctxt, &preset_name, row_count, rows_per_partitions_base, &rows_per_partitions_groups).await }); diff --git a/src/scripting/rune_uuid.rs b/src/scripting/rune_uuid.rs index 039c7dd..3e3d92e 100644 --- a/src/scripting/rune_uuid.rs +++ b/src/scripting/rune_uuid.rs @@ -20,9 +20,9 @@ impl Uuid { Uuid(builder.into_uuid()) } - #[rune::function(protocol = STRING_DISPLAY)] + #[rune::function(protocol = DISPLAY_FMT)] pub fn string_display(&self, f: &mut rune::runtime::Formatter) -> VmResult<()> { - vm_write!(f, "{}", self.0); + let _ = vm_write!(f, "{}", self.0); VmResult::Ok(()) } } diff --git a/src/scripting/split_lines_iter.rs b/src/scripting/split_lines_iter.rs index 66ff85a..9be2794 100644 --- a/src/scripting/split_lines_iter.rs +++ b/src/scripting/split_lines_iter.rs @@ -92,43 +92,53 @@ pub fn read_split_lines_iter( let mut maxsplit = -1; let mut do_trim = true; let mut skip_empty = true; + let as_str = |v: &Value| -> Option { + v.borrow_ref::() + .ok() + .map(|s| s.as_str().to_string()) + }; + let as_int = |v: &Value| v.as_signed().ok(); + let as_bool = |v: &Value| v.as_bool().ok(); match params.as_slice() { // (str): delimiter - [Value::String(custom_delimiter)] => { - delimiter = custom_delimiter.borrow_ref().unwrap().to_string(); + [a] if as_str(a).is_some() => { + delimiter = as_str(a).unwrap(); } // (int): maxsplit - [Value::Integer(custom_maxsplit)] => { - maxsplit = *custom_maxsplit; + [a] if as_int(a).is_some() => { + maxsplit = as_int(a).unwrap(); } // (bool): do_trim - [Value::Bool(custom_do_trim)] => { - do_trim = *custom_do_trim; + [a] if as_bool(a).is_some() => { + do_trim = as_bool(a).unwrap(); } // (bool, bool): do_trim, skip_empty - [Value::Bool(custom_do_trim), Value::Bool(custom_skip_empty)] => { - do_trim = *custom_do_trim; - skip_empty = *custom_skip_empty; + [a, b] if as_bool(a).is_some() && as_bool(b).is_some() => { + do_trim = as_bool(a).unwrap(); + skip_empty = as_bool(b).unwrap(); } // (str, int): delimiter, maxsplit - [Value::String(custom_delimiter), Value::Integer(custom_maxsplit)] => { - delimiter = custom_delimiter.borrow_ref().unwrap().to_string(); - maxsplit = *custom_maxsplit; + [a, b] if as_str(a).is_some() && as_int(b).is_some() => { + delimiter = as_str(a).unwrap(); + maxsplit = as_int(b).unwrap(); } // (str, int, bool): delimiter, maxsplit, do_trim - [Value::String(custom_delimiter), Value::Integer(custom_maxsplit), Value::Bool(custom_do_trim)] => - { - delimiter = custom_delimiter.borrow_ref().unwrap().to_string(); - maxsplit = *custom_maxsplit; - do_trim = *custom_do_trim; + [a, b, c] if as_str(a).is_some() && as_int(b).is_some() && as_bool(c).is_some() => { + delimiter = as_str(a).unwrap(); + maxsplit = as_int(b).unwrap(); + do_trim = as_bool(c).unwrap(); } // (str, int, bool, bool): delimiter, maxsplit, do_trim, skip_empty - [Value::String(custom_delimiter), Value::Integer(custom_maxsplit), Value::Bool(custom_do_trim), Value::Bool(custom_skip_empty)] => + [a, b, c, d] + if as_str(a).is_some() + && as_int(b).is_some() + && as_bool(c).is_some() + && as_bool(d).is_some() => { - delimiter = custom_delimiter.borrow_ref().unwrap().to_string(); - maxsplit = *custom_maxsplit; - do_trim = *custom_do_trim; - skip_empty = *custom_skip_empty; + delimiter = as_str(a).unwrap(); + maxsplit = as_int(b).unwrap(); + do_trim = as_bool(c).unwrap(); + skip_empty = as_bool(d).unwrap(); } _ => panic!("Invalid arguments for read_split_lines_iter"), }