Skip to content
207 changes: 10 additions & 197 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ use std::time::Duration;

use anyhow::anyhow;
use chrono::Utc;
#[cfg(feature = "cql")]
use clap::builder::PossibleValue;
use clap::{Parser, ValueEnum};
use itertools::Itertools;
use serde::{Deserialize, Serialize};

use crate::scripting::db_config;

/// Limit of retry errors to be kept and then printed in scope of a sampling interval
pub const PRINT_RETRY_ERROR_LIMIT: u64 = 5;

Expand Down Expand Up @@ -158,71 +158,10 @@ impl FromStr for RetryInterval {

#[derive(Parser, Debug, Serialize, Deserialize)]
pub struct ConnectionConf {
/// Number of connections per Cassandra node / Scylla shard.
#[clap(
short('c'),
long("connections"),
default_value = "1",
value_name = "COUNT"
)]
pub count: NonZeroUsize,

/// List of Cassandra addresses to connect to.
/// List of addresses to connect to.
#[clap(name = "addresses", default_value = "localhost")]
pub addresses: Vec<String>,

/// Cassandra user name
#[clap(long, env("CASSANDRA_USER"), default_value = "")]
pub user: String,

/// Password to use if password authentication is required by the server
#[clap(long, env("CASSANDRA_PASSWORD"), default_value = "")]
pub password: String,

/// Enable SSL
#[clap(long("ssl"))]
pub ssl: bool,

/// Path to the CA certificate file in PEM format
#[clap(long("ssl-ca"), value_name = "PATH")]
pub ssl_ca_cert_file: Option<PathBuf>,

/// Path to the client SSL certificate file in PEM format
#[clap(long("ssl-cert"), value_name = "PATH")]
pub ssl_cert_file: Option<PathBuf>,

/// Path to the client SSL private key file in PEM format
#[clap(long("ssl-key"), value_name = "PATH")]
pub ssl_key_file: Option<PathBuf>,

/// Verify if the peer's certificate is trusted
#[clap(long("ssl-peer-verification"))]
pub ssl_peer_verification: bool,

/// Datacenter name
#[clap(long("datacenter"), required = false)]
pub datacenter: Option<String>,

/// Rack name
#[clap(long("rack"), required = false)]
pub rack: Option<String>,

/// CQL query consistency level.
/// 'SERIAL' and 'LOCAL_SERIAL' values are compatible only with SELECT statements
/// and make Scylla use Paxos consensus algorithm
#[cfg(feature = "cql")]
#[clap(long("consistency"), required = false, default_value = "LOCAL_QUORUM")]
pub consistency: Consistency,

/// Serial consistency level for conditional (LWT) queries
#[cfg(feature = "cql")]
#[clap(
long("serial-consistency"),
required = false,
default_value = "LOCAL_SERIAL"
)]
pub serial_consistency: SerialConsistency,

#[clap(
long("request-timeout"),
default_value = "5s",
Expand Down Expand Up @@ -263,6 +202,10 @@ pub struct ConnectionConf {
default_value = "fail-fast"
)]
pub validation_strategy: ValidationStrategy,

#[clap(flatten)]
#[serde(flatten)]
pub db: db_config::DbConnectionConf,
}

#[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Serialize, Deserialize, ValueEnum)]
Expand All @@ -273,136 +216,6 @@ pub enum ValidationStrategy {
Ignore, // Ignore validation errors - face, print, go on.
}

#[cfg(feature = "cql")]
#[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum Consistency {
Any,
One,
Two,
Three,
Quorum,
All,
LocalOne,
#[default]
LocalQuorum,
EachQuorum,
// NOTE: 'Serial' and 'LocalSerial' values may be used in SELECT statements
// to make them use Paxos consensus algorithm.
Serial,
LocalSerial,
}

#[cfg(feature = "cql")]
impl Consistency {
pub fn consistency(&self) -> scylla::frame::types::Consistency {
match self {
Self::Any => scylla::frame::types::Consistency::Any,
Self::One => scylla::frame::types::Consistency::One,
Self::Two => scylla::frame::types::Consistency::Two,
Self::Three => scylla::frame::types::Consistency::Three,
Self::Quorum => scylla::frame::types::Consistency::Quorum,
Self::All => scylla::frame::types::Consistency::All,
Self::LocalOne => scylla::frame::types::Consistency::LocalOne,
Self::LocalQuorum => scylla::frame::types::Consistency::LocalQuorum,
Self::EachQuorum => scylla::frame::types::Consistency::EachQuorum,
Self::Serial => scylla::frame::types::Consistency::Serial,
Self::LocalSerial => scylla::frame::types::Consistency::LocalSerial,
}
}
}

#[cfg(feature = "cql")]
impl ValueEnum for Consistency {
fn value_variants<'a>() -> &'a [Self] {
&[
Self::Any,
Self::One,
Self::Two,
Self::Three,
Self::Quorum,
Self::All,
Self::LocalOne,
Self::LocalQuorum,
Self::EachQuorum,
Self::Serial,
Self::LocalSerial,
]
}

fn from_str(s: &str, _ignore_case: bool) -> Result<Self, String> {
match s.to_lowercase().as_str() {
"any" => Ok(Self::Any),
"one" | "1" => Ok(Self::One),
"two" | "2" => Ok(Self::Two),
"three" | "3" => Ok(Self::Three),
"quorum" | "q" => Ok(Self::Quorum),
"all" => Ok(Self::All),
"local_one" | "localone" | "l1" => Ok(Self::LocalOne),
"local_quorum" | "localquorum" | "lq" => Ok(Self::LocalQuorum),
"each_quorum" | "eachquorum" | "eq" => Ok(Self::EachQuorum),
"serial" | "s" => Ok(Self::Serial),
"local_serial" | "localserial" | "ls" => Ok(Self::LocalSerial),
s => Err(format!("Unknown consistency level {s}")),
}
}

fn to_possible_value(&self) -> Option<PossibleValue> {
match self {
Self::Any => Some(PossibleValue::new("ANY")),
Self::One => Some(PossibleValue::new("ONE")),
Self::Two => Some(PossibleValue::new("TWO")),
Self::Three => Some(PossibleValue::new("THREE")),
Self::Quorum => Some(PossibleValue::new("QUORUM")),
Self::All => Some(PossibleValue::new("ALL")),
Self::LocalOne => Some(PossibleValue::new("LOCAL_ONE")),
Self::LocalQuorum => Some(PossibleValue::new("LOCAL_QUORUM")),
Self::EachQuorum => Some(PossibleValue::new("EACH_QUORUM")),
Self::Serial => Some(PossibleValue::new("SERIAL")),
Self::LocalSerial => Some(PossibleValue::new("LOCAL_SERIAL")),
}
}
}

#[cfg(feature = "cql")]
#[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum SerialConsistency {
Serial,
#[default]
LocalSerial,
}

#[cfg(feature = "cql")]
impl SerialConsistency {
pub fn serial_consistency(&self) -> scylla::frame::types::SerialConsistency {
match self {
Self::Serial => scylla::frame::types::SerialConsistency::Serial,
Self::LocalSerial => scylla::frame::types::SerialConsistency::LocalSerial,
}
}
}

#[cfg(feature = "cql")]
impl ValueEnum for SerialConsistency {
fn value_variants<'a>() -> &'a [Self] {
&[Self::Serial, Self::LocalSerial]
}

fn from_str(s: &str, _ignore_case: bool) -> Result<Self, String> {
match s.to_lowercase().as_str() {
"serial" | "s" => Ok(Self::Serial),
"local_serial" | "localserial" | "ls" => Ok(Self::LocalSerial),
s => Err(format!("Unknown serial consistency level {s}")),
}
}

fn to_possible_value(&self) -> Option<PossibleValue> {
match self {
Self::Serial => Some(PossibleValue::new("SERIAL")),
Self::LocalSerial => Some(PossibleValue::new("LOCAL_SERIAL")),
}
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct WeightedFunction {
pub name: String,
Expand Down Expand Up @@ -489,7 +302,7 @@ pub struct SchemaCommand {
#[clap(name = "workload", required = true, value_name = "PATH")]
pub workload: PathBuf,

// Cassandra connection settings.
// Connection settings.
#[clap(flatten)]
pub connection: ConnectionConf,
}
Expand Down Expand Up @@ -520,7 +333,7 @@ pub struct LoadCommand {
#[clap(name = "workload", required = true, value_name = "PATH")]
pub workload: PathBuf,

// Cassandra connection settings.
// Connection settings.
#[clap(flatten)]
pub connection: ConnectionConf,
}
Expand Down Expand Up @@ -642,7 +455,7 @@ pub struct RunCommand {
#[clap(short, long)]
pub quiet: bool,

// Cassandra connection settings.
// Connection settings.
#[clap(flatten)]
pub connection: ConnectionConf,

Expand Down
12 changes: 8 additions & 4 deletions src/report/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,11 +484,13 @@ impl Display for RunConfigCmp<'_> {
self.line("Cluster", "", |conf| {
OptionDisplay(conf.cluster_name.clone())
}),
#[cfg(feature = "cql")]
self.line("Datacenter", "", |conf| {
conf.connection.datacenter.clone().unwrap_or_default()
conf.connection.db.datacenter.clone().unwrap_or_default()
}),
#[cfg(feature = "cql")]
self.line("Rack", "", |conf| {
conf.connection.rack.clone().unwrap_or_default()
conf.connection.db.rack.clone().unwrap_or_default()
}),
self.line("DB version", "", |conf| {
OptionDisplay(conf.db_version.clone())
Expand All @@ -507,11 +509,12 @@ impl Display for RunConfigCmp<'_> {
}),
#[cfg(feature = "cql")]
self.line("Consistency", "", |conf| {
conf.connection.consistency.consistency().to_string()
conf.connection.db.consistency.consistency().to_string()
}),
#[cfg(feature = "cql")]
self.line("Serial consistency", "", |conf| {
conf.connection
.db
.serial_consistency
.serial_consistency()
.to_string()
Expand Down Expand Up @@ -550,8 +553,9 @@ impl Display for RunConfigCmp<'_> {

let lines: Vec<Box<dyn Display>> = vec![
self.line("Threads", "", |conf| Quantity::from(conf.threads)),
#[cfg(feature = "cql")]
self.line("Connections", "", |conf| {
Quantity::from(conf.connection.count)
Quantity::from(conf.connection.db.count)
}),
self.line("Concurrency", "req", |conf| {
Quantity::from(conf.concurrency)
Expand Down
23 changes: 23 additions & 0 deletions src/scripting/alternator/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use clap::Parser;
use serde::{Deserialize, Serialize};

#[derive(Parser, Debug, Default, Serialize, Deserialize)]
pub struct DbConnectionConf {
/// Use AWS credentials and region from the environment.
/// If this flag is set, access key ID, secret access key, and region args will be ignored.
#[clap(long("aws-credentials"))]
pub aws_credentials: bool,
Copy link
Copy Markdown
Collaborator

@vponomaryov vponomaryov Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the aws-credentials must be mutually exclusive to other 3 on the config validation level.

So, user will know he needs to specify this information in one or another way.
It will not make him to remember the order of significance...

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used conflicts_with_all to mark this as mutually exclusive with access-key-id, secret-access-key and region.


/// Access key ID.
#[clap(long("access-key-id"), default_value = "")]
pub access_key_id: String,

/// Secret access key.
#[serde(skip_serializing)] // Don't save the secret to generated reports.
#[clap(long("secret-access-key"), default_value = "")]
pub secret_access_key: String,

/// Region.
#[clap(long("region"), default_value = "us-east-1")]
pub region: String,
}
30 changes: 24 additions & 6 deletions src/scripting/alternator/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,32 @@ use aws_sdk_dynamodb::Client;
pub async fn connect(conf: &ConnectionConf) -> Result<Context, AlternatorError> {
let address = conf.addresses.first().cloned().unwrap_or_default();

// TODO: use latte parameters for setting the configuration
let config = aws_config::defaults(BehaviorVersion::latest())
let mut config_loader = aws_config::defaults(BehaviorVersion::latest())
.endpoint_url(&address)
.region(Region::new("us-east-1"))
.credentials_provider(Credentials::new("", "", None, None, ""))
.retry_config(RetryConfig::standard().with_max_attempts(1))
.load()
.await;
.timeout_config(
aws_config::timeout::TimeoutConfig::builder()
.operation_timeout(conf.request_timeout)
.build(),
);

// We only specify custom credentials if aws_credentials flag is not set.
// If aws_credentials flag is set, the SDK will automatically use credentials from the environment.
if !conf.db.aws_credentials {
let creds = Credentials::new(
&conf.db.access_key_id,
&conf.db.secret_access_key,
None,
None,
"",
);

config_loader = config_loader
.credentials_provider(creds)
.region(Region::new(conf.db.region.clone()));
}

let config = config_loader.load().await;

let client = Client::new(&config);

Expand Down
Loading
Loading