Skip to content
Merged
207 changes: 10 additions & 197 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ use std::time::Duration;

use anyhow::anyhow;
use chrono::Utc;
#[cfg(feature = "cql")]
use clap::builder::PossibleValue;
use clap::{Parser, ValueEnum};
use itertools::Itertools;
use serde::{Deserialize, Serialize};

use crate::scripting::db_config;

/// Limit of retry errors to be kept and then printed in scope of a sampling interval
pub const PRINT_RETRY_ERROR_LIMIT: u64 = 5;

Expand Down Expand Up @@ -158,71 +158,10 @@ impl FromStr for RetryInterval {

#[derive(Parser, Debug, Serialize, Deserialize)]
pub struct ConnectionConf {
/// Number of connections per Cassandra node / Scylla shard.
#[clap(
short('c'),
long("connections"),
default_value = "1",
value_name = "COUNT"
)]
pub count: NonZeroUsize,

/// List of Cassandra addresses to connect to.
/// List of addresses to connect to.
#[clap(name = "addresses", default_value = "localhost")]
pub addresses: Vec<String>,

/// Cassandra user name
#[clap(long, env("CASSANDRA_USER"), default_value = "")]
pub user: String,

/// Password to use if password authentication is required by the server
#[clap(long, env("CASSANDRA_PASSWORD"), default_value = "")]
pub password: String,

/// Enable SSL
#[clap(long("ssl"))]
pub ssl: bool,

/// Path to the CA certificate file in PEM format
#[clap(long("ssl-ca"), value_name = "PATH")]
pub ssl_ca_cert_file: Option<PathBuf>,

/// Path to the client SSL certificate file in PEM format
#[clap(long("ssl-cert"), value_name = "PATH")]
pub ssl_cert_file: Option<PathBuf>,

/// Path to the client SSL private key file in PEM format
#[clap(long("ssl-key"), value_name = "PATH")]
pub ssl_key_file: Option<PathBuf>,

/// Verify if the peer's certificate is trusted
#[clap(long("ssl-peer-verification"))]
pub ssl_peer_verification: bool,

/// Datacenter name
#[clap(long("datacenter"), required = false)]
pub datacenter: Option<String>,

/// Rack name
#[clap(long("rack"), required = false)]
pub rack: Option<String>,

/// CQL query consistency level.
/// 'SERIAL' and 'LOCAL_SERIAL' values are compatible only with SELECT statements
/// and make Scylla use Paxos consensus algorithm
#[cfg(feature = "cql")]
#[clap(long("consistency"), required = false, default_value = "LOCAL_QUORUM")]
pub consistency: Consistency,

/// Serial consistency level for conditional (LWT) queries
#[cfg(feature = "cql")]
#[clap(
long("serial-consistency"),
required = false,
default_value = "LOCAL_SERIAL"
)]
pub serial_consistency: SerialConsistency,

#[clap(
long("request-timeout"),
default_value = "5s",
Expand Down Expand Up @@ -263,6 +202,10 @@ pub struct ConnectionConf {
default_value = "fail-fast"
)]
pub validation_strategy: ValidationStrategy,

#[clap(flatten)]
#[serde(flatten)]
pub db: db_config::DbConnectionConf,
}

#[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Serialize, Deserialize, ValueEnum)]
Expand All @@ -273,136 +216,6 @@ pub enum ValidationStrategy {
Ignore, // Ignore validation errors - face, print, go on.
}

#[cfg(feature = "cql")]
#[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum Consistency {
Any,
One,
Two,
Three,
Quorum,
All,
LocalOne,
#[default]
LocalQuorum,
EachQuorum,
// NOTE: 'Serial' and 'LocalSerial' values may be used in SELECT statements
// to make them use Paxos consensus algorithm.
Serial,
LocalSerial,
}

#[cfg(feature = "cql")]
impl Consistency {
pub fn consistency(&self) -> scylla::frame::types::Consistency {
match self {
Self::Any => scylla::frame::types::Consistency::Any,
Self::One => scylla::frame::types::Consistency::One,
Self::Two => scylla::frame::types::Consistency::Two,
Self::Three => scylla::frame::types::Consistency::Three,
Self::Quorum => scylla::frame::types::Consistency::Quorum,
Self::All => scylla::frame::types::Consistency::All,
Self::LocalOne => scylla::frame::types::Consistency::LocalOne,
Self::LocalQuorum => scylla::frame::types::Consistency::LocalQuorum,
Self::EachQuorum => scylla::frame::types::Consistency::EachQuorum,
Self::Serial => scylla::frame::types::Consistency::Serial,
Self::LocalSerial => scylla::frame::types::Consistency::LocalSerial,
}
}
}

#[cfg(feature = "cql")]
impl ValueEnum for Consistency {
fn value_variants<'a>() -> &'a [Self] {
&[
Self::Any,
Self::One,
Self::Two,
Self::Three,
Self::Quorum,
Self::All,
Self::LocalOne,
Self::LocalQuorum,
Self::EachQuorum,
Self::Serial,
Self::LocalSerial,
]
}

fn from_str(s: &str, _ignore_case: bool) -> Result<Self, String> {
match s.to_lowercase().as_str() {
"any" => Ok(Self::Any),
"one" | "1" => Ok(Self::One),
"two" | "2" => Ok(Self::Two),
"three" | "3" => Ok(Self::Three),
"quorum" | "q" => Ok(Self::Quorum),
"all" => Ok(Self::All),
"local_one" | "localone" | "l1" => Ok(Self::LocalOne),
"local_quorum" | "localquorum" | "lq" => Ok(Self::LocalQuorum),
"each_quorum" | "eachquorum" | "eq" => Ok(Self::EachQuorum),
"serial" | "s" => Ok(Self::Serial),
"local_serial" | "localserial" | "ls" => Ok(Self::LocalSerial),
s => Err(format!("Unknown consistency level {s}")),
}
}

fn to_possible_value(&self) -> Option<PossibleValue> {
match self {
Self::Any => Some(PossibleValue::new("ANY")),
Self::One => Some(PossibleValue::new("ONE")),
Self::Two => Some(PossibleValue::new("TWO")),
Self::Three => Some(PossibleValue::new("THREE")),
Self::Quorum => Some(PossibleValue::new("QUORUM")),
Self::All => Some(PossibleValue::new("ALL")),
Self::LocalOne => Some(PossibleValue::new("LOCAL_ONE")),
Self::LocalQuorum => Some(PossibleValue::new("LOCAL_QUORUM")),
Self::EachQuorum => Some(PossibleValue::new("EACH_QUORUM")),
Self::Serial => Some(PossibleValue::new("SERIAL")),
Self::LocalSerial => Some(PossibleValue::new("LOCAL_SERIAL")),
}
}
}

#[cfg(feature = "cql")]
#[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum SerialConsistency {
Serial,
#[default]
LocalSerial,
}

#[cfg(feature = "cql")]
impl SerialConsistency {
pub fn serial_consistency(&self) -> scylla::frame::types::SerialConsistency {
match self {
Self::Serial => scylla::frame::types::SerialConsistency::Serial,
Self::LocalSerial => scylla::frame::types::SerialConsistency::LocalSerial,
}
}
}

#[cfg(feature = "cql")]
impl ValueEnum for SerialConsistency {
fn value_variants<'a>() -> &'a [Self] {
&[Self::Serial, Self::LocalSerial]
}

fn from_str(s: &str, _ignore_case: bool) -> Result<Self, String> {
match s.to_lowercase().as_str() {
"serial" | "s" => Ok(Self::Serial),
"local_serial" | "localserial" | "ls" => Ok(Self::LocalSerial),
s => Err(format!("Unknown serial consistency level {s}")),
}
}

fn to_possible_value(&self) -> Option<PossibleValue> {
match self {
Self::Serial => Some(PossibleValue::new("SERIAL")),
Self::LocalSerial => Some(PossibleValue::new("LOCAL_SERIAL")),
}
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct WeightedFunction {
pub name: String,
Expand Down Expand Up @@ -489,7 +302,7 @@ pub struct SchemaCommand {
#[clap(name = "workload", required = true, value_name = "PATH")]
pub workload: PathBuf,

// Cassandra connection settings.
// Connection settings.
#[clap(flatten)]
pub connection: ConnectionConf,
}
Expand Down Expand Up @@ -520,7 +333,7 @@ pub struct LoadCommand {
#[clap(name = "workload", required = true, value_name = "PATH")]
pub workload: PathBuf,

// Cassandra connection settings.
// Connection settings.
#[clap(flatten)]
pub connection: ConnectionConf,
}
Expand Down Expand Up @@ -642,7 +455,7 @@ pub struct RunCommand {
#[clap(short, long)]
pub quiet: bool,

// Cassandra connection settings.
// Connection settings.
#[clap(flatten)]
pub connection: ConnectionConf,

Expand Down
12 changes: 8 additions & 4 deletions src/report/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,11 +484,13 @@ impl Display for RunConfigCmp<'_> {
self.line("Cluster", "", |conf| {
OptionDisplay(conf.cluster_name.clone())
}),
#[cfg(feature = "cql")]
self.line("Datacenter", "", |conf| {
conf.connection.datacenter.clone().unwrap_or_default()
conf.connection.db.datacenter.clone().unwrap_or_default()
}),
#[cfg(feature = "cql")]
self.line("Rack", "", |conf| {
conf.connection.rack.clone().unwrap_or_default()
conf.connection.db.rack.clone().unwrap_or_default()
}),
self.line("DB version", "", |conf| {
OptionDisplay(conf.db_version.clone())
Expand All @@ -507,11 +509,12 @@ impl Display for RunConfigCmp<'_> {
}),
#[cfg(feature = "cql")]
self.line("Consistency", "", |conf| {
conf.connection.consistency.consistency().to_string()
conf.connection.db.consistency.consistency().to_string()
}),
#[cfg(feature = "cql")]
self.line("Serial consistency", "", |conf| {
conf.connection
.db
.serial_consistency
.serial_consistency()
.to_string()
Expand Down Expand Up @@ -550,8 +553,9 @@ impl Display for RunConfigCmp<'_> {

let lines: Vec<Box<dyn Display>> = vec![
self.line("Threads", "", |conf| Quantity::from(conf.threads)),
#[cfg(feature = "cql")]
self.line("Connections", "", |conf| {
Quantity::from(conf.connection.count)
Quantity::from(conf.connection.db.count)
}),
self.line("Concurrency", "req", |conf| {
Quantity::from(conf.concurrency)
Expand Down
23 changes: 23 additions & 0 deletions src/scripting/alternator/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use clap::Parser;
use serde::{Deserialize, Serialize};

#[derive(Parser, Debug, Default, Serialize, Deserialize)]
pub struct DbConnectionConf {
/// Use AWS credentials and region from the environment.
/// Mutually exclusive with `access-key-id`, `secret-access-key` and `region`.
#[clap(long("aws-credentials"), conflicts_with_all = &["access_key_id", "secret_access_key", "region"])]
pub aws_credentials: bool,
Comment thread
vponomaryov marked this conversation as resolved.

/// Access key ID.
#[clap(long("access-key-id"), default_value = "")]
pub access_key_id: String,

/// Secret access key.
#[serde(skip_serializing)] // Don't save the secret to generated reports.
#[clap(long("secret-access-key"), default_value = "")]
pub secret_access_key: String,

/// Region.
#[clap(long("region"), default_value = "us-east-1")]
pub region: String,
}
30 changes: 24 additions & 6 deletions src/scripting/alternator/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,32 @@ use aws_sdk_dynamodb::Client;
pub async fn connect(conf: &ConnectionConf) -> Result<Context, AlternatorError> {
let address = conf.addresses.first().cloned().unwrap_or_default();

// TODO: use latte parameters for setting the configuration
let config = aws_config::defaults(BehaviorVersion::latest())
let mut config_loader = aws_config::defaults(BehaviorVersion::latest())
.endpoint_url(&address)
.region(Region::new("us-east-1"))
.credentials_provider(Credentials::new("", "", None, None, ""))
.retry_config(RetryConfig::standard().with_max_attempts(1))
.load()
.await;
.timeout_config(
aws_config::timeout::TimeoutConfig::builder()
.operation_timeout(conf.request_timeout)
.build(),
);

// We only specify custom credentials if aws_credentials flag is not set.
// If aws_credentials flag is set, the SDK will automatically use credentials from the environment.
if !conf.db.aws_credentials {
let creds = Credentials::new(
&conf.db.access_key_id,
&conf.db.secret_access_key,
None,
None,
"",
);

config_loader = config_loader
.credentials_provider(creds)
.region(Region::new(conf.db.region.clone()));
}

let config = config_loader.load().await;

let client = Client::new(&config);

Expand Down
Loading
Loading