diff --git a/.gitignore b/.gitignore index 3bf606ddd960..abb0c9064c66 100644 --- a/.gitignore +++ b/.gitignore @@ -35,3 +35,4 @@ compile_commands.json .envrc clang_build .idea/ +nuke diff --git a/CMakeLists.txt b/CMakeLists.txt index 43d8eb444ffd..7cc93d580464 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -181,7 +181,7 @@ target_sources(scylla-main generic_server.cc debug.cc init.cc - keys.cc + keys/keys.cc multishard_mutation_query.cc mutation_query.cc node_ops/task_manager_module.cc diff --git a/alternator/rmw_operation.hh b/alternator/rmw_operation.hh index d3094ddcb255..e1d1908e2077 100644 --- a/alternator/rmw_operation.hh +++ b/alternator/rmw_operation.hh @@ -15,7 +15,7 @@ #include "consumed_capacity.hh" #include "executor.hh" #include "tracing/trace_state.hh" -#include "keys.hh" +#include "keys/keys.hh" namespace alternator { diff --git a/alternator/serialization.hh b/alternator/serialization.hh index ff4e94f00783..7ea259dce8c4 100644 --- a/alternator/serialization.hh +++ b/alternator/serialization.hh @@ -13,7 +13,7 @@ #include #include "types/types.hh" #include "schema/schema_fwd.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "utils/rjson.hh" #include "utils/big_decimal.hh" diff --git a/api/storage_service.cc b/api/storage_service.cc index d777b8daaa10..ffb602bf2532 100644 --- a/api/storage_service.cc +++ b/api/storage_service.cc @@ -359,6 +359,9 @@ void set_repair(http_context& ctx, routes& r, sharded& repair, s // if the option is not sane, repair_start() throws immediately, so // convert the exception to an HTTP error throw httpd::bad_param_exception(e.what()); + } catch (const tablets_unsupported& e) { + throw base_exception("Cannot repair tablet keyspace. Use /storage_service/tablets/repair to repair tablet keyspaces.", + http::reply::status_type::forbidden); } }); diff --git a/cdc/cdc_partitioner.cc b/cdc/cdc_partitioner.cc index 9edcb564db0c..d24d42ee246e 100644 --- a/cdc/cdc_partitioner.cc +++ b/cdc/cdc_partitioner.cc @@ -12,7 +12,7 @@ #include "sstables/key.hh" #include "utils/class_registrator.hh" #include "cdc/generation.hh" -#include "keys.hh" +#include "keys/keys.hh" namespace cdc { diff --git a/cdc/generation.cc b/cdc/generation.cc index 12df8b565661..b13333d6201c 100644 --- a/cdc/generation.cc +++ b/cdc/generation.cc @@ -16,7 +16,7 @@ #include "gms/endpoint_state.hh" #include "gms/versioned_value.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "replica/database.hh" #include "db/system_keyspace.hh" #include "db/system_distributed_keyspace.hh" diff --git a/conf/scylla.yaml b/conf/scylla.yaml index 00bbd53f7ffc..36123f06da86 100644 --- a/conf/scylla.yaml +++ b/conf/scylla.yaml @@ -776,6 +776,35 @@ maintenance_socket: ignore # ... # +# +# Azure Key Vault host(s). +# +# The unique name of azure host/account config that can be referenced in table schema. +# +# host.yourdomain.com={ azure_tenant_id=, azure_client_id=, azure_client_secret=, azure_client_certificate_path=, master_key=/, truststore=/path/to/truststore.pem, priority_string=, key_cache_expiry=, key_cache_refersh=}:... +# +# Authentication can be explicit with Service Principal credentials. Either secret or certificate can be provided. +# If both are provided, the secret will be used. If no credentials are provided, the provider will try to detect them +# from the environment, the Azure CLI, and IMDS, in this specific order. +# +# master_key is a Vault key that will be used to wrap all keys used for actual encryption of scylla data. +# This key must be pre-created and the principal must have permissions for Wrapkey and Unwrapkey operations on this key. +# +# azure_hosts: +# : +# azure_tenant_id: (optional) +# azure_client_id: (optional) +# azure_client_secret: (optional) +# azure_client_certificate_path: (optional) +# master_key: / - named Vault key for key wrapping (optional) +# truststore: (optional) +# priority_string: (optional) +# key_cache_expiry: (optional) +# key_cache_refresh: (optional) +# : +# ... +# + # # Server-global user information encryption settings # @@ -862,3 +891,18 @@ rf_rack_valid_keyspaces: false # Uri for the vector store using dns name. Only http schema is supported. Port number is mandatory. # Default is empty, which means that the vector store is not used. # vector_store_uri: http://vector-store.dns.name:{port} + +# +# io-streaming rate limiting +# When setting this value to be non-zero scylla throttles disk throughput for +# stream (network) activities such as backup, repair, tablet migration and more. +# This limit is useful for user queries so the network interface does +# not get saturated by streaming activities. +# The recommended value is 75% of network bandwidth +# E.g for i4i.8xlarge (https://github.com/scylladb/scylla-machine-image/tree/next/common/aws_net_params.json): +# network: 18.75 GiB/s --> 18750 Mib/s --> 1875 MB/s (from network bits to network bytes: divide by 10, not 8) +# Converted to disk bytes: 1875 * 1000 / 1024 = 1831 MB/s (disk wise) +# 75% of disk bytes is: 0.75 * 1831 = 1373 megabytes/s +# stream_io_throughput_mb_per_sec: 1373 +# + diff --git a/configure.py b/configure.py index b5c49b230e67..6ffb3b25009b 100755 --- a/configure.py +++ b/configure.py @@ -551,6 +551,7 @@ def find_ninja(): 'test/boost/sstable_conforms_to_mutation_source_test', 'test/boost/sstable_datafile_test', 'test/boost/sstable_generation_test', + 'test/boost/sstable_inexact_index_test', 'test/boost/sstable_move_test', 'test/boost/sstable_mutation_test', 'test/boost/sstable_partition_index_cache_test', @@ -831,7 +832,7 @@ def find_ninja(): 'readers/mutation_reader.cc', 'readers/mutation_readers.cc', 'mutation_query.cc', - 'keys.cc', + 'keys/keys.cc', 'counters.cc', 'compress.cc', 'sstable_dict_autotrainer.cc', @@ -1055,6 +1056,11 @@ def find_ninja(): 'utils/s3/credentials_providers/aws_credentials_provider_chain.cc', 'utils/s3/utils/manip_s3.cc', 'utils/advanced_rpc_compressor.cc', + 'utils/azure/identity/credentials.cc', + 'utils/azure/identity/service_principal_credentials.cc', + 'utils/azure/identity/managed_identity_credentials.cc', + 'utils/azure/identity/azure_cli_credentials.cc', + 'utils/azure/identity/default_credentials.cc', 'gms/version_generator.cc', 'gms/versioned_value.cc', 'gms/gossiper.cc', @@ -1177,6 +1183,8 @@ def find_ninja(): 'ent/encryption/gcp_host.cc', 'ent/encryption/gcp_key_provider.cc', 'ent/encryption/utils.cc', + 'ent/encryption/azure_host.cc', + 'ent/encryption/azure_key_provider.cc', 'ent/ldap/ldap_connection.cc', 'multishard_mutation_query.cc', 'reader_concurrency_semaphore.cc', diff --git a/cql3/expr/expr-utils.hh b/cql3/expr/expr-utils.hh index f2cc5c5d5697..ab79093da3f6 100644 --- a/cql3/expr/expr-utils.hh +++ b/cql3/expr/expr-utils.hh @@ -6,7 +6,7 @@ #include "expression.hh" #include "bytes.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "utils/interval.hh" #include "cql3/expr/restrictions.hh" #include "cql3/assignment_testable.hh" diff --git a/db/cache_mutation_reader.hh b/db/cache_mutation_reader.hh index 08d31aed5978..1314f5bb1542 100644 --- a/db/cache_mutation_reader.hh +++ b/db/cache_mutation_reader.hh @@ -17,7 +17,7 @@ #include "partition_snapshot_row_cursor.hh" #include "read_context.hh" #include "readers/delegating.hh" -#include "clustering_key_filter.hh" +#include "keys/clustering_key_filter.hh" namespace cache { diff --git a/db/config.cc b/db/config.cc index fdfce7581d5d..a888a959e707 100644 --- a/db/config.cc +++ b/db/config.cc @@ -819,7 +819,7 @@ db::config::config(std::shared_ptr exts) , inter_dc_stream_throughput_outbound_megabits_per_sec(this, "inter_dc_stream_throughput_outbound_megabits_per_sec", value_status::Unused, 0, "Throttles all streaming file transfer between the data centers. This setting allows throttles streaming throughput betweens data centers in addition to throttling all network stream traffic as configured with stream_throughput_outbound_megabits_per_sec.") , stream_io_throughput_mb_per_sec(this, "stream_io_throughput_mb_per_sec", liveness::LiveUpdate, value_status::Used, 0, - "Throttles streaming I/O to the specified total throughput (in MiBs/s) across the entire system. Streaming I/O includes the one performed by repair and both RBNO and legacy topology operations such as adding or removing a node. Setting the value to 0 disables stream throttling.") + "Throttles streaming I/O to the specified total throughput (in MiBs/s) across the entire system. Streaming I/O includes the one performed by repair and both RBNO and legacy topology operations such as adding or removing a node. Setting the value to 0 disables stream throttling. It is recommended to set the value for this parameter to be 75% of network bandwidth") , stream_plan_ranges_fraction(this, "stream_plan_ranges_fraction", liveness::LiveUpdate, value_status::Used, 0.1, "Specify the fraction of ranges to stream in a single stream plan. Value is between 0 and 1.") , enable_file_stream(this, "enable_file_stream", liveness::LiveUpdate, value_status::Used, true, "Set true to use file based stream for tablet instead of mutation based stream") diff --git a/db/hints/internal/hint_sender.cc b/db/hints/internal/hint_sender.cc index ed8dbb901ea3..7157679e516f 100644 --- a/db/hints/internal/hint_sender.cc +++ b/db/hints/internal/hint_sender.cc @@ -175,14 +175,14 @@ future<> hint_sender::stop(drain should_drain) noexcept { // // The next call for send_hints_maybe() will send the last hints to the current end point and when it is // done there is going to be no more pending hints and the corresponding hints directory may be removed. - manager_logger.trace("Draining for {}: start", end_point_key()); + manager_logger.info("Draining for {}: start", end_point_key()); set_draining(); send_hints_maybe(); _ep_manager.flush_current_hints().handle_exception([] (auto e) { manager_logger.error("Failed to flush pending hints: {}. Ignoring...", e); }).get(); send_hints_maybe(); - manager_logger.trace("Draining for {}: end", end_point_key()); + manager_logger.info("Draining for {}: end", end_point_key()); } // TODO: Change this log to match the class name, but first make sure no test // relies on the old one. diff --git a/db/hints/manager.cc b/db/hints/manager.cc index 637944a95d59..185a8e76c536 100644 --- a/db/hints/manager.cc +++ b/db/hints/manager.cc @@ -539,7 +539,7 @@ future<> manager::change_host_filter(host_filter filter) { "change_host_filter: cannot change the configuration because hints all hints were drained"}); } - manager_logger.debug("change_host_filter: changing from {} to {}", _host_filter, filter); + manager_logger.info("change_host_filter: changing from {} to {}", _host_filter, filter); // Change the host_filter now and save the old one so that we can // roll back in case of failure @@ -611,11 +611,19 @@ future<> manager::change_host_filter(host_filter filter) { }); }); } catch (...) { + const sstring exception_message = eptr + ? seastar::format("{} + {}", eptr, std::current_exception()) + : seastar::format("{}", std::current_exception()); + + manager_logger.warn("Changing the host filter has failed: {}", exception_message); + if (eptr) { std::throw_with_nested(eptr); } throw; } + + manager_logger.info("The host filter has been changed successfully"); } bool manager::check_dc_for(endpoint_id ep) const noexcept { @@ -635,7 +643,7 @@ future<> manager::drain_for(endpoint_id host_id, gms::inet_address ip) noexcept co_return; } - manager_logger.trace("Draining starts for {}", host_id); + manager_logger.info("Draining starts for {}", host_id); const auto holder = seastar::gate::holder{_draining_eps_gate}; // As long as we hold on to this lock, no migration of hinted handoff to host IDs @@ -661,7 +669,7 @@ future<> manager::drain_for(endpoint_id host_id, gms::inet_address ip) noexcept return ep_man.with_file_update_mutex([&ep_man] -> future<> { return remove_file(ep_man.hints_dir().native()).then([&ep_man] { - manager_logger.debug("Removed hint directory for {}", ep_man.end_point_key()); + manager_logger.info("Removed hint directory for {}", ep_man.end_point_key()); }); }); }); @@ -722,7 +730,7 @@ future<> manager::drain_for(endpoint_id host_id, gms::inet_address ip) noexcept manager_logger.error("Exception when draining {}: {}", host_id, eptr); } - manager_logger.trace("drain_for: finished draining {}", host_id); + manager_logger.info("drain_for: finished draining {}", host_id); } void manager::update_backlog(size_t backlog, size_t max_backlog) { diff --git a/db/row_cache.cc b/db/row_cache.cc index e85836466d9e..30db32b7fbf5 100644 --- a/db/row_cache.cc +++ b/db/row_cache.cc @@ -25,7 +25,7 @@ #include "readers/nonforwardable.hh" #include "cache_mutation_reader.hh" #include "partition_snapshot_reader.hh" -#include "clustering_key_filter.hh" +#include "keys/clustering_key_filter.hh" #include "utils/assert.hh" #include "utils/updateable_value.hh" #include "utils/labels.hh" diff --git a/db/size_estimates_virtual_reader.cc b/db/size_estimates_virtual_reader.cc index e895fd9422cb..5b48f31a4555 100644 --- a/db/size_estimates_virtual_reader.cc +++ b/db/size_estimates_virtual_reader.cc @@ -12,7 +12,7 @@ #include "utils/assert.hh" -#include "clustering_bounds_comparator.hh" +#include "keys/clustering_bounds_comparator.hh" #include "replica/database_fwd.hh" #include "db/system_keyspace.hh" #include "dht/i_partitioner.hh" diff --git a/db/view/row_locking.hh b/db/view/row_locking.hh index 6c7a13b4fcba..510a29af71ea 100644 --- a/db/view/row_locking.hh +++ b/db/view/row_locking.hh @@ -30,7 +30,7 @@ #include "dht/decorated_key.hh" #include "utils/estimated_histogram.hh" #include "utils/latency.hh" -#include "keys.hh" +#include "keys/keys.hh" class row_locker { public: diff --git a/db/view/view.cc b/db/view/view.cc index 89f60fa0a0d3..31c76952b4aa 100644 --- a/db/view/view.cc +++ b/db/view/view.cc @@ -25,7 +25,7 @@ #include "db/view/base_info.hh" #include "replica/database.hh" -#include "clustering_bounds_comparator.hh" +#include "keys/clustering_bounds_comparator.hh" #include "cql3/statements/select_statement.hh" #include "cql3/util.hh" #include "cql3/restrictions/statement_restrictions.hh" @@ -45,7 +45,7 @@ #include "dht/sharder.hh" #include "gms/inet_address.hh" #include "gms/feature_service.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "locator/abstract_replication_strategy.hh" #include "locator/network_topology_strategy.hh" #include "mutation/mutation.hh" diff --git a/dht/auto_refreshing_sharder.hh b/dht/auto_refreshing_sharder.hh index 7ad6e6ed2205..de74449f5b21 100644 --- a/dht/auto_refreshing_sharder.hh +++ b/dht/auto_refreshing_sharder.hh @@ -41,7 +41,7 @@ public: virtual ~auto_refreshing_sharder(); - virtual unsigned shard_for_reads(const token& t) const override; + virtual std::optional try_get_shard_for_reads(const token& t) const override; virtual dht::shard_replica_set shard_for_writes(const token& t, std::optional sel) const override; diff --git a/dht/decorated_key.hh b/dht/decorated_key.hh index 657a3bbe722d..fcaa52e38d28 100644 --- a/dht/decorated_key.hh +++ b/dht/decorated_key.hh @@ -9,7 +9,7 @@ #pragma once -#include "keys.hh" +#include "keys/keys.hh" #include "schema/schema_fwd.hh" #include "dht/token.hh" #include "dht/i_partitioner_fwd.hh" diff --git a/dht/i_partitioner.cc b/dht/i_partitioner.cc index 203f61a0f31c..1a03af4fa73a 100644 --- a/dht/i_partitioner.cc +++ b/dht/i_partitioner.cc @@ -43,8 +43,8 @@ static_sharder::shard_of(const token& t) const { return dht::shard_of(_shard_count, _sharding_ignore_msb_bits, t); } -unsigned -static_sharder::shard_for_reads(const token& t) const { +std::optional +static_sharder::try_get_shard_for_reads(const token& t) const { return shard_of(t); } @@ -513,8 +513,8 @@ auto_refreshing_sharder::refresh() { }); } -unsigned auto_refreshing_sharder::shard_for_reads(const token& t) const { - return _sharder->shard_for_reads(t); +std::optional auto_refreshing_sharder::try_get_shard_for_reads(const token& t) const { + return _sharder->try_get_shard_for_reads(t); } dht::shard_replica_set diff --git a/dht/i_partitioner.hh b/dht/i_partitioner.hh index c9596b5a6744..50145589672c 100644 --- a/dht/i_partitioner.hh +++ b/dht/i_partitioner.hh @@ -10,7 +10,7 @@ #pragma once #include -#include "keys.hh" +#include "keys/keys.hh" #include #include #include "dht/token.hh" diff --git a/dht/ring_position.hh b/dht/ring_position.hh index dd9ededf45f5..54f6bafc106b 100644 --- a/dht/ring_position.hh +++ b/dht/ring_position.hh @@ -9,7 +9,7 @@ #pragma once -#include "keys.hh" +#include "keys/keys.hh" #include "dht/token.hh" #include "dht/decorated_key.hh" diff --git a/dht/token-sharding.hh b/dht/token-sharding.hh index 3cb4a0864891..86aa1b85724e 100644 --- a/dht/token-sharding.hh +++ b/dht/token-sharding.hh @@ -54,6 +54,12 @@ public: sharder(unsigned shard_count = smp::count, unsigned sharding_ignore_msb_bits = 0); virtual ~sharder() = default; + /** + * Returns the shard that handles a particular token for reads, or empty if this + * node doesn't contain data for it. + */ + virtual std::optional try_get_shard_for_reads(const token& t) const = 0; + /** * Returns the shard that handles a particular token for reads. * Use shard_for_writes() to determine the set of shards that should receive writes. @@ -67,7 +73,12 @@ public: * } * */ - virtual unsigned shard_for_reads(const token& t) const = 0; + unsigned shard_for_reads(const token& t) const { + // FIXME: Consider throwing when there is no owning shard on the current host rather than returning 0. + // It's a coordination mistake to route requests to non-owners. Topology coordinator should synchronize + // with request coordinators before moving the shard away. + return try_get_shard_for_reads(t).value_or(0); + } /** * Returns the set of shards which should receive a write to token t. @@ -134,7 +145,7 @@ public: virtual std::optional next_shard(const token& t) const; virtual token token_for_next_shard(const token& t, shard_id shard, unsigned spans = 1) const; - virtual unsigned shard_for_reads(const token& t) const override; + virtual std::optional try_get_shard_for_reads(const token& t) const override; virtual shard_replica_set shard_for_writes(const token& t, std::optional sel) const override; virtual token token_for_next_shard_for_reads(const token& t, shard_id shard, unsigned spans = 1) const override; virtual std::optional next_shard_for_reads(const token& t) const override; diff --git a/dht/token.cc b/dht/token.cc index eec7f6cf3189..412c78f3dfea 100644 --- a/dht/token.cc +++ b/dht/token.cc @@ -78,7 +78,7 @@ static float ratio_helper(int64_t a, int64_t b) { } std::map -token::describe_ownership(const std::vector& sorted_tokens) { +token::describe_ownership(const utils::chunked_vector& sorted_tokens) { std::map ownerships; auto i = sorted_tokens.begin(); diff --git a/dht/token.hh b/dht/token.hh index 98e753ca5642..b6acf63d1f1c 100644 --- a/dht/token.hh +++ b/dht/token.hh @@ -10,6 +10,7 @@ #include "bytes_fwd.hh" #include "types/types.hh" +#include "utils/chunked_vector.hh" #include #include @@ -194,7 +195,7 @@ public: * @param sortedtokens a sorted List of tokens * @return the mapping from 'token' to 'percentage of the ring owned by that token'. */ - static std::map describe_ownership(const std::vector& sorted_tokens); + static std::map describe_ownership(const utils::chunked_vector& sorted_tokens); static data_type get_token_validator(); diff --git a/docs/_utils/redirects.yaml b/docs/_utils/redirects.yaml index 5f415f837f9b..411290870909 100644 --- a/docs/_utils/redirects.yaml +++ b/docs/_utils/redirects.yaml @@ -172,4 +172,7 @@ /stable/upgrade/upgrade-opensource/upgrade-guide-from-4.5-to-4.6/metric-update-4.5-to-4.6.html: /stable/upgrade/index.html # Divide API reference to smaller files -# /stable/reference/api-reference.html: /stable/reference/api/index.html \ No newline at end of file +# /stable/reference/api-reference.html: /stable/reference/api/index.html + +# Fixed typo in the file name +/stable/operating-scylla/nodetool-commands/enbleautocompaction.html: /stable/operating-scylla/nodetool-commands/enableautocompaction.html \ No newline at end of file diff --git a/docs/operating-scylla/nodetool-commands/enbleautocompaction.rst b/docs/operating-scylla/nodetool-commands/enableautocompaction.rst similarity index 100% rename from docs/operating-scylla/nodetool-commands/enbleautocompaction.rst rename to docs/operating-scylla/nodetool-commands/enableautocompaction.rst diff --git a/docs/operating-scylla/nodetool.rst b/docs/operating-scylla/nodetool.rst index 605b8c8abefd..7bb57edee723 100644 --- a/docs/operating-scylla/nodetool.rst +++ b/docs/operating-scylla/nodetool.rst @@ -25,7 +25,7 @@ Nodetool nodetool-commands/disablebinary nodetool-commands/disablegossip nodetool-commands/drain - nodetool-commands/enbleautocompaction + nodetool-commands/enableautocompaction nodetool-commands/enablebackup nodetool-commands/enablebinary nodetool-commands/enablegossip @@ -98,7 +98,7 @@ Operations that are not listed below are currently not available. * :doc:`disablebinary ` - Disable native transport (binary protocol). * :doc:`disablegossip ` - Disable gossip (effectively marking the node down). * :doc:`drain ` - Drain the node (stop accepting writes and flush all column families). -* :doc:`enableautocompaction ` - Enable automatic compaction of a keyspace or table. +* :doc:`enableautocompaction ` - Enable automatic compaction of a keyspace or table. * :doc:`enablebackup ` - Enable incremental backup. * :doc:`enablebinary ` - Re-enable native transport (binary protocol). * :doc:`enablegossip ` - Re-enable gossip. diff --git a/docs/operating-scylla/security/encryption-at-rest.rst b/docs/operating-scylla/security/encryption-at-rest.rst index 5115d8b33b9f..b55eb43422e6 100644 --- a/docs/operating-scylla/security/encryption-at-rest.rst +++ b/docs/operating-scylla/security/encryption-at-rest.rst @@ -62,6 +62,8 @@ Table keys are used for encrypting SSTables. Depending on your key provider, thi * Replicated Key Provider - encrypted_keys table * KMIP Key Provider - KMIP server * KMS Key Provider - AWS +* GCP Key Provider - GCP +* Azure Key Provider - Azure * Local Key Provider - in a local file with multiple keys. You can provide your own key or ScyllaDB can make one for you. .. _ear-key-providers: @@ -98,6 +100,9 @@ When encrypting the system tables or SSTables, you need to state which provider * - GCP Key Provider - GcpKeyProviderFactory - Used key(s) provided by the GCP KMS service. + * - Azure Key Provider + - AzureKeyProviderFactory + - Used key(s) provided by the Azure Key Vault service. About Local Key Storage @@ -403,6 +408,112 @@ If you are using Google GCP KMS to encrypt tables or system information, add the .. include:: /rst_include/scylla-commands-restart-index.rst +.. _encryption-at-rest-set-azure: + +Set the Azure Host +---------------------- + +If you are using Azure Key Vault to encrypt tables or system information, you need to add the Azure information to the ``scylla.yaml`` configuration file. + +What You'll Need +================ + +Before configuring Azure Key Vault integration, ensure you have: + +* A Vault key. + +* An Azure identity to authenticate ScyllaDB against the Vault key. Choose one of: + + * **Managed Identity** (recommended for production) - No credentials are needed. ScyllaDB will automatically detect and use the default managed identity of the Azure VM. + * **Service Principal** - Requires providing a tenant ID, client ID, and a secret or certificate. + +* Sufficient permissions to perform the ``wrapkey`` and ``unwrapkey`` operations on the key. + If you are using RBAC, you can assign the "Key Vault Crypto User" role to your identity. + +.. note:: + ScyllaDB can also authenticate indirectly through a preconfigured Azure CLI, but this is offered only for testing purposes. + +Procedure +========= + +#. Edit the ``scylla.yaml`` file located in ``/etc/scylla/`` to add a new host in the Azure host(s) section: + + .. code-block:: yaml + + azure_hosts: + : + azure_tenant_id: (optional) + azure_client_id: (optional) + azure_client_secret: (optional) + azure_client_certificate_path: (optional) + master_key: / - named Vault key for encrypting data keys (optional) + truststore: (optional) + priority_string: (optional) + key_cache_expiry: (optional) + key_cache_refresh: (optional) + # : + + Where: + + * ```` - The name to identify the Azure host. You have to provide this name to encrypt a :ref:`new ` or :ref:`existing ` table. + * ``azure_tenant_id`` - The ID of the Azure tenant. Required only for authentication with a service principal. + * ``azure_client_id`` - The client ID of your service principal. Required only for authentication with a service principal. + * ``azure_client_secret`` - The secret of your service principal. Required only for authentication with a service principal. + * ``azure_client_certificate_path`` - The path to the PEM-encoded certificate and private key file of your service principal. Can be used instead of a secret. + * ``master_key`` - The / of your key. This parameter can be omitted if you specify it through the :ref:`encryption options ` of the table schema. + * ``truststore`` - Path to a PEM file with CA certificates to validate the server's TLS certificate. + * ``priority_string`` - The TLS priority string for TLS handshakes. + * ``key_cache_expiry`` - Key cache expiry period, after which keys will be re-requested from Azure Key Vault. Default is 600s. + * ``key_cache_refresh`` - Key cache refresh period - the frequency at which the cache is checked for expired entries. Default is 1200s. + + Example: + + .. tabs:: + + .. group-tab:: Managed Identity (Recommended for Production) + + .. code-block:: yaml + + azure_hosts: + my-azure1: + master_key: mykeyvault/mykey + + .. group-tab:: Service Principal with Secret + + .. code-block:: yaml + + azure_hosts: + my-azure1: + azure_tenant_id: 00000000-1111-2222-3333-444444444444 + azure_client_id: 55555555-6666-7777-8888-999999999999 + azure_client_secret: mysecret + master_key: mykeyvault/mykey + + .. note:: + Alternatively, ScyllaDB can obtain these credentials from the following environment variables: + ``AZURE_TENANT_ID``, ``AZURE_CLIENT_ID``, and ``AZURE_CLIENT_SECRET``. + + .. group-tab:: Service Principal with Certificate + + .. code-block:: yaml + + azure_hosts: + my-azure1: + azure_tenant_id: 00000000-1111-2222-3333-444444444444 + azure_client_id: 55555555-6666-7777-8888-999999999999 + azure_client_certificate_path: /path/to/certificate.pem + master_key: mykeyvault/mykey + + .. note:: + Alternatively, ScyllaDB can obtain these credentials from the following environment variables: + ``AZURE_TENANT_ID``, ``AZURE_CLIENT_ID``, and ``AZURE_CLIENT_CERTIFICATE_PATH``. + +#. Save the file. +#. Drain the node with :doc:`nodetool drain ` +#. Restart the scylla-server service. + +.. include:: /rst_include/scylla-commands-restart-index.rst + Encrypt Tables ----------------------------- @@ -421,6 +532,7 @@ Ensure you have an encryption key available: * If you are using AWS KMS, :ref:`set the KMS Host `. * If you are using KMIP, :ref:`set the KMIP Host `. * If you are using Google GCP KMS, :ref:`set the GCP Host `. +* If you are using Azure Key Vault, :ref:`set the Azure Host `. * If you want to create your own key, follow the procedure in :ref:`Create Encryption Keys `. * If you do not create your own key, use the following procedure for ScyllaDB to create a key for you (the default location ``/etc/scylla/data_encryption_keys`` may cause @@ -448,6 +560,7 @@ the ``user_info_encryption`` option: kmip_host: kms_host: gcp_host: + azure_host: Where: @@ -467,6 +580,8 @@ Where: Required if you use KMS. * ``gcp_host`` - The name of your :ref:`gcp_host ` group. Required if you use GCP. +* ``azure_host`` - The name of your :ref:`azure_host ` group. + Required if you use Azure. **Example** @@ -587,6 +702,8 @@ This procedure demonstrates how to encrypt a new table. } ; +.. _ear-create-table-master-key-override: + You can specify a different master key than the one configured for ``kms_host`` in the ``scylla.yaml`` file: .. code-block:: cql @@ -797,6 +914,18 @@ Once this encryption is enabled, it is used for all system data. Where ``gcp_host`` is the unique name of the GCP host specified in the scylla.yaml file. + Example for Azure: + + .. code-block:: none + + system_info_encryption: + enabled: True + cipher_algorithm: AES/CBC/PKCS5Padding + secret_key_strength: 128 + key_provider: AzureKeyProviderFactory + azure_host: myScylla + + Where ``azure_host`` is the unique name of the Azure host specified in the scylla.yaml file. #. Do not close the yaml file. Change the system key directory location according to your settings. diff --git a/ent/encryption/CMakeLists.txt b/ent/encryption/CMakeLists.txt index d1e19c8dcedf..daf925e58f8f 100644 --- a/ent/encryption/CMakeLists.txt +++ b/ent/encryption/CMakeLists.txt @@ -9,6 +9,8 @@ target_sources(scylla_encryption encrypted_file_impl.cc encryption.cc encryption_config.cc + azure_host.cc + azure_key_provider.cc gcp_host.cc gcp_key_provider.cc kmip_host.cc diff --git a/ent/encryption/azure_host.cc b/ent/encryption/azure_host.cc new file mode 100644 index 000000000000..951f9d98c739 --- /dev/null +++ b/ent/encryption/azure_host.cc @@ -0,0 +1,580 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#include + +#include +#include + +#include "db/config.hh" +#include "utils/log.hh" +#include "utils/hash.hh" +#include "utils/rjson.hh" +#include "utils/base64.hh" +#include "utils/loading_cache.hh" +#include "utils/rest/client.hh" +#include "utils/azure/identity/exceptions.hh" +#include "utils/azure/identity/default_credentials.hh" +#include "utils/azure/identity/service_principal_credentials.hh" +#include "azure_host.hh" +#include "encryption.hh" +#include "encryption_exceptions.hh" + +using namespace std::chrono_literals; + +static logging::logger azlog("azure_vault"); + +namespace encryption { + +vault_error vault_error::make_error(const http::reply::status_type& status, std::string_view result) { + const auto& jres = rjson::parse(result); + const auto& error_details = rjson::get(jres, ERROR_KEY); + return vault_error( + status, + rjson::get(error_details, ERROR_CODE_KEY), + rjson::get(error_details, ERROR_MESSAGE_KEY) + ); +} + +class vault_log_filter : public rest::http_log_filter { +public: + enum class op_type { + wrapkey, + unwrapkey, + }; + + explicit vault_log_filter(op_type op) : _op(op) {} + + string_opt filter_header(std::string_view name, std::string_view value) const override { + if (boost::iequals(name, "Authorization") && value.starts_with("Bearer")) { + return REDACTED_VALUE; + } + return std::nullopt; + } + + string_opt filter_body(body_type type, std::string_view body) const override { + if ((_op == op_type::wrapkey && type == body_type::request) + || (_op == op_type::unwrapkey && type == body_type::response)) { + auto j = rjson::parse(body); + auto val = rjson::find(j, "value"); + if (val) { + val->SetString(REDACTED_VALUE); + return rjson::print(j); + } + } + return std::nullopt; + } +private: + op_type _op; +}; + +class azure_host::impl { +public: + static inline constexpr std::chrono::milliseconds default_expiry = 600s; + static inline constexpr std::chrono::milliseconds default_refresh = 1200s; + + impl(encryption_context*, const std::string& name, const host_options& options); + future<> init(); + const host_options& options() const; + future get_or_create_key(const key_info&, const option_override* = nullptr); + future get_key_by_id(const id_type&, const key_info&); +private: + encryption_context* _ctxt; + const std::string _name; + const std::string _log_prefix; + const host_options _options; + std::unique_ptr _credentials; + bool _initialized; + + struct attr_cache_key { + seastar::sstring master_key; + key_info info; + bool operator==(const attr_cache_key& v) const = default; + }; + + struct attr_cache_key_hash { + size_t operator()(const attr_cache_key& k) const { + return utils::tuple_hash()(std::tie(k.master_key, k.info.len)); + } + }; + + friend struct fmt::formatter; + + struct id_cache_key { + azure_host::id_type id; + bool operator==(const id_cache_key& v) const = default; + }; + + struct id_cache_key_hash { + size_t operator()(const id_cache_key& k) const { + return std::hash()(k.id); + } + }; + + friend struct fmt::formatter; + + template + using cache_type = utils::loading_cache< + Key, + Value, + 2, + utils::loading_cache_reload_enabled::yes, + utils::simple_entry_size, + Hash + >; + cache_type _attr_cache; + cache_type _id_cache; + + static constexpr char AKV_HOST_TEMPLATE[] = "{}.vault.azure.net"; + static constexpr char AKV_PATH_TEMPLATE[] = "/keys/{}/{}/{}?api-version=7.4"; + static constexpr char AKV_LATEST_VERSION[] = ""; // an empty version denotes the latest + static constexpr char AKV_WRAPKEY_OP[] = "wrapkey"; + static constexpr char AKV_UNWRAPKEY_OP[] = "unwrapkey"; + static constexpr char AKV_ENCRYPTION_ALG[] = "RSA-OAEP-256"; + static constexpr char AKV_TOKEN_RESOURCE_URI[] = "https://vault.azure.net"; // no trailing slash + + static std::tuple parse_key(std::string_view); + static std::tuple parse_vault(std::string_view); + future> make_creds(); + future send_request(const sstring& host, unsigned port, bool use_https, const sstring& path, const rjson::value& body, const rest::http_log_filter& filter); + future send_request_with_retry(const sstring& host, unsigned port, bool use_https, const sstring& path, const rjson::value& body, const rest::http_log_filter& filter); + future create_key(const attr_cache_key&); + future find_key(const id_cache_key&); +}; + +azure_host::impl::impl(encryption_context* ctxt, const std::string& name, const azure_host::host_options& options) + : _ctxt(ctxt) + , _name(name) + , _log_prefix(fmt::format("AzureVault:{}", name)) + , _options(options) + , _credentials() + , _initialized(false) + , _attr_cache(utils::loading_cache_config{ + .max_size = std::numeric_limits::max(), + .expiry = options.key_cache_expiry.value_or(default_expiry), + .refresh = options.key_cache_refresh.value_or(default_refresh)}, azlog, std::bind_front(&impl::create_key, this)) + , _id_cache(utils::loading_cache_config{ + .max_size = std::numeric_limits::max(), + .expiry = options.key_cache_expiry.value_or(default_expiry), + .refresh = options.key_cache_refresh.value_or(default_refresh)}, azlog, std::bind_front(&impl::find_key, this)) +{ + if (!_options.tenant_id.empty() && !_options.client_id.empty() && (!_options.client_secret.empty() || !_options.client_cert.empty())) { + _credentials = std::make_unique( + options.tenant_id, options.client_id, options.client_secret, options.client_cert, + options.authority, options.truststore, options.priority_string, _log_prefix); + return; + } + azlog.debug("[{}] No credentials configured. Falling back to default credentials.", _log_prefix); + _credentials = std::make_unique( + azure::default_credentials::all_sources, _options.imds_endpoint, + _options.truststore, _options.priority_string, _log_prefix); +} + +/** + * Wraps exceptions to encryption::base_error exceptions. + * Should be used in all public methods. + */ +template +static future wrap_exceptions(const std::string& context, Callable&& func) { + try { + co_return co_await func(); + } catch (base_error&) { + throw; + } catch (const std::invalid_argument& e) { + std::throw_with_nested(configuration_error(fmt::format("{}: {}", context, e.what()))); + } catch (const rjson::malformed_value& e) { + std::throw_with_nested(malformed_response_error(fmt::format("{}: {}", context, e.what()))); + } catch (...) { + std::throw_with_nested(service_error(fmt::format("{}: {}", context, std::current_exception()))); + } +} + +future<> azure_host::impl::init() { + if (_initialized) { + co_return; + } + if (_options.master_key.empty()) { + azlog.info("[{}] No master key configured. Not verifying.", _log_prefix); + co_return; + } + azlog.info("[{}] Verifying access to master key {}", _log_prefix, _options.master_key); + co_await wrap_exceptions("init", [this] -> future<> { + azlog.debug("[{}] Wrapping a dummy key", _log_prefix); + attr_cache_key k{ + .master_key = _options.master_key, + .info = key_info{ .alg = "AES", .len = 128 }, + }; + auto [key, id] = co_await create_key(k); + azlog.debug("[{}] Unwrapping the dummy key", _log_prefix); + auto data = co_await find_key({ .id = id }); + if (key->key() != data) { + throw service_error(fmt::format("[{}] Key verification failed", _log_prefix)); + } + _initialized = true; + }); +} + +const azure_host::host_options& azure_host::impl::options() const { + return _options; +} + +template +static T get_option(const encryption::azure_host::option_override* oov, std::optional C::* f, const T& def) { + if (oov) { + return (oov->*f).value_or(def); + } + return def; +}; + +future azure_host::impl::get_or_create_key(const key_info& info, const option_override* oov) { + attr_cache_key key { + .master_key = get_option(oov, &option_override::master_key, _options.master_key), + .info = info, + }; + + if (key.master_key.empty()) { + throw configuration_error(fmt::format("[{}] No master key set in azure host config or encryption attributes", _log_prefix)); + } + co_return co_await wrap_exceptions("get_or_create_key", [this, &key] -> future { + co_return co_await _attr_cache.get(key); + }); +} + +future azure_host::impl::get_key_by_id(const azure_host::id_type& id, const key_info& info) { + id_cache_key key { .id = id }; + co_return co_await wrap_exceptions("get_key_by_id", [this, &key, &info] -> future { + auto data = co_await _id_cache.get(key); + co_return make_shared(info, data); + }); +} + +std::tuple azure_host::impl::parse_key(std::string_view spec) { + auto i = spec.find_last_of('/'); + if (i == std::string_view::npos) { + throw std::invalid_argument(fmt::format("Invalid master key spec '{}'. Must be in format /", spec)); + } + if (i >= spec.size() - 1) { + throw std::invalid_argument(fmt::format("Invalid master key spec '{}'. Key name is missing. Expected format: /", spec)); + } + return std::make_tuple(std::string(spec.substr(0, i)), std::string(spec.substr(i + 1))); +} + +std::tuple azure_host::impl::parse_vault(std::string_view vault) { + static const boost::regex vault_name_re(R"([a-zA-Z0-9-]+)"); + static const boost::regex vault_endpoint_re(R"((https?)://([^/:]+)(?::(\d+))?)"); + + boost::smatch match; + std::string tmp{vault}; + + if (boost::regex_match(tmp, match, vault_name_re)) { + // If the vault is just a name, use the default Azure Key Vault endpoint. + return {"https", fmt::format(AKV_HOST_TEMPLATE, vault), 443}; + } + + if (boost::regex_match(tmp, match, vault_endpoint_re)) { + std::string scheme = match[1]; + std::string host = match[2]; + std::string port_str = match[3]; + + unsigned port = (port_str.empty()) ? (scheme == "https" ? 443 : 80) : std::stoi(port_str); + return {scheme, host, port}; + } + + throw std::invalid_argument(fmt::format("Invalid vault '{}'. Must be either a name or an endpoint in format: http(s)://[:port]", vault)); +} + +future> azure_host::impl::make_creds() { + auto creds = ::make_shared(); + if (!_options.priority_string.empty()) { + creds->set_priority_string(_options.priority_string); + } else { + creds->set_priority_string(db::config::default_tls_priority); + } + if (!_options.truststore.empty()) { + co_await creds->set_x509_trust_file(_options.truststore, seastar::tls::x509_crt_format::PEM); + } else { + co_await creds->set_system_trust(); + } + co_return creds; +} + +/** + * @brief Retries for transient errors. + * + * Retries are performed for 401, 408, 429, 500, 502, 503, and 504 errors. + * 401 _may_ indicate an edge case where the cached token expired during the request. As such, it is retried immediately. + * The rest of the error codes are taken from the generic retry policy of the Azure C++ SDK [1] and they follow an exponential backoff strategy. + * Three retries are attempted in total. + * The latencies between retries are: 100, 200 and 400 milliseconds. + * + * This retry policy is destined only for short-lived transient errors. + * Transient errors that require higher delays should be handled by the upper layers. + * Persistent throttling (429) errors are not expected, and they likely indicate a misconfiguration. + * + * [1] https://github.com/Azure/azure-sdk-for-cpp/blob/126452efd30860263398a152f11f337007f529f4/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp#L133 + */ +future azure_host::impl::send_request_with_retry(const sstring& host, unsigned port, bool use_https, const sstring& path, const rjson::value& body, const rest::http_log_filter& filter) { + constexpr int MAX_RETRIES = 3; + constexpr std::chrono::milliseconds DELTA_BACKOFF {100}; + std::chrono::milliseconds backoff; + + int retries = 0; + while (true) { + try { + co_return co_await send_request(host, port, use_https, path, body, filter); + } catch (azure::auth_error& e) { + std::throw_with_nested(permission_error(fmt::format("{}/{}", host, path))); + } catch (vault_error& e) { + auto status = e.status(); + + if (retries >= MAX_RETRIES) { + if (status == http::reply::status_type::unauthorized) { + std::throw_with_nested(permission_error(fmt::format("{}/{}", host, path))); + } else { + std::throw_with_nested(service_error(fmt::format("{}/{}", host, path))); + } + } + + // Always retry if the request is unauthorized, to catch races where + // the token expired while making the request. This is not optimal, + // as an unauthorized response may have a non-retryable cause, but + // there is no way to tell from the error code, and the error message + // is not meant for decision making. + if (status == http::reply::status_type::unauthorized) { + azlog.debug("[{}] {}/{}: Request failed with status {}. Reason: {}. Retrying...", + _log_prefix, host, path, static_cast(status), e.what()); + retries++; + continue; + } + + bool should_retry = + status == http::reply::status_type::request_timeout || + status == http::reply::status_type::too_many_requests || + status == http::reply::status_type::internal_server_error || + status == http::reply::status_type::bad_gateway || + status == http::reply::status_type::service_unavailable || + status == http::reply::status_type::gateway_timeout; + + if (!should_retry) { + std::throw_with_nested(service_error(fmt::format("{}/{}", host, path))); + } + + backoff = DELTA_BACKOFF * (1 << retries); + azlog.debug("[{}] {}/{}: Request failed with status {}. Reason: {}. Retrying in {} ms...", + _log_prefix, host, path, static_cast(status), e.what(), backoff.count()); + retries++; + } catch (...) { + std::throw_with_nested(network_error(fmt::format("{}/{}", host, path))); + } + co_await seastar::sleep(backoff); + } +} + +future azure_host::impl::send_request(const sstring& host, unsigned port, bool use_https, const sstring& path, const rjson::value& body, const rest::http_log_filter& filter) { + auto token = co_await _credentials->get_access_token(AKV_TOKEN_RESOURCE_URI); + + shared_ptr creds; + std::optional options; + + if (use_https) { + creds = co_await make_creds(); + // Do not wait when terminating the TLS connection. TLS close_notify + // alerts are ignored by the Key Vault service. + // + // Also, do not use numeric hosts as SNI hostnames. SNI works only with DNS hostnames. + // This by extension also disables hostname validation, but that's fine; numeric hosts + // may appear only in testing. + bool is_numeric_host = seastar::net::inet_address::parse_numerical(host).has_value(); + sstring server_name = is_numeric_host ? sstring{} : host; + options = { .wait_for_eof_on_shutdown = false, .server_name = server_name }; + } + + rest::httpclient client(host, port, std::move(creds), options); + client.target(path); + client.method(httpd::operation_type::POST); + client.add_header("Authorization", fmt::format("Bearer {}", token.token)); + client.add_header("Content-Type", "application/json"); + client.content(std::move(rjson::print(body))); + + azlog.trace("Sending request: {}", rest::redacted_request_type{ client.request(), filter }); + + auto res = co_await client.send(); + if (res.result() == http::reply::status_type::ok) { + azlog.trace("Got response: {}", rest::redacted_result_type{ res, filter }); + co_return rjson::parse(res.body()); + } else { + azlog.trace("Got unexpected response: {}", rest::redacted_result_type{ res, filter }); + throw vault_error::make_error(res.result(), res.body()); + } +} + +future azure_host::impl::create_key(const attr_cache_key& k) { + static const vault_log_filter filter{vault_log_filter::op_type::wrapkey}; + auto& info = k.info; + if (_ctxt && this_shard_id() != 0) { + auto [data, id] = co_await smp::submit_to(0, [this, k]() -> future> { + auto host = _ctxt->get_azure_host(_name); + auto [key, id] = co_await host->_impl->_attr_cache.get(k); + co_return std::make_tuple(key != nullptr ? key->key() : bytes{}, id); + }); + co_return key_and_id_type{ + data.empty() ? nullptr : make_shared(info, data), + id + }; + } + azlog.debug("[{}] Creating new key: {}", _log_prefix, info); + auto [vault, keyname] = parse_key(k.master_key); + auto [scheme, host, port] = parse_vault(vault); + auto key = make_shared(info); + auto path = fmt::format(AKV_PATH_TEMPLATE, keyname, AKV_LATEST_VERSION, AKV_WRAPKEY_OP); + auto body = [&key] { + auto b = rjson::empty_object(); + rjson::add(b, "alg", AKV_ENCRYPTION_ALG); + rjson::add(b, "value", base64url_encode(key->key())); + return b; + }(); + rjson::value resp; + try { + resp = co_await send_request_with_retry(host, port, scheme == "https", path, body, filter); + } catch (...) { + azlog.error("[{}] Failed to wrap key {} with master_key={}: {}", _log_prefix, info, k.master_key, std::current_exception()); + throw; + } + auto key_id = rjson::get(resp, "kid"); + auto cipher = rjson::get(resp, "value"); + boost::regex version_regex(R"foo(.*/([^/]+)$)foo"); + boost::smatch match; + if (!boost::regex_search(key_id, match, version_regex)) { + throw std::runtime_error(fmt::format("Failed to parse key version from key id {}", key_id)); + } + auto key_version = match[1].str(); + + auto sid = fmt::format("{}/{}/{}:{}", vault, keyname, key_version, cipher); + bytes id(sid.begin(), sid.end()); + + azlog.trace("[{}] Created key id {}", _log_prefix, sid); + co_return key_and_id_type{ key, id }; +} + +future azure_host::impl::find_key(const id_cache_key& k) { + static const vault_log_filter filter{vault_log_filter::op_type::unwrapkey}; + if (_ctxt && this_shard_id() != 0) { + co_return co_await smp::submit_to(0, [this, k]() -> future { + auto host = _ctxt->get_azure_host(_name); + auto bytes = co_await host->_impl->_id_cache.get(k); + co_return bytes; + }); + } + const auto id = to_string_view(k.id); + azlog.debug("[{}] Finding key: {}", _log_prefix, id); + + auto [vault, keyname, version, cipher] = [&id] { + // Regex for key ID in format: + // "//:" or "http(s)://://:" + // Captures: + // 1. Vault (either a name or an endpoint) + // 2. Key name + // 3. Key version + // 4. Cipher text for data encryption key + static const boost::regex id_re(R"foo(((?:https?://[^/]+)|[^/]+)/([^/]+)/([^:]+):(.+))foo"); + boost::match_results match; + if (!boost::regex_match(id.begin(), id.end(), match, id_re)) { + throw std::invalid_argument(fmt::format("Not a valid key id: {}", id)); + } + return std::make_tuple(match[1].str(), match[2].str(), match[3].str(), match[4].str()); + }(); + + auto [scheme, host, port] = parse_vault(vault); + auto path = seastar::format(AKV_PATH_TEMPLATE, keyname, version, AKV_UNWRAPKEY_OP); + auto body = [&cipher] { + auto b = rjson::empty_object(); + rjson::add(b, "alg", AKV_ENCRYPTION_ALG); + rjson::add(b, "value", cipher); + return b; + }(); + rjson::value resp; + try { + resp = co_await send_request_with_retry(host, port, scheme == "https", path, body, filter); + } catch (...) { + azlog.error("[{}] Failed to unwrap key {}: {}", _log_prefix, k.id, std::current_exception()); + throw; + } + auto data = base64url_decode(rjson::get(resp, "value")); + co_return data; +} + +// ==================== azure_host class implementation ==================== + +azure_host::azure_host(const std::string& name, const host_options& options) + : _impl(std::make_unique(nullptr, name, options)) +{} + +azure_host::azure_host(encryption_context& ctxt, const std::string& name, const host_options& options) + : _impl(std::make_unique(&ctxt, name, options)) +{} + +azure_host::azure_host(encryption_context& ctxt, const std::string& name, const std::unordered_map& map) + : azure_host(ctxt, name, [&map] { + host_options opts; + map_wrapper> m(map); + + opts.tenant_id = m("azure_tenant_id").value_or(""); + opts.client_id = m("azure_client_id").value_or(""); + opts.client_secret = m("azure_client_secret").value_or(""); + opts.client_cert = m("azure_client_certificate_path").value_or(""); + opts.authority = m("azure_authority_host").value_or(""); + opts.imds_endpoint = m("imds_endpoint").value_or(""); + + opts.master_key = m("master_key").value_or(""); + + opts.truststore = m("truststore").value_or(""); + opts.priority_string = m("priority_string").value_or(""); + + opts.key_cache_expiry = parse_expiry(m("key_cache_expiry")); + opts.key_cache_refresh = parse_expiry(m("key_cache_refresh")); + + return opts; + }()) +{} + +azure_host::~azure_host() = default; + +future<> azure_host::init() { + return _impl->init(); +} + +const azure_host::host_options& azure_host::options() const { + return _impl->options(); +} + +future azure_host::get_or_create_key(const key_info& info, const option_override* oov) { + return _impl->get_or_create_key(info, oov); +} + +future azure_host::get_key_by_id(const azure_host::id_type& id, const key_info& info) { + return _impl->get_key_by_id(id, info); +} + +} // namespace encryption + +template<> +struct fmt::formatter { + constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } + auto format(const encryption::azure_host::impl::attr_cache_key& d, fmt::format_context& ctxt) const { + return fmt::format_to(ctxt.out(), "{},{},{}", d.master_key, d.info.alg, d.info.len); + } +}; + +template<> +struct fmt::formatter { + constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } + auto format(const encryption::azure_host::impl::id_cache_key& d, fmt::format_context& ctxt) const { + return fmt::format_to(ctxt.out(), "{}", d.id); + } +}; \ No newline at end of file diff --git a/ent/encryption/azure_host.hh b/ent/encryption/azure_host.hh new file mode 100644 index 000000000000..3ae03d1182ce --- /dev/null +++ b/ent/encryption/azure_host.hh @@ -0,0 +1,119 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#pragma once + +#include + +#include +#include +#include + +#include "symmetric_key.hh" + +namespace encryption { + +class encryption_context; + +class vault_error : public std::exception { + // The error response always contains a code and a message. + // https://github.com/microsoft/api-guidelines/blob/vNext/azure/Guidelines.md#handling-errors + static constexpr char ERROR_KEY[] = "error"; + static constexpr char ERROR_CODE_KEY[] = "code"; + static constexpr char ERROR_MESSAGE_KEY[] = "message"; + http::reply::status_type _status; + std::string _code, _msg; +public: + vault_error(const http::reply::status_type& status, std::string_view code, std::string_view msg) + : _status(status) + , _code(code) + , _msg(fmt::format("{}: {}", code, msg)) + {} + const http::reply::status_type& status() const { + return _status; + } + const std::string& code() const { + return _code; + } + const char* what() const noexcept override { + return _msg.c_str(); + } + static vault_error make_error(const http::reply::status_type& status, std::string_view result); +}; + +// Common error codes from Azure Key Vault: +// https://learn.microsoft.com/en-us/azure/key-vault/general/rest-error-codes +// +// Not to be confused with the more specialized 'innererror' codes: +// https://learn.microsoft.com/en-us/azure/key-vault/general/common-error-codes +namespace vault_errors { + [[maybe_unused]] static constexpr char Unauthorized[] = "Unauthorized"; + [[maybe_unused]] static constexpr char Forbidden[] = "Forbidden"; + [[maybe_unused]] static constexpr char Throttled[] = "Throttled"; + [[maybe_unused]] static constexpr char KeyNotFound[] = "KeyNotFound"; +} + +class azure_host { +public: + class impl; + + using id_type = bytes; + using key_ptr = shared_ptr; + using key_and_id_type = std::tuple; + + struct host_options { + std::string tenant_id; + std::string client_id; + std::string client_secret; + std::string client_cert; + std::string authority; + std::string imds_endpoint; + + // Azure Key Vault Key to encrypt data keys with. Format: / + std::string master_key; + + // TLS options + std::string truststore; + std::string priority_string; + + std::optional key_cache_expiry; + std::optional key_cache_refresh; + }; + + azure_host(const std::string& name, const host_options&); // used for testing + azure_host(encryption_context&, const std::string& name, const host_options&); + azure_host(encryption_context&, const std::string& name, const std::unordered_map&); + ~azure_host(); + + /** + * Async initialization. + * + * Performs preliminary checks. + * Should precede any calls to the Key Management API. + */ + future<> init(); + const host_options& options() const; + + struct option_override { + std::optional master_key; + }; + + /** + * Key Management API to create or retrieve data keys. + * + * Throws encryption::base_error derivatives to indicate failure conditions + * such as host misconfigurations, network issues, or authentication failures. + */ + future get_or_create_key(const key_info&, const option_override* = nullptr); + future get_key_by_id(const id_type&, const key_info&); +private: + std::unique_ptr _impl; +}; + +} \ No newline at end of file diff --git a/ent/encryption/azure_key_provider.cc b/ent/encryption/azure_key_provider.cc new file mode 100644 index 000000000000..bbf309e3d633 --- /dev/null +++ b/ent/encryption/azure_key_provider.cc @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#include "azure_key_provider.hh" +#include "azure_host.hh" + +namespace encryption { + +class azure_key_provider : public key_provider { +public: + azure_key_provider(::shared_ptr azure_host, std::string name, azure_host::option_override oov) + : _azure_host(std::move(azure_host)) + , _name(std::move(name)) + , _oov(std::move(oov)) + {} + future> key(const key_info& info, opt_bytes id) override { + if (id) { + return _azure_host->get_key_by_id(*id, info).then([id](key_ptr k) { + return make_ready_future>(std::tuple(k, id)); + }); + } + return _azure_host->get_or_create_key(info, &_oov).then([](std::tuple k_id) { + return make_ready_future>(k_id); + }); + } + void print(std::ostream& os) const override { + os << _name; + } +private: + ::shared_ptr _azure_host; + std::string _name; + azure_host::option_override _oov; +}; + +shared_ptr azure_key_provider_factory::get_provider(encryption_context& ctxt, const options& map) { + opt_wrapper opts(map); + auto azure_host = opts("azure_host"); + + + if (!azure_host) { + throw std::invalid_argument("azure_host must be provided"); + } + + azure_host::option_override oov { + .master_key = opts("master_key"), + }; + + auto host = ctxt.get_azure_host(*azure_host); + auto id = azure_host.value() + ":" + oov.master_key.value_or(host->options().master_key); + + auto provider = ctxt.get_cached_provider(id); + + if (!provider) { + provider = ::make_shared(host, *azure_host, std::move(oov)); + ctxt.cache_provider(id, provider); + } + + return provider; +} + +} diff --git a/ent/encryption/azure_key_provider.hh b/ent/encryption/azure_key_provider.hh new file mode 100644 index 000000000000..92b0ec7651ff --- /dev/null +++ b/ent/encryption/azure_key_provider.hh @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#pragma once + +#include "encryption.hh" + +namespace encryption { + +class azure_key_provider_factory : public key_provider_factory { +public: + shared_ptr get_provider(encryption_context&, const options&) override; +}; + +/** + * See comment for AWS KMS regarding system key support. + */ +} diff --git a/ent/encryption/encryption.cc b/ent/encryption/encryption.cc index f4d102072174..85bd5e69a603 100644 --- a/ent/encryption/encryption.cc +++ b/ent/encryption/encryption.cc @@ -44,6 +44,8 @@ #include "kms_host.hh" #include "gcp_key_provider.hh" #include "gcp_host.hh" +#include "azure_key_provider.hh" +#include "azure_host.hh" #include "bytes.hh" #include "utils/class_registrator.hh" #include "cql3/query_processor.hh" @@ -70,6 +72,7 @@ static constexpr auto LOCAL_FILE_SYSTEM_KEY_PROVIDER_FACTORY = "LocalFileSystemK static constexpr auto KMIP_KEY_PROVIDER_FACTORY = "KmipKeyProviderFactory"; static constexpr auto KMS_KEY_PROVIDER_FACTORY = "KmsKeyProviderFactory"; static constexpr auto GCP_KEY_PROVIDER_FACTORY = "GcpKeyProviderFactory"; +static constexpr auto AZURE_KEY_PROVIDER_FACTORY = "AzureKeyProviderFactory"; bytes base64_decode(const sstring& s, size_t off, size_t len) { if (off >= s.size()) { @@ -276,6 +279,7 @@ class encryption_context_impl : public encryption_context { std::vector>> _per_thread_kmip_host_cache; std::vector>> _per_thread_kms_host_cache; std::vector>> _per_thread_gcp_host_cache; + std::vector>> _per_thread_azure_host_cache; std::vector> _per_thread_global_user_extension; std::unique_ptr _cfg; sharded* _qp;; @@ -291,6 +295,7 @@ class encryption_context_impl : public encryption_context { , _per_thread_kmip_host_cache(smp::count) , _per_thread_kms_host_cache(smp::count) , _per_thread_gcp_host_cache(smp::count) + , _per_thread_azure_host_cache(smp::count) , _per_thread_global_user_extension(smp::count) , _cfg(std::move(cfg)) , _qp(find_or_null(services)) @@ -328,6 +333,7 @@ class encryption_context_impl : public encryption_context { map[KMIP_KEY_PROVIDER_FACTORY] = std::make_unique(); map[KMS_KEY_PROVIDER_FACTORY] = std::make_unique(); map[GCP_KEY_PROVIDER_FACTORY] = std::make_unique(); + map[AZURE_KEY_PROVIDER_FACTORY] = std::make_unique(); return map; }(); @@ -374,55 +380,38 @@ class encryption_context_impl : public encryption_context { return k; } - shared_ptr get_kmip_host(const sstring& host) override { - auto& cache = _per_thread_kmip_host_cache[this_shard_id()]; - auto i = cache.find(host); - if (i != cache.end()) { - return i->second; + template + shared_ptr get_host(const sstring& host, CacheType& cache, const ConfigType& config_map) { + auto& host_cache = cache[this_shard_id()]; + auto it = host_cache.find(host); + if (it != host_cache.end()) { + return it->second; } - auto j = _cfg->kmip_hosts().find(host); - if (j != _cfg->kmip_hosts().end()) { - auto result = ::make_shared(*this, host, j->second); - cache.emplace(host, result); + auto config_it = config_map.find(host); + if (config_it != config_map.end()) { + auto result = ::make_shared(*this, host, config_it->second); + host_cache.emplace(host, result); return result; } - throw std::invalid_argument("No such host: "+ host); + throw std::invalid_argument("No such host: " + host); } - shared_ptr get_kms_host(const sstring& host) override { - auto& cache = _per_thread_kms_host_cache[this_shard_id()]; - auto i = cache.find(host); - if (i != cache.end()) { - return i->second; - } - - auto j = _cfg->kms_hosts().find(host); - if (j != _cfg->kms_hosts().end()) { - auto result = ::make_shared(*this, host, j->second); - cache.emplace(host, result); - return result; - } + shared_ptr get_kmip_host(const sstring& host) override { + return get_host(host, _per_thread_kmip_host_cache, _cfg->kmip_hosts()); + } - throw std::invalid_argument("No such host: "+ host); + shared_ptr get_kms_host(const sstring& host) override { + return get_host(host, _per_thread_kms_host_cache, _cfg->kms_hosts()); } shared_ptr get_gcp_host(const sstring& host) override { - auto& cache = _per_thread_gcp_host_cache[this_shard_id()]; - auto i = cache.find(host); - if (i != cache.end()) { - return i->second; - } - - auto j = _cfg->gcp_hosts().find(host); - if (j != _cfg->gcp_hosts().end()) { - auto result = ::make_shared(*this, host, j->second); - cache.emplace(host, result); - return result; - } + return get_host(host, _per_thread_gcp_host_cache, _cfg->gcp_hosts()); + } - throw std::invalid_argument("No such host: "+ host); + shared_ptr get_azure_host(const sstring& host) override { + return get_host(host, _per_thread_azure_host_cache, _cfg->azure_hosts()); } @@ -485,7 +474,8 @@ class encryption_context_impl : public encryption_context { _per_thread_kmip_host_cache[this_shard_id()].clear(); _per_thread_kms_host_cache[this_shard_id()].clear(); _per_thread_gcp_host_cache[this_shard_id()].clear(); - _per_thread_global_user_extension[this_shard_id()] = {}; + _per_thread_azure_host_cache[this_shard_id()].clear(); + _per_thread_global_user_extension[this_shard_id()] = {}; }); } @@ -1082,6 +1072,14 @@ future> register_extensions(const db::co }); } + if (!cfg.azure_hosts().empty()) { + // only pre-create on shard 0. + co_await parallel_for_each(cfg.azure_hosts(), [ctxt](auto& p) { + auto host = ctxt->get_azure_host(p.first); + return host->init(); + }); + } + replicated_key_provider_factory::init(exts); auto user_opts = maybe_get_options(cfg.user_info_encryption(), "user table encryption"); diff --git a/ent/encryption/encryption.hh b/ent/encryption/encryption.hh index d70fd3ae19d7..308753f684d9 100644 --- a/ent/encryption/encryption.hh +++ b/ent/encryption/encryption.hh @@ -152,6 +152,7 @@ class system_key; class kmip_host; class kms_host; class gcp_host; +class azure_host; /** * Context is a singleton object, shared across shards. I.e. even though there are obvious mutating @@ -169,6 +170,7 @@ public: virtual shared_ptr get_kmip_host(const sstring&) = 0; virtual shared_ptr get_kms_host(const sstring&) = 0; virtual shared_ptr get_gcp_host(const sstring&) = 0; + virtual shared_ptr get_azure_host(const sstring&) = 0; virtual shared_ptr get_cached_provider(const sstring& id) const = 0; virtual void cache_provider(const sstring& id, shared_ptr) = 0; diff --git a/ent/encryption/encryption_config.cc b/ent/encryption/encryption_config.cc index 63e553fcab75..6874dd3ffe19 100644 --- a/ent/encryption/encryption_config.cc +++ b/ent/encryption/encryption_config.cc @@ -110,6 +110,19 @@ auth_file can contain either a user, service or impersonated service account. master_key is an GCP KMS key name from which all keys used for actual encryption of scylla data will be derived. This key must be pre-created with access policy allowing the above credentials Encrypt and Decrypt operations. +)foo") + , azure_hosts(this, "azure_hosts", value_status::Used, { }, + R"foo(Azure Key Vault host(s). + +The unique name of Azure Key Vault host that can be referenced in table schema. + +host.yourdomain.com={ azure_tenant_id=, azure_client_id=, azure_client_secret=, azure_client_certificate_path=, master_key=/, truststore=/path/to/truststore.pem, priority_string=, key_cache_expiry=, key_cache_refersh=}:... + +Authentication can be explicit with Service Principal credentials or by resolving default credentials (see Azure docs). + +master_key is a Vault key name from which all keys used for actual encryption of scylla data will be derived. +This key must be pre-created with access policy allowing the above credentials Wrapkey and Unwrapkey operations. + )foo") , user_info_encryption(this, "user_info_encryption", value_status::Used, { { "enabled", "false" }, { CIPHER_ALGORITHM, diff --git a/ent/encryption/encryption_config.hh b/ent/encryption/encryption_config.hh index 189a08398048..44596c43f008 100644 --- a/ent/encryption/encryption_config.hh +++ b/ent/encryption/encryption_config.hh @@ -26,6 +26,7 @@ public: named_value kmip_hosts; named_value kms_hosts; named_value gcp_hosts; + named_value azure_hosts; named_value user_info_encryption; named_value allow_per_table_encryption; }; diff --git a/ent/encryption/gcp_host.cc b/ent/encryption/gcp_host.cc index 0fac0dd321fa..c3ea6d403fbd 100644 --- a/ent/encryption/gcp_host.cc +++ b/ent/encryption/gcp_host.cc @@ -854,7 +854,7 @@ future encryption::gcp_host::impl:: * a data key, and encrypt it as the key ID. * * For ID -> key, we simply split the ID into the encrypted key part, and - * the master key name part, decrypt the first using the second (AWS KMS Decrypt), + * the master key name part, decrypt the first using the second (GCP KMS Decrypt), * and create a local key using the result. * * Data recovery: diff --git a/init.hh b/init.hh index d3cb20120aac..80baf2a04e42 100644 --- a/init.hh +++ b/init.hh @@ -21,15 +21,10 @@ namespace db { class extensions; -class seed_provider_type; class config; -namespace view { -class view_update_generator; -} } namespace gms { -class feature_service; class inet_address; } diff --git a/install-dependencies.sh b/install-dependencies.sh index 4149627b461a..84501793e082 100755 --- a/install-dependencies.sh +++ b/install-dependencies.sh @@ -54,6 +54,10 @@ debian_base_packages=( slapd ldap-utils libcpp-jwt-dev + elfutils + curl + jq + git-lfs ) fedora_packages=( @@ -137,6 +141,8 @@ fedora_packages=( buildah https://github.com/scylladb/cassandra-stress/releases/download/v3.18.1/cassandra-stress-java21-3.18.1-1.noarch.rpm + elfutils + jq ) fedora_python3_packages=( diff --git a/clustering_bounds_comparator.hh b/keys/clustering_bounds_comparator.hh similarity index 100% rename from clustering_bounds_comparator.hh rename to keys/clustering_bounds_comparator.hh diff --git a/clustering_interval_set.hh b/keys/clustering_interval_set.hh similarity index 100% rename from clustering_interval_set.hh rename to keys/clustering_interval_set.hh diff --git a/clustering_key_filter.hh b/keys/clustering_key_filter.hh similarity index 100% rename from clustering_key_filter.hh rename to keys/clustering_key_filter.hh diff --git a/clustering_ranges_walker.hh b/keys/clustering_ranges_walker.hh similarity index 100% rename from clustering_ranges_walker.hh rename to keys/clustering_ranges_walker.hh diff --git a/compound.hh b/keys/compound.hh similarity index 100% rename from compound.hh rename to keys/compound.hh diff --git a/compound_compat.hh b/keys/compound_compat.hh similarity index 100% rename from compound_compat.hh rename to keys/compound_compat.hh diff --git a/full_position.hh b/keys/full_position.hh similarity index 98% rename from full_position.hh rename to keys/full_position.hh index 43d1caeaf5f1..5028a05adc71 100644 --- a/full_position.hh +++ b/keys/full_position.hh @@ -8,7 +8,7 @@ #pragma once -#include "keys.hh" +#include "keys/keys.hh" #include "mutation/position_in_partition.hh" struct full_position; diff --git a/keys.cc b/keys/keys.cc similarity index 100% rename from keys.cc rename to keys/keys.cc diff --git a/keys.hh b/keys/keys.hh similarity index 99% rename from keys.hh rename to keys/keys.hh index 54505d6ec9ec..42c4ee4e1a9e 100644 --- a/keys.hh +++ b/keys/keys.hh @@ -10,7 +10,7 @@ #include "bytes.hh" #include "types/types.hh" -#include "compound_compat.hh" +#include "keys/compound_compat.hh" #include "utils/managed_bytes.hh" #include "utils/hashing.hh" #include "utils/utf8.hh" diff --git a/locator/simple_strategy.cc b/locator/simple_strategy.cc index a611b1954dc0..b757f64e83dc 100644 --- a/locator/simple_strategy.cc +++ b/locator/simple_strategy.cc @@ -12,6 +12,7 @@ #include "simple_strategy.hh" #include "exceptions/exceptions.hh" #include "utils/assert.hh" +#include "utils/chunked_vector.hh" #include "utils/class_registrator.hh" #include @@ -31,7 +32,7 @@ simple_strategy::simple_strategy(replication_strategy_params params) : } future simple_strategy::calculate_natural_endpoints(const token& t, const token_metadata& tm) const { - const std::vector& tokens = tm.sorted_tokens(); + const utils::chunked_vector& tokens = tm.sorted_tokens(); if (tokens.empty()) { co_return host_id_set{}; diff --git a/locator/tablet_sharder.hh b/locator/tablet_sharder.hh index 9b52a8fb29c7..38a43b785999 100644 --- a/locator/tablet_sharder.hh +++ b/locator/tablet_sharder.hh @@ -40,7 +40,7 @@ private: return std::nullopt; }; - dht::shard_replica_set shard_for_writes(tablet_id tid, host_id host, std::optional sel = std::nullopt) const { + dht::shard_replica_set choose_shard_for_writes(tablet_id tid, host_id host, std::optional sel = std::nullopt) const { auto* trinfo = _tmap->get_tablet_transition_info(tid); auto& tinfo = _tmap->get_tablet_info(tid); dht::shard_replica_set shards; @@ -78,7 +78,7 @@ private: return shards; } - std::optional shard_for_reads(tablet_id tid, host_id host) const { + std::optional choose_shard_for_reads(tablet_id tid, host_id host) const { ensure_tablet_map(); auto* trinfo = _tmap->get_tablet_transition_info(tid); auto& tinfo = _tmap->get_tablet_info(tid); @@ -107,13 +107,10 @@ public: virtual ~tablet_sharder() = default; - virtual unsigned shard_for_reads(const token& t) const override { + virtual std::optional try_get_shard_for_reads(const token& t) const override { ensure_tablet_map(); auto tid = _tmap->get_tablet_id(t); - // FIXME: Consider throwing when there is no owning shard on the current host rather than returning 0. - // It's a coordination mistake to route requests to non-owners. Topology coordinator should synchronize - // with request coordinators before moving the shard away. - auto shard = shard_for_reads(tid, _host).value_or(0); + auto shard = choose_shard_for_reads(tid, _host); tablet_logger.trace("[{}] shard_of({}) = {}, tablet={}", _table, t, shard, tid); return shard; } @@ -121,7 +118,7 @@ public: virtual dht::shard_replica_set shard_for_writes(const token& t, std::optional sel = std::nullopt) const override { ensure_tablet_map(); auto tid = _tmap->get_tablet_id(t); - auto shards = shard_for_writes(tid, _host, sel); + auto shards = choose_shard_for_writes(tid, _host, sel); tablet_logger.trace("[{}] shard_for_writes({}) = {}, tablet={}", _table, t, shards, tid); return shards; } @@ -130,7 +127,7 @@ public: ensure_tablet_map(); std::optional tb = _tmap->get_tablet_id(t); while ((tb = _tmap->next_tablet(*tb))) { - auto r = shard_for_reads(*tb, _host); + auto r = choose_shard_for_reads(*tb, _host); auto next = _tmap->get_first_token(*tb); tablet_logger.trace("[{}] token_for_next_shard({}) = {{{}, {}}}, tablet={}", _table, t, next, r, *tb); return dht::shard_and_token{r.value_or(0), next}; diff --git a/locator/tablets.cc b/locator/tablets.cc index 8e88146679b5..2a6459875ee9 100644 --- a/locator/tablets.cc +++ b/locator/tablets.cc @@ -297,29 +297,12 @@ bool tablet_metadata::is_base_table(table_id id) const { return !_base_table.contains(id); } -void tablet_metadata::mutate_tablet_map(table_id id, noncopyable_function func) { - auto it = _tablets.find(id); - if (it == _tablets.end()) { - throw no_such_tablet_map(id); - } - auto tablet_map_copy = make_lw_shared(*it->second); - func(*tablet_map_copy); - auto new_map_ptr = lw_shared_ptr(std::move(tablet_map_copy)); - // share the tablet map with all co-located tables - for (auto colocated_id : _table_groups.at(id)) { - if (colocated_id != id) { - _tablets[colocated_id] = make_foreign(new_map_ptr); - } - } - it->second = make_foreign(std::move(new_map_ptr)); -} - future<> tablet_metadata::mutate_tablet_map_async(table_id id, noncopyable_function(tablet_map&)> func) { auto it = _tablets.find(id); if (it == _tablets.end()) { throw no_such_tablet_map(id); } - auto tablet_map_copy = make_lw_shared(*it->second); + auto tablet_map_copy = make_lw_shared(co_await it->second->clone_gently()); co_await func(*tablet_map_copy); auto new_map_ptr = lw_shared_ptr(std::move(tablet_map_copy)); // share the tablet map with all co-located tables @@ -455,6 +438,28 @@ tablet_map::tablet_map(size_t tablet_count) _tablets.resize(tablet_count); } +tablet_map tablet_map::clone() const { + return tablet_map(_tablets, _log2_tablets, _transitions, _resize_decision, _resize_task_info, _repair_scheduler_config); +} + +future tablet_map::clone_gently() const { + tablet_container tablets; + tablets.reserve(_tablets.size()); + for (const auto& t : _tablets) { + tablets.emplace_back(t); + co_await coroutine::maybe_yield(); + } + + transitions_map transitions; + transitions.reserve(_transitions.size()); + for (const auto& [id, trans] : _transitions) { + transitions.emplace(id, trans); + co_await coroutine::maybe_yield(); + } + + co_return tablet_map(std::move(tablets), _log2_tablets, std::move(transitions), _resize_decision, _resize_task_info, _repair_scheduler_config); +} + void tablet_map::check_tablet_id(tablet_id id) const { if (size_t(id) >= tablet_count()) { throw std::logic_error(format("Invalid tablet id: {} >= {}", id, tablet_count())); @@ -531,21 +536,14 @@ tablet_replica tablet_map::get_secondary_replica(tablet_id id) const { return replicas.at((size_t(id)+1) % replicas.size()); } -tablet_replica tablet_map::get_primary_replica_within_dc(tablet_id id, const topology& topo, sstring dc) const { - return maybe_get_primary_replica(id, get_tablet_info(id).replicas, [&] (const auto& tr) { - const auto& node = topo.get_node(tr.host); - return node.dc_rack().dc == dc; - }).value(); -} - std::optional tablet_map::maybe_get_selected_replica(tablet_id id, const topology& topo, const tablet_task_info& tablet_task_info) const { return maybe_get_primary_replica(id, get_tablet_info(id).replicas, [&] (const auto& tr) { return tablet_task_info.selected_by_filters(tr, topo); }); } -future> tablet_map::get_sorted_tokens() const { - std::vector tokens; +future> tablet_map::get_sorted_tokens() const { + utils::chunked_vector tokens; tokens.reserve(tablet_count()); for (auto id : tablet_ids()) { diff --git a/locator/tablets.hh b/locator/tablets.hh index c75b40237d18..519e7ec6cb0b 100644 --- a/locator/tablets.hh +++ b/locator/tablets.hh @@ -449,17 +449,29 @@ class tablet_map { public: using tablet_container = utils::chunked_vector; private: + using transitions_map = std::unordered_map; // The implementation assumes that _tablets.size() is a power of 2: // // _tablets.size() == 1 << _log2_tablets // tablet_container _tablets; size_t _log2_tablets; // log_2(_tablets.size()) - std::unordered_map _transitions; + transitions_map _transitions; resize_decision _resize_decision; tablet_task_info _resize_task_info; repair_scheduler_config _repair_scheduler_config; + // Internal constructor, used by clone() and clone_gently(). + tablet_map(tablet_container tablets, size_t log2_tablets, transitions_map transitions, + resize_decision resize_decision, tablet_task_info resize_task_info, repair_scheduler_config repair_scheduler_config) + : _tablets(std::move(tablets)) + , _log2_tablets(log2_tablets) + , _transitions(std::move(transitions)) + , _resize_decision(resize_decision) + , _resize_task_info(std::move(resize_task_info)) + , _repair_scheduler_config(std::move(repair_scheduler_config)) + {} + /// Returns the largest token owned by tablet_id when the tablet_count is `1 << log2_tablets`. dht::token get_last_token(tablet_id id, size_t log2_tablets) const; @@ -472,6 +484,15 @@ public: /// \param tablet_count The desired tablets to allocate. Must be a power of two. explicit tablet_map(size_t tablet_count); + tablet_map(tablet_map&&) = default; + tablet_map(const tablet_map&) = delete; + + tablet_map& operator=(tablet_map&&) = default; + tablet_map& operator=(const tablet_map&) = delete; + + tablet_map clone() const; + future clone_gently() const; + /// Returns tablet_id of a tablet which owns a given token. tablet_id get_tablet_id(token) const; @@ -501,7 +522,6 @@ public: /// Returns the primary replica for the tablet tablet_replica get_primary_replica(tablet_id id) const; - tablet_replica get_primary_replica_within_dc(tablet_id id, const topology& topo, sstring dc) const; /// Returns the secondary replica for the tablet, which is assumed to be directly following the primary replica in the replicas vector /// \throws std::runtime_error if the tablet has less than 2 replicas. @@ -511,7 +531,7 @@ public: std::optional maybe_get_selected_replica(tablet_id id, const topology& topo, const tablet_task_info& tablet_task_info) const; /// Returns a vector of sorted last tokens for tablets. - future> get_sorted_tokens() const; + future> get_sorted_tokens() const; /// Returns the id of the first tablet. tablet_id first_tablet() const { @@ -676,7 +696,6 @@ public: // Allow mutating a tablet_map // Uses the copy-modify-swap idiom. // If func throws, no changes are done to the tablet map. - void mutate_tablet_map(table_id, noncopyable_function func); future<> mutate_tablet_map_async(table_id, noncopyable_function(tablet_map&)> func); future<> clear_gently(); diff --git a/locator/token_metadata.cc b/locator/token_metadata.cc index ecd2acaccd37..2acb6bf56bd2 100644 --- a/locator/token_metadata.cc +++ b/locator/token_metadata.cc @@ -11,6 +11,7 @@ #include "locator/snitch_base.hh" #include "locator/abstract_replication_strategy.hh" #include "locator/tablets.hh" +#include "utils/chunked_vector.hh" #include "utils/log.hh" #include "partition_range_compat.hh" #include @@ -58,7 +59,7 @@ class token_metadata_impl final { std::optional _topology_change_info; - std::vector _sorted_tokens; + utils::chunked_vector _sorted_tokens; tablet_metadata _tablets; @@ -97,7 +98,7 @@ class token_metadata_impl final { token_metadata_impl(token_metadata::config cfg) noexcept : _topology(std::move(cfg.topo_cfg)) {}; token_metadata_impl(const token_metadata_impl&) = delete; // it's too huge for direct copy, use clone_async() token_metadata_impl(token_metadata_impl&&) noexcept = default; - const std::vector& sorted_tokens() const; + const utils::chunked_vector& sorted_tokens() const; future<> update_normal_tokens(std::unordered_set tokens, host_id endpoint); const token& first_token(const token& start) const; size_t first_token_index(const token& start) const; @@ -370,7 +371,7 @@ future<> token_metadata_impl::clear_gently() noexcept { } void token_metadata_impl::sort_tokens() { - std::vector sorted; + utils::chunked_vector sorted; sorted.reserve(_token_to_endpoint_map.size()); for (auto&& i : _token_to_endpoint_map) { @@ -394,7 +395,7 @@ void token_metadata::set_tablets(tablet_metadata tm) { _impl->set_tablets(std::move(tm)); } -const std::vector& token_metadata_impl::sorted_tokens() const { +const utils::chunked_vector& token_metadata_impl::sorted_tokens() const { return _sorted_tokens; } @@ -708,14 +709,12 @@ future<> token_metadata_impl::update_topology_change_info(dc_rack_fn& get_dc_rac // merge tokens from token_to_endpoint and bootstrap_tokens, // preserving tokens of leaving endpoints - auto all_tokens = std::vector(); + auto all_tokens = sorted_tokens(); all_tokens.reserve(sorted_tokens().size() + get_bootstrap_tokens().size()); - all_tokens.resize(sorted_tokens().size()); - std::copy(begin(sorted_tokens()), end(sorted_tokens()), begin(all_tokens)); for (const auto& p: get_bootstrap_tokens()) { all_tokens.push_back(p.first); } - std::sort(begin(all_tokens), end(all_tokens)); + std::sort(all_tokens.begin(), all_tokens.end()); auto prev_value = std::move(_topology_change_info); _topology_change_info.emplace(make_lw_shared(std::move(target_token_metadata)), @@ -817,7 +816,7 @@ void token_metadata_impl::del_replacing_endpoint(host_id existing_node) { } topology_change_info::topology_change_info(lw_shared_ptr target_token_metadata_, - std::vector all_tokens_, + utils::chunked_vector all_tokens_, token_metadata::read_new_t read_new_) : target_token_metadata(std::move(target_token_metadata_)) , all_tokens(std::move(all_tokens_)) @@ -860,7 +859,7 @@ void token_metadata::set_shared_token_metadata(shared_token_metadata& stm) { _shared_token_metadata = &stm; } -const std::vector& +const utils::chunked_vector& token_metadata::sorted_tokens() const { return _impl->sorted_tokens(); } diff --git a/locator/token_metadata.hh b/locator/token_metadata.hh index b986267ea415..8f9cf8b66bea 100644 --- a/locator/token_metadata.hh +++ b/locator/token_metadata.hh @@ -23,6 +23,7 @@ #include #include #include +#include "utils/chunked_vector.hh" #include "utils/phased_barrier.hh" #include "service/topology_state_machine.hh" @@ -158,7 +159,7 @@ public: return tmp; } private: - std::vector::const_iterator _cur_it; + utils::chunked_vector::const_iterator _cur_it; size_t _remaining = 0; const token_metadata_impl* _token_metadata = nullptr; @@ -184,7 +185,7 @@ public: token_metadata(token_metadata&&) noexcept; // Can't use "= default;" - hits some static_assert in unique_ptr token_metadata& operator=(token_metadata&&) noexcept; ~token_metadata(); - const std::vector& sorted_tokens() const; + const utils::chunked_vector& sorted_tokens() const; const tablet_metadata& tablets() const; tablet_metadata& tablets(); void set_tablets(tablet_metadata); @@ -365,11 +366,11 @@ private: struct topology_change_info { lw_shared_ptr target_token_metadata; - std::vector all_tokens; + utils::chunked_vector all_tokens; token_metadata::read_new_t read_new; topology_change_info(lw_shared_ptr target_token_metadata_, - std::vector all_tokens_, + utils::chunked_vector all_tokens_, token_metadata::read_new_t read_new_); future<> clear_gently(); }; diff --git a/locator/util.cc b/locator/util.cc index 17474179d3c4..871cc79a67a4 100644 --- a/locator/util.cc +++ b/locator/util.cc @@ -8,6 +8,7 @@ #include "locator/util.hh" #include "replica/database.hh" #include "gms/gossiper.hh" +#include "utils/chunked_vector.hh" #include namespace locator { @@ -28,7 +29,7 @@ construct_range_to_endpoint_map( // Caller is responsible to hold token_metadata valid until the returned future is resolved static future -get_all_ranges(const std::vector& sorted_tokens) { +get_all_ranges(const utils::chunked_vector& sorted_tokens) { if (sorted_tokens.empty()) co_return dht::token_range_vector(); int size = sorted_tokens.size(); @@ -49,14 +50,14 @@ get_all_ranges(const std::vector& sorted_tokens) { // Caller is responsible to hold token_metadata valid until the returned future is resolved future> get_range_to_address_map(locator::effective_replication_map_ptr erm, - const std::vector& sorted_tokens) { + const utils::chunked_vector& sorted_tokens) { co_return co_await construct_range_to_endpoint_map(erm, co_await get_all_ranges(sorted_tokens)); } // Caller is responsible to hold token_metadata valid until the returned future is resolved -static future> +static future> get_tokens_in_local_dc(const locator::token_metadata& tm) { - std::vector filtered_tokens; + utils::chunked_vector filtered_tokens; auto local_dc_filter = tm.get_topology().get_local_dc_filter(); for (auto token : tm.sorted_tokens()) { auto endpoint = tm.get_endpoint(token); diff --git a/locator/util.hh b/locator/util.hh index 8f2f74939201..b9f0d71607e5 100644 --- a/locator/util.hh +++ b/locator/util.hh @@ -25,5 +25,5 @@ namespace gms { namespace locator { future> describe_ring(const replica::database& db, const gms::gossiper& gossiper, const sstring& keyspace, bool include_only_local_dc = false); future> get_range_to_address_map( - locator::effective_replication_map_ptr erm, const std::vector& sorted_tokens); -} \ No newline at end of file + locator::effective_replication_map_ptr erm, const utils::chunked_vector& sorted_tokens); +} diff --git a/mutation/mutation.hh b/mutation/mutation.hh index 4df755e8d63d..bb6fbdd605fb 100644 --- a/mutation/mutation.hh +++ b/mutation/mutation.hh @@ -9,7 +9,7 @@ #pragma once #include "mutation_partition.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "schema/schema_fwd.hh" #include "utils/assert.hh" #include "utils/hashing.hh" diff --git a/mutation/mutation_compactor.hh b/mutation/mutation_compactor.hh index af1bcd0f4dcc..2ef74abd833d 100644 --- a/mutation/mutation_compactor.hh +++ b/mutation/mutation_compactor.hh @@ -13,7 +13,7 @@ #include "mutation_fragment_stream_validator.hh" #include "mutation_tombstone_stats.hh" #include "tombstone_gc.hh" -#include "full_position.hh" +#include "keys/full_position.hh" #include #include "utils/log.hh" diff --git a/mutation/mutation_fragment.cc b/mutation/mutation_fragment.cc index dbbe25f1c699..c49b01b02bd9 100644 --- a/mutation/mutation_fragment.cc +++ b/mutation/mutation_fragment.cc @@ -10,7 +10,7 @@ #include "mutation_fragment.hh" #include "mutation_fragment_v2.hh" -#include "clustering_interval_set.hh" +#include "keys/clustering_interval_set.hh" #include "utils/assert.hh" #include "utils/hashing.hh" #include "utils/xx_hasher.hh" diff --git a/mutation/mutation_partition.cc b/mutation/mutation_partition.cc index 67c6ceba32c3..627c7639c034 100644 --- a/mutation/mutation_partition.cc +++ b/mutation/mutation_partition.cc @@ -11,7 +11,7 @@ #include #include "mutation_partition.hh" -#include "clustering_interval_set.hh" +#include "keys/clustering_interval_set.hh" #include "converting_mutation_partition_applier.hh" #include "partition_builder.hh" #include "query-result-writer.hh" @@ -26,7 +26,7 @@ #include #include "types/map.hh" #include "compaction/compaction_garbage_collector.hh" -#include "clustering_key_filter.hh" +#include "keys/clustering_key_filter.hh" #include "mutation_partition_view.hh" #include "tombstone_gc.hh" #include "utils/assert.hh" diff --git a/mutation/mutation_partition.hh b/mutation/mutation_partition.hh index 0ee5227a2c77..cb34e9f80f93 100644 --- a/mutation/mutation_partition.hh +++ b/mutation/mutation_partition.hh @@ -17,7 +17,7 @@ #include "schema/schema_fwd.hh" #include "tombstone.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "position_in_partition.hh" #include "atomic_cell_or_collection.hh" #include "hashing_partition_visitor.hh" diff --git a/mutation/mutation_partition_v2.cc b/mutation/mutation_partition_v2.cc index 72591629e263..e76024260387 100644 --- a/mutation/mutation_partition_v2.cc +++ b/mutation/mutation_partition_v2.cc @@ -10,7 +10,7 @@ #include #include "mutation_partition_v2.hh" -#include "clustering_interval_set.hh" +#include "keys/clustering_interval_set.hh" #include "converting_mutation_partition_applier.hh" #include "partition_builder.hh" #include "query-result-writer.hh" diff --git a/mutation/partition_version.cc b/mutation/partition_version.cc index 799d6c580506..d8d6dfcddded 100644 --- a/mutation/partition_version.cc +++ b/mutation/partition_version.cc @@ -14,7 +14,7 @@ #include "utils/assert.hh" #include "utils/coroutine.hh" #include "real_dirty_memory_accounter.hh" -#include "clustering_interval_set.hh" +#include "keys/clustering_interval_set.hh" static void remove_or_mark_as_unique_owner(partition_version* current, mutation_cleaner* cleaner) { diff --git a/mutation/position_in_partition.hh b/mutation/position_in_partition.hh index 4ccc5ff8cce8..8908804916e5 100644 --- a/mutation/position_in_partition.hh +++ b/mutation/position_in_partition.hh @@ -10,8 +10,8 @@ #include "utils/assert.hh" #include "types/types.hh" -#include "keys.hh" -#include "clustering_bounds_comparator.hh" +#include "keys/keys.hh" +#include "keys/clustering_bounds_comparator.hh" #include "query-request.hh" #include diff --git a/mutation/range_tombstone.hh b/mutation/range_tombstone.hh index 646e04e3ecb1..36003d0fc0d2 100644 --- a/mutation/range_tombstone.hh +++ b/mutation/range_tombstone.hh @@ -10,9 +10,9 @@ #include #include "utils/hashing.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "mutation/tombstone.hh" -#include "clustering_bounds_comparator.hh" +#include "keys/clustering_bounds_comparator.hh" #include "mutation/position_in_partition.hh" /** diff --git a/mutation/range_tombstone_splitter.hh b/mutation/range_tombstone_splitter.hh index 59976e5d2a43..2417001c0db5 100644 --- a/mutation/range_tombstone_splitter.hh +++ b/mutation/range_tombstone_splitter.hh @@ -8,7 +8,7 @@ #pragma once -#include "clustering_ranges_walker.hh" +#include "keys/clustering_ranges_walker.hh" #include "mutation/mutation_fragment.hh" template diff --git a/partition_snapshot_reader.hh b/partition_snapshot_reader.hh index d89f7df8ff71..d1a4f52b905f 100644 --- a/partition_snapshot_reader.hh +++ b/partition_snapshot_reader.hh @@ -12,7 +12,7 @@ #include "readers/mutation_reader_fwd.hh" #include "readers/mutation_reader.hh" #include "readers/range_tombstone_change_merger.hh" -#include "clustering_key_filter.hh" +#include "keys/clustering_key_filter.hh" #include "query-request.hh" #include "db/partition_snapshot_row_cursor.hh" #include diff --git a/querier.hh b/querier.hh index 2c3a70b64ff5..609091e34259 100644 --- a/querier.hh +++ b/querier.hh @@ -13,7 +13,7 @@ #include "mutation/mutation_compactor.hh" #include "reader_concurrency_semaphore.hh" #include "readers/mutation_source.hh" -#include "full_position.hh" +#include "keys/full_position.hh" #include diff --git a/query-request.hh b/query-request.hh index 7d96514f8d15..66ec09af826a 100644 --- a/query-request.hh +++ b/query-request.hh @@ -17,7 +17,7 @@ #include "db/functions/function.hh" #include "db/functions/aggregate_function.hh" #include "db/consistency_level_type.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "dht/ring_position.hh" #include "enum_set.hh" #include "utils/interval.hh" diff --git a/query-result-reader.hh b/query-result-reader.hh index 485dbcac8a82..45a40f5ac665 100644 --- a/query-result-reader.hh +++ b/query-result-reader.hh @@ -11,7 +11,7 @@ #include "utils/assert.hh" #include "query-result.hh" -#include "full_position.hh" +#include "keys/full_position.hh" #include "idl/query.dist.hh" #include "idl/query.dist.impl.hh" diff --git a/query-result-writer.hh b/query-result-writer.hh index eb56a135cad8..c042b011e09c 100644 --- a/query-result-writer.hh +++ b/query-result-writer.hh @@ -13,7 +13,7 @@ #include "query-result.hh" #include "utils/digest_algorithm.hh" #include "utils/digester.hh" -#include "full_position.hh" +#include "keys/full_position.hh" #include "mutation/tombstone.hh" #include "idl/query.dist.hh" #include "idl/query.dist.impl.hh" diff --git a/query-result.hh b/query-result.hh index c712ee33112b..f93de8aebecd 100644 --- a/query-result.hh +++ b/query-result.hh @@ -11,7 +11,7 @@ #include "bytes_ostream.hh" #include "utils/digest_algorithm.hh" #include "query-request.hh" -#include "full_position.hh" +#include "keys/full_position.hh" #include #include #include diff --git a/readers/mutation_readers.cc b/readers/mutation_readers.cc index 1aa2d4db5593..e5548b1b8324 100644 --- a/readers/mutation_readers.cc +++ b/readers/mutation_readers.cc @@ -7,8 +7,8 @@ */ #include "utils/assert.hh" -#include "clustering_key_filter.hh" -#include "clustering_ranges_walker.hh" +#include "keys/clustering_key_filter.hh" +#include "keys/clustering_ranges_walker.hh" #include "mutation/mutation.hh" #include "mutation/mutation_partition.hh" #include "mutation/mutation_compactor.hh" diff --git a/repair/repair.cc b/repair/repair.cc index f9d610b6a891..5f66adb2ce27 100644 --- a/repair/repair.cc +++ b/repair/repair.cc @@ -1193,7 +1193,6 @@ future repair_service::do_repair_start(gms::gossip_address_map& addr_map, s size_t nr_tablet_table = 0; size_t nr_vnode_table = 0; - bool is_tablet = false; for (auto& table_name : cfs) { auto& t = db.find_column_family(keyspace, table_name); if (t.uses_tablets()) { @@ -1207,53 +1206,7 @@ future repair_service::do_repair_start(gms::gossip_address_map& addr_map, s if (nr_vnode_table != 0) { throw std::invalid_argument("Mixed vnode table and tablet table"); } - is_tablet = true; - } - if (is_tablet) { - // Reject unsupported options for tablet repair - if (!options.start_token.empty()) { - throw std::invalid_argument("The startToken option is not supported for tablet repair"); - } - if (!options.end_token.empty()) { - throw std::invalid_argument("The endToken option is not supported for tablet repair"); - } - if (options.small_table_optimization) { - throw std::invalid_argument("The small_table_optimization option is not supported for tablet repair"); - } - - // Reject unsupported option combinations. - if (!options.data_centers.empty() && !options.hosts.empty()) { - throw std::invalid_argument("Cannot combine dataCenters and hosts options."); - } - if (!options.ignore_nodes.empty() && !options.hosts.empty()) { - throw std::invalid_argument("Cannot combine ignore_nodes and hosts options."); - } - - std::unordered_set hosts; - for (const auto& n : options.hosts) { - try { - auto node = gms::inet_address(n); - hosts.insert(_gossiper.local().get_host_id(node)); - } catch(...) { - throw std::invalid_argument(fmt::format("Failed to parse node={} in hosts={} specified by user: {}", - n, options.hosts, std::current_exception())); - } - } - std::unordered_set ignore_nodes; - for (const auto& n : options.ignore_nodes) { - try { - auto node = gms::inet_address(n); - ignore_nodes.insert(_gossiper.local().get_host_id(node)); - } catch(...) { - throw std::invalid_argument(fmt::format("Failed to parse node={} in ignore_nodes={} specified by user: {}", - n, options.ignore_nodes, std::current_exception())); - } - } - - bool primary_replica_only = options.primary_range; - auto ranges_parallelism = options.ranges_parallelism == -1 ? std::nullopt : std::optional(options.ranges_parallelism); - co_await repair_tablets(id, keyspace, cfs, primary_replica_only, options.ranges, options.data_centers, hosts, ignore_nodes, ranges_parallelism); - co_return id.id; + throw tablets_unsupported(); } } @@ -1472,10 +1425,6 @@ future> repair::user_requested_repair_task_impl::expected_ co_return _ranges.size() * _cfs.size() * smp::count; } -std::optional repair::user_requested_repair_task_impl::expected_children_number() const { - return smp::count; -} - future repair_start(seastar::sharded& repair, sharded& am, sstring keyspace, std::unordered_map options) { return repair.invoke_on(0, [keyspace = std::move(keyspace), options = std::move(options), &am] (repair_service& local_repair) { @@ -1629,10 +1578,6 @@ future> repair::data_sync_repair_task_impl::expected_total co_return _cfs_size ? std::make_optional(_ranges.size() * _cfs_size * smp::count) : std::nullopt; } -std::optional repair::data_sync_repair_task_impl::expected_children_number() const { - return smp::count; -} - future<> repair_service::bootstrap_with_repair(locator::token_metadata_ptr tmptr, std::unordered_set bootstrap_tokens) { SCYLLA_ASSERT(this_shard_id() == 0); return seastar::async([this, tmptr = std::move(tmptr), tokens = std::move(bootstrap_tokens)] () mutable { @@ -2279,197 +2224,6 @@ future<> repair_service::replace_with_repair(std::unordered_map get_token_owners_in_dcs(std::vector data_centers, locator::effective_replication_map_ptr erm) { - auto dc_endpoints_map = erm->get_token_metadata().get_datacenter_token_owners(); - std::unordered_set dc_endpoints; - for (const sstring& dc : data_centers) { - auto it = dc_endpoints_map.find(dc); - if (it == dc_endpoints_map.end()) { - throw std::runtime_error(fmt::format("Unknown dc={}", dc)); - } - for (const auto& endpoint : it->second) { - dc_endpoints.insert(endpoint); - } - } - return dc_endpoints; -} - -// Repair all tablets belong to this node for the given table -future<> repair_service::repair_tablets(repair_uniq_id rid, sstring keyspace_name, std::vector table_names, bool primary_replica_only, dht::token_range_vector ranges_specified, std::vector data_centers, std::unordered_set hosts, std::unordered_set ignore_nodes, std::optional ranges_parallelism) { - std::vector task_metas; - for (auto& table_name : table_names) { - lw_shared_ptr t; - try { - t = _db.local().find_column_family(keyspace_name, table_name).shared_from_this(); - } catch (replica::no_such_column_family& e) { - rlogger.debug("repair[{}] Table {}.{} does not exist anymore", rid.uuid(), keyspace_name, table_name); - continue; - } - if (!t->uses_tablets()) { - throw std::runtime_error(format("repair[{}] Table {}.{} is not a tablet table", rid.uuid(), keyspace_name, table_name)); - } - table_id tid = t->schema()->id(); - // Invoke group0 read barrier before obtaining erm pointer so that it sees all prior metadata changes - auto dropped = !utils::get_local_injector().enter("repair_tablets_no_sync") && - co_await streaming::table_sync_and_check(_db.local(), _mm, tid); - if (dropped) { - rlogger.debug("repair[{}] Table {}.{} does not exist anymore", rid.uuid(), keyspace_name, table_name); - continue; - } - locator::effective_replication_map_ptr erm; - while (true) { - _repair_module->check_in_shutdown(); - erm = t->get_effective_replication_map(); - auto local_version = erm->get_token_metadata().get_version(); - const locator::tablet_map& tmap = erm->get_token_metadata_ptr()->tablets().get_tablet_map(tid); - if (!tmap.has_transitions() && co_await container().invoke_on(0, [local_version] (repair_service& rs) { - // We need to ensure that there is no ongoing global request. - return local_version == rs._tsm.local()._topology.version && !rs._tsm.local()._topology.is_busy(); - })) { - break; - } - rlogger.info("repair[{}] Topology is busy, waiting for it to quiesce", rid.uuid()); - erm = nullptr; - co_await container().invoke_on(0, [] (repair_service& rs) { - return rs._tsm.local().await_not_busy(); - }); - rlogger.info("repair[{}] Topology quiesced", rid.uuid()); - } - auto& tmap = erm->get_token_metadata_ptr()->tablets().get_tablet_map(tid); - struct repair_tablet_meta { - locator::tablet_id id; - dht::token_range range; - locator::host_id master_host_id; - shard_id master_shard_id; - locator::tablet_replica_set replicas; - }; - std::vector metas; - auto myhostid = erm->get_token_metadata_ptr()->get_my_id(); - auto mydc = erm->get_topology().get_datacenter(); - bool select_primary_ranges_within_dc = false; - // If the user specified the ranges option, ignore the primary_replica_only option. - // Since the ranges are requested explicitly. - if (!ranges_specified.empty()) { - primary_replica_only = false; - } - if (primary_replica_only) { - // The logic below follows existing vnode table repair. - // When "primary_range" option is on, neither data_centers nor hosts - // may be set, except data_centers may contain only local DC (-local) - if (data_centers.size() == 1 && data_centers[0] == mydc) { - select_primary_ranges_within_dc = true; - } else if (data_centers.size() > 0 || hosts.size() > 0) { - throw std::runtime_error("You need to run primary range repair on all nodes in the cluster."); - } - } - if (!hosts.empty()) { - if (!hosts.contains(myhostid)) { - throw std::runtime_error("The current host must be part of the repair"); - } - } - co_await tmap.for_each_tablet([&] (locator::tablet_id id, const locator::tablet_info& info) -> future<> { - auto range = tmap.get_token_range(id); - auto& replicas = info.replicas; - - if (primary_replica_only) { - const auto pr = select_primary_ranges_within_dc ? tmap.get_primary_replica_within_dc(id, erm->get_topology(), mydc) : tmap.get_primary_replica(id); - if (pr.host == myhostid) { - metas.push_back(repair_tablet_meta{id, range, myhostid, pr.shard, replicas}); - } - return make_ready_future<>(); - } - - bool found = false; - shard_id master_shard_id; - // Repair all tablets belong to this node - for (auto& r : replicas) { - if (r.host == myhostid) { - master_shard_id = r.shard; - found = true; - break; - } - } - if (found) { - metas.push_back(repair_tablet_meta{id, range, myhostid, master_shard_id, replicas}); - } - return make_ready_future<>(); - }); - - std::unordered_set dc_endpoints = get_token_owners_in_dcs(data_centers, erm); - if (!data_centers.empty() && !dc_endpoints.contains(myhostid)) { - throw std::runtime_error("The current host must be part of the repair"); - } - - size_t nr = 0; - for (auto& m : metas) { - nr++; - rlogger.debug("repair[{}] Collect {} out of {} tablets: table={}.{} tablet_id={} range={} replicas={} primary_replica_only={}", - rid.uuid(), nr, metas.size(), keyspace_name, table_name, m.id, m.range, m.replicas, primary_replica_only); - std::vector nodes; - auto master_shard_id = m.master_shard_id; - auto range = m.range; - dht::token_range_vector intersection_ranges; - if (!ranges_specified.empty()) { - for (auto& r : ranges_specified) { - // When the management tool, e.g., scylla manager, uses - // ranges option to select which ranges to repair for a - // tablet table to schedule repair work, it needs to - // disable tablet migration to make sure the node and token - // range mapping does not change. - // - // If the migration is not disabled, the ranges provided by - // the user might not match exactly with the token range of the - // tablets. We do the intersection of the ranges to repair - // as a best effort. - auto intersection_opt = range.intersection(r, dht::token_comparator()); - if (intersection_opt) { - intersection_ranges.push_back(intersection_opt.value()); - rlogger.debug("repair[{}] Select tablet table={}.{} tablet_id={} range={} replicas={} primary_replica_only={} ranges_specified={} intersection_ranges={}", - rid.uuid(), keyspace_name, table_name, m.id, m.range, m.replicas, primary_replica_only, ranges_specified, intersection_opt.value()); - } - co_await coroutine::maybe_yield(); - } - } else { - intersection_ranges.push_back(range); - } - if (intersection_ranges.empty()) { - rlogger.debug("repair[{}] Skip tablet table={}.{} tablet_id={} range={} replicas={} primary_replica_only={} ranges_specified={}", - rid.uuid(), keyspace_name, table_name, m.id, m.range, m.replicas, primary_replica_only, ranges_specified); - continue; - } - std::vector shards; - for (auto& r : m.replicas) { - auto shard = r.shard; - if (r.host != myhostid) { - bool select = data_centers.empty() ? true : dc_endpoints.contains(r.host); - - if (select && !hosts.empty()) { - select = hosts.contains(r.host); - } - - if (select && !ignore_nodes.empty()) { - select = !ignore_nodes.contains(r.host); - } - - if (select) { - rlogger.debug("repair[{}] Repair get neighbors table={}.{} hostid={} shard={} myhostid={}", - rid.uuid(), keyspace_name, table_name, r.host, shard, myhostid); - shards.push_back(shard); - nodes.push_back(r.host); - } - } - } - for (auto& r : intersection_ranges) { - rlogger.debug("repair[{}] Repair tablet task table={}.{} master_shard_id={} range={} neighbors={} replicas={}", - rid.uuid(), keyspace_name, table_name, master_shard_id, r, repair_neighbors(nodes, shards).shard_map, m.replicas); - task_metas.push_back(tablet_repair_task_meta{keyspace_name, table_name, tid, master_shard_id, r, repair_neighbors(nodes, shards), m.replicas, erm}); - co_await coroutine::maybe_yield(); - } - } - } - auto task = co_await _repair_module->make_and_start_task({}, rid, keyspace_name, tasks::task_id::create_null_id(), table_names, streaming::stream_reason::repair, std::move(task_metas), ranges_parallelism, service::default_session_id); -} - // It is called by the repair_tablet rpc verb to repair the given tablet future repair_service::repair_tablet(gms::gossip_address_map& addr_map, locator::tablet_metadata_guard& guard, locator::global_tablet_id gid, tasks::task_info global_tablet_repair_task_info, service::frozen_topology_guard topo_guard, std::optional rebuild_replicas) { auto id = _repair_module->new_repair_uniq_id(); @@ -2709,10 +2463,6 @@ future> repair::tablet_repair_task_impl::expected_total_wo co_return sz ? std::make_optional(sz) : std::nullopt; } -std::optional repair::tablet_repair_task_impl::expected_children_number() const { - return get_metas_size(); -} - node_ops_cmd_category categorize_node_ops_cmd(node_ops_cmd cmd) noexcept { switch (cmd) { case node_ops_cmd::removenode_prepare: diff --git a/repair/repair.hh b/repair/repair.hh index 942ecc435b00..2cad0337a5f2 100644 --- a/repair/repair.hh +++ b/repair/repair.hh @@ -8,6 +8,7 @@ #pragma once +#include #include #include #include @@ -27,6 +28,10 @@ #include "tasks/types.hh" #include "gms/gossip_address_map.hh" +struct tablets_unsupported : std::runtime_error { + tablets_unsupported() : std::runtime_error("tablets are not supported for this operation") {} +}; + namespace tasks { namespace repair { class task_manager_module; @@ -91,6 +96,7 @@ constexpr shard_id repair_unspecified_shard = shard_id(-1); // repair_get_status(). The returned future becomes available quickly, // as soon as repair_get_status() can be used - it doesn't wait for the // repair to complete. +// repair_start() repairs only vnode keyspaces. future repair_start(seastar::sharded& repair, sharded& am, sstring keyspace, std::unordered_map options); diff --git a/repair/row_level.hh b/repair/row_level.hh index 69403749d14e..213d71394812 100644 --- a/repair/row_level.hh +++ b/repair/row_level.hh @@ -178,8 +178,6 @@ private: shared_ptr ops_info); public: - future<> repair_tablets(repair_uniq_id id, sstring keyspace_name, std::vector table_names, bool primary_replica_only = true, dht::token_range_vector ranges_specified = {}, std::vector dcs = {}, std::unordered_set hosts = {}, std::unordered_set ignore_nodes = {}, std::optional ranges_parallelism = std::nullopt); - future repair_tablet(gms::gossip_address_map& addr_map, locator::tablet_metadata_guard& guard, locator::global_tablet_id gid, tasks::task_info global_tablet_repair_task_info, service::frozen_topology_guard topo_guard, std::optional rebuild_replicas); private: diff --git a/repair/task_manager_module.hh b/repair/task_manager_module.hh index 4c2063027344..2458002856bc 100644 --- a/repair/task_manager_module.hh +++ b/repair/task_manager_module.hh @@ -74,7 +74,6 @@ protected: future<> run() override; virtual future> expected_total_workload() const override; - virtual std::optional expected_children_number() const override; }; class data_sync_repair_task_impl : public repair_task_impl { @@ -103,7 +102,6 @@ protected: future<> run() override; virtual future> expected_total_workload() const override; - virtual std::optional expected_children_number() const override; }; class tablet_repair_task_impl : public repair_task_impl { @@ -145,7 +143,6 @@ protected: future<> run() override; virtual future> expected_total_workload() const override; - virtual std::optional expected_children_number() const override; }; class shard_repair_task_impl : public repair_task_impl { diff --git a/replica/tablets.cc b/replica/tablets.cc index f178792dee61..9b75a5c0c98e 100644 --- a/replica/tablets.cc +++ b/replica/tablets.cc @@ -636,7 +636,8 @@ struct tablet_metadata_builder { } else { auto tablet_count = row.get_as("tablet_count"); auto tmap = tablet_map(tablet_count); - current = active_tablet_map{table, tmap, tmap.first_tablet()}; + auto first_tablet = tmap.first_tablet(); + current = active_tablet_map{table, std::move(tmap), first_tablet}; } // Resize decision fields are static columns, so set them only once per table. @@ -818,6 +819,15 @@ class tablet_sstable_set : public sstables::sstable_set_impl { uint64_t _bytes_on_disk = 0; public: + tablet_sstable_set(const tablet_sstable_set& o) + : _schema(o._schema) + , _tablet_map(o._tablet_map.clone()) + , _sstable_sets(o._sstable_sets) + , _sstable_set_ids(o._sstable_set_ids) + , _size(o._size) + , _bytes_on_disk(o._bytes_on_disk) + {} + tablet_sstable_set(schema_ptr s, const storage_group_manager& sgm, const locator::tablet_map& tmap) : _schema(std::move(s)) , _tablet_map(tmap.tablet_count()) diff --git a/schema/schema.hh b/schema/schema.hh index f7d1d29d8288..55d2d27f9e2e 100644 --- a/schema/schema.hh +++ b/schema/schema.hh @@ -21,7 +21,7 @@ #include #include #include "types/types.hh" -#include "compound.hh" +#include "keys/compound.hh" #include "gc_clock.hh" #include "compress.hh" #include "compaction/compaction_strategy_type.hh" diff --git a/service/pager/paging_state.cc b/service/pager/paging_state.cc index d8a2a0344024..b0f299ebd178 100644 --- a/service/pager/paging_state.cc +++ b/service/pager/paging_state.cc @@ -9,7 +9,7 @@ */ #include "bytes.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "paging_state.hh" #include #include "idl/paging_state.dist.hh" diff --git a/service/pager/paging_state.hh b/service/pager/paging_state.hh index d0430647fa35..9e26364790aa 100644 --- a/service/pager/paging_state.hh +++ b/service/pager/paging_state.hh @@ -13,7 +13,7 @@ #include #include "bytes_fwd.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "query-request.hh" #include "db/read_repair_decision.hh" #include "mutation/position_in_partition.hh" diff --git a/service/raft/group0_state_machine.cc b/service/raft/group0_state_machine.cc index 4a65fd11e938..6be212e0adbf 100644 --- a/service/raft/group0_state_machine.cc +++ b/service/raft/group0_state_machine.cc @@ -36,7 +36,6 @@ #include "db/system_keyspace.hh" #include "service/storage_proxy.hh" #include "service/raft/raft_group0_client.hh" -#include "partition_slice_builder.hh" #include "timestamp.hh" #include "utils/overloaded_functor.hh" #include "utils/to_string.hh" diff --git a/service/raft/group0_state_machine.hh b/service/raft/group0_state_machine.hh index 46dd8b9c288e..aff47ef2dfef 100644 --- a/service/raft/group0_state_machine.hh +++ b/service/raft/group0_state_machine.hh @@ -11,7 +11,7 @@ #include #include "data_dictionary/data_dictionary.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "service/broadcast_tables/experimental/lang.hh" #include "raft/raft.hh" #include "service/raft/group0_state_id_handler.hh" diff --git a/service/storage_proxy.cc b/service/storage_proxy.cc index 35eae79f72e8..92bd64a24a22 100644 --- a/service/storage_proxy.cc +++ b/service/storage_proxy.cc @@ -1164,13 +1164,24 @@ const dht::token& end_token(const dht::partition_range& r) { return r.end() ? r.end()->value().token() : max_token; } -static unsigned get_cas_shard(const schema& s, dht::token token) { - return s.table().shard_for_reads(token); +static unsigned get_cas_shard(const schema& s, dht::token token, const locator::effective_replication_map& erm) { + if (const auto shard = erm.get_sharder(s).try_get_shard_for_reads(token); shard) { + return *shard; + } + if (const auto& rs = erm.get_replication_strategy(); rs.uses_tablets()) { + const auto& tablet_map = erm.get_token_metadata().tablets().get_tablet_map(s.id()); + const auto tablet_id = tablet_map.get_tablet_id(token); + return tablet_map.get_primary_replica(tablet_id).shard % smp::count; + } else { + on_internal_error(paxos::paxos_state::logger, + format("failed to detect shard for reads for non-tablet-based rs {}, table {}.{}", + rs.get_type(), s.ks_name(), s.cf_name())); + } } cas_shard::cas_shard(const schema& s, dht::token token) : _token_guard(s.table(), token) - , _shard(get_cas_shard(s, token)) + , _shard(get_cas_shard(s, token, *_token_guard.get_erm())) { } @@ -6551,7 +6562,10 @@ future storage_proxy::cas(schema_ptr schema, cas_shard cas_shard, shared_p db::validate_for_cas(cl_for_paxos); db::validate_for_cas_learn(cl_for_learn, schema->ks_name()); - if (get_cas_shard(*schema, partition_ranges[0].start()->value().as_decorated_key().token()) != this_shard_id()) { + auto token_guard = std::move(cas_shard).token_guard(); + auto key = partition_ranges[0].start()->value().as_decorated_key(); + const auto token = key.token(); + if (get_cas_shard(*schema, token, *token_guard.get_erm()) != this_shard_id()) { on_internal_error(paxos::paxos_state::logger, "storage_proxy::cas called on a wrong shard"); } @@ -6566,9 +6580,9 @@ future storage_proxy::cas(schema_ptr schema, cas_shard cas_shard, shared_p shared_ptr handler; try { handler = seastar::make_shared(shared_from_this(), - std::move(cas_shard).token_guard(), + std::move(token_guard), query_options.trace_state, query_options.permit, - partition_ranges[0].start()->value().as_decorated_key(), + std::move(key), schema, cmd, cl_for_paxos, cl_for_learn, write_timeout, cas_timeout); } catch (exceptions::unavailable_exception& ex) { write ? get_stats().cas_write_unavailables.mark() : get_stats().cas_read_unavailables.mark(); @@ -6580,7 +6594,6 @@ future storage_proxy::cas(schema_ptr schema, cas_shard cas_shard, shared_p unsigned contentions = 0; - dht::token token = partition_ranges[0].start()->value().as_decorated_key().token(); utils::latency_counter lc; lc.start(); diff --git a/service/storage_service.cc b/service/storage_service.cc index a4919b6bb1d2..afcd1d32e248 100644 --- a/service/storage_service.cc +++ b/service/storage_service.cc @@ -10,6 +10,7 @@ */ #include "storage_service.hh" +#include "utils/chunked_vector.hh" #include "utils/disk_space_monitor.hh" #include "compaction/task_manager_module.hh" #include "gc_clock.hh" @@ -3549,9 +3550,9 @@ future> storage_service::effective_ownership( // // The call for get_range_for_endpoint is done once per endpoint const auto& tm = *erm->get_token_metadata_ptr(); - const auto tokens = co_await std::invoke([&]() -> future> { + const auto tokens = co_await std::invoke([&]() -> future> { if (!erm->get_replication_strategy().uses_tablets()) { - return make_ready_future>(tm.sorted_tokens()); + return make_ready_future>(tm.sorted_tokens()); } else { auto& cf = ss._db.local().find_column_family(keyspace_name, table_name); const auto& tablets = tm.tablets().get_tablet_map(cf.schema()->id()); @@ -4704,6 +4705,10 @@ future<> storage_service::do_drain() { // Need to stop transport before group0, otherwise RPCs may fail with raft_group_not_found. co_await stop_transport(); + // Drain view builder before group0, because the view builder uses group0 to coordinate view building. + // Drain after transport is stopped, because view_builder::drain aborts view writes for user writes as well. + co_await _view_builder.invoke_on_all(&db::view::view_builder::drain); + // group0 persistence relies on local storage, so we need to stop group0 first. // This must be kept in sync with defer_verbose_shutdown for group0 in main.cc to // handle the case when initialization fails before reaching drain_on_shutdown for ss. @@ -4719,7 +4724,6 @@ future<> storage_service::do_drain() { return bm.drain(); }); - co_await _view_builder.invoke_on_all(&db::view::view_builder::drain); co_await _db.invoke_on_all(&replica::database::drain); co_await _sys_ks.invoke_on_all(&db::system_keyspace::shutdown); co_await _repair.invoke_on_all(&repair_service::shutdown); diff --git a/service/topology_coordinator.cc b/service/topology_coordinator.cc index cb37d321b293..c49c11e86557 100644 --- a/service/topology_coordinator.cc +++ b/service/topology_coordinator.cc @@ -957,10 +957,10 @@ class topology_coordinator : public endpoint_lifecycle_subscriber { // the base table will coordinate the transition for the entire group. continue; } - locator::tablet_map old_tablets = tmptr->tablets().get_tablet_map(table_or_mv->id()); + locator::tablet_map old_tablets = co_await tmptr->tablets().get_tablet_map(table_or_mv->id()).clone_gently(); locator::replication_strategy_params params{repl_opts, old_tablets.tablet_count()}; auto new_strategy = locator::abstract_replication_strategy::create_replication_strategy("NetworkTopologyStrategy", params); - new_tablet_map = co_await new_strategy->maybe_as_tablet_aware()->reallocate_tablets(table_or_mv, tmptr, old_tablets); + new_tablet_map = co_await new_strategy->maybe_as_tablet_aware()->reallocate_tablets(table_or_mv, tmptr, std::move(old_tablets)); } catch (const std::exception& e) { error = e.what(); rtlogger.error("Couldn't process global_topology_request::keyspace_rf_change, error: {}," @@ -1972,6 +1972,12 @@ class topology_coordinator : public endpoint_lifecycle_subscriber { } } + void trigger_load_stats_refresh() { + (void)_tablet_load_stats_refresh.trigger().handle_exception([] (auto ep) { + rtlogger.warn("Error during tablet load stats refresh: {}", ep); + }); + } + future<> cancel_all_requests(group0_guard guard, std::unordered_set dead_nodes) { utils::chunked_vector muts; std::vector reject_join; @@ -2437,10 +2443,7 @@ class topology_coordinator : public endpoint_lifecycle_subscriber { muts.emplace_back(rtbuilder.build()); co_await _voter_handler.on_node_added(node.id, _as); co_await update_topology_state(take_guard(std::move(node)), std::move(muts), "bootstrap: read fence completed"); - // Make sure the load balancer knows the capacity for the new node immediately. - (void)_tablet_load_stats_refresh.trigger().handle_exception([] (auto ep) { - rtlogger.warn("Error during tablet load stats refresh: {}", ep); - }); + trigger_load_stats_refresh(); } break; case node_state::removing: { @@ -2499,6 +2502,7 @@ class topology_coordinator : public endpoint_lifecycle_subscriber { co_await _voter_handler.on_node_added(node.id, _as); co_await update_topology_state(take_guard(std::move(node)), std::move(muts), "replace: read fence completed"); + trigger_load_stats_refresh(); } break; default: diff --git a/service/vector_store_client.cc b/service/vector_store_client.cc index 63d930ce05a9..015652c7c1c3 100644 --- a/service/vector_store_client.cc +++ b/service/vector_store_client.cc @@ -13,7 +13,7 @@ #include "exceptions/exceptions.hh" #include "utils/sequential_producer.hh" #include "dht/i_partitioner.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "utils/rjson.hh" #include "schema/schema.hh" #include diff --git a/sstables/index_reader.hh b/sstables/index_reader.hh index 43cdf477a300..558ff1d013fa 100644 --- a/sstables/index_reader.hh +++ b/sstables/index_reader.hh @@ -50,6 +50,10 @@ struct open_rt_marker { // to a start of some partition or to EOF, // while "position" means a position that corresponds either to the start of some partition, // to the start of some clustering entry, or to EOF. +// +// Note: even though some methods of the index are inexact (i.e. they advance the index to *some* +// Data position close to the queried ring position), they are monotonic. +// I.e. if B >= A, then advance(B) >= advance(A). class abstract_index_reader { public: virtual ~abstract_index_reader() = default; @@ -58,19 +62,29 @@ public: // True iff lower bound is at EOF. virtual bool eof() const = 0; - // Advances lower bound to the first PK no smaller than pos. - // Returns true iff `key` is a partition key present in the sstable. + // If `key` is a partition key present in the sstable, advances lower bound to `key`. + // Otherwise advances lower bound to the some PK no greater than `key`. + // Returns `true` iff it's possible that `key` is a partition key present in the sstable. + // (In other words, if it returns `false`, then the key is definitely not present. + // Otherwise it's unknown if it's present). // // Precondition: pos >= lower bound // // Note: this is the most important and performance-sensitive method of the reader. // This is what's used by sstable readers to find positions for single-partition reads. virtual future advance_lower_and_check_if_present(dht::ring_position_view key) = 0; - // Advances lower bound to the first PK no smaller than pos. - // Precondition: pos >= lower bound - virtual future<> advance_to(dht::ring_position_view pos) = 0; - // Advances lower bound to the first PK which lies inside or after the range. + // Advances lower bound to the first PK greater than dk. + // + // Preconditions: dk >= lower bound, dk is present in the sstable + virtual future<> advance_past_definitely_present_partition(const dht::decorated_key& dk) = 0; + // Advances lower bound to dk. + // + // Preconditions: dk >= lower bound, dk is present in the sstable + virtual future<> advance_to_definitely_present_partition(const dht::decorated_key& dk) = 0; + // Advances lower bound to the first PK which lies inside or after the range, + // or to some close predecessor of that optimal PK. // Advances upper bound to the first PK which lies after the range. + // or to some close successor of that optimal PK. // Preconditions: // 1. next lower bound >= lower bound // 2. next upper bound >= upper bound @@ -123,11 +137,33 @@ public: virtual future<> read_partition_data() = 0; // Returns tombstone for the current partition, // if such information is available in the index. + // + // Note: it's an arbitrary decision of the writer of the index whether + // the the partition tombstone has been attached to a given index entry, + // and the user of the index reader should not assume that it has. + // + // The main use case for this information is reads which start in the + // middle of a large partition. The Data reader needs to know the partition + // header (full partition key and partition tombstone) to emit a partition, + // but the header is at the beginning of the partition, potentially far + // from the queried rows. + // Embedding the partition header in the index lets the Data reader skip + // avoid doing a separate disk read to get the header from the Data file. + // + // Thus, `partition_tombstone()` and `get_partition_key()` usually + // return an engaged optional at least for those partitions which have an + // intra-partition index (because that's when they can be used to skip + // a disk seek) but the reader shouldn't assume that. It should check + // if they are available, and if not, it should fall back to reading + // the partition header from the Data file. + // // Precondition: partition_data_ready() virtual std::optional partition_tombstone() = 0; - // Returns the key for current partition. + // Returns the key for current partition of the lower bound, + // if available (se the comment for partition_tombstone) in the index. + // // Precondition: partition_data_ready() - virtual partition_key get_partition_key() = 0; + virtual std::optional get_partition_key() = 0; // Returns data file positions corresponding to the bounds. // End position may be unset virtual data_file_positions_range data_file_positions() const = 0; @@ -985,7 +1021,7 @@ public: // Returns the key for current partition. // Can be called only when partition_data_ready(). - partition_key get_partition_key() override { + std::optional get_partition_key() override { return _alloc_section(_region, [this] { index_entry& e = current_partition_entry(_lower_bound); return e.get_key().to_partition_key(*_sstable->_schema); @@ -1077,6 +1113,14 @@ public: }); } + future<> advance_past_definitely_present_partition(const dht::decorated_key& dk) override { + return advance_to(_lower_bound, dht::ring_position_view::for_after_key(dk)); + } + + future<> advance_to_definitely_present_partition(const dht::decorated_key& dk) override { + return advance_to(_lower_bound, dht::ring_position_view(dk, dht::ring_position_view::after_key::no)); + } + // Like advance_to(dht::ring_position_view), but returns information whether the key was found // If upper_bound is provided, the upper bound within position is looked up future advance_lower_and_check_if_present(dht::ring_position_view key) override { @@ -1197,12 +1241,6 @@ public: return advance_to_next_partition(_lower_bound); } - // Positions the cursor on the first partition which is not smaller than pos (like std::lower_bound). - // Must be called for non-decreasing positions. - future<> advance_to(dht::ring_position_view pos) override { - return advance_to(_lower_bound, pos); - } - // Returns positions in the data file of the cursor. // End position may be unset data_file_positions_range data_file_positions() const override { diff --git a/sstables/key.hh b/sstables/key.hh index ef0b4849c06c..1ad94ce39225 100644 --- a/sstables/key.hh +++ b/sstables/key.hh @@ -11,8 +11,8 @@ #include "schema/schema_fwd.hh" #include #include "replica/database_fwd.hh" -#include "keys.hh" -#include "compound_compat.hh" +#include "keys/keys.hh" +#include "keys/compound_compat.hh" #include "dht/token.hh" namespace sstables { diff --git a/sstables/kl/reader.cc b/sstables/kl/reader.cc index e9d1c5ba0dcc..a7a25e008ae1 100644 --- a/sstables/kl/reader.cc +++ b/sstables/kl/reader.cc @@ -12,8 +12,8 @@ #include "sstables/sstable_mutation_reader.hh" #include "sstables/sstables.hh" #include "sstables/types.hh" -#include "clustering_key_filter.hh" -#include "clustering_ranges_walker.hh" +#include "keys/clustering_key_filter.hh" +#include "keys/clustering_ranges_walker.hh" #include "concrete_types.hh" #include "utils/to_string.hh" #include "utils/value_or_reference.hh" @@ -1208,7 +1208,7 @@ class sstable_mutation_reader : public mp_row_consumer_reader_k_l { } return (_index_in_current_partition ? _index_reader->advance_to_next_partition() - : get_index_reader().advance_to(dht::ring_position_view::for_after_key(*_current_partition_key))).then([this] { + : get_index_reader().advance_past_definitely_present_partition(*_current_partition_key)).then([this] { _index_in_current_partition = true; auto [start, end] = _index_reader->data_file_positions(); if (end && start > *end) { @@ -1229,7 +1229,7 @@ class sstable_mutation_reader : public mp_row_consumer_reader_k_l { return read_from_datafile(); } auto pk = _index_reader->get_partition_key(); - auto key = dht::decorate_key(*_schema, std::move(pk)); + auto key = dht::decorate_key(*_schema, std::move(pk.value())); _consumer.setup_for_partition(key.key()); on_next_partition(std::move(key), tombstone(*tomb)); return make_ready_future<>(); @@ -1299,7 +1299,7 @@ class sstable_mutation_reader : public mp_row_consumer_reader_k_l { return [this] { if (!_index_in_current_partition) { _index_in_current_partition = true; - return get_index_reader().advance_to(*_current_partition_key); + return get_index_reader().advance_to_definitely_present_partition(*_current_partition_key); } return make_ready_future(); }().then([this, pos] { diff --git a/sstables/mutation_fragment_filter.hh b/sstables/mutation_fragment_filter.hh index 6bce96bd88d0..90020c0d4638 100644 --- a/sstables/mutation_fragment_filter.hh +++ b/sstables/mutation_fragment_filter.hh @@ -10,8 +10,8 @@ #include "exceptions.hh" #include "mutation/mutation_fragment.hh" -#include "clustering_ranges_walker.hh" -#include "clustering_key_filter.hh" +#include "keys/clustering_ranges_walker.hh" +#include "keys/clustering_key_filter.hh" #include "readers/mutation_reader_fwd.hh" namespace sstables { diff --git a/sstables/mx/reader.cc b/sstables/mx/reader.cc index f8bb5d690cc0..1e7ccfc0f127 100644 --- a/sstables/mx/reader.cc +++ b/sstables/mx/reader.cc @@ -33,7 +33,7 @@ class mp_row_consumer_reader_mx : public mp_row_consumer_reader_base, public mut _permit.on_finish_sstable_read(); } - void on_next_partition(dht::decorated_key, tombstone); + virtual data_consumer::proceed on_next_partition(dht::decorated_key, tombstone); }; enum class row_processing_result { @@ -298,7 +298,11 @@ class mp_row_consumer_m { auto pk = key.to_partition_key(*_schema); setup_for_partition(pk); auto dk = dht::decorate_key(*_schema, pk); - _reader->on_next_partition(std::move(dk), tombstone(deltime)); + + auto should_proceed = _reader->on_next_partition(std::move(dk), tombstone(deltime)); + if (should_proceed == data_consumer::proceed::no) { + return data_consumer::proceed::no; + } return data_consumer::proceed(!_reader->is_buffer_full() && !need_preempt()); } @@ -1262,7 +1266,7 @@ class mx_sstable_mutation_reader : public mp_row_consumer_reader_mx { std::unique_ptr _index_reader; // We avoid unnecessary lookup for single partition reads thanks to this flag bool _single_partition_read = false; - const dht::partition_range& _pr; + std::reference_wrapper _pr; streamed_mutation::forwarding _fwd; mutation_reader::forwarding _fwd_mr; read_monitor& _monitor; @@ -1283,11 +1287,13 @@ class mx_sstable_mutation_reader : public mp_row_consumer_reader_mx { streamed_mutation::forwarding fwd, mutation_reader::forwarding fwd_mr, read_monitor& mon, - integrity_check integrity) + integrity_check integrity, + std::unique_ptr ir = nullptr) : mp_row_consumer_reader_mx(std::move(schema), permit, std::move(sst)) , _slice_holder(std::move(slice)) , _slice(_slice_holder.get()) , _consumer(this, _schema, std::move(permit), _slice, std::move(trace_state), fwd, _sst) + , _index_reader(std::move(ir)) // FIXME: I want to add `&& fwd_mr == mutation_reader::forwarding::no` below // but can't because many call sites use the default value for // `mutation_reader::forwarding` which is `yes`. @@ -1297,6 +1303,7 @@ class mx_sstable_mutation_reader : public mp_row_consumer_reader_mx { , _fwd_mr(fwd_mr) , _monitor(mon) , _integrity(integrity) { + sstlog.trace("mx_sstable_mutation_reader {}: init with _pr={}", fmt::ptr(this), _pr.get()); if (reversed()) { if (!_single_partition_read) { on_internal_error(sstlog, format( @@ -1343,7 +1350,8 @@ class mx_sstable_mutation_reader : public mp_row_consumer_reader_mx { } return (_index_in_current_partition ? _index_reader->advance_to_next_partition() - : get_index_reader().advance_to(dht::ring_position_view::for_after_key(*_current_partition_key))).then([this] { + : get_index_reader().advance_past_definitely_present_partition(*_current_partition_key)) + .then([this] { _index_in_current_partition = true; auto [start, end] = _index_reader->data_file_positions(); if (end && start > *end) { @@ -1363,8 +1371,12 @@ class mx_sstable_mutation_reader : public mp_row_consumer_reader_mx { sstlog.trace("reader {}: no tombstone", fmt::ptr(this)); return read_from_datafile(); } - auto pk = _index_reader->get_partition_key(); - auto key = dht::decorate_key(*_schema, std::move(pk)); + std::optional pk = _index_reader->get_partition_key(); + if (!pk) { + sstlog.trace("reader {}: no partition key", fmt::ptr(this)); + return read_from_datafile(); + } + auto key = dht::decorate_key(*_schema, std::move(*pk)); _consumer.setup_for_partition(key.key()); on_next_partition(std::move(key), tombstone(*tomb)); return make_ready_future<>(); @@ -1383,6 +1395,19 @@ class mx_sstable_mutation_reader : public mp_row_consumer_reader_mx { return make_ready_future<>(); } + if (_saved_partition_tombstone) { + // This is the special case where we ended reading last partition range + // only after parsing the partition key after the range, + // and then the reader was forwarded to that key. + // Since the parser can't be moved back, we serve the partition key + // and the tombstone that we saved after parsing. + auto tomb = *_saved_partition_tombstone; + _saved_partition_tombstone.reset(); + sstlog.trace("reader {}: serving partition key {} due to _saved_partition_tombstone {}", fmt::ptr(this), *_current_partition_key, tomb); + on_next_partition(*_current_partition_key, tomb); + return make_ready_future<>(); + } + if (!_consumer.is_mutation_end()) { throw malformed_sstable_exception(format("consumer not at partition boundary, position: {}", position_in_partition_view::printer(*_schema, _consumer.position())), _sst->get_filename()); @@ -1435,7 +1460,7 @@ class mx_sstable_mutation_reader : public mp_row_consumer_reader_mx { if (!_index_in_current_partition) { _index_in_current_partition = true; // FIXME reversed multi partition reads - return get_index_reader().advance_to(*_current_partition_key); + return get_index_reader().advance_to_definitely_present_partition(*_current_partition_key); } return make_ready_future(); }().then([this, pos = *pos] { @@ -1501,7 +1526,7 @@ class mx_sstable_mutation_reader : public mp_row_consumer_reader_mx { if (_single_partition_read) { _sst->get_stats().on_single_partition_read(); - const auto& key = dht::ring_position_view(_pr.start()->value()); + const auto& key = dht::ring_position_view(_pr.get().start()->value()); const auto present = co_await get_index_reader().advance_lower_and_check_if_present(key); @@ -1598,6 +1623,19 @@ class mx_sstable_mutation_reader : public mp_row_consumer_reader_mx { _partition_finished = true; } } + // Advances the index to the first position + // which hasn't been crossed by the Data file parser yet. + future<> advance_index_until_unseen_partition() { + while (true) { + auto [start, end] = _index_reader->data_file_positions(); + if (start >= _context->position()) { + sstlog.trace("mp_row_consumer_reader_mx {}: advance_index_until_unseen_partition(): advanced to {}", fmt::ptr(this), start); + co_return; + } else { + co_await _index_reader->advance_to_next_partition(); + } + } + } virtual future<> fast_forward_to(const dht::partition_range& pr) override { if (reversed()) { // FIXME @@ -1605,6 +1643,8 @@ class mx_sstable_mutation_reader : public mp_row_consumer_reader_mx { } return maybe_initialize().then([this, &pr] (bool initialized) { + _pr = pr; + sstlog.trace("mp_row_consumer_reader_mx {}: fast_forward_to({})", fmt::ptr(this), _pr); if (!initialized) { _end_of_stream = true; return make_ready_future<>(); @@ -1618,9 +1658,54 @@ class mx_sstable_mutation_reader : public mp_row_consumer_reader_mx { return f1.then([this] { auto [start, end] = _index_reader->data_file_positions(); parse_assert(bool(end), _sst->get_filename()); + sstlog.trace("mp_row_consumer_reader_mx {}: fast_forward_to({}), index returned range [{}, {}), parser currently at {}", fmt::ptr(this), _pr, start, *end, _context->position()); + if (start < _context->position()) { + sstlog.trace("mp_row_consumer_reader_mx {}: _saved_partition_tombstone={}", fmt::ptr(this), _saved_partition_tombstone); + // If we got here, the index returned a Data start which precedes + // the data parser's position. + // But by contract, fast_forward_to can only be used to move the reader forward. + // The new range must lie after the old range. + // + // So there are two ways for this to happen: + // 1. When reading the last range, the parser was advanced past the first + // partition key lying after the range. + // (This is signified by _saved_partition_tombstone). + // And then, the reader was advanced to a new range that + // to the best knowledge of the index, might contain that exact key. + // In this case, we have to start reading from that key (which we have remembered). + // 2. The parser was not advanced past the last range, but the inexact + // index returned a start position for the new range which in reality + // lies inside the old range. In this case, we should start reading + // from the first partition with position greater or equal to the parser's position. + + // The index *could* be in the current position, + // but setting `false` here is fine. + // Worst case, the reader will forward the index to the current position again. + // Lack of this optimization doesn't matter to a range read. + _index_in_current_partition = false; + if (_saved_partition_tombstone) { + // Case 1 from the comment above. + if (*end >= _context->position()) { + _read_enabled = true; + return _context->fast_forward_to(_context->position(), *end); + } else { + _read_enabled = false; + return make_ready_future<>(); + } + } else { + // Case 2 from the comment above. + return advance_index_until_unseen_partition().then([this] { + auto [start, end] = _index_reader->data_file_positions(); + _read_enabled = true; + _context->reset(indexable_element::partition); + return _context->fast_forward_to(start, *end); + }); + } + } if (start != *end) { _read_enabled = true; _index_in_current_partition = true; + _saved_partition_tombstone.reset(); _context->reset(indexable_element::partition); return _context->fast_forward_to(start, *end); } @@ -1718,6 +1803,48 @@ class mx_sstable_mutation_reader : public mp_row_consumer_reader_mx { sstlog.warn("Failed closing of sstable_mutation_reader: {}. Ignored since the reader is already done.", ep); }); } + + data_consumer::proceed on_next_partition(dht::decorated_key key, tombstone tomb) override { + if (_pr.get().before(key, dht::ring_position_comparator(*_schema))) { + sstlog.trace("mp_row_consumer_reader_mx {}: on_next_partition({}), _pr={}, skipping key before range", fmt::ptr(this), key, _pr); + // If we got here, then the index returned a Data file range which + // includes some partitions before the queried range. + // + // A BTI index is inexact in general (it can return a Data file range + // which includes some partitions before and after the queried range), + // but it is guaranteed to return the exact position if it's queried + // for a key which is present in the sstable. + // + // So, for a single partition read, if we parsed a partition which is before + // the queried range, it means that the queried partition doesn't exist in the sstable. + // (I.e. the index gave us a false positive). + // In this case, the read is over. There's nothing to read. + // + // Otherwise, for range reads, we are going to skip this partition + // (i.e. advance the index to the range after `key`) and resume reading. + _end_of_stream = _single_partition_read; + _before_partition = false; + _partition_finished = true; + _current_partition_key = std::move(key); + return data_consumer::proceed::no; + } else if (_pr.get().after(key, dht::ring_position_comparator(*_schema))) { + // If we got here, then the index returned a Data file range which + // includes some partitions after the queried range. + // The read is over. The new key and everything after it should be ignored. + sstlog.trace("mp_row_consumer_reader_mx {}: on_next_partition({}), _pr={}, skipping key after range", fmt::ptr(this), key, _pr); + _end_of_stream = true; + // The read is over for now, but the reader can be later forwarded to the key we just read. + // The parser can't move backwards, so we have to remember the key and tombstone + // to handle that case. + _saved_partition_tombstone = tomb; + _current_partition_key = std::move(key); + return data_consumer::proceed::no; + } else { + // This is the normal path. + sstlog.trace("mp_row_consumer_reader_mx {}: on_next_partition({}), _pr={}, consuming key in range", fmt::ptr(this), key, _pr); + return mp_row_consumer_reader_mx::on_next_partition(std::move(key), tomb); + } + } }; static mutation_reader make_reader( @@ -1730,10 +1857,11 @@ static mutation_reader make_reader( streamed_mutation::forwarding fwd, mutation_reader::forwarding fwd_mr, read_monitor& monitor, - integrity_check integrity) { + integrity_check integrity, + std::unique_ptr ir = nullptr) { return make_mutation_reader( std::move(sstable), std::move(schema), std::move(permit), range, - std::move(slice), std::move(trace_state), fwd, fwd_mr, monitor, integrity); + std::move(slice), std::move(trace_state), fwd, fwd_mr, monitor, integrity, std::move(ir)); } mutation_reader make_reader( @@ -1766,6 +1894,22 @@ mutation_reader make_reader( value_or_reference(std::move(slice)), std::move(trace_state), fwd, fwd_mr, monitor, integrity); } +mutation_reader make_reader_with_index_reader( + shared_sstable sstable, + schema_ptr schema, + reader_permit permit, + const dht::partition_range& range, + const query::partition_slice& slice, + tracing::trace_state_ptr trace_state, + streamed_mutation::forwarding fwd, + mutation_reader::forwarding fwd_mr, + read_monitor& monitor, + integrity_check integrity, + std::unique_ptr ir) { + return make_reader(std::move(sstable), std::move(schema), std::move(permit), range, + value_or_reference(slice), std::move(trace_state), fwd, fwd_mr, monitor, integrity, std::move(ir)); +} + /// a reader which does not support seeking to given position. /// /// unlike mx_sstable_mutation_reader which allows fast forwarding read, @@ -1859,7 +2003,7 @@ mutation_reader make_full_scan_reader( std::move(trace_state), monitor, integrity); } -void mp_row_consumer_reader_mx::on_next_partition(dht::decorated_key key, tombstone tomb) { +data_consumer::proceed mp_row_consumer_reader_mx::on_next_partition(dht::decorated_key key, tombstone tomb) { _partition_finished = false; _before_partition = false; _end_of_stream = false; @@ -1867,6 +2011,7 @@ void mp_row_consumer_reader_mx::on_next_partition(dht::decorated_key key, tombst push_mutation_fragment( mutation_fragment_v2(*_schema, _permit, partition_start(*_current_partition_key, tomb))); _sst->get_stats().on_partition_read(); + return data_consumer::proceed::yes; } // A validating consumer implementing the Consumer concept of data_consume_rows_context_m. @@ -1903,6 +2048,7 @@ class validating_consumer { mutation_fragment_stream_validator _validator; uint64_t _error_count = 0; std::optional _expected_pkey; + std::optional _last_pkey; std::optional _expected_clustering_block; position_in_partition _current_pos; bool _stop_after_partition_header = false; @@ -1976,6 +2122,9 @@ class validating_consumer { bool in_expected_clustering_block() const { return _expected_clustering_block && !_expected_clustering_block->done; } + partition_key get_last_pkey() { + return _last_pkey.value(); + } void report_error(sstring what) { ++_error_count; @@ -1983,10 +2132,10 @@ class validating_consumer { } data_consumer::proceed consume_partition_start(sstables::key_view key, sstables::deletion_time deltime) { - auto pk = key.to_partition_key(*_schema); - auto dk = dht::decorate_key(*_schema, pk); + _last_pkey = key.to_partition_key(*_schema); + auto dk = dht::decorate_key(*_schema, *_last_pkey); _current_pos = position_in_partition(position_in_partition::partition_start_tag_t{}); - sstlog.trace("validating_consumer {}: {}({}) _expected_pkey={}", fmt::ptr(this), __FUNCTION__, pk, _expected_pkey); + sstlog.trace("validating_consumer {}: {}({}) _expected_pkey={}", fmt::ptr(this), __FUNCTION__, _last_pkey, _expected_pkey); validate_fragment_order(mutation_fragment_v2::kind::partition_start, {}); if (_expected_pkey && !_expected_pkey->equal(*_schema, dk.key())) { report_error(format("mismatching index/data: partition mismatch: index: {}, data: {}", *_expected_pkey, dk.key())); @@ -2120,6 +2269,8 @@ future validate( big_index_reader = nullptr; } + consumer.reset_index_expected_partition(); + if (idx_reader) { co_await idx_reader->read_partition_data(); @@ -2129,15 +2280,21 @@ future validate( const auto index_pos = idx_reader->data_file_positions().start; const auto data_pos = context->position(); - sstlog.trace("validate(): index-data position check for partition {}: {} == {}", idx_reader->get_partition_key(), data_pos, index_pos); + auto pk = idx_reader->get_partition_key(); + + if (pk) { + sstlog.trace("validate(): index-data position check for partition {}: {} == {}", *pk, data_pos, index_pos); + } else { + sstlog.trace("validate(): index-data position check for a partition: {} == {}", data_pos, index_pos); + } if (index_pos != data_pos) { consumer.report_error(format("mismatching index/data: position mismatch: index: {}, data: {}", index_pos, data_pos)); } current_partition_pos = data_pos; - consumer.set_index_expected_partition(idx_reader->get_partition_key()); - } else { - consumer.reset_index_expected_partition(); + if (pk) { + consumer.set_index_expected_partition(*pk); + } } std::optional current_pi_block; @@ -2192,8 +2349,8 @@ future validate( // Check if promoted index still has more entries. if (idx_cursor && (current_pi_block = co_await idx_cursor->next_entry())) { consumer.report_error(format("mismatching index/data: promoted index has more blocks, but it is end of partition {} ({})", - idx_reader->get_partition_key().with_schema(*schema), - idx_reader->get_partition_key())); + consumer.get_last_pkey().with_schema(*schema), + consumer.get_last_pkey())); } if (idx_reader) { diff --git a/sstables/mx/reader.hh b/sstables/mx/reader.hh index f0ab3cca8c87..84c24427315e 100644 --- a/sstables/mx/reader.hh +++ b/sstables/mx/reader.hh @@ -12,6 +12,7 @@ #include "readers/mutation_reader.hh" #include "sstables/progress_monitor.hh" #include "sstables/types_fwd.hh" +#include "sstables/index_reader.hh" namespace sstables { namespace mx { @@ -44,6 +45,19 @@ mutation_reader make_reader( read_monitor& monitor, integrity_check integrity); +mutation_reader make_reader_with_index_reader( + shared_sstable sstable, + schema_ptr schema, + reader_permit permit, + const dht::partition_range& range, + const query::partition_slice& slice, + tracing::trace_state_ptr trace_state, + streamed_mutation::forwarding fwd, + mutation_reader::forwarding fwd_mr, + read_monitor& monitor, + integrity_check integrity, + std::unique_ptr); + // A reader which doesn't use the index at all. It reads everything from the // sstable and it doesn't support skipping. mutation_reader make_full_scan_reader( diff --git a/sstables/mx/types.hh b/sstables/mx/types.hh index 0b90db8e9928..8d6c8d965315 100644 --- a/sstables/mx/types.hh +++ b/sstables/mx/types.hh @@ -9,7 +9,7 @@ #pragma once #include "sstables/exceptions.hh" -#include "clustering_bounds_comparator.hh" +#include "keys/clustering_bounds_comparator.hh" #include namespace sstables { diff --git a/sstables/sstable_mutation_reader.hh b/sstables/sstable_mutation_reader.hh index 32bf1518cff3..e6daac984cb1 100644 --- a/sstables/sstable_mutation_reader.hh +++ b/sstables/sstable_mutation_reader.hh @@ -10,7 +10,7 @@ #include #include -#include "keys.hh" +#include "keys/keys.hh" #include #include #include "index_reader.hh" @@ -49,6 +49,17 @@ protected: bool _before_partition = true; std::optional _current_partition_key; + + // If inexact sstable indexes are used, the reader might only realize + // it has exceeded the queried range after it has already parsed + // the partition header following the range. + // + // In this case the parsed partition key and tombstone can't be emitted. + // However, the reader will need to emit them if its forwarded to that + // key later. It can't re-parse them, since the parser can't go backwards + // and recreating the context is awkward, so it saves the partition key + // and the tombstone and serves them to the consumer if that case happens. + std::optional _saved_partition_tombstone; public: mp_row_consumer_reader_base(shared_sstable sst); diff --git a/sstables/types.hh b/sstables/types.hh index e3115ad4cb36..021559ef6cc4 100644 --- a/sstables/types.hh +++ b/sstables/types.hh @@ -20,7 +20,6 @@ #include "sstables/key.hh" #include "sstables/file_writer.hh" #include "db/commitlog/replay_position.hh" -#include "version.hh" #include #include #include diff --git a/tasks/task_manager.cc b/tasks/task_manager.cc index 64f872143b2a..779878431079 100644 --- a/tasks/task_manager.cc +++ b/tasks/task_manager.cc @@ -121,10 +121,6 @@ future> task_manager::task::impl::expected_total_workload( return make_ready_future>(std::nullopt); } -std::optional task_manager::task::impl::expected_children_number() const { - return std::nullopt; -} - task_manager::task::progress task_manager::task::impl::get_binary_progress() const { return tasks::task_manager::task::progress{ .completed = is_complete(), @@ -133,20 +129,10 @@ task_manager::task::progress task_manager::task::impl::get_binary_progress() con } future task_manager::task::impl::get_progress() const { - auto children_num = _children.size(); - if (children_num == 0) { - co_return get_binary_progress(); + std::optional expected_workload = co_await expected_total_workload(); + if (!expected_workload && _children.size() == 0) { + co_return task_manager::task::progress{}; } - - std::optional expected_workload = std::nullopt; - auto expected_children_num = expected_children_number(); - // When get_progress is called, the task can have some of its children unregistered yet. - // Then if total workload is not known, progress obtained from children may be deceiving. - // In such a situation it's safer to return binary progress value. - if (expected_children_num.value_or(0) != children_num && !(expected_workload = co_await expected_total_workload())) { - co_return get_binary_progress(); - } - auto progress = co_await _children.get_progress(_status.progress_units); progress.total = expected_workload.value_or(progress.total); co_return progress; diff --git a/tasks/task_manager.hh b/tasks/task_manager.hh index 55669bff9010..7b5e733ae3b8 100644 --- a/tasks/task_manager.hh +++ b/tasks/task_manager.hh @@ -226,7 +226,6 @@ public: future<> finish_failed(std::exception_ptr ex, std::string error) noexcept; future<> finish_failed(std::exception_ptr ex) noexcept; virtual future> expected_total_workload() const; - virtual std::optional expected_children_number() const; task_manager::task::progress get_binary_progress() const; friend task; diff --git a/test/boost/CMakeLists.txt b/test/boost/CMakeLists.txt index 06bcb078f3b7..39c13016e095 100644 --- a/test/boost/CMakeLists.txt +++ b/test/boost/CMakeLists.txt @@ -233,6 +233,8 @@ add_scylla_test(sstable_datafile_test KIND SEASTAR) add_scylla_test(sstable_generation_test KIND BOOST) +add_scylla_test(sstable_inexact_index_test + KIND SEASTAR) add_scylla_test(sstable_move_test KIND SEASTAR) add_scylla_test(sstable_mutation_test diff --git a/test/boost/cache_mutation_reader_test.cc b/test/boost/cache_mutation_reader_test.cc index e32f616d6b77..04c177cf038c 100644 --- a/test/boost/cache_mutation_reader_test.cc +++ b/test/boost/cache_mutation_reader_test.cc @@ -14,7 +14,7 @@ #include "test/lib/scylla_test_case.hh" #include #include "schema/schema_builder.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "mutation/mutation_partition.hh" #include "mutation/partition_version.hh" #include "mutation/mutation.hh" diff --git a/test/boost/clustering_ranges_walker_test.cc b/test/boost/clustering_ranges_walker_test.cc index 4f129e0a59dd..358bb05b9e4c 100644 --- a/test/boost/clustering_ranges_walker_test.cc +++ b/test/boost/clustering_ranges_walker_test.cc @@ -11,7 +11,7 @@ #include "test/lib/scylla_test_case.hh" #include "test/lib/simple_schema.hh" -#include "clustering_ranges_walker.hh" +#include "keys/clustering_ranges_walker.hh" using namespace std::chrono_literals; diff --git a/test/boost/compound_test.cc b/test/boost/compound_test.cc index be714476195a..7b522a8674dd 100644 --- a/test/boost/compound_test.cc +++ b/test/boost/compound_test.cc @@ -12,8 +12,8 @@ #include "test/lib/random_utils.hh" #include "test/lib/test_utils.hh" -#include "compound.hh" -#include "compound_compat.hh" +#include "keys/compound.hh" +#include "keys/compound_compat.hh" #include "test/boost/range_assert.hh" #include "schema/schema_builder.hh" #include "dht/murmur3_partitioner.hh" diff --git a/test/boost/counter_test.cc b/test/boost/counter_test.cc index fe586134e6d9..353ba04a1f64 100644 --- a/test/boost/counter_test.cc +++ b/test/boost/counter_test.cc @@ -17,7 +17,7 @@ #include "test/lib/scylla_test_case.hh" #include "test/lib/test_utils.hh" #include "schema/schema_builder.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "mutation/mutation.hh" #include "mutation/frozen_mutation.hh" diff --git a/test/boost/encryption_at_rest_test.cc b/test/boost/encryption_at_rest_test.cc index b112f29afa76..9d72307a6622 100644 --- a/test/boost/encryption_at_rest_test.cc +++ b/test/boost/encryption_at_rest_test.cc @@ -17,16 +17,20 @@ #include #include #include +#include +#include #include #include +#include "ent/encryption/azure_host.hh" #include "ent/encryption/encryption.hh" #include "ent/encryption/symmetric_key.hh" #include "ent/encryption/local_file_provider.hh" #include "ent/encryption/encryption_exceptions.hh" #include "test/lib/tmpdir.hh" +#include "test/lib/test_utils.hh" #include "test/lib/random_utils.hh" #include "test/lib/cql_test_env.hh" #include "test/lib/cql_assertions.hh" @@ -40,6 +44,9 @@ #include "sstables/sstables.hh" #include "cql3/untyped_result_set.hh" #include "utils/rjson.hh" +#include "utils/azure/identity/exceptions.hh" +#include "utils/azure/identity/managed_identity_credentials.hh" +#include "utils/azure/identity/service_principal_credentials.hh" #include "replica/database.hh" #include "service/client_state.hh" @@ -492,22 +499,31 @@ class fake_proxy { auto do_io = [this, &addr, &dst_addr, port](connected_socket& src, connected_socket& dst) noexcept -> future<> { try { auto sin = src.input(); - auto dout = dst.output(); + auto dout = output_stream(dst.output().detach(), 1024); // note: have to have differing conditions for proxying // and shutdown, and need to check inside look, because // kmip connector caches connection -> not new socket. - while (_go_on && _do_proxy && !sin.eof()) { - auto buf = co_await sin.read(); - auto n = buf.size(); - testlog.trace("Read {} bytes: {}->{}:{}", n, addr, dst_addr, port); - if (_do_proxy) { - co_await dout.write(std::move(buf)); - co_await dout.flush(); - testlog.trace("Wrote {} bytes: {}->{}:{}", n, addr, dst_addr, port); + std::exception_ptr p; + try { + while (_go_on && _do_proxy && !sin.eof()) { + auto buf = co_await sin.read(); + auto n = buf.size(); + testlog.trace("Read {} bytes: {}->{}:{}", n, addr, dst_addr, port); + if (_do_proxy) { + co_await dout.write(std::move(buf)); + co_await dout.flush(); + testlog.trace("Wrote {} bytes: {}->{}:{}", n, addr, dst_addr, port); + } } + } catch (...) { + p = std::current_exception(); } + co_await dout.flush(); co_await dout.close(); co_await sin.close(); + if (p) { + std::rethrow_exception(p); + } } catch (...) { testlog.warn("Exception running proxy {}:{}->{}: {}", dst_addr, port, _address, std::current_exception()); } @@ -1283,3 +1299,609 @@ SEASTAR_TEST_CASE(test_kmip_provider_broken_sstables_on_restart, *check_run_test // Note: cannot do the above test for gcp, because we can't use false endpoints there. Could mess with address resolution, // but there is no infrastructure for that atm. + +/* + Simple test of Azure Key Provider. + + User1 is assumed to have permissions to wrap/unwrap using the given key. + User2 is assumed to _not_ have permissions to wrap/unwrap using the given key. + + This test is parameterized with env vars: + * ENABLE_AZURE_TEST - set to non-zero (1/true) to run Azure tests (enabled by default) + * ENABLE_AZURE_TEST_REAL - set to non-zero (1/true) to run tests against real Azure services (disabled by default, same tests run against a local mock server regardless) + * AZURE_TENANT_ID - the tenant where the principals live + * AZURE_USER_1_CLIENT_ID - the client ID of user1 + * AZURE_USER_1_CLIENT_SECRET - the secret of user1 + * AZURE_USER_1_CLIENT_CERTIFICATE - the PEM-encoded certificate and private key of user1 + * AZURE_USER_2_CLIENT_ID - the client ID of user2 + * AZURE_USER_2_CLIENT_SECRET - the secret of user2 + * AZURE_USER_2_CLIENT_CERTIFICATE - the PEM-encoded certificate and private key of user2 + * AZURE_KEY_NAME - set to / +*/ + +struct azure_test_env { + std::string key_name; + std::string tenant_id; + std::string user_1_client_id; + std::string user_1_client_secret; + std::string user_1_client_certificate; + std::string user_2_client_id; + std::string user_2_client_secret; + std::string user_2_client_certificate; + std::string authority_host; + std::string imds_endpoint; +}; + +static std::string get_mock_azure_addr() { + return tests::getenv_safe("MOCK_AZURE_VAULT_SERVER_HOST"); +} + +static unsigned long get_mock_azure_port() { + return std::stoul(tests::getenv_safe("MOCK_AZURE_VAULT_SERVER_PORT")); +} + +static future get_mock_azure_env(const tmpdir& tmp) { + co_return azure_test_env { + .key_name = fmt::format("http://{}:{}/mock-key", get_mock_azure_addr(), get_mock_azure_port()), + .tenant_id = "00000000-1111-2222-3333-444444444444", + .user_1_client_id = "mock-client-id", + .user_1_client_secret = "mock-client-secret", + .user_1_client_certificate = "test/resource/certs/scylla.pem", // a cert file with valid format - the contents won't be checked by the mock server + .user_2_client_id = "mock-client-id-invalid", + .user_2_client_secret = "mock-client-secret-invalid", + .user_2_client_certificate = "/dev/null", // a cert file with invalid format + .authority_host = fmt::format("http://{}:{}", get_mock_azure_addr(), get_mock_azure_port()), + .imds_endpoint = fmt::format("http://{}:{}", get_mock_azure_addr(), get_mock_azure_port()), + }; +} + +static azure_test_env get_real_azure_env() { + return azure_test_env { + .key_name = get_var_or_default("AZURE_KEY_NAME", ""), + .tenant_id = get_var_or_default("AZURE_TENANT_ID", ""), + .user_1_client_id = get_var_or_default("AZURE_USER_1_CLIENT_ID", ""), + .user_1_client_secret = get_var_or_default("AZURE_USER_1_CLIENT_SECRET", ""), + .user_1_client_certificate = get_var_or_default("AZURE_USER_1_CLIENT_CERTIFICATE", ""), + .user_2_client_id = get_var_or_default("AZURE_USER_2_CLIENT_ID", ""), + .user_2_client_secret = get_var_or_default("AZURE_USER_2_CLIENT_SECRET", ""), + .user_2_client_certificate = get_var_or_default("AZURE_USER_2_CLIENT_CERTIFICATE", ""), + .authority_host = "''", + .imds_endpoint = "''", + }; +} + +static future<> azure_test_helper(std::function(const tmpdir&, const azure_test_env&)> f, bool real_server = false) { + tmpdir tmp; + + auto env = real_server ? get_real_azure_env() : co_await get_mock_azure_env(tmp); + + if (real_server) { + if (env.key_name.empty()) { + BOOST_ERROR("No 'AZURE_KEY_NAME' provided"); + } + if (env.tenant_id.empty()) { + BOOST_ERROR("No 'AZURE_TENANT_ID' provided"); + } + if (env.user_1_client_id.empty() || env.user_1_client_secret.empty() || env.user_1_client_certificate.empty()) { + BOOST_ERROR("Missing or incompete credentials for user 1: All three of 'AZURE_USER_1_CLIENT_ID', 'AZURE_USER_1_CLIENT_SECRET' and 'AZURE_USER_1_CLIENT_CERTIFICATE' must be provided"); + } + if (env.user_2_client_id.empty() || env.user_2_client_secret.empty() || env.user_2_client_certificate.empty()) { + BOOST_ERROR("Missing or incompete credentials for user 2: All three of 'AZURE_USER_2_CLIENT_ID', 'AZURE_USER_2_CLIENT_SECRET' and 'AZURE_USER_2_CLIENT_CERTIFICATE' must be provided"); + } + } + + co_await f(tmp, env); +} + +static auto check_azure_mock_test_decorator() { + return check_run_test_decorator("ENABLE_AZURE_TEST", true); +} + +static auto check_azure_real_test_decorator() { + return boost::unit_test::precondition([](boost::unit_test::test_unit_id){ + return check_run_test("ENABLE_AZURE_TEST", true) + && check_run_test("ENABLE_AZURE_TEST_REAL", false); + }); +} + +SEASTAR_TEST_CASE(test_azure_provider_with_imds, *check_azure_mock_test_decorator()) { + co_await azure_test_helper([](const tmpdir& tmp, const azure_test_env& azure) -> future<> { + auto yaml = fmt::format(R"foo( + azure_hosts: + azure_test: + master_key: {0} + imds_endpoint: {1} + )foo" + , azure.key_name, azure.imds_endpoint + ); + + co_await test_provider("'key_provider': 'AzureKeyProviderFactory', 'azure_host': 'azure_test', 'cipher_algorithm':'AES/CBC/PKCS5Padding', 'secret_key_strength': 128", tmp, yaml); + }, false); +} + +static future<> do_test_azure_provider_with_secret(bool real_server) { + co_await azure_test_helper([](const tmpdir& tmp, const azure_test_env& azure) -> future<> { + auto yaml = fmt::format(R"foo( + azure_hosts: + azure_test: + master_key: {0} + azure_tenant_id: {1} + azure_client_id: {2} + azure_client_secret: {3} + azure_authority_host: {5} + )foo" + , azure.key_name, azure.tenant_id, azure.user_1_client_id, azure.user_1_client_secret, azure.user_1_client_certificate, azure.authority_host + ); + + co_await test_provider("'key_provider': 'AzureKeyProviderFactory', 'azure_host': 'azure_test', 'cipher_algorithm':'AES/CBC/PKCS5Padding', 'secret_key_strength': 128", tmp, yaml); + }, real_server); +} + +SEASTAR_TEST_CASE(test_azure_provider_with_secret, *check_azure_mock_test_decorator()) { + co_await do_test_azure_provider_with_secret(false); +} + +SEASTAR_TEST_CASE(test_azure_provider_with_secret_real, *check_azure_real_test_decorator()) { + co_await do_test_azure_provider_with_secret(true); +} + +static future<> do_test_azure_provider_with_certificate(bool real_server) { + co_await azure_test_helper([](const tmpdir& tmp, const azure_test_env& azure) -> future<> { + auto yaml = fmt::format(R"foo( + azure_hosts: + azure_test: + master_key: {0} + azure_tenant_id: {1} + azure_client_id: {2} + azure_client_certificate_path: {4} + azure_authority_host: {5} + )foo" + , azure.key_name, azure.tenant_id, azure.user_1_client_id, azure.user_1_client_secret, azure.user_1_client_certificate, azure.authority_host + ); + + co_await test_provider("'key_provider': 'AzureKeyProviderFactory', 'azure_host': 'azure_test', 'cipher_algorithm':'AES/CBC/PKCS5Padding', 'secret_key_strength': 128", tmp, yaml); + }, real_server); +} + +SEASTAR_TEST_CASE(test_azure_provider_with_certificate, *check_azure_mock_test_decorator()) { + co_await do_test_azure_provider_with_certificate(false); +} + +SEASTAR_TEST_CASE(test_azure_provider_with_certificate_real, *check_azure_real_test_decorator()) { + co_await do_test_azure_provider_with_certificate(true); +} + +SEASTAR_TEST_CASE(test_azure_provider_with_master_key_in_cf, *check_azure_mock_test_decorator()) { + co_await azure_test_helper([](const tmpdir& tmp, const azure_test_env& azure) -> future<> { + auto yaml = fmt::format(R"foo( + azure_hosts: + azure_test: + azure_tenant_id: {1} + azure_client_id: {2} + azure_client_secret: {3} + azure_authority_host: {5} + )foo" + , azure.key_name, azure.tenant_id, azure.user_1_client_id, azure.user_1_client_secret, azure.user_1_client_certificate, azure.authority_host + ); + + // should fail + BOOST_REQUIRE_EXCEPTION( + co_await test_provider("'key_provider': 'AzureKeyProviderFactory', 'azure_host': 'azure_test', 'cipher_algorithm':'AES/CBC/PKCS5Padding', 'secret_key_strength': 128", tmp, yaml) + , exceptions::configuration_exception, [](const exceptions::configuration_exception& e) { + try { + std::rethrow_if_nested(e); + } catch (const encryption::configuration_error& inner) { + return sstring(inner.what()).find("No master key set") != sstring::npos; + } catch (...) { + return false; // Unexpected nested exception type + } + return false; // No nested exception + } + ); + + // should be ok + co_await test_provider(fmt::format("'key_provider': 'AzureKeyProviderFactory', 'azure_host': 'azure_test', 'master_key': '{}', 'cipher_algorithm':'AES/CBC/PKCS5Padding', 'secret_key_strength': 128", azure.key_name) + , tmp, yaml + ); + }, false); +} + +SEASTAR_TEST_CASE(test_azure_provider_with_no_host, *check_azure_mock_test_decorator()) { + co_await azure_test_helper([](const tmpdir& tmp, const azure_test_env& azure) -> future<> { + auto yaml = R"foo( + azure_hosts: + )foo"; + + // should fail + BOOST_REQUIRE_EXCEPTION( + co_await test_provider("'key_provider': 'AzureKeyProviderFactory', 'azure_host': 'azure_test', 'cipher_algorithm':'AES/CBC/PKCS5Padding', 'secret_key_strength': 128", tmp, yaml) + , std::invalid_argument + , [](const std::invalid_argument& e) { + return sstring(e.what()).find("No such host") != sstring::npos; + } + ); + }, false); +} + +/** + * Verify that the Azure key provider throws if the provided Service Principal + * credentials are incomplete. The provider will first fall back to the default + * credentials source to detect credentials from the system (env vars, Azure CLI, + * IMDS), and only after all these attempts fail will it throw. + * + * Note: Just in case we ever run these tests on Azure VMs, use a non-routable + * IP address for the IMDS endpoint to ensure the connection will fail. + */ +SEASTAR_TEST_CASE(test_azure_provider_with_incomplete_creds, *check_azure_mock_test_decorator()) { + co_await azure_test_helper([](const tmpdir& tmp, const azure_test_env& azure) -> future<> { + auto yaml = fmt::format(R"foo( + azure_hosts: + azure_test: + master_key: {0} + azure_tenant_id: {1} + azure_client_id: {2} + imds_endpoint: http://192.0.2.1:80 + )foo" + , azure.key_name, azure.tenant_id, azure.user_1_client_id, azure.user_1_client_secret, azure.user_1_client_certificate + ); + + // should fail + BOOST_REQUIRE_EXCEPTION( + co_await test_provider("'key_provider': 'AzureKeyProviderFactory', 'azure_host': 'azure_test', 'cipher_algorithm':'AES/CBC/PKCS5Padding', 'secret_key_strength': 128", tmp, yaml) + , permission_error, [](const encryption::permission_error& e) { + try { + std::rethrow_if_nested(e); + } catch (const azure::auth_error& inner) { + return sstring(inner.what()).find("No credentials found in any source.") != sstring::npos; + } catch (...) { + return false; // Unexpected nested exception type + } + return false; // No nested exception + } + ); + }, false); +} + +static future<> do_test_azure_provider_with_invalid_key(bool real_server) { + co_await azure_test_helper([](const tmpdir& tmp, const azure_test_env& azure) -> future<> { + auto vault = azure.key_name.substr(0, azure.key_name.find_last_of('/')); + auto master_key = fmt::format("{}/nonexistentkey", vault); + auto yaml = fmt::format(R"foo( + azure_hosts: + azure_test: + master_key: {0} + azure_tenant_id: {1} + azure_client_id: {2} + azure_client_secret: {3} + azure_authority_host: {5} + )foo" + , master_key, azure.tenant_id, azure.user_1_client_id, azure.user_1_client_secret, azure.user_1_client_certificate, azure.authority_host + ); + + // should fail + BOOST_REQUIRE_EXCEPTION( + co_await test_provider("'key_provider': 'AzureKeyProviderFactory', 'azure_host': 'azure_test', 'cipher_algorithm':'AES/CBC/PKCS5Padding', 'secret_key_strength': 128", tmp, yaml) + , service_error, [](const service_error& e) { + try { + std::rethrow_if_nested(e); + } catch (const encryption::vault_error& inner) { + // Both error codes are valid depending on the scope of the role assignment: + // - "Forbidden": key-scoped permissions + // - "KeyNotFound": vault-scoped permissions + return inner.code() == "Forbidden" || inner.code() == "KeyNotFound"; + } catch (...) { + return false; // Unexpected nested exception type + } + return false; // No nested exception + } + ); + }, real_server); +} + +SEASTAR_TEST_CASE(test_azure_provider_with_invalid_key, *check_azure_mock_test_decorator()) { + co_await do_test_azure_provider_with_invalid_key(false); +} + +SEASTAR_TEST_CASE(test_azure_provider_with_invalid_key_real, *check_azure_real_test_decorator()) { + co_await do_test_azure_provider_with_invalid_key(true); +} + +/** + * Verify that trying to access key materials with a user w/o permissions to wrap/unwrap using vault + * fails. + */ +static future<> do_test_azure_provider_with_invalid_user(bool real_server) { + co_await azure_test_helper([](const tmpdir& tmp, const azure_test_env& azure) -> future<> { + auto yaml = fmt::format(R"foo( + azure_hosts: + azure_test: + master_key: {0} + azure_tenant_id: {1} + azure_client_id: {2} + azure_client_secret: {3} + azure_authority_host: {5} + )foo" + , azure.key_name, azure.tenant_id, azure.user_2_client_id, azure.user_2_client_secret, azure.user_2_client_certificate, azure.authority_host + ); + + // should fail + BOOST_REQUIRE_EXCEPTION( + co_await test_provider("'key_provider': 'AzureKeyProviderFactory', 'azure_host': 'azure_test', 'cipher_algorithm':'AES/CBC/PKCS5Padding', 'secret_key_strength': 128", tmp, yaml) + , service_error, [](const service_error& e) { + try { + std::rethrow_if_nested(e); + } catch (const encryption::vault_error& inner) { + return inner.code() == "Forbidden"; + } catch (...) { + return false; // Unexpected nested exception type + } + return false; // No nested exception + } + ); + }, real_server); +} + +SEASTAR_TEST_CASE(test_azure_provider_with_invalid_user, *check_azure_mock_test_decorator()) { + co_await do_test_azure_provider_with_invalid_user(false); +} + +SEASTAR_TEST_CASE(test_azure_provider_with_invalid_user_real, *check_azure_real_test_decorator()) { + co_await do_test_azure_provider_with_invalid_user(true); +} + +/** + * Verify that the secret has higher precedence that the certificate. + * Use the wrong user's certificate to make sure it causes the test to fail. + */ +static future<> do_test_azure_provider_with_both_secret_and_cert(bool real_server) { + co_await azure_test_helper([](const tmpdir& tmp, const azure_test_env& azure) -> future<> { + auto yaml = fmt::format(R"foo( + azure_hosts: + azure_test: + azure_tenant_id: {1} + azure_client_id: {2} + azure_client_secret: {3} + azure_client_certificate_path: {4} + azure_authority_host: {5} + )foo" + , azure.key_name, azure.tenant_id, azure.user_1_client_id, azure.user_1_client_secret, azure.user_2_client_certificate, azure.authority_host + ); + + // should be ok + co_await test_provider(fmt::format("'key_provider': 'AzureKeyProviderFactory', 'azure_host': 'azure_test', 'master_key': '{}', 'cipher_algorithm':'AES/CBC/PKCS5Padding', 'secret_key_strength': 128", azure.key_name) + , tmp, yaml + ); + }, real_server); +} + +SEASTAR_TEST_CASE(test_azure_provider_with_both_secret_and_cert, *check_azure_mock_test_decorator()) { + co_await do_test_azure_provider_with_both_secret_and_cert(false); +} + +SEASTAR_TEST_CASE(test_azure_provider_with_both_secret_and_cert_real, *check_azure_real_test_decorator()) { + co_await do_test_azure_provider_with_both_secret_and_cert(true); +} + +SEASTAR_TEST_CASE(test_azure_network_error, *check_azure_mock_test_decorator()) { + co_await azure_test_helper([&](const tmpdir& tmp, const azure_test_env& azure) -> future<> { + auto host_endpoint = fmt::format("{}:{}", get_mock_azure_addr(), get_mock_azure_port()); + auto key = azure.key_name.substr(azure.key_name.find_last_of('/') + 1); + co_await network_error_test_helper(tmp, host_endpoint, [&](const auto& proxy) { + auto yaml = fmt::format(R"foo( + azure_hosts: + azure_test: + master_key: http://{0}/{1} + azure_tenant_id: {2} + azure_client_id: {3} + azure_client_secret: {4} + azure_authority_host: {5} + key_cache_expiry: 1ms + )foo" + , proxy.address(), key, azure.tenant_id, azure.user_1_client_id, azure.user_1_client_secret, azure.authority_host + ); + return std::make_tuple(scopts_map({ { "key_provider", "AzureKeyProviderFactory" }, { "azure_host", "azure_test" } }), yaml); + }); + }, false); +} + +/* + * Utility function to spawn a dedicated mock server instance for a particular test case. + * Useful for tests that require error injection, where the server's global state needs to be configured accordingly. + * Since test.py may run tests in parallel, using the global server instance is not safe for such tests. + * + * The code was based on `kmip_test_helper()`. + */ +static future<> with_dedicated_azure_mock_server(const std::function(std::string host, unsigned int port)>& f) { + tmpdir tmp; + + auto pyexec = tests::proc::find_file_in_path("python"); + + promise> authority_promise; + auto fut = authority_promise.get_future(); + + BOOST_TEST_MESSAGE("Starting dedicated Azure Vault mock server"); + + auto python = co_await tests::proc::process_fixture::create(pyexec, + { // args + pyexec.string(), + "test/pylib/start_azure_vault_mock.py", + "--log-level", "INFO", + "--host", get_var_or_default("MOCK_AZURE_VAULT_SERVER_HOST", "127.0.0.1"), + "--port", "0", // random port + }, + // env + {}, + // stdout handler + tests::proc::process_fixture::create_copy_handler(std::cout), + // stderr handler + [authority_promise = std::move(authority_promise), b = false](std::string_view line) mutable -> future> { + static std::regex authority_ex(R"foo(Starting Azure Vault mock server on \('([\d\.]+)', (\d+)\))foo"); + + std::cerr << line << std::endl; + std::match_results m; + if (!b && std::regex_search(line.begin(), line.end(), m, authority_ex)) { + authority_promise.set_value(std::make_pair(m[1].str(), std::stoi(m[2].str()))); + BOOST_TEST_MESSAGE("Matched Azure Vault host and port: " + m[1].str() + ":" + m[2].str()); + b = true; + } + co_return continue_consuming{}; + } + ); + + std::exception_ptr ep; + + try { + // arbitrary timeout of 20s for the server to make some output. Very generous. + auto [host, port] = co_await with_timeout(std::chrono::steady_clock::now() + 20s, std::move(fut)); + + // wait for port. + auto sleep_interval = 100ms; + auto timeout = 5s; + auto end_time = seastar::lowres_clock::now() + timeout; + bool connected = false; + while (seastar::lowres_clock::now() < end_time) { + BOOST_TEST_MESSAGE(fmt::format("Connecting to {}:{}", host, port)); + try { + // TODO: seastar does not have a connect with timeout. That would be helpful here. But alas... + co_await seastar::connect(socket_address(net::inet_address(host), uint16_t(port))); + BOOST_TEST_MESSAGE("Dedicated Azure Vault mock server up and available"); + connected = true; + break; + } catch (...) { + } + co_await sleep(sleep_interval); + } + + if (!connected) { + throw std::runtime_error(fmt::format("Timed out connecting to Azure Vault mock server at {}:{}", host, port)); + } + + co_await f(host, port); + + } catch (timed_out_error&) { + ep = std::make_exception_ptr(std::runtime_error("Could not start dedicated Azure Vault mock server")); + } catch (...) { + ep = std::current_exception(); + } + + BOOST_TEST_MESSAGE("Stopping dedicated Azure Vault mock server"); + + python.terminate(); + co_await python.wait(); + + if (ep) { + std::rethrow_exception(ep); + } +} + +static future<> configure_azure_mock_server(const std::string& host, const unsigned int port, const std::string& service, const std::string& error_type, int repeat) { + auto cln = http::experimental::client(socket_address(net::inet_address(host), uint16_t(port))); + auto close_client = deferred_close(cln); + auto req = http::request::make("POST", host, "/config/error"); + req._headers["Content-Length"] = "0"; + req.query_parameters["service"] = service; + req.query_parameters["error_type"] = error_type; + req.query_parameters["repeat"] = std::to_string(repeat); + co_await cln.make_request(std::move(req), [](const http::reply&, input_stream&&) -> future<> { return seastar::make_ready_future(); }); +} + +SEASTAR_TEST_CASE(test_imds, *check_azure_mock_test_decorator()) { + co_await with_dedicated_azure_mock_server([](std::string host, unsigned int port) -> future<> { + // Create new credential object for each test case because it caches the token. + { + testlog.info("Testing IMDS success path"); + azure::managed_identity_credentials creds { fmt::format("{}:{}", host, port) }; + co_await creds.get_access_token("https://vault.azure.net/.default"); + } + + { + testlog.info("Testing IMDS transient errors"); + azure::managed_identity_credentials creds { fmt::format("{}:{}", host, port) }; + co_await configure_azure_mock_server(host, port, "imds", "InternalError", 1); + // expected to not throw + co_await creds.get_access_token("https://vault.azure.net/.default"); + } + + { + testlog.info("Testing IMDS non-transient errors"); + azure::managed_identity_credentials creds { fmt::format("{}:{}", host, port) }; + co_await configure_azure_mock_server(host, port, "imds", "NoIdentity", 1); + BOOST_REQUIRE_THROW( + co_await creds.get_access_token("https://vault.azure.net/.default"), + azure::creds_auth_error + ); + } + }); +} + +SEASTAR_TEST_CASE(test_entra_sts, *check_azure_mock_test_decorator()) { + co_await with_dedicated_azure_mock_server([](std::string host, unsigned int port) -> future<> { + auto make_entra_creds = [&] { + return azure::service_principal_credentials { + "00000000-1111-2222-3333-444444444444", + "mock-client-id", + "mock-client-secret", + "", + fmt::format("http://{}:{}", host, port), + }; + }; + + // Create new credential object for each test case because it caches the token. + { + testlog.info("Testing Entra STS success path"); + auto creds = make_entra_creds(); + co_await creds.get_access_token("https://vault.azure.net/.default"); + } + + { + testlog.info("Testing Entra STS transient errors"); + auto creds = make_entra_creds(); + co_await configure_azure_mock_server(host, port, "entra", "TemporarilyUnavailable", 1); + // expected to not throw + co_await creds.get_access_token("https://vault.azure.net/.default"); + } + + { + testlog.info("Testing Entra STS non-transient errors"); + auto creds = make_entra_creds(); + co_await configure_azure_mock_server(host, port, "entra", "InvalidSecret", 1); + BOOST_REQUIRE_THROW( + co_await creds.get_access_token("https://vault.azure.net/.default"), + azure::creds_auth_error + ); + } + }); +} + +SEASTAR_TEST_CASE(test_azure_host, *check_azure_mock_test_decorator()) { + co_await with_dedicated_azure_mock_server([](std::string host, unsigned int port) -> future<> { + key_info kinfo { .alg = "AES/CBC/PKCS5Padding", .len = 128}; + azure_host::host_options options { + .imds_endpoint = fmt::format("http://{}:{}", host, port), + .master_key = fmt::format("http://{}:{}/test-key", host, port), + }; + + { + testlog.info("Testing Key Vault success path"); + azure_host azhost {"azure_test", options}; + co_await azhost.get_or_create_key(kinfo, nullptr); + } + + { + testlog.info("Testing Key Vault transient errors"); + azure_host azhost {"azure_test", options}; + co_await configure_azure_mock_server(host, port, "vault", "Throttled", 1); + co_await azhost.get_or_create_key(kinfo, nullptr); + } + + { + testlog.info("Testing Key Vault non-transient errors"); + azure_host azhost {"azure_test", options}; + co_await configure_azure_mock_server(host, port, "vault", "Forbidden", 1); + BOOST_REQUIRE_THROW( + co_await azhost.get_or_create_key(kinfo, nullptr), + encryption::service_error + ); + } + }); +} diff --git a/test/boost/index_reader_test.cc b/test/boost/index_reader_test.cc index f42cd29c0212..ddab800f4bb5 100644 --- a/test/boost/index_reader_test.cc +++ b/test/boost/index_reader_test.cc @@ -83,7 +83,7 @@ SEASTAR_TEST_CASE(test_promoted_index_parsing_page_crossing_and_retries) { auto index = std::make_unique(sst, permit, trace, use_caching::yes, true); auto close_index = deferred_close(*index); - index->advance_to(dht::ring_position_view(pk)).get(); + index->advance_to_definitely_present_partition(pk).get(); index->read_partition_data().get(); auto cur = dynamic_cast(index->current_clustered_cursor()); @@ -196,12 +196,12 @@ SEASTAR_TEST_CASE(test_no_data_file_read_on_missing_clustering_keys_with_dense_i auto index = std::make_unique(sst, permit, trace, use_caching::yes, true); auto close_index = deferred_close(*index); - index->advance_to(dht::ring_position_view(pk)).get(); + index->advance_to_definitely_present_partition(pk).get(); index->read_partition_data().get(); auto index2 = std::make_unique(sst, permit, trace, use_caching::yes, true); auto close_index2 = deferred_close(*index2); - index2->advance_to(dht::ring_position_view(pk)).get(); + index2->advance_to_definitely_present_partition(pk).get(); index2->advance_to_next_partition().get(); auto next_partition_offset = index2->data_file_positions().start; diff --git a/test/boost/keys_test.cc b/test/boost/keys_test.cc index a2ebfe0659ba..d8978a423b39 100644 --- a/test/boost/keys_test.cc +++ b/test/boost/keys_test.cc @@ -9,7 +9,7 @@ #define BOOST_TEST_MODULE core #include -#include "keys.hh" +#include "keys/keys.hh" #include "schema/schema.hh" #include "schema/schema_builder.hh" #include "types/types.hh" diff --git a/test/boost/locator_topology_test.cc b/test/boost/locator_topology_test.cc index 4cc068a88a38..f9c66d2458d2 100644 --- a/test/boost/locator_topology_test.cc +++ b/test/boost/locator_topology_test.cc @@ -287,7 +287,7 @@ SEASTAR_THREAD_TEST_CASE(test_load_sketch) { node3_shards[1]++; auto table = table_id(utils::make_random_uuid()); - tab_meta.set_tablet_map(table, tmap); + tab_meta.set_tablet_map(table, std::move(tmap)); tm.set_tablets(std::move(tab_meta)); return make_ready_future<>(); }).get(); diff --git a/test/boost/mutation_test.cc b/test/boost/mutation_test.cc index 2ea4b9d30059..ac9f82466543 100644 --- a/test/boost/mutation_test.cc +++ b/test/boost/mutation_test.cc @@ -27,7 +27,7 @@ #include "replica/database.hh" #include "utils/UUID_gen.hh" -#include "clustering_interval_set.hh" +#include "keys/clustering_interval_set.hh" #include "schema/schema_builder.hh" #include "query-result-set.hh" #include "query-result-reader.hh" @@ -58,7 +58,7 @@ #include "mutation/mutation_rebuilder.hh" #include "mutation/mutation_partition.hh" #include "mutation/async_utils.hh" -#include "clustering_key_filter.hh" +#include "keys/clustering_key_filter.hh" #include "readers/from_mutations.hh" #include "readers/from_fragments.hh" diff --git a/test/boost/mvcc_test.cc b/test/boost/mvcc_test.cc index cfb3feda656a..6da7eb3d0dde 100644 --- a/test/boost/mvcc_test.cc +++ b/test/boost/mvcc_test.cc @@ -15,7 +15,7 @@ #include "mutation/partition_version.hh" #include "db/partition_snapshot_row_cursor.hh" #include "partition_snapshot_reader.hh" -#include "clustering_interval_set.hh" +#include "keys/clustering_interval_set.hh" #include "test/lib/scylla_test_case.hh" #include diff --git a/test/boost/network_topology_strategy_test.cc b/test/boost/network_topology_strategy_test.cc index 7a7012aa482b..cc31bbeae224 100644 --- a/test/boost/network_topology_strategy_test.cc +++ b/test/boost/network_topology_strategy_test.cc @@ -526,7 +526,7 @@ SEASTAR_THREAD_TEST_CASE(NetworkTopologyStrategy_tablets_test) { "NetworkTopologyStrategy", params); auto realloc_tab_awr_ptr = realloc_ars_ptr->maybe_as_tablet_aware(); BOOST_REQUIRE(realloc_tab_awr_ptr); - auto realloc_tmap = tab_awr_ptr->reallocate_tablets(s, stm.get(), tmap).get(); + auto realloc_tmap = tab_awr_ptr->reallocate_tablets(s, stm.get(), std::move(tmap)).get(); full_ring_check(realloc_tmap, realloc_ars_ptr, stm.get()); } } @@ -622,7 +622,7 @@ static void test_random_balancing(sharded& snitch, gms::inet_address auto inc_nts_ptr = dynamic_cast(inc_ars_ptr.get()); auto inc_tab_awr_ptr = inc_ars_ptr->maybe_as_tablet_aware(); BOOST_REQUIRE(inc_tab_awr_ptr); - auto inc_tmap = inc_tab_awr_ptr->reallocate_tablets(s, tmptr, tmap).get(); + auto inc_tmap = inc_tab_awr_ptr->reallocate_tablets(s, tmptr, tmap.clone_gently().get()).get(); full_ring_check(inc_tmap, ars_ptr, stm.get()); check_tablets_balance(inc_tmap, inc_nts_ptr, topo); } @@ -636,7 +636,7 @@ static void test_random_balancing(sharded& snitch, gms::inet_address auto dec_nts_ptr = dynamic_cast(dec_ars_ptr.get()); auto dec_tab_awr_ptr = dec_ars_ptr->maybe_as_tablet_aware(); BOOST_REQUIRE(dec_tab_awr_ptr); - auto dec_tmap = dec_tab_awr_ptr->reallocate_tablets(s, tmptr, tmap).get(); + auto dec_tmap = dec_tab_awr_ptr->reallocate_tablets(s, tmptr, tmap.clone_gently().get()).get(); full_ring_check(dec_tmap, ars_ptr, stm.get()); check_tablets_balance(dec_tmap, dec_nts_ptr, topo); } @@ -1309,10 +1309,9 @@ SEASTAR_THREAD_TEST_CASE(tablets_simple_rack_aware_view_pairing_test) { auto view_tmap = tab_awr_ptr->allocate_tablets_for_new_table(view_schema, tmptr, 1).get(); testlog.debug("view_table_id={}", view_table_id); - stm.mutate_token_metadata([&] (token_metadata& tm) { - tm.tablets().set_tablet_map(base_table_id, base_tmap); - tm.tablets().set_tablet_map(view_table_id, view_tmap); - return make_ready_future(); + stm.mutate_token_metadata([&] (token_metadata& tm) -> future<> { + tm.tablets().set_tablet_map(base_table_id, co_await base_tmap.clone_gently()); + tm.tablets().set_tablet_map(view_table_id, co_await view_tmap.clone_gently()); }).get(); tmptr = stm.get(); @@ -1463,10 +1462,9 @@ void test_complex_rack_aware_view_pairing_test(bool more_or_less) { auto view_tmap = tab_awr_ptr->allocate_tablets_for_new_table(view_schema, tmptr, 1).get(); testlog.debug("view_table_id={}", view_table_id); - stm.mutate_token_metadata([&] (token_metadata& tm) { - tm.tablets().set_tablet_map(base_table_id, base_tmap); - tm.tablets().set_tablet_map(view_table_id, view_tmap); - return make_ready_future(); + stm.mutate_token_metadata([&] (token_metadata& tm) -> future<> { + tm.tablets().set_tablet_map(base_table_id, co_await base_tmap.clone_gently()); + tm.tablets().set_tablet_map(view_table_id, co_await view_tmap.clone_gently()); }).get(); tmptr = stm.get(); diff --git a/test/boost/range_tombstone_list_test.cc b/test/boost/range_tombstone_list_test.cc index 58f989013cee..01f6d4ba8699 100644 --- a/test/boost/range_tombstone_list_test.cc +++ b/test/boost/range_tombstone_list_test.cc @@ -11,7 +11,7 @@ #include #include #include -#include "keys.hh" +#include "keys/keys.hh" #include "schema/schema_builder.hh" #include "mutation/range_tombstone_list.hh" #include "test/boost/range_tombstone_list_assertions.hh" diff --git a/test/boost/sstable_3_x_test.cc b/test/boost/sstable_3_x_test.cc index be111cb7e664..8fa029bbacb1 100644 --- a/test/boost/sstable_3_x_test.cc +++ b/test/boost/sstable_3_x_test.cc @@ -32,7 +32,7 @@ #include "test/lib/random_utils.hh" #include "test/lib/test_utils.hh" #include "sstables/types.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "types/types.hh" #include "types/user.hh" #include "partition_slice_builder.hh" diff --git a/test/boost/sstable_inexact_index_test.cc b/test/boost/sstable_inexact_index_test.cc new file mode 100644 index 000000000000..58732d26cf3f --- /dev/null +++ b/test/boost/sstable_inexact_index_test.cc @@ -0,0 +1,607 @@ +/* + * Copyright (C) 2025-present ScyllaDB + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#include +#include "sstables/mx/reader.hh" +#include "test/lib/mutation_reader_assertions.hh" +#include "test/lib/simple_schema.hh" +#include "test/lib/sstable_utils.hh" +#include "utils/assert.hh" +#include "test/lib/sstable_test_env.hh" +#include "partition_slice_builder.hh" + +// We'd like to test edge cases in the sstable reader, +// related to the handling of inexact index answers. +// (I.e. a BTI index query may return a result off by one partition +// and the reader has to handle that). +// +// Triggering those edge cases might require a specific sequence of calls to the reader +// (e.g. advancing the reader to a range, and then advancing it to a later range which +// touches the first one) +// combined with a specific sequence of answers from the index +// (e.g. first answer being accurate and second answer being off by one). +// +// Open-coding an exhaustive list of such scenarios is impractical. +// But we can write the test in a nondeterministic form: +// +// 1. Whenever queried, let the index return *some* answer within the contract. +// 2. Advance the reader to *some* range, check output. +// 3. If possible, advance the reader to *some* range later than the first one, check output. +// +// and drive this syntax by rerunning the test scenario with +// all possible values for *some* explored. +// +// This is an API for writing such nondeterministic tests. +// +// The intended usage is: +// +// ``` +// nondeterministic_choice_stack ncs; +// do { +// switch (ncs.choose_up_to(3)) { +// // Fork the test scenario into 3 possible paths... +// ... +// } +// } while (ncs.rewind()); +// ``` +class nondeterministic_choice_stack { + std::vector> choices; + int next_frame = 0; +public: + // Nondeterministically choose a non-negative integer. + // This should be followed by mark_last_choice() if + // the returned integer is the maximum value which should be chosen. + // + // Can be used if the number of possible choices isn't known in advance. + // For most uses, choose_up_to(n) is powerful enough, and is easier to use. + int choose() { + if (next_frame >= int(choices.size())) { + choices.emplace_back(0, false); + } else if (next_frame + 1 == int(choices.size())) { + choices.back().first += 1; + } + return choices[next_frame++].first; + } + // Mark the last `choose()` fork as fully explored. + void mark_last_choice() { + SCYLLA_ASSERT(next_frame > 0); + if (next_frame != int(choices.size())) { + SCYLLA_ASSERT(choices[next_frame - 1].second); + } + choices[next_frame - 1].second = true; + } + // Nondeterministically choose an integer in range [0, n] + int choose_up_to(int n) { + int result = choose(); + if (result >= n) { + mark_last_choice(); + } + return result; + } + // Nondeterministically choose a boolean. + bool choose_bool() { + return choose_up_to(1); + } + // Proceed to the next iteration of the test. + bool rewind() { + while (!choices.empty() && choices.back().second) { + choices.pop_back(); + } + next_frame = 0; + return choices.size() > 0; + } +}; + +// Return the positions of all clustering key blocks (known to the index) +// in the Data file. +std::vector get_ck_positions(shared_sstable sst, reader_permit permit) { + std::vector result; + auto abstract_r = sst->make_index_reader(permit, nullptr, use_caching::no, false); + auto& r = dynamic_cast(*abstract_r); + r.advance_to(dht::partition_range::make_open_ended_both_sides()).get(); + r.read_partition_data().get(); + auto close_ir = deferred_close(r); + + while (!r.eof()) { + auto partition_pos = r.data_file_positions().start; + sstables::clustered_index_cursor* cur = r.current_clustered_cursor(); + while (auto ei_opt = cur->next_entry().get()) { + result.push_back(ei_opt.value().offset + partition_pos); + } + r.advance_to_next_partition().get(); + } + return result; +} + +// Return the positions of all partitions in the Data file. +std::vector get_partition_positions(shared_sstable sst, reader_permit permit) { + std::vector result; + { + auto ir = sst->make_index_reader(permit, nullptr, use_caching::no, false); + ir->advance_to(dht::partition_range::make_open_ended_both_sides()).get(); + ir->read_partition_data().get(); + auto close_ir = deferred_close(*ir); + while (!ir->eof()) { + result.push_back(ir->data_file_positions().start); + ir->advance_to_next_partition().get(); + } + } + return result; +} + +// An in-memory implementation of a sstable index, +// which can be off by one partition on queries, +// and might not be able to provide the partition key for the target partition. +// +// (I.e. it behaves like index which by default only stores key prefixes +// instead of whole keys). +// +// It is nondeterministic whether the index query resutls are exact or inexact, +// and whether the partition key is known or not. +// This allows for checking if the reader correctly handles either case. +struct inexact_partition_index : abstract_index_reader { + nondeterministic_choice_stack& _ncs; + std::vector _positions; + std::span _pk_positions; + std::span _ck_positions; + std::span _pks; + std::span _cks; + schema_ptr _s; + + int _lower = 0; + int _upper = 0; + + inexact_partition_index( + nondeterministic_choice_stack& ncs, + std::span pk_positions, + std::span ck_positions, + std::span pks, + std::span cks, + schema_ptr s) + : _ncs(ncs) + , _pk_positions(pk_positions) + , _ck_positions(ck_positions) + , _pks(pks) + , _cks(cks) + , _s(s) { + _positions.insert(_positions.end(), ck_positions.begin(), ck_positions.end()); + _positions.insert(_positions.end(), pk_positions.begin(), pk_positions.end()); + std::ranges::sort(_positions); + _upper = _positions.size() - 1; + } + + future<> close() noexcept override { + return make_ready_future<>(); + } + data_file_positions_range data_file_positions() const override { + return {_positions[_lower], _positions[_upper]}; + } + future> last_block_offset() override { + abort(); + } + future advance_lower_and_check_if_present(dht::ring_position_view rpv) override { + auto cmp = dht::ring_position_less_comparator(*_s); + auto pk_idx = std::ranges::lower_bound(_pks, rpv, cmp) - _pks.begin(); + if (dht::ring_position_comparator(*_s)(_pks[pk_idx], rpv) != 0) { + if (pk_idx != 0) { + if (_ncs.choose_bool()) { + // The index *might* conservatively return a partition immediately before the target. + pk_idx -= 1; + } + } + } + _lower = pk_idx_to_pos_idx(pk_idx); + testlog.trace("(true); + } + future<> advance_upper_past(position_in_partition_view pos) override { + auto cmp = position_in_partition::less_compare(*_s); + auto pk_idx = pos_idx_to_pk_idx(_lower); + _upper = pk_idx_to_pos_idx(pk_idx); + size_t ck_idx = std::ranges::upper_bound(_cks, pos, cmp) - _cks.begin(); + _upper = (ck_idx == _cks.size()) ? _upper + 1 : _upper + ck_idx; + testlog.trace("inexact_partition_index/advance_upper_past: _lower={}, _upper={}, pk_idx={}", _lower, _upper, pk_idx); + return make_ready_future<>(); + } + future<> advance_to_next_partition() override { + auto pk_idx = pos_idx_to_pk_idx(_lower); + testlog.trace(">inexact_partition_index/advance_to_next_partition: _lower={}, _upper={}, pk_idx={}", _lower, _upper, pk_idx); + pk_idx += 1; + _lower = pk_idx_to_pos_idx(pk_idx); + testlog.trace("(); + } + indexable_element element_kind() const override { + if (eof()) { + return indexable_element::partition; + } + if (_positions[_lower] == *std::ranges::lower_bound(_pk_positions, _positions[_lower])) { + return indexable_element::partition; + } else { + return indexable_element::cell; + } + } + uint64_t pk_idx_to_pos_idx(uint64_t pk_idx) { + return std::ranges::lower_bound(_positions, _pk_positions[pk_idx]) - _positions.begin(); + } + uint64_t pos_idx_to_pk_idx(uint64_t pos_idx) { + return std::ranges::upper_bound(_pk_positions, _positions[pos_idx]) - _pk_positions.begin() - 1; + } + future<> advance_past_definitely_present_partition(const dht::decorated_key& dk) override { + auto cmp = dht::ring_position_less_comparator(*_s); + size_t pk_idx = std::ranges::lower_bound(_pks, dk, cmp) - _pks.begin() + 1; + _lower = pk_idx_to_pos_idx(pk_idx); + testlog.trace("inexact_partition_index/advance_past_definitely_present_partition: _lower={}, _upper={}, pk_idx={}", _lower, _upper, pk_idx); + return make_ready_future<>(); + } + future<> advance_to_definitely_present_partition(const dht::decorated_key& dk) override { + auto cmp = dht::ring_position_less_comparator(*_s); + size_t pk_idx = std::ranges::lower_bound(_pks, dk, cmp) - _pks.begin(); + _lower = pk_idx_to_pos_idx(pk_idx); + testlog.trace("inexact_partition_index/advance_to_definitely_present_partition: _lower={}, _upper={}, pk_idx={}", _lower, _upper, pk_idx); + return make_ready_future<>(); + } + future<> advance_to(position_in_partition_view pos) override { + auto cmp = position_in_partition::less_compare(*_s); + auto pk_idx = pos_idx_to_pk_idx(_lower); + testlog.trace(">inexact_partition_index/advance_to(pipv={}): _lower={}, _upper={}, pk_idx={}", pos, _lower, _upper, pk_idx); + size_t ck_idx = std::max(0, std::ranges::upper_bound(_cks, pos, cmp) - _cks.begin() - 1); + _lower = pk_idx * (_cks.size() + 1) + ck_idx; + testlog.trace("(); + } + std::optional partition_tombstone() override { + return sstables::deletion_time { + std::numeric_limits::max(), + std::numeric_limits::min(), + }; + } + std::optional get_partition_key() override { + if (_ncs.choose_bool()) { + // The index *might* have the PK present. + return _pks[pos_idx_to_pk_idx(_lower)].key(); + } + return std::nullopt; + } + bool partition_data_ready() const override { + return true; + } + future<> read_partition_data() override { + return make_ready_future<>(); + } + future<> advance_reverse(position_in_partition_view pos) override { + abort(); + } + future<> advance_to(const dht::partition_range& pr) override { + auto cmp = dht::ring_position_less_comparator(*_s); + if (pr.start()) { + auto rpv = dht::ring_position_view::for_range_start(pr); + auto pk_idx = std::ranges::lower_bound(_pks, rpv, cmp) - _pks.begin(); + if (pk_idx != 0) { + if (_ncs.choose_bool()) { + // The index *might* conservatively return a partition immediately before the target. + pk_idx -= 1; + } + } + _lower = pk_idx_to_pos_idx(pk_idx); + } else { + _lower = 0; + } + if (pr.end()) { + auto rpv = dht::ring_position_view::for_range_end(pr); + size_t pk_idx = std::ranges::upper_bound(_pks, rpv, cmp) - _pks.begin(); + if (pk_idx != _pks.size()) { + if (_ncs.choose_bool()) { + // The index *might* conservatively return a partition immediately after the target. + pk_idx += 1; + } + } + _upper = pk_idx_to_pos_idx(pk_idx); + } else { + _upper = _positions.size() - 1; + } + testlog.trace("inexact_partition_index/advance_to(pr={}), _lower={}, _upper={}", pr, _lower, _upper); + return make_ready_future<>(); + } + future<> advance_reverse_to_next_partition() override { + abort(); + } + std::optional end_open_marker() const override { + return std::nullopt; + } + std::optional reverse_end_open_marker() const override { + abort(); + } + future<> prefetch_lower_bound(position_in_partition_view pos) override { + return make_ready_future<>(); + } + bool eof() const override { + return _lower == int(_positions.size() - 1); + } +}; + +// Test design: +// - Write an sstable. +// - Using a key universe which includes the written keys and some extra "in-between" keys, +// nondeterministically: +// - Construct a reader (using a nondeterministically-inexact index) on *some* partition range, +// check that it outputs the right mutation fragments. +// - Optionally, forward it to *some* range after the first one, check that the reader +// outputs the right mutation fragments. +SEASTAR_TEST_CASE(test_inexact_partition_index_range_query) { + return sstables::test_env::do_with_async([] (sstables::test_env& env) { + // Force a clustering block for every row, + // so that the non-full slice induces parser skips to the middle of the partition. + env.manager().set_promoted_index_block_size(1); + auto row_value = std::string(1024, 'a'); + + simple_schema table; + auto permit = env.make_reader_permit(); + + // Having more keys makes the test stronger but increases + // the number of test paths exponentially. + // + // 3 keys make the test too long for CI. + constexpr bool hard_mode = false; + + // Number of all possible partition keys. + size_t n_all_pks; + // Indices of present partitions in the list of possible keys. + std::vector present_pk_indices; + + if (hard_mode) { + n_all_pks = 7; + present_pk_indices = std::vector{1, 3, 5}; + } else { + n_all_pks = 3; + present_pk_indices = std::vector{0, 2}; + } + + // Generate the list of all possible partition keys. + std::vector all_dks; + for (size_t i = 0; i < n_all_pks; ++i) { + auto pk = partition_key::from_single_value(*table.schema(), serialized(format("pk{:010d}", i))); + all_dks.push_back(dht::decorate_key(*table.schema(), pk)); + } + std::ranges::sort(all_dks, dht::ring_position_less_comparator(*table.schema())); + + // Choose the written partition keys. + std::vector present_dks; + for (const auto& i : present_pk_indices) { + present_dks.push_back(all_dks[i]); + } + + // Generate the present mutations. + utils::chunked_vector muts; + auto cks = table.make_ckeys(3); + for (const auto& i : present_pk_indices) { + mutation m(table.schema(), all_dks[i]); + for (const auto& ck : cks) { + table.add_row(m, ck, row_value); + } + muts.push_back(std::move(m)); + } + + // Generate the sstable. + auto sst = make_sstable_containing(env.make_sstable(table.schema()), muts); + + // Use the index to find key positions. + std::vector partition_positions = get_partition_positions(sst, permit); + std::vector ck_positions = get_ck_positions(sst, permit); + partition_positions.push_back(sst->data_size()); + { + testlog.debug("Sstable initialized with:"); + size_t i = 0; + size_t j = 0; + for (const auto& dk : present_pk_indices) { + testlog.debug("{}: {} (dk={})", dk, partition_positions[i], all_dks[dk]); + while (j < ck_positions.size() && ck_positions[j] < partition_positions[i + 1]) { + testlog.debug(" ck at {}", ck_positions[j]); + ++j; + } + ++i; + } + testlog.debug("eof at {}", partition_positions.back()); + } + + nondeterministic_choice_stack ncs; + int test_cases = 0; + do { + testlog.debug("Starting test case {}", test_cases); + ++test_cases; + + bool do_slice = ncs.choose_bool(); + auto slice = do_slice + ? partition_slice_builder(*table.schema()) + .with_range(query::clustering_range::make_singular(cks[1])) + .build() + : table.schema()->full_slice(); + auto range = do_slice + ? query::clustering_range::make_singular(cks[1]) + : query::clustering_range::make_open_ended_both_sides(); + + dht::partition_range pr; + + std::optional reader; + constexpr int max_forwards = 1; + + // Special value -1 means -inf + // Special value all_dks.size() means +inf + int last_dk = -1; + bool last_inclusive = false; + + for (int range_index = 0; range_index <= max_forwards; ++range_index) { + // Choose some partition range after the last one + auto start_dk = last_dk; + if (last_inclusive) { + start_dk += 1; + } + if (start_dk >= int(all_dks.size())) { + break; + } + start_dk += ncs.choose_up_to(all_dks.size() - 1 - start_dk); + bool start_inclusive; + if (start_dk == -1 || (start_dk == last_dk && last_inclusive)) { + start_inclusive = true; + } else { + start_inclusive = ncs.choose_bool(); + } + int end_dk = start_dk; + if (!start_inclusive || start_dk == -1) { + end_dk += 1; + } + end_dk += ncs.choose_up_to(all_dks.size() - end_dk); + bool end_inclusive; + if (end_dk == int(all_dks.size()) || (end_dk == start_dk)) { + end_inclusive = true; + } else { + end_inclusive = ncs.choose_bool(); + } + if (start_dk == -1) { + if (end_dk == int(all_dks.size())) { + pr = dht::partition_range::make_open_ended_both_sides(); + } else { + pr = dht::partition_range::make_ending_with(dht::partition_range::bound(all_dks[end_dk], end_inclusive)); + } + } else if (end_dk == int(all_dks.size())) { + pr = dht::partition_range::make_starting_with(dht::partition_range::bound(all_dks[start_dk], start_inclusive)); + } else { + pr = dht::partition_range::make( + dht::partition_range::bound(all_dks[start_dk], start_inclusive), + dht::partition_range::bound(all_dks[end_dk], end_inclusive) + ); + } + + // If not created, create the reader on the chosen partition range, + // otherwise forward to it. + if (reader) { + testlog.debug("Forward to {}{}, {}{} (pr={})", + start_inclusive ? '[' : '(', + start_dk, + end_dk, + end_inclusive ? ']' : ')', + pr + ); + reader->fast_forward_to(pr); + } else { + testlog.debug("Create for {}{}, {}{} (pr={})", + start_inclusive ? '[' : '(', + start_dk, + end_dk, + end_inclusive ? ']' : ')', + pr + ); + reader = assert_that(sstables::mx::make_reader_with_index_reader( + sst, + table.schema(), + permit, + pr, + slice, + nullptr, + streamed_mutation::forwarding::no, + mutation_reader::forwarding::yes, + default_read_monitor(), + sstables::integrity_check::no, + std::make_unique(ncs, partition_positions, ck_positions, present_dks, cks, table.schema()) + )); + } + // Check that the reader produces correct output for the chosen partition range. + for (int i = start_dk + !start_inclusive; i < end_dk + end_inclusive; ++i) { + if (auto it = std::ranges::find(present_pk_indices, i); it != present_pk_indices.end()) { + testlog.debug("Check produces {} (dk={})", i, muts[it - present_pk_indices.begin()].decorated_key()); + reader->produces(muts[it - present_pk_indices.begin()].sliced({range})); + } + } + testlog.debug("Check produces end of stream"); + reader->produces_end_of_stream(); + last_dk = end_dk; + last_inclusive = end_inclusive; + } + } while (ncs.rewind()); + }); +} + +// Test that the reader can deal with a single-partition query +// for which the index returns a false positive. +SEASTAR_TEST_CASE(test_inexact_partition_index_singular_query) { + return sstables::test_env::do_with_async([] (sstables::test_env& env) { + simple_schema table; + auto permit = env.make_reader_permit(); + + // Number of all possible partition keys. + size_t n_all_pks = 3; + // Indices of present partitions in the list of possible keys. + std::vector present_pk_indices{0, 2}; + + // Generate the list of all possible partition keys. + std::vector all_dks; + for (size_t i = 0; i < n_all_pks; ++i) { + auto pk = partition_key::from_single_value(*table.schema(), serialized(format("pk{:010d}", i))); + all_dks.push_back(dht::decorate_key(*table.schema(), pk)); + } + std::ranges::sort(all_dks, dht::ring_position_less_comparator(*table.schema())); + + // Choose the written partition keys. + std::vector present_dks; + for (const auto& i : present_pk_indices) { + present_dks.push_back(all_dks[i]); + } + + // Generate the present mutations. + utils::chunked_vector muts; + for (const auto& i : present_pk_indices) { + mutation m(table.schema(), all_dks[i]); + muts.push_back(std::move(m)); + } + + // Generate the sstable. + auto sst = make_sstable_containing(env.make_sstable(table.schema()), muts); + + // Use the index to find key positions. + std::vector partition_positions = get_partition_positions(sst, permit); + partition_positions.push_back(sst->data_size()); + { + testlog.debug("Sstable initialized with:"); + size_t i = 0; + for (const auto& dk : present_pk_indices) { + testlog.debug("{}: {} (dk={})", dk, partition_positions[i], all_dks[dk]); + ++i; + } + testlog.debug("eof at {}", partition_positions.back()); + } + + nondeterministic_choice_stack ncs; + do { + auto slice = table.schema()->full_slice(); + auto pr = dht::partition_range::make_singular(all_dks[1]); + + testlog.debug("Create for pr={}", pr); + mutation_reader_assertions reader = assert_that(sstables::mx::make_reader_with_index_reader( + sst, + table.schema(), + permit, + pr, + slice, + nullptr, + streamed_mutation::forwarding::no, + mutation_reader::forwarding::yes, + default_read_monitor(), + sstables::integrity_check::no, + std::make_unique( + ncs, + partition_positions, + std::span(), + present_dks, + std::span(), + table.schema()) + )); + testlog.debug("Check that the reader produces nothing"); + reader.produces_end_of_stream(); + } while (ncs.rewind()); + }); +} diff --git a/test/boost/tablets_test.cc b/test/boost/tablets_test.cc index 23d7addc20bd..034bc92dfef4 100644 --- a/test/boost/tablets_test.cc +++ b/test/boost/tablets_test.cc @@ -209,7 +209,7 @@ SEASTAR_TEST_CASE(test_tablet_metadata_persistence) { verify_tablet_metadata_persistence(e, tm, ts); // Increase RF of table2 - tm.mutate_tablet_map(table2, [&] (tablet_map& tmap) { + tm.mutate_tablet_map_async(table2, [&] (tablet_map& tmap) { auto tb = tmap.first_tablet(); tb = *tmap.next_tablet(tb); @@ -234,7 +234,8 @@ SEASTAR_TEST_CASE(test_tablet_metadata_persistence) { tablet_replica {h1, 4}, session_id(utils::UUID_gen::get_time_UUID()) }); - }); + return make_ready_future(); + }).get(); verify_tablet_metadata_persistence(e, tm, ts); @@ -1440,11 +1441,12 @@ SEASTAR_THREAD_TEST_CASE(test_token_ownership_splitting) { } static -void apply_resize_plan(token_metadata& tm, const migration_plan& plan) { +future<> apply_resize_plan(token_metadata& tm, const migration_plan& plan) { for (auto [table_id, resize_decision] : plan.resize_plan().resize) { - tm.tablets().mutate_tablet_map(table_id, [&] (tablet_map& tmap) { + co_await tm.tablets().mutate_tablet_map_async(table_id, [&] (tablet_map& tmap) { resize_decision.sequence_number = tmap.resize_decision().sequence_number + 1; tmap.set_resize_decision(resize_decision); + return make_ready_future(); }); } } @@ -1485,28 +1487,30 @@ future<> handle_resize_finalize(cql_test_env& e, group0_guard& guard, const migr // Reflects the plan in a given token metadata as if the migrations were fully executed. static -void apply_plan(token_metadata& tm, const migration_plan& plan) { +future<> apply_plan(token_metadata& tm, const migration_plan& plan) { for (auto&& mig : plan.migrations()) { - tm.tablets().mutate_tablet_map(mig.tablet.table, [&] (tablet_map& tmap) { + co_await tm.tablets().mutate_tablet_map_async(mig.tablet.table, [&] (tablet_map& tmap) { auto tinfo = tmap.get_tablet_info(mig.tablet.tablet); testlog.trace("Replacing tablet {} replica from {} to {}", mig.tablet.tablet, mig.src, mig.dst); tinfo.replicas = replace_replica(tinfo.replicas, mig.src, mig.dst); tmap.set_tablet(mig.tablet.tablet, tinfo); + return make_ready_future(); }); } - apply_resize_plan(tm, plan); + co_await apply_resize_plan(tm, plan); } // Reflects the plan in a given token metadata as if the migrations were started but not yet executed. static -void apply_plan_as_in_progress(token_metadata& tm, const migration_plan& plan) { +future<> apply_plan_as_in_progress(token_metadata& tm, const migration_plan& plan) { for (auto&& mig : plan.migrations()) { - tm.tablets().mutate_tablet_map(mig.tablet.table, [&] (tablet_map& tmap) { + co_await tm.tablets().mutate_tablet_map_async(mig.tablet.table, [&] (tablet_map& tmap) { auto tinfo = tmap.get_tablet_info(mig.tablet.tablet); tmap.set_tablet_transition_info(mig.tablet.tablet, migration_to_transition_info(tinfo, mig)); + return make_ready_future(); }); } - apply_resize_plan(tm, plan); + co_await apply_resize_plan(tm, plan); } static @@ -1548,8 +1552,7 @@ void do_rebalance_tablets(cql_test_env& e, return; } stm.mutate_token_metadata([&] (token_metadata& tm) { - apply_plan(tm, plan); - return make_ready_future<>(); + return apply_plan(tm, plan); }).get(); if (auto_split && load_stats) { @@ -1612,8 +1615,7 @@ void rebalance_tablets_as_in_progress(tablet_allocator& talloc, shared_token_met break; } stm.mutate_token_metadata([&] (token_metadata& tm) { - apply_plan_as_in_progress(tm, plan); - return make_ready_future<>(); + return apply_plan_as_in_progress(tm, plan); }).get(); } } @@ -1621,18 +1623,18 @@ void rebalance_tablets_as_in_progress(tablet_allocator& talloc, shared_token_met // Completes any in progress tablet migrations. static void execute_transitions(shared_token_metadata& stm) { - stm.mutate_token_metadata([&] (token_metadata& tm) { + stm.mutate_token_metadata([&] (token_metadata& tm) -> future<> { for (auto&& [table, tables] : tm.tablets().all_table_groups()) { - tm.tablets().mutate_tablet_map(table, [&] (tablet_map& tmap) { + co_await tm.tablets().mutate_tablet_map_async(table, [&] (tablet_map& tmap) { for (auto&& [tablet, trinfo]: tmap.transitions()) { auto ti = tmap.get_tablet_info(tablet); ti.replicas = trinfo.next; tmap.set_tablet(tablet, ti); } tmap.clear_transitions(); + return make_ready_future(); }); } - return make_ready_future<>(); }).get(); } @@ -1899,11 +1901,11 @@ SEASTAR_THREAD_TEST_CASE(test_load_balancing_with_colocated_tablets) { } }); - tablet_map tmap1(tmap); + tablet_map tmap1 = co_await tmap.clone_gently(); tmeta.set_tablet_map(table1, std::move(tmap1)); co_await tmeta.set_colocated_table(table2, table1); - tablet_map tmap3(tmap); + tablet_map tmap3 = co_await tmap.clone_gently(); tmeta.set_tablet_map(table3, std::move(tmap3)); co_await tmeta.set_colocated_table(table4, table3); }); @@ -3047,9 +3049,9 @@ SEASTAR_THREAD_TEST_CASE(test_split_and_merge_of_colocated_tables) { } }); - tablet_map tmap1(tmap); + tablet_map tmap1 = co_await tmap.clone_gently(); tmeta.set_tablet_map(table1, std::move(tmap1)); - return tmeta.set_colocated_table(table2, table1); + co_await tmeta.set_colocated_table(table2, table1); }); auto& stm = e.shared_token_metadata().local(); @@ -3836,7 +3838,7 @@ static void execute_tablet_for_new_rf_test(calculate_tablet_replicas_for_new_rf_ stm.mutate_token_metadata([&] (token_metadata& tm) { tablet_metadata tab_meta; auto table = s->id(); - tab_meta.set_tablet_map(table, allocated_map); + tab_meta.set_tablet_map(table, std::move(allocated_map)); tm.set_tablets(std::move(tab_meta)); return make_ready_future<>(); }).get(); @@ -3846,7 +3848,7 @@ static void execute_tablet_for_new_rf_test(calculate_tablet_replicas_for_new_rf_ initial_rep_factor[dc] = std::stoul(shard_count); } - auto tablets = stm.get()->tablets().get_tablet_map(s->id()); + auto tablets = stm.get()->tablets().get_tablet_map(s->id()).clone_gently().get(); BOOST_REQUIRE_EQUAL(tablets.tablet_count(), tablet_count); for (auto tb : tablets.tablet_ids()) { const locator::tablet_info& ti = tablets.get_tablet_info(tb); @@ -3863,10 +3865,10 @@ static void execute_tablet_for_new_rf_test(calculate_tablet_replicas_for_new_rf_ } try { - tablet_map old_tablets = stm.get()->tablets().get_tablet_map(s->id()); + tablet_map old_tablets = stm.get()->tablets().get_tablet_map(s->id()).clone_gently().get(); locator::replication_strategy_params params{test_config.new_dc_rep_factor, old_tablets.tablet_count()}; auto new_strategy = abstract_replication_strategy::create_replication_strategy("NetworkTopologyStrategy", params); - auto tmap = new_strategy->maybe_as_tablet_aware()->reallocate_tablets(s, stm.get(), old_tablets).get(); + auto tmap = new_strategy->maybe_as_tablet_aware()->reallocate_tablets(s, stm.get(), std::move(old_tablets)).get(); auto const& ts = tmap.tablets(); BOOST_REQUIRE_EQUAL(ts.size(), tablet_count); diff --git a/test/boost/types_test.cc b/test/boost/types_test.cc index e478180b8b8f..2143df8a596f 100644 --- a/test/boost/types_test.cc +++ b/test/boost/types_test.cc @@ -14,7 +14,7 @@ #include #include "types/types.hh" #include "types/tuple.hh" -#include "compound.hh" +#include "keys/compound.hh" #include "db/marshal/type_parser.hh" #include "cql3/cql3_type.hh" #include "cql3/type_json.hh" diff --git a/test/cluster/test_repair.py b/test/cluster/test_repair.py index e9833c4f0e8f..cceae6a8b50d 100644 --- a/test/cluster/test_repair.py +++ b/test/cluster/test_repair.py @@ -165,8 +165,9 @@ async def do_batchlog_flush_in_repair(manager, cache_time_in_ms): nr_repairs = 2 * nr_repairs_per_node total_repair_duration = 0 + cfg = { 'tablets_mode_for_new_keyspaces': 'disabled' } cmdline = ["--repair-hints-batchlog-flush-cache-time-in-ms", str(cache_time_in_ms), "--smp", "1", "--logger-log-level", "api=trace"] - node1, node2 = await manager.servers_add(2, cmdline=cmdline, auto_rack_dc="dc1") + node1, node2 = await manager.servers_add(2, config=cfg, cmdline=cmdline, auto_rack_dc="dc1") cql = manager.get_cql() cql.execute("CREATE KEYSPACE ks WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 2}") @@ -216,41 +217,6 @@ async def test_batchlog_flush_in_repair_with_cache(manager): async def test_batchlog_flush_in_repair_without_cache(manager): await do_batchlog_flush_in_repair(manager, 0); -@pytest.mark.asyncio -@skip_mode('release', 'error injections are not supported in release mode') -async def test_repair_abort(manager): - cfg = {'tablets_mode_for_new_keyspaces': 'enabled'} - servers = await manager.servers_add(2, config=cfg, auto_rack_dc="dc1") - cql = manager.get_cql() - - cql.execute("CREATE KEYSPACE ks WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 2}") - cql.execute("CREATE TABLE ks.tbl (pk int, ck int, PRIMARY KEY (pk, ck)) WITH tombstone_gc = {'mode': 'repair'}") - - await manager.api.client.post(f"/task_manager/ttl", params={ "ttl": "100000" }, - host=servers[0].ip_addr) - - await manager.api.enable_injection(servers[0].ip_addr, "repair_tablet_repair_task_impl_run", False, {}) - - # Start repair. - sequence_number = await manager.api.client.post_json(f"/storage_service/repair_async/ks", host=servers[0].ip_addr) - - # Get repair id. - stats_list = await manager.api.client.get_json("/task_manager/list_module_tasks/repair", host=servers[0].ip_addr) - ids = [stats["task_id"] for stats in stats_list if stats["sequence_number"] == sequence_number] - assert len(ids) == 1 - id = ids[0] - - # Abort repair. - await manager.api.client.post("/storage_service/force_terminate_repair", host=servers[0].ip_addr) - - await manager.api.message_injection(servers[0].ip_addr, "repair_tablet_repair_task_impl_run") - await manager.api.disable_injection(servers[0].ip_addr, "repair_tablet_repair_task_impl_run") - - # Check if repair was aborted. - await manager.api.client.get_json(f"/task_manager/wait_task/{id}", host=servers[0].ip_addr) - statuses = await manager.api.client.get_json(f"/task_manager/task_status_recursive/{id}", host=servers[0].ip_addr) - assert all([status["state"] == "failed" for status in statuses]) - @pytest.mark.asyncio @skip_mode('release', 'error injections are not supported in release mode') async def test_keyspace_drop_during_data_sync_repair(manager): diff --git a/test/cluster/test_tablets.py b/test/cluster/test_tablets.py index 0b77d862e461..d3bcdde01f1f 100644 --- a/test/cluster/test_tablets.py +++ b/test/cluster/test_tablets.py @@ -1088,46 +1088,6 @@ async def test_tablet_split_finalization_with_migrations(manager: ManagerClient) logger.info("Waiting for migrations to complete") await log.wait_for("Tablet load balancer did not make any plan", from_mark=migration_mark) -@pytest.mark.asyncio -@skip_mode('release', 'error injections are not supported in release mode') -async def test_tablet_split_finalization_with_repair(manager: ManagerClient): - injection = "handle_tablet_resize_finalization_wait" - cfg = { - 'enable_tablets': True, - 'error_injections_at_startup': [ - injection, - "repair_tablets_no_sync", - 'short_tablet_stats_refresh_interval', - ] - } - servers = await manager.servers_add(2, config=cfg) - - cql = manager.get_cql() - await cql.run_async("CREATE KEYSPACE test WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1} AND tablets = {'initial': 4};") - await cql.run_async("CREATE TABLE test.test (pk int PRIMARY KEY, c int) WITH compaction = {'class': 'NullCompactionStrategy'};") - await asyncio.gather(*[cql.run_async(f"INSERT INTO test.test (pk, c) VALUES ({k}, {k%3});") for k in range(64)]) - await manager.api.keyspace_flush(servers[0].ip_addr, "test", "test") - - logs = [await manager.server_open_log(s.server_id) for s in servers] - marks = [await log.mark() for log in logs] - - logger.info("Trigger split in table") - await cql.run_async("ALTER TABLE test.test WITH tablets = {'min_tablet_count': 8};") - - logger.info("Wait for tablets to split") - done, pending = await asyncio.wait([asyncio.create_task(log.wait_for('handle_tablet_resize_finalization: waiting', from_mark=mark)) for log, mark in zip(logs, marks)], return_when=asyncio.FIRST_COMPLETED) - for task in pending: - task.cancel() - - async def repair(): - await manager.api.client.post(f"/storage_service/repair_async/test", host=servers[0].ip_addr) - - async def check_repair_waits(): - await logs[0].wait_for("Topology is busy, waiting for it to quiesce", from_mark=marks[0]) - await manager.api.message_injection(servers[0].ip_addr, injection) - - await asyncio.gather(repair(), check_repair_waits()) - @pytest.mark.asyncio @skip_mode('release', 'error injections are not supported in release mode') async def test_two_tablets_concurrent_repair_and_migration_repair_writer_level(manager: ManagerClient): diff --git a/test/cluster/test_tablets2.py b/test/cluster/test_tablets2.py index 7cbc8f2df66c..97cc7f709091 100644 --- a/test/cluster/test_tablets2.py +++ b/test/cluster/test_tablets2.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 # from typing import Any -from cassandra.query import SimpleStatement, ConsistencyLevel +from cassandra.query import ConsistencyLevel from test.pylib.internal_types import HostID, ServerInfo, ServerNum from test.pylib.manager_client import ManagerClient @@ -43,27 +43,6 @@ async def disable_injection_on(manager, error_name, servers): errs = [manager.api.disable_injection(s.ip_addr, error_name) for s in servers] await asyncio.gather(*errs) -async def repair_on_node(manager: ManagerClient, server: ServerInfo, servers: list[ServerInfo], keyspace, table = "test", ranges: str = ''): - node = server.ip_addr - await manager.servers_see_each_other(servers) - live_nodes_wanted = [s.ip_addr for s in servers] - live_nodes = await manager.api.get_alive_endpoints(node) - live_nodes_wanted.sort() - live_nodes.sort() - assert live_nodes == live_nodes_wanted - logger.info(f"Repair table on node {node} live_nodes={live_nodes} live_nodes_wanted={live_nodes_wanted}") - await manager.api.repair(node, keyspace, table, ranges) - -async def load_repair_history(cql, hosts): - all_rows = [] - for host in hosts: - logging.info(f'Query hosts={host}'); - rows = await cql.run_async("SELECT * from system.repair_history", host=host) - all_rows += rows - for row in all_rows: - logging.info(f"Got repair_history_entry={row}") - return all_rows - async def safe_server_stop_gracefully(manager, server_id, timeout: float = 60, reconnect: bool = False): # Explicitly close the driver to avoid reconnections if scylla fails to update gossiper state on shutdown. # It's a problem until https://github.com/scylladb/scylladb/issues/15356 is fixed. @@ -489,244 +468,6 @@ async def test_table_dropped_during_streaming(manager: ManagerClient): replica = await get_tablet_replica(manager, servers[0], ks, 'test2', tablet_token) assert replica == (s1_host_id, 0) -@pytest.mark.repair -@pytest.mark.asyncio -async def test_tablet_repair(manager: ManagerClient): - logger.info("Bootstrapping cluster") - cmdline = [ - '--logger-log-level', 'repair=trace', - '--task-ttl-in-seconds', '3600', # Make sure the test passes with non-zero task_ttl. - ] - servers = await manager.servers_add(3, cmdline=cmdline, property_file=[ - {"dc": "dc1", "rack": "r1"}, - {"dc": "dc1", "rack": "r1"}, - {"dc": "dc1", "rack": "r2"} - ]) - - await inject_error_on(manager, "tablet_allocator_shuffle", servers) - - cql = manager.get_cql() - async with new_test_keyspace(manager, "WITH replication = {'class': 'NetworkTopologyStrategy', " - "'replication_factor': 2} AND tablets = {'initial': 32}") as ks: - await cql.run_async(f"CREATE TABLE {ks}.test (pk int PRIMARY KEY, c int);") - - logger.info("Populating table") - - keys = range(256) - - stmt = cql.prepare(f"INSERT INTO {ks}.test (pk, c) VALUES (?, ?)") - stmt.consistency_level = ConsistencyLevel.ONE - - # Repair runs concurrently with tablet shuffling which exercises issues with serialization - # of repair and tablet migration. - # - # We do it 30 times because it's been experimentally shown to be enough to trigger the issue with high probability. - # Lack of proper synchronization would manifest as repair failure with the following cause: - # - # failed_because=std::runtime_error (multishard_writer: No shards for token 7505809055260144771 of test.test) - # - # ...which indicates that repair tried to stream data to a node which is no longer a tablet replica. - repair_cycles = 30 - for i in range(repair_cycles): - # Write concurrently with repair to increase the chance of repair having some discrepancy to resolve and send writes. - inserts_future = asyncio.gather(*[cql.run_async(stmt, [k, i]) for k in keys]) - - # Disable in the background so that repair is started with migrations in progress. - # We need to disable balancing so that repair which blocks on migrations eventually gets unblocked. - # Otherwise, shuffling would keep the topology busy forever. - disable_balancing_future = asyncio.create_task(manager.api.disable_tablet_balancing(servers[0].ip_addr)) - - await repair_on_node(manager, servers[0], servers, ks) - - await inserts_future - await disable_balancing_future - await manager.api.enable_tablet_balancing(servers[0].ip_addr) - - key_count = len(keys) - stmt = cql.prepare(f"SELECT * FROM {ks}.test;") - stmt.consistency_level = ConsistencyLevel.ALL - rows = await cql.run_async(stmt) - assert len(rows) == key_count - for r in rows: - assert r.c == repair_cycles - 1 - -# Reproducer for race between split and repair: https://github.com/scylladb/scylladb/issues/19378 -# Verifies repair will not complete with sstables that still require split, causing split -# execution to fail. -@pytest.mark.repair -@pytest.mark.asyncio -@skip_mode('release', 'error injections are not supported in release mode') -async def test_concurrent_tablet_repair_and_split(manager: ManagerClient): - logger.info("Bootstrapping cluster") - cmdline = [ - '--logger-log-level', 'raft_topology=debug', - '--target-tablet-size-in-bytes', '1024', - ] - servers = await manager.servers_add(3, cmdline=cmdline, config={ - 'error_injections_at_startup': ['short_tablet_stats_refresh_interval'] - }, property_file=[ - {"dc": "dc1", "rack": "r1"}, - {"dc": "dc1", "rack": "r1"}, - {"dc": "dc1", "rack": "r2"} - ]) - - await manager.api.disable_tablet_balancing(servers[0].ip_addr) - - cql = manager.get_cql() - async with new_test_keyspace(manager, "WITH replication = {'class': 'NetworkTopologyStrategy', " - "'replication_factor': 2} AND tablets = {'initial': 32}") as ks: - await cql.run_async(f"CREATE TABLE {ks}.test (pk int PRIMARY KEY, c int);") - - logger.info("Populating table") - - keys = range(5000) # Enough keys to trigger repair digest mismatch with a high chance. - stmt = cql.prepare(f"INSERT INTO {ks}.test (pk, c) VALUES (?, ?)") - stmt.consistency_level = ConsistencyLevel.ONE - - await inject_error_on(manager, "tablet_load_stats_refresh_before_rebalancing", servers) - - s0_log = await manager.server_open_log(servers[0].server_id) - s0_mark = await s0_log.mark() - - await asyncio.gather(*[cql.run_async(stmt, [k, -1]) for k in keys]) - - # split decision is sstable size based, so data must be flushed first - for server in servers: - await manager.api.flush_keyspace(server.ip_addr, ks) - - await manager.api.enable_injection(servers[0].ip_addr, "tablet_split_finalization_postpone", False) - await manager.api.enable_tablet_balancing(servers[0].ip_addr) - - logger.info("Waiting for split prepare...") - await s0_log.wait_for('Setting split ready sequence number to', from_mark=s0_mark) - s0_mark = await s0_log.mark() - logger.info("Waited for split prepare") - - # Balancer is re-enabled later for split execution - await asyncio.create_task(manager.api.disable_tablet_balancing(servers[0].ip_addr)) - - # Write concurrently with repair to increase the chance of repair having some discrepancy to resolve and send writes. - inserts_future = asyncio.gather(*[cql.run_async(stmt, [k, 1]) for k in keys]) - - await repair_on_node(manager, servers[0], servers, ks) - - await inserts_future - - logger.info("Waiting for split execute...") - await manager.api.disable_injection(servers[0].ip_addr, "tablet_split_finalization_postpone") - await manager.api.enable_tablet_balancing(servers[0].ip_addr) - await s0_log.wait_for('Detected tablet split for table', from_mark=s0_mark) - await inject_error_one_shot_on(manager, "tablet_split_finalization_postpone", servers) - logger.info("Waited for split execute...") - - key_count = len(keys) - stmt = cql.prepare(f"SELECT * FROM {ks}.test;") - stmt.consistency_level = ConsistencyLevel.ALL - rows = await cql.run_async(stmt) - assert len(rows) == key_count - -@pytest.mark.repair -@pytest.mark.asyncio -async def test_tablet_missing_data_repair(manager: ManagerClient): - logger.info("Bootstrapping cluster") - cmdline = [ - '--hinted-handoff-enabled', 'false', - ] - servers = await manager.servers_add(3, cmdline=cmdline, auto_rack_dc="dc1") - - cql = manager.get_cql() - async with new_test_keyspace(manager, "WITH replication = {'class': 'NetworkTopologyStrategy', " - "'replication_factor': 3} AND tablets = {'initial': 32}") as ks: - await cql.run_async(f"CREATE TABLE {ks}.test (pk int PRIMARY KEY, c int);") - - keys_list = [range(0, 100), range(100, 200), range(200, 300)] - keys_for_server = dict([(s.server_id, keys_list[idx]) for idx, s in enumerate(servers)]) - keys = range(0, 300) - - async def insert_with_down(down_server): - logger.info(f"Stopped server {down_server.server_id}") - logger.info(f"Insert into server {down_server.server_id}") - await asyncio.gather(*[cql.run_async(f"INSERT INTO {ks}.test (pk, c) VALUES ({k}, {k});") - for k in keys_for_server[down_server.server_id]]) - - cql = await safe_rolling_restart(manager, servers, with_down=insert_with_down) - - await repair_on_node(manager, servers[0], servers, ks) - - async def check_with_down(down_node): - logger.info("Checking table") - query = SimpleStatement(f"SELECT * FROM {ks}.test;", consistency_level=ConsistencyLevel.ONE) - rows = await cql.run_async(query) - assert len(rows) == len(keys) - for r in rows: - assert r.c == r.pk - - cql = await safe_rolling_restart(manager, servers, with_down=check_with_down) - - -@pytest.mark.repair -@pytest.mark.asyncio -async def test_tablet_repair_history(manager: ManagerClient): - logger.info("Bootstrapping cluster") - servers = await manager.servers_add(3, auto_rack_dc="dc1") - - rf = 3 - tablets = 8 - - cql = manager.get_cql() - async with new_test_keyspace(manager, f"WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': {rf}}} AND tablets = {{'initial': {tablets}}}") as ks: - await cql.run_async(f"CREATE TABLE {ks}.test (pk int PRIMARY KEY, c int) WITH tombstone_gc = {{'mode':'repair'}};") - - logger.info("Populating table") - - keys = range(256) - await asyncio.gather(*[cql.run_async(f"INSERT INTO {ks}.test (pk, c) VALUES ({k}, {k});") for k in keys]) - - hosts = await wait_for_cql_and_get_hosts(cql, servers, time.time() + 60) - logging.info(f'Got hosts={hosts}'); - - await repair_on_node(manager, servers[0], servers, ks) - - all_rows = await load_repair_history(cql, hosts) - assert len(all_rows) == rf * tablets - -@pytest.mark.repair -@pytest.mark.asyncio -async def test_tablet_repair_ranges_selection(manager: ManagerClient): - logger.info("Bootstrapping cluster") - servers = await manager.servers_add(2, auto_rack_dc="dc1") - - rf = 2 - tablets = 4 - nr_ranges = 0; - - cql = manager.get_cql() - async with new_test_keyspace(manager, f"WITH replication = {{'class': 'NetworkTopologyStrategy', 'replication_factor': {rf}}} AND tablets = {{'initial': {tablets}}}") as ks: - await cql.run_async(f"CREATE TABLE {ks}.test (pk int PRIMARY KEY, c int) WITH tombstone_gc = {{'mode':'repair'}};") - - logger.info("Populating table") - - keys = range(256) - await asyncio.gather(*[cql.run_async(f"INSERT INTO {ks}.test (pk, c) VALUES ({k}, {k});") for k in keys]) - - hosts = await wait_for_cql_and_get_hosts(cql, servers, time.time() + 60) - logging.info(f'Got hosts={hosts}'); - - await repair_on_node(manager, servers[0], servers, ks, ranges='-4611686018427387905:-1,4611686018427387903:9223372036854775807') - nr_ranges = nr_ranges + 2 - await repair_on_node(manager, servers[0], servers, ks, ranges='-2000:-1000,1000:2000') - nr_ranges = nr_ranges + 2 - await repair_on_node(manager, servers[0], servers, ks, ranges='3000:-3000') - # The wrap around range (3000, -3000] will produce the following intersection range - # range=(minimum token,-4611686018427387905] ranges_specified={(3000,+inf), (-inf, -3000]} intersection_ranges=(minimum token,-4611686018427387905] - # range=(-4611686018427387905,-1] ranges_specified={(3000,+inf), (-inf, -3000]} intersection_ranges=(-4611686018427387905,-3000] - # range=(-1,4611686018427387903] ranges_specified={(3000,+inf), (-inf, -3000]} intersection_ranges=(3000,4611686018427387903] - # range=(4611686018427387903,9223372036854775807] ranges_specified={(3000,+inf), (-inf, -3000]} intersection_ranges=(4611686018427387903,9223372036854775807] - nr_ranges = nr_ranges + 4 - - all_rows = await load_repair_history(cql, hosts) - assert len(all_rows) == rf * nr_ranges; - @pytest.mark.asyncio async def test_tablet_cleanup(manager: ManagerClient): cmdline = ['--smp=2', '--commitlog-sync=batch'] @@ -1407,73 +1148,6 @@ async def test_tablet_storage_freeing(manager: ManagerClient): size_after = await manager.server_get_sstables_disk_usage(servers[0].server_id, ks, "test") assert size_before * 0.33 < size_after < size_before * 0.66 -@pytest.mark.asyncio -@skip_mode('release', 'error injections are not supported in release mode') -async def test_tombstone_gc_disabled_on_pending_replica(manager: ManagerClient): - logger.info("Bootstrapping cluster") - servers = [await manager.server_add()] - - await manager.api.disable_tablet_balancing(servers[0].ip_addr) - - cql = manager.get_cql() - async with new_test_keyspace(manager, "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1} AND tablets = {'initial': 4}") as ks: - await cql.run_async(f"CREATE TABLE {ks}.test (pk int PRIMARY KEY, c int) WITH gc_grace_seconds = 0;") - - servers.append(await manager.server_add()) - - key = 7 # Whatever - tablet_token = 0 # Doesn't matter since there is one tablet - await cql.run_async(f"INSERT INTO {ks}.test (pk, c) VALUES ({key}, 1) USING timestamp 9") - rows = await cql.run_async(f"SELECT pk from {ks}.test") - assert len(rows) == 1 - - replica = await get_tablet_replica(manager, servers[0], ks, 'test', tablet_token) - - s0_host_id = await manager.get_host_id(servers[0].server_id) - s1_host_id = await manager.get_host_id(servers[1].server_id) - dst_shard = 0 - - await manager.api.enable_injection(servers[1].ip_addr, "stream_mutation_fragments", one_shot=True) - s1_log = await manager.server_open_log(servers[1].server_id) - s1_mark = await s1_log.mark() - - migration_task = asyncio.create_task( - manager.api.move_tablet(servers[0].ip_addr, ks, "test", replica[0], replica[1], s1_host_id, dst_shard, tablet_token)) - - await s1_log.wait_for('stream_mutation_fragments: waiting', from_mark=s1_mark) - s1_mark = await s1_log.mark() - - # write a tombstone with timestamp X to DB - await cql.run_async(f'DELETE FROM {ks}.test USING timestamp 10 WHERE pk = {key}') - - # flush both servers - for s in servers: - await manager.api.flush_keyspace(s.ip_addr, ks) - - await asyncio.sleep(1) - - # major compact both servers - for s in servers: - await manager.api.keyspace_compaction(s.ip_addr, ks) - - # write backdated data to test.test with timestamp X-1 with the same key as the tombstone - await cql.run_async(f'INSERT INTO {ks}.test (pk, c) VALUES ({key}, 0) USING timestamp 9') - - # release streaming - await manager.api.message_injection(servers[1].ip_addr, "stream_mutation_fragments") - await s1_log.wait_for('stream_mutation_fragments: done', from_mark=s1_mark) - - logger.info("Waiting for migration to finish") - await migration_task - logger.info("Migration done") - - for s in servers: - await manager.api.flush_keyspace(s.ip_addr, ks) - - # verify result - rows = await cql.run_async(f'SELECT pk, c FROM {ks}.test WHERE pk = {key};') - assert len(rows) == 0 - @pytest.mark.asyncio @skip_mode('release', 'error injections are not supported in release mode') async def test_schema_change_during_cleanup(manager: ManagerClient): diff --git a/test/cluster/test_tablets_lwt.py b/test/cluster/test_tablets_lwt.py index 21193b53ace7..3cbd8b4b4f82 100644 --- a/test/cluster/test_tablets_lwt.py +++ b/test/cluster/test_tablets_lwt.py @@ -50,10 +50,7 @@ async def test_lwt(manager: ManagerClient): cmdline = [ '--logger-log-level', 'paxos=trace' ] - config = { - 'rf_rack_valid_keyspaces': False - } - servers = await manager.servers_add(3, cmdline=cmdline, config=config) + servers = await manager.servers_add(3, cmdline=cmdline, auto_rack_dc='my_dc') cql = manager.get_cql() hosts = await wait_for_cql_and_get_hosts(cql, servers, time.time() + 60) quorum = len(hosts) // 2 + 1 @@ -132,10 +129,7 @@ async def test_lwt_state_is_preserved_on_tablet_migration(manager: ManagerClient cmdline = [ '--logger-log-level', 'paxos=trace' ] - config = { - 'rf_rack_valid_keyspaces': False - } - servers = await manager.servers_add(3, cmdline=cmdline, config=config) + servers = await manager.servers_add(3, cmdline=cmdline, auto_rack_dc='my_dc') cql = manager.get_cql() async def set_injection(set_to: list[ServerInfo], injection: str): @@ -171,7 +165,7 @@ async def set_injection(set_to: list[ServerInfo], injection: str): await cql.run_async(lwt1, host=hosts[0]) # Step3: start n4, migrate the single table tablet from n2 to n4. - servers += [await manager.server_add(cmdline=cmdline, config=config)] + servers += [await manager.server_add(cmdline=cmdline, property_file={'dc': 'my_dc', 'rack': 'rack2'})] n4_host_id = await manager.get_host_id(servers[3].server_id) logger.info("Migrating the tablet from n2 to n4") n2_host_id = await manager.get_host_id(servers[1].server_id) @@ -208,7 +202,6 @@ async def set_injection(set_to: list[ServerInfo], injection: str): @skip_mode('release', 'error injections are not supported in release mode') async def test_no_lwt_with_tablets_feature(manager: ManagerClient): config = { - 'rf_rack_valid_keyspaces': False, 'error_injections_at_startup': [ { 'name': 'suppress_features', @@ -252,10 +245,7 @@ async def test_lwt_state_is_preserved_on_tablet_rebuild(manager: ManagerClient): cmdline = [ '--logger-log-level', 'paxos=trace' ] - config = { - 'rf_rack_valid_keyspaces': False - } - servers = await manager.servers_add(3, cmdline=cmdline, config=config) + servers = await manager.servers_add(3, cmdline=cmdline, auto_rack_dc='my_dc') cql = manager.get_cql() logger.info("Resolve hosts") @@ -301,7 +291,9 @@ async def set_injection(set_to: list[ServerInfo], injection: str): use_host_id=True, wait_replaced_dead=True ) - servers += [await manager.server_add(replace_cfg=replace_cfg, cmdline=cmdline, config=config)] + servers += [await manager.server_add(replace_cfg=replace_cfg, + cmdline=cmdline, + property_file={'dc': 'my_dc', 'rack': 'rack1'})] # Check we've actually run rebuild logs = [] for s in servers: @@ -332,11 +324,8 @@ async def test_lwt_concurrent_base_table_recreation(manager: ManagerClient): cmdline = [ '--logger-log-level', 'paxos=trace' ] - config = { - 'rf_rack_valid_keyspaces': False - } - server = await manager.server_add(cmdline=cmdline, config=config) + server = await manager.server_add(cmdline=cmdline) cql = manager.get_cql() logger.info("Create a keyspace") @@ -377,7 +366,6 @@ async def recreate_table(): async def test_lwt_timeout_while_creating_paxos_state_table(manager: ManagerClient, build_mode): timeout = 10000 if build_mode == 'debug' else 1000 config = { - 'rf_rack_valid_keyspaces': False, 'write_request_timeout_in_ms': timeout, 'error_injections_at_startup': [ { @@ -414,11 +402,8 @@ async def test_lwt_for_tablets_is_not_supported_without_raft(manager: ManagerCli cmdline = [ '--logger-log-level', 'paxos=trace' ] - config = { - 'rf_rack_valid_keyspaces': False - } - servers = [await manager.server_add(cmdline=cmdline, config=config)] + servers = [await manager.server_add(cmdline=cmdline)] cql = manager.get_cql() hosts = await wait_for_cql_and_get_hosts(cql, servers, time.time() + 60) @@ -456,7 +441,6 @@ async def test_paxos_state_table_permissions(manager: ManagerClient): '--logger-log-level', 'paxos=trace' ] config = { - 'rf_rack_valid_keyspaces': False, 'authenticator': 'PasswordAuthenticator', 'authorizer': 'CassandraAuthorizer' } @@ -539,3 +523,56 @@ async def test_paxos_state_table_permissions(manager: ManagerClient): # Relogin as a default user to be able to DROP the test keyspace await manager.driver_connect() cql = manager.get_cql() + + +@pytest.mark.asyncio +async def test_lwt_coordinator_shard(manager: ManagerClient): + # The test checks that an LWT coordinator runs on a replica shard, and not on a 'default' (zero) shard. + # Scenario: + # 1. Start a cluster with one node with --smp 2 + # 2. Create a table with one tablet and one replica, move tablet to the second shard + # 3. Start another node + # 4. Run an LWT on the second node, check the logs and assert that an LWT was executed on shard 1 + # Note: Before the changes in storage_proxy/get_cas_shard, the last LWT would be executed + # on shard 0 of the second node because this node doesn't host any replicas of the + # tablet, and sharder.shard_for_reads would return 0 as the 'default' shard. + + logger.info("Starting the first node") + cmdline = [ + '--logger-log-level', 'paxos=trace', + '--smp', '2' + ] + servers = [await manager.server_add(cmdline=cmdline)] + cql = manager.get_cql() + + logger.info("Disable tablet balancing") + await asyncio.gather(*(manager.api.disable_tablet_balancing(s.ip_addr) for s in servers)) + + logger.info("Create a keyspace") + async with new_test_keyspace(manager, "WITH replication = {'class': 'NetworkTopologyStrategy', 'replication_factor': 1} AND tablets = {'initial': 1}") as ks: + logger.info("Create a table") + await cql.run_async(f"CREATE TABLE {ks}.test (pk int PRIMARY KEY, c int);") + + logger.info("Migrating the tablet to shard 1") + n1_host_id = await manager.get_host_id(servers[0].server_id) + tablets = await get_all_tablet_replicas(manager, servers[0], ks, 'test') + assert len(tablets) == 1 + tablet = tablets[0] + assert len(tablet.replicas) == 1 + old_replica = tablet.replicas[0] + await manager.api.move_tablet(servers[0].ip_addr, ks, "test", *old_replica, + *(n1_host_id, 1), tablet.last_token) + + logger.info("Starting a second node") + servers += [await manager.server_add(cmdline=cmdline)] + + logger.info("Wait for cql and get hosts") + hosts = await wait_for_cql_and_get_hosts(cql, servers, time.time() + 60) + + logger.info(f"Execute CAS on {hosts[1]}") + await cql.run_async(f"INSERT INTO {ks}.test (pk, c) VALUES (1, 1) IF NOT EXISTS", host=hosts[1]) + + n2_log = await manager.server_open_log(servers[1].server_id) + matches = await n2_log.grep("CAS\\[0\\] successful") + assert len(matches) == 1 + assert "shard 1" in matches[0][0] diff --git a/test/lib/index_reader_assertions.hh b/test/lib/index_reader_assertions.hh index 6e18a8cb0848..0cdf3d3b9d86 100644 --- a/test/lib/index_reader_assertions.hh +++ b/test/lib/index_reader_assertions.hh @@ -34,14 +34,15 @@ public: auto prev = dht::ring_position::min(); _r->read_partition_data().get(); while (!_r->eof()) { - auto k = _r->get_partition_key(); - auto rp = dht::ring_position(dht::decorate_key(s, k)); + if (auto k = _r->get_partition_key()) { + auto rp = dht::ring_position(dht::decorate_key(s, *k)); if (rp_cmp(prev, rp) >= 0) { BOOST_FAIL(format("Partitions have invalid order: {} >= {}", prev, rp)); } prev = rp; + } if (auto* r = dynamic_cast(_r.get())) { sstables::clustered_index_cursor* cur = r->current_clustered_cursor(); diff --git a/test/lib/mutation_assertions.cc b/test/lib/mutation_assertions.cc index 969e2f166682..d093864edc7d 100644 --- a/test/lib/mutation_assertions.cc +++ b/test/lib/mutation_assertions.cc @@ -2,7 +2,7 @@ // SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 #include "mutation_assertions.hh" -#include "clustering_interval_set.hh" +#include "keys/clustering_interval_set.hh" mutation_partition_assertion& mutation_partition_assertion::is_equal_to(const schema& s, const mutation_partition& other, diff --git a/test/lib/simple_schema.hh b/test/lib/simple_schema.hh index 9ad6609fcbb4..65ba3e4f20d1 100644 --- a/test/lib/simple_schema.hh +++ b/test/lib/simple_schema.hh @@ -14,7 +14,7 @@ #include "schema/schema.hh" #include "schema/schema_registry.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "mutation/mutation_fragment.hh" #include "mutation/mutation.hh" #include "schema/schema_builder.hh" diff --git a/test/lib/sstable_utils.hh b/test/lib/sstable_utils.hh index da727eab5ccb..f32d83662c5b 100644 --- a/test/lib/sstable_utils.hh +++ b/test/lib/sstable_utils.hh @@ -69,7 +69,9 @@ public: try { while (!ir->eof()) { co_await ir->read_partition_data(); - auto pk = ir->get_partition_key(); + // In general the index might not be able to return the key, + // but this helper is only used for tests of indexes which are able to do that. + auto pk = ir->get_partition_key().value(); entries.emplace_back(index_entry{sstables::key::from_partition_key(*s, pk), pk, ir->get_promoted_index_size()}); co_await ir->advance_to_next_partition(); diff --git a/test/perf/tablet_load_balancing.cc b/test/perf/tablet_load_balancing.cc index 2025b9ad0682..aa87d92aaf83 100644 --- a/test/perf/tablet_load_balancing.cc +++ b/test/perf/tablet_load_balancing.cc @@ -91,11 +91,12 @@ size_t get_tablet_count(const tablet_metadata& tm) { } static -void apply_resize_plan(token_metadata& tm, const migration_plan& plan) { +future<> apply_resize_plan(token_metadata& tm, const migration_plan& plan) { for (auto [table_id, resize_decision] : plan.resize_plan().resize) { - tm.tablets().mutate_tablet_map(table_id, [&resize_decision] (tablet_map& tmap) { + co_await tm.tablets().mutate_tablet_map_async(table_id, [&resize_decision] (tablet_map& tmap) { resize_decision.sequence_number = tmap.resize_decision().sequence_number + 1; tmap.set_resize_decision(resize_decision); + return make_ready_future(); }); } for (auto table_id : plan.resize_plan().finalize_resize) { @@ -108,15 +109,16 @@ void apply_resize_plan(token_metadata& tm, const migration_plan& plan) { // Reflects the plan in a given token metadata as if the migrations were fully executed. static -void apply_plan(token_metadata& tm, const migration_plan& plan) { +future<> apply_plan(token_metadata& tm, const migration_plan& plan) { for (auto&& mig : plan.migrations()) { - tm.tablets().mutate_tablet_map(mig.tablet.table, [&mig] (tablet_map& tmap) { + co_await tm.tablets().mutate_tablet_map_async(mig.tablet.table, [&mig] (tablet_map& tmap) { auto tinfo = tmap.get_tablet_info(mig.tablet.tablet); tinfo.replicas = replace_replica(tinfo.replicas, mig.src, mig.dst); tmap.set_tablet(mig.tablet.tablet, tinfo); + return make_ready_future(); }); } - apply_resize_plan(tm, plan); + co_await apply_resize_plan(tm, plan); } using seconds_double = std::chrono::duration; @@ -185,8 +187,7 @@ rebalance_stats rebalance_tablets(cql_test_env& e, locator::load_stats_ptr load_ return stats; } stm.mutate_token_metadata([&] (token_metadata& tm) { - apply_plan(tm, plan); - return make_ready_future<>(); + return apply_plan(tm, plan); }).get(); } throw std::runtime_error("rebalance_tablets(): convergence not reached within limit"); diff --git a/test/pylib/azure_vault_server_mock.py b/test/pylib/azure_vault_server_mock.py new file mode 100644 index 000000000000..05278739ca11 --- /dev/null +++ b/test/pylib/azure_vault_server_mock.py @@ -0,0 +1,649 @@ +# +# Copyright (C) 2025-present ScyllaDB +# +# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 +# + +# Mock Azure Vault server for local testing and error injection. + +import re +import os +import json +import uuid +import time +import base64 +import urllib +import asyncio +import logging +import threading +from functools import partial +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer + +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import rsa, padding + +UNKNOWN_RESOURCE = r""" + + + +404 - File or directory not found. + + + + +
+
+

404 - File or directory not found.

+

The resource you are looking for might have been removed, had its name changed, or is temporarily unavailable.

+
+
+ + +""" + +INVALID_CLIENT_ID = "mock-client-id-invalid" + + +class Error: + def __init__(self, error_type: str, count: int): + self.error_type = error_type + self.count = count + + +class AzureVault: + """ + Azure Key Vault service. + + Supported operations: wrapkey, unwrapkey + Supported algorithms: RSA-OAEP-SHA256 + + If the key is not found, it is created automatically. + """ + class RSAKey: + @staticmethod + def _create_version(): + version = uuid.uuid4().hex + private_key = rsa.generate_private_key(public_exponent=65537, key_size=4096) + public_key = private_key.public_key() + return version, private_key, public_key + + def __init__(self): + version, prvkey, pubkey = self._create_version() + self.default = version + self.versions = {version: (prvkey, pubkey)} + + def get_version(self, version = None): + return self.versions.get(version or self.default) + + def create_version(self): + version, prvkey, pubkey = self._create_version() + self.versions[version] = (prvkey, pubkey) + return version + + def encrypt(self, message: bytes, version: str = None) -> bytes: + _, pubkey = self.get_version(version) + return pubkey.encrypt(message, padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None + )) + + def decrypt(self, message: bytes, version = None) -> bytes: + prvkey, _ = self.get_version(version) + return prvkey.decrypt(message, padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None + )) + + INVALID_KEY = "nonexistentkey" + + ERROR_TEMPLATES = { + 'KeyNotFound': (404, { + "error": { + "code": "KeyNotFound", + "message": "A key with (name/id) {key_name}/{key_version} was not found in this key vault. If you recently deleted this key you may be able to recover it using the correct recovery command. For help resolving this issue, please see https://go.microsoft.com/fwlink/?linkid=2125182" + } + }), + 'Forbidden': (403, { + "error": { + "code": "Forbidden", + "message": "Caller is not authorized to perform action on resource." + } + }), + 'Unauthorized': (401, { + "error": { + "code": "Unauthorized", + "message": "[TokenExpired] Error validating token: 'S2S12086'." + } + }), + 'Throttled': (429, { + "error": { + "code": "Throttled", + "message": "Request was not processed because too many requests were received. Reason: VaultRequestTypeLimitReached" + } + }), + 'BadRequest': (400, { + "error": { + "code": "BadRequest", + "message": "Property has invalid value\r\n" + } + }) + } + + def __init__(self): + self.keys = {} + self.error_config = None + self.lock = threading.Lock() + + def get_or_create_key(self, key_name: str): + with self.lock: + if key_name not in self.keys: + self.keys[key_name] = self.RSAKey() + return self.keys[key_name] + + def get_key(self, key_name: str): + with self.lock: + return self.keys.get(key_name) + + def _error_response(self, error_type, **fmt_args): + if error_type not in self.ERROR_TEMPLATES: + raise ValueError(f"Unknown error type: {error_type}") + status_code, template = self.ERROR_TEMPLATES[error_type] + response = template.copy() + if fmt_args: + response['error'] = template['error'].copy() + response['error']['message'] = template['error']['message'].format(**fmt_args) + return status_code, response + + def wrapkey(self, data: str, key_name: str, key_version: str = None): + try: + if key_name == self.INVALID_KEY: + return self._error_response('KeyNotFound', + key_name=key_name, + key_version=key_version or '') + j = json.loads(data) + alg, value = j['alg'], j['value'] + # TODO: Validate the algorithm, we only support RSA-OAEP-SHA256 + key = self.get_or_create_key(key_name) + key_version = key_version or key.default + + with self.lock: + if self.error_config is not None: + error = self.error_config + if error.count > 0: + error.count -= 1 + return self._error_response(error.error_type, + key_name=key_name, + key_version=key_version) + + wrapped = key.encrypt(bytes(value, 'utf-8'), key_version) + ret = base64.urlsafe_b64encode(wrapped).decode('utf-8') + return (200, { + "kid": f"https://mock-vault/keys/{key_name}/{key_version}", + "value": ret + }) + except Exception as e: + return (400, { + "error": { + "code": "BadRequest", + "message": f"Failed to wrapkey: {str(e)}" + } + }) + + def unwrapkey(self, data: str, key_name: str, key_version: str = None): + try: + if key_name == self.INVALID_KEY: + return self._error_response('KeyNotFound', + key_name=key_name, + key_version=key_version or '') + j = json.loads(data) + alg, value = j['alg'], j['value'] + # TODO: Validate the algorithm, we only support RSA-OAEP-SHA256 + key = self.get_key(key_name) + version = key.get_version(key_version) + + with self.lock: + if self.error_config is not None: + error = self.error_config + if error.count > 0: + error.count -= 1 + return self._error_response(error.error_type, + key_name=key_name, + key_version=key_version) + + if not (key and version): + return self._error_response('KeyNotFound', key_name=key_name, key_version=key_version or '') + unwrapped = key.decrypt(base64.urlsafe_b64decode(value), key_version) + return (200, { + "kid": f"https://mock-vault/keys/{key_name}", + "value": unwrapped.decode('utf-8') + }) + except Exception as e: + return (400, { + "error": { + "code": "BadRequest", + "message": f"Failed to unwrapkey: {str(e)}" + } + }) + + def set_error_config(self, error_type, count): + with self.lock: + self.error_config = Error(error_type, count) + + +class AzureIMDS: + """ + Azure Instance Metadata Service (IMDS) + + Supports only token requests. + The token is a very simple JWT, with no signature or validation. + The generated tokens are cached and they are valid for 24 hours. + """ + TOKEN_DURATION = 86400 # 24 hours + + ERROR_TEMPLATES = { + 'NotFound': (404, { + 'error': 'not_found', + 'error_description': 'IMDS endpoint is updating.', + }), + 'InternalError': (500, { + 'error': 'unknown', + 'error_description': 'Failed to retrieve token from the Active directory. For details see logs in .', + }), + 'Throttled': (429, { + 'error': 'throttled', + 'error_description': 'API Rate Limits have been exceeded.', + }), + 'NoIdentity': (401, { + 'error': 'invalid_request', + 'error_description': 'Identity not found.', # E.g., no Managed Identity assigned to the VM + }), + 'MultipleIdentities': (401, { + 'error': 'invalid_request', + 'error_description': 'Multiple user assigned identities exist, please specify the clientId / resourceId of the identity in the token request', + }), + } + + def __init__(self): + self.token = None + self.token_resource = None + self.expires = 0 + self.error_config = None + self.lock = threading.Lock() + + @staticmethod + def base64url_encode(data: bytes) -> bytes: + return base64.urlsafe_b64encode(data).rstrip(b'=') + + @staticmethod + def generate_fake_jwt(resource: str, duration: int = TOKEN_DURATION): + epoch = int(time.time()) + expires = epoch + duration + header = {"alg": "none", "typ": "JWT"} + payload = { + "aud": resource, + "iss": "https://sts.windows.net/mock-tenant-id/", + "sub": "mock-client-id", + "exp": expires, + "nbf": epoch, + "iat": epoch + } + + jwt_header = AzureIMDS.base64url_encode(json.dumps(header).encode()) + jwt_payload = AzureIMDS.base64url_encode(json.dumps(payload).encode()) + + token = jwt_header + b"." + jwt_payload + b"." + return (token.decode(), expires) + + def expired(self): + return time.time() >= self.expires + + def _error_response(self, error_type): + if error_type not in self.ERROR_TEMPLATES: + raise ValueError(f"Unknown error type: {error_type}") + return self.ERROR_TEMPLATES[error_type] + + def get_access_token(self, resource: str): + with self.lock: + if self.error_config is not None: + error = self.error_config + if error.count > 0: + error.count -= 1 + return self._error_response(error.error_type) + if not self.token or self.token_resource != resource or self.expired(): + self.token, self.expires = self.generate_fake_jwt(resource) + self.token_resource = resource + return (200, { + "access_token": self.token, + "expires_in": str(int(self.expires - time.time())), + "resource": resource, + "token_type": "Bearer" + }) + + def set_error_config(self, error_type, count): + with self.lock: + self.error_config = Error(error_type, count) + + +class AzureEntraSTS: + """ + Azure Entra Security Token Service (STS) + + Supports token requests with Service Principal credentials (secrets, certificates). + Validates the request form, but not the actual values. + The generated tokens are not cached, and are valid for 1 hour. + """ + TOKEN_DURATION = 3600 # 1 hour + + # https://learn.microsoft.com/en-us/entra/identity-platform/reference-error-codes#handling-error-codes-in-your-application + ERROR_TEMPLATES = { + 'InvalidRequest': (400, { + 'error': 'invalid_request', + 'error_description': "AADSTS900144: The request body must contain the following parameter: '{missing_param}'", + 'error_codes': [900144], + }), + 'InvalidClient': (400, { + 'error': 'invalid_client', + 'error_description': "AADSTS7000216: 'client_assertion', 'client_secret' or 'request' is required for the 'client_credentials' grant type.", + 'error_codes': [7000216], + }), + 'InvalidSecret': (401, { + 'error': 'invalid_client', + 'error_description': "AADSTS7000215: Invalid client secret is provided.", + 'error_codes': [7000215], + }), + 'TemporarilyUnavailable': (503, { + 'error': 'temporarily_unavailable', + 'error_description': 'The server is temporarily too busy to handle the request.', + 'error_codes': [90006], + }), + } + + @staticmethod + def base64url_encode(data: bytes) -> bytes: + return base64.urlsafe_b64encode(data).rstrip(b'=') + + @staticmethod + def generate_fake_jwt(client_id: str, tenant_id: str, resource: str, duration: int = TOKEN_DURATION): + epoch = int(time.time()) + expires = epoch + duration + header = {"alg": "none", "typ": "JWT"} + payload = { + "aud": resource, + "iss": f"https://sts.windows.net/{tenant_id}/", + "appid": client_id, + "tid": tenant_id, + "oid": "00000000-0000-0000-0000-000000000000", + "sub": client_id, + "exp": expires, + "nbf": epoch, + "iat": epoch + } + + jwt_header = AzureEntraSTS.base64url_encode(json.dumps(header).encode()) + jwt_payload = AzureEntraSTS.base64url_encode(json.dumps(payload).encode()) + + token = jwt_header + b"." + jwt_payload + b"." + return token.decode() + + def __init__(self): + self.error_config = None + self.lock = threading.Lock() + + def _error_response(self, error_type, **kwargs): + if error_type not in self.ERROR_TEMPLATES: + raise ValueError(f"Unknown error type: {error_type}") + status_code, template = self.ERROR_TEMPLATES[error_type] + response = template.copy() + if kwargs: + response['error_description'] = response['error_description'].format(**kwargs) + return status_code, response + + def get_access_token(self, tenant_id, form_data): + return_error = False + with self.lock: + if self.error_config is not None: + error = self.error_config + if error.count > 0: + error.count -= 1 + return_error = True + if return_error: + return self._error_response(error.error_type) + # Just check that all the required parameters are provided. + # Ignore their values. + required_params = ['client_id', 'scope', 'grant_type'] + missing_params = [param for param in required_params if param not in form_data] + if missing_params: + return self._error_response('InvalidRequest', missing_params=missing_params[0]) + if 'client_secret' not in form_data and 'client_assertion' not in form_data: + return self._error_response('InvalidClient') + if 'client_assertion' in form_data and 'client_assertion_type' not in form_data: + return self._error_response('InvalidRequest', missing_params='client_assertion_type') + + client_id=form_data['client_id'][0] + resource = form_data['scope'][0].removesuffix('.default') + + token = self.generate_fake_jwt(client_id, tenant_id, resource) + return (200, { + "token_type": "Bearer", + "expires_in": self.TOKEN_DURATION, + "ext_expires_in": self.TOKEN_DURATION, + "access_token": token + }) + + def set_error_config(self, error_type, count): + with self.lock: + self.error_config = Error(error_type, count) + + +class RequestHandler(BaseHTTPRequestHandler): + """ + HTTP request handler + + Supported endpoints: + - POST /keys/{key_name}/{key_version}/{wrapkey, unwrapkey} (Vault) + - GET /metadata/identity/oauth2/token (IMDS) + - POST /{tenant_id}/oauth2/v2.0/token (Entra STS) + - POST /config/error?service={vault,imds,entra}&error_type={KeyNotFound,Forbidden,...}&repeat={count} (Error injection) + """ + keyop_re = re.compile(r'^/keys/([^/]+)/([^/]*)/([^/?]+)(\?.*)?$') + entra_token_re = re.compile(r'^/([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})/oauth2/v2\.0/token$') + imds_token_path = '/metadata/identity/oauth2/token' + error_path = "/config/error?" + + def __init__(self, vault, imds, entra, logger, *args, **kwargs): + self.vault = vault + self.imds = imds + self.entra = entra + self.logger = logger + super().__init__(*args, **kwargs) + + def log_message(self, format, *args): + if not self.logger.isEnabledFor(logging.DEBUG): + return + self.logger.debug("%s - - [%s] %s", + self.client_address[0], + self.log_date_time_string(), + format % args) + + def handle_one_request(self): + # Check for TLS handshake (first byte is typically 0x16 in TLS) + try: + first_byte = self.rfile.peek(1)[0] + if first_byte == 0x16: + self.close_connection = True + self.logger.error("TLS/HTTPS connection attempted. This server only supports plain HTTP.") + return + except (IndexError, AttributeError): + pass + super().handle_one_request() + + def _extract_client_id(self): + """Extract client ID from Authorization header if present""" + auth_header = self.headers.get('Authorization', '') + if auth_header.startswith('Bearer '): + try: + token = auth_header[7:] + parts = token.split('.') + if len(parts) >= 2: + padding = '=' * (4 - len(parts[1]) % 4) + payload = json.loads(base64.urlsafe_b64decode(parts[1] + padding).decode('utf-8')) + return payload.get('appid') + except Exception: + pass + return None + + def _send_response(self, status_code, mime_type, response): + self.send_response(status_code) + self.send_header('Content-Type', mime_type) + self.send_header('Content-Length', str(len(response))) + self.end_headers() + self.wfile.write(response) + + def do_POST(self): + # Error injection endpoint + if self.path.startswith(self.error_path): + get = lambda params, key : params[key][0] if key in params and params[key] else None + parsed_url = urllib.parse.urlparse(self.path) + params = urllib.parse.parse_qs(parsed_url.query) + + service = get(params, 'service') + error_type = get(params, 'error_type') + repeat = int(get(params, 'repeat') or '1') + + if None in (service, error_type): + resp = {"error": {"code": "BadParameter", "message": "Query parameters 'service' and 'error_type' are required."}} + self._send_response(400, 'application/json', json.dumps(resp).encode('utf-8')) + return + + svc = getattr(self, service, None) + if svc is None: + resp = {"error": {"code": "BadParameter", "message": "Invalid value for 'service'. Must be one of: 'vault', 'imds', 'entra'."}} + self._send_response(400, 'application/json', json.dumps(resp).encode('utf-8')) + return + + svc.set_error_config(error_type, repeat) + self.send_response(200) + self.end_headers() + return + + # Entra token request endpoint + match = self.entra_token_re.match(self.path) + if match: + tenant_id = match.group(1) + content_length = int(self.headers.get('Content-Length', 0)) + body = self.rfile.read(content_length).decode('utf-8') + form_data = urllib.parse.parse_qs(body) + + status, resp = self.entra.get_access_token(tenant_id, form_data) + self._send_response(status, 'application/json', json.dumps(resp).encode('utf-8')) + return + + # Vault key operations endpoint + match = self.keyop_re.match(self.path) + if match: + key_name, key_version, operation, _ = match.groups() + content_length = self.headers['Content-Length'] + data = self.rfile.read(int(content_length) if content_length else -1) + + if self._extract_client_id() == INVALID_CLIENT_ID: + status, resp = self.vault._error_response('Forbidden') + self._send_response(status, 'application/json', json.dumps(resp).encode('utf-8')) + return + + if operation == 'wrapkey': + status, resp = self.vault.wrapkey(data, key_name, key_version) + self._send_response(status, 'application/json', json.dumps(resp).encode('utf-8')) + elif operation == 'unwrapkey': + status, resp = self.vault.unwrapkey(data, key_name, key_version) + self._send_response(status, 'application/json', json.dumps(resp).encode('utf-8')) + else: + resp = {"error":{"code":"BadParameter","message":f"Method POST does not allow operation '{operation}'"}} + self._send_response(400, 'application/json', json.dumps(resp).encode('utf-8')) + return + + self._send_response(404, 'text/html', UNKNOWN_RESOURCE.encode('utf-8')) + + def do_GET(self): + parsed_path = urllib.parse.urlparse(self.path) + if parsed_path.path == self.imds_token_path: + # IMDS token request endpoint + metadata_header = self.headers.get('Metadata') + if metadata_header != 'true': + resp = {'error': 'invalid_request', 'error_description': 'Metadata header missing or invalid.'} + self._send_response(400, 'application/json', json.dumps(resp).encode('utf-8')) + return + query_params = urllib.parse.parse_qs(parsed_path.query) + for param in ['api-version', 'resource']: + if param not in query_params: + resp = {'error': 'invalid_request', 'error_description': f'Missing required parameter "{param}". Fix the request and retry.'} + self._send_response(400, 'application/json', json.dumps(resp).encode('utf-8')) + return + status, resp = self.imds.get_access_token(query_params['resource'][0]) + self._send_response(status, 'application/json', json.dumps(resp).encode('utf-8')) + else: + self._send_response(404, 'text/plain', '') + + +class MockAzureVaultServer: + def __init__(self, host, port, logger): + self.logger = logger + self.vault = AzureVault() + self.imds = AzureIMDS() + self.entra = AzureEntraSTS() + handler = partial(RequestHandler, self.vault, self.imds, self.entra, self.logger) + self.server = ThreadingHTTPServer((host, port), handler) + self.server_thread = None + self.server.request_queue_size = 10 + self.is_stopped = asyncio.Event() + self.is_stopped.set() + self.envs = {'MOCK_AZURE_VAULT_SERVER_PORT': f'{port}', 'MOCK_AZURE_VAULT_SERVER_HOST': f'{host}'} + + def _set_environ(self): + for key, value in self.envs.items(): + os.environ[key] = value + + def _unset_environ(self): + for key in self.envs.keys(): + del os.environ[key] + + def get_envs_settings(self): + return self.envs + + async def start(self): + if self.is_stopped.is_set(): + self.logger.info(f'Starting Azure Vault mock server on {self.server.server_address}') + self._set_environ() + loop = asyncio.get_running_loop() + self.server_thread = loop.run_in_executor(None, self.server.serve_forever) + self.is_stopped.clear() + + async def stop(self): + if not self.is_stopped.is_set(): + self.logger.info(f'Stopping Azure Vault mock server') + self._unset_environ() + self.server.shutdown() + self.server.server_close() + await self.server_thread + self.is_stopped.set() + + async def run(self): + try: + await self.start() + await self.is_stopped.wait() + except (Exception, asyncio.CancelledError) as e: + self.logger.error(f'Server error: {e}') + await self.stop() \ No newline at end of file diff --git a/test/pylib/rest_client.py b/test/pylib/rest_client.py index 0fadcfe9e7d8..6bf06c35ade6 100644 --- a/test/pylib/rest_client.py +++ b/test/pylib/rest_client.py @@ -441,14 +441,26 @@ async def get_raft_leader(self, node_ip: str, group_id: Optional[str] = None) -> async def repair(self, node_ip: str, keyspace: str, table: str, ranges: str = '') -> None: """Repair the given table and wait for it to complete""" - if ranges: - params = {"columnFamilies": table, "ranges": ranges} + vnode_keyspaces = await self.client.get_json(f"/storage_service/keyspaces", host=node_ip, params={"replication": "vnodes"}) + if keyspace in vnode_keyspaces: + if ranges: + params = {"columnFamilies": table, "ranges": ranges} + else: + params = {"columnFamilies": table} + sequence_number = await self.client.post_json(f"/storage_service/repair_async/{keyspace}", host=node_ip, params=params) + status = await self.client.get_json(f"/storage_service/repair_status", host=node_ip, params={"id": str(sequence_number)}) + if status != 'SUCCESSFUL': + raise Exception(f"Repair id {sequence_number} on node {node_ip} for table {keyspace}.{table} failed: status={status}") else: - params = {"columnFamilies": table} - sequence_number = await self.client.post_json(f"/storage_service/repair_async/{keyspace}", host=node_ip, params=params) - status = await self.client.get_json(f"/storage_service/repair_status", host=node_ip, params={"id": str(sequence_number)}) - if status != 'SUCCESSFUL': - raise Exception(f"Repair id {sequence_number} on node {node_ip} for table {keyspace}.{table} failed: status={status}") + if ranges: + raise ValueError(f"Ranges parameter is not supported for tablet keyspaces") + params={ + "ks": keyspace, + "table": table, + "tokens": "all", + "await_completion": "true", + } + await self.client.post_json(f"/storage_service/tablets/repair", host=node_ip, params=params) def __get_autocompaction_url(self, keyspace: str, table: Optional[str] = None) -> str: """Return autocompaction url for the given keyspace/table""" diff --git a/test/pylib/start_azure_vault_mock.py b/test/pylib/start_azure_vault_mock.py new file mode 100755 index 000000000000..4f21eb744a87 --- /dev/null +++ b/test/pylib/start_azure_vault_mock.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# +# Copyright (C) 2025-present ScyllaDB +# +# SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 +# + +import signal +import logging +import asyncio +import argparse +from azure_vault_server_mock import MockAzureVaultServer + + +async def run(): + parser = argparse.ArgumentParser(description="Start Azure Vault mock server") + parser.add_argument('--host', default='127.0.0.1') + parser.add_argument('--port', type=int, default=5467) + parser.add_argument('--log-level', default=logging.WARNING, + choices=logging.getLevelNamesMapping().keys(), + help="Set log level") + args = parser.parse_args() + logging.basicConfig(level=args.log_level, format='%(levelname)s %(asctime)s - %(message)s') + server = MockAzureVaultServer(args.host, args.port, logging.getLogger('azure-vault-server')) + + print('Starting Azure Vault mock server') + await server.start() + signal.sigwait({signal.SIGINT, signal.SIGTERM}) + print('Stopping Azure Vault mock server') + await server.stop() + + +if __name__ == '__main__': + asyncio.run(run()) \ No newline at end of file diff --git a/test/pylib/suite/base.py b/test/pylib/suite/base.py index 26cf731bb5b0..3a23db050568 100644 --- a/test/pylib/suite/base.py +++ b/test/pylib/suite/base.py @@ -33,6 +33,7 @@ from test.pylib.resource_gather import get_resource_gather, setup_cgroup from test.pylib.s3_proxy import S3ProxyServer from test.pylib.s3_server_mock import MockS3Server +from test.pylib.azure_vault_server_mock import MockAzureVaultServer from test.pylib.util import LogPrefixAdapter, get_xdist_worker_id if TYPE_CHECKING: @@ -575,6 +576,14 @@ async def make_async_finalize(): await proxy_s3_server.start() TestSuite.artifacts.add_exit_artifact(None, proxy_s3_server.stop) + azure_vault_server = MockAzureVaultServer( + host=await hosts.lease_host(), + port=5467, + logger=LogPrefixAdapter(logger=logging.getLogger("azure_vault"), extra={"prefix": "azure_vault"}), + ) + await azure_vault_server.start() + TestSuite.artifacts.add_exit_artifact(None, azure_vault_server.stop) + def find_suite_config(path: pathlib.Path) -> pathlib.Path: for directory in (path.joinpath("_") if path.is_dir() else path).absolute().relative_to(TEST_DIR).parents: diff --git a/test/rest_api/test_repair_task.py b/test/rest_api/test_repair_task.py index 66503faa161d..97431fe02765 100644 --- a/test/rest_api/test_repair_task.py +++ b/test/rest_api/test_repair_task.py @@ -14,105 +14,105 @@ def run_repair_and_wait(rest_api, keyspace): resp.raise_for_status() return sequence_number -def test_repair_task_progress_finished_task(cql, this_dc, rest_api): +def test_repair_task_progress_finished_task(cql, this_dc, rest_api, test_keyspace_vnodes): drain_module_tasks(rest_api, module_name) with set_tmp_task_ttl(rest_api, long_time): # Insert some data. - with new_test_keyspace(cql, f"WITH REPLICATION = {{ 'class' : 'NetworkTopologyStrategy', '{this_dc}' : 1 }}") as keyspace: - schema = 'p int, v text, primary key (p)' - with new_test_table(cql, keyspace, schema) as t0: - stmt = cql.prepare(f"INSERT INTO {t0} (p, v) VALUES (?, ?)") - cql.execute(stmt, [0, 'hello']) - cql.execute(stmt, [1, 'world']) - with new_test_table(cql, keyspace, schema) as t1: - stmt = cql.prepare(f"INSERT INTO {t1} (p, v) VALUES (?, ?)") - cql.execute(stmt, [2, 'hello']) - cql.execute(stmt, [3, 'world']) - - sequence_number = run_repair_and_wait(rest_api, keyspace) - - # Get all repairs. - ids = [task["task_id"] for task in list_tasks(rest_api, "repair") if task["sequence_number"] == sequence_number] - assert len(ids) == 1, "Wrong number of internal repair tasks" - status_tree = get_task_status_recursively(rest_api, ids[0]) - status = status_tree[0] - assert status["progress_completed"] == status["progress_total"], "Incorrect task progress" - - assert "children_ids" in status, "Shard tasks weren't created" - children = [s for s in status_tree if s["parent_id"] == status["id"]] - assert all([child["progress_completed"] == child["progress_total"] for child in children]), "Some shard tasks have incorrect progress" - - assert sum([child["progress_total"] for child in children]) == status["progress_total"], "Total progress of parent is not equal to children total progress sum" - assert sum([child["progress_completed"] for child in children]) == status["progress_completed"], "Completed progress of parent is not equal to children completed progress sum" + keyspace = test_keyspace_vnodes + schema = 'p int, v text, primary key (p)' + with new_test_table(cql, keyspace, schema) as t0: + stmt = cql.prepare(f"INSERT INTO {t0} (p, v) VALUES (?, ?)") + cql.execute(stmt, [0, 'hello']) + cql.execute(stmt, [1, 'world']) + with new_test_table(cql, keyspace, schema) as t1: + stmt = cql.prepare(f"INSERT INTO {t1} (p, v) VALUES (?, ?)") + cql.execute(stmt, [2, 'hello']) + cql.execute(stmt, [3, 'world']) + + sequence_number = run_repair_and_wait(rest_api, keyspace) + + # Get all repairs. + ids = [task["task_id"] for task in list_tasks(rest_api, "repair") if task["sequence_number"] == sequence_number] + assert len(ids) == 1, "Wrong number of internal repair tasks" + status_tree = get_task_status_recursively(rest_api, ids[0]) + status = status_tree[0] + assert status["progress_completed"] == status["progress_total"], "Incorrect task progress" + + assert "children_ids" in status, "Shard tasks weren't created" + children = [s for s in status_tree if s["parent_id"] == status["id"]] + assert all([child["progress_completed"] == child["progress_total"] for child in children]), "Some shard tasks have incorrect progress" + + assert sum([child["progress_total"] for child in children]) == status["progress_total"], "Total progress of parent is not equal to children total progress sum" + assert sum([child["progress_completed"] for child in children]) == status["progress_completed"], "Completed progress of parent is not equal to children completed progress sum" drain_module_tasks(rest_api, module_name) -def test_repair_task_tree(cql, this_dc, rest_api): +def test_repair_task_tree(cql, this_dc, rest_api, test_keyspace_vnodes): drain_module_tasks(rest_api, module_name) with set_tmp_task_ttl(rest_api, long_time): - with new_test_keyspace(cql, f"WITH REPLICATION = {{ 'class' : 'NetworkTopologyStrategy', '{this_dc}' : 1 }}") as keyspace: - schema = 'p int, v text, primary key (p)' - with new_test_table(cql, keyspace, schema) as t0: - stmt = cql.prepare(f"INSERT INTO {t0} (p, v) VALUES (?, ?)") - cql.execute(stmt, [0, 'hello']) - cql.execute(stmt, [1, 'world']) - with new_test_table(cql, keyspace, schema) as t1: - stmt = cql.prepare(f"INSERT INTO {t1} (p, v) VALUES (?, ?)") - cql.execute(stmt, [2, 'hello']) - cql.execute(stmt, [3, 'world']) - - run_repair_and_wait(rest_api, keyspace) - - ids = [task["task_id"] for task in list_tasks(rest_api, module_name)] - assert ids, "repair task was not created" - assert len(ids) == 1, "incorrect number of non-internal tasks" - - status_tree = get_task_status_recursively(rest_api, ids[0]) - status = status_tree[0] - assert status["state"] == "done", f"tasks with id {status['id']} failed" - - repair_tree_depth = 1 - check_child_parent_relationship(rest_api, status_tree, status, False) + keyspace = test_keyspace_vnodes + schema = 'p int, v text, primary key (p)' + with new_test_table(cql, keyspace, schema) as t0: + stmt = cql.prepare(f"INSERT INTO {t0} (p, v) VALUES (?, ?)") + cql.execute(stmt, [0, 'hello']) + cql.execute(stmt, [1, 'world']) + with new_test_table(cql, keyspace, schema) as t1: + stmt = cql.prepare(f"INSERT INTO {t1} (p, v) VALUES (?, ?)") + cql.execute(stmt, [2, 'hello']) + cql.execute(stmt, [3, 'world']) + + run_repair_and_wait(rest_api, keyspace) + + ids = [task["task_id"] for task in list_tasks(rest_api, module_name)] + assert ids, "repair task was not created" + assert len(ids) == 1, "incorrect number of non-internal tasks" + + status_tree = get_task_status_recursively(rest_api, ids[0]) + status = status_tree[0] + assert status["state"] == "done", f"tasks with id {status['id']} failed" + + repair_tree_depth = 1 + check_child_parent_relationship(rest_api, status_tree, status, False) drain_module_tasks(rest_api, module_name) -def test_repair_task_progress(cql, this_dc, rest_api): +def test_repair_task_progress(cql, this_dc, rest_api, test_keyspace_vnodes): drain_module_tasks(rest_api, module_name) with set_tmp_task_ttl(rest_api, long_time): # Insert some data. - with new_test_keyspace(cql, f"WITH REPLICATION = {{ 'class' : 'NetworkTopologyStrategy', '{this_dc}' : 1 }}") as keyspace: - schema = 'p int, v text, primary key (p)' - with new_test_table(cql, keyspace, schema) as t0: - stmt = cql.prepare(f"INSERT INTO {t0} (p, v) VALUES (?, ?)") - cql.execute(stmt, [0, 'hello']) - cql.execute(stmt, [1, 'world']) - with new_test_table(cql, keyspace, schema) as t1: - stmt = cql.prepare(f"INSERT INTO {t1} (p, v) VALUES (?, ?)") - cql.execute(stmt, [2, 'hello']) - cql.execute(stmt, [3, 'world']) - - injection = "repair_shard_repair_task_impl_do_repair_ranges" - with scylla_inject_error(rest_api, injection, True): - resp = rest_api.send("POST", f"storage_service/repair_async/{keyspace}") - resp.raise_for_status() - sequence_number = resp.json() - - # Get all repairs. - statuses = [] - while not statuses or "children_ids" not in statuses[0]: - statuses = [get_task_status(rest_api, task["task_id"]) for task in list_tasks(rest_api, "repair") if task["sequence_number"] == sequence_number] - assert len(statuses) == 1, "Wrong number of internal repair tasks" - status = statuses[0] - - for child_ident in status["children_ids"]: - # Check if task state is correct. - child_status = get_task_status(rest_api, child_ident["task_id"]) - assert child_status["state"] == "running", "Incorrect task progress" - assert child_status["progress_completed"] * 2 <= child_status["progress_total"], "Incorrect task progress" - - resp = rest_api.send("POST", f"v2/error_injection/injection/{injection}/message") - resp.raise_for_status() - - child_status = wait_for_task(rest_api, status["id"]) - status_tree = get_task_status_recursively(rest_api, status["id"]) - for child_status in get_children(status_tree, status["id"]): - assert child_status["progress_completed"] == child_status["progress_total"], "Incorrect task progress" + keyspace = test_keyspace_vnodes + schema = 'p int, v text, primary key (p)' + with new_test_table(cql, keyspace, schema) as t0: + stmt = cql.prepare(f"INSERT INTO {t0} (p, v) VALUES (?, ?)") + cql.execute(stmt, [0, 'hello']) + cql.execute(stmt, [1, 'world']) + with new_test_table(cql, keyspace, schema) as t1: + stmt = cql.prepare(f"INSERT INTO {t1} (p, v) VALUES (?, ?)") + cql.execute(stmt, [2, 'hello']) + cql.execute(stmt, [3, 'world']) + + injection = "repair_shard_repair_task_impl_do_repair_ranges" + with scylla_inject_error(rest_api, injection, True): + resp = rest_api.send("POST", f"storage_service/repair_async/{keyspace}") + resp.raise_for_status() + sequence_number = resp.json() + + # Get all repairs. + statuses = [] + while not statuses or "children_ids" not in statuses[0]: + statuses = [get_task_status(rest_api, task["task_id"]) for task in list_tasks(rest_api, "repair") if task["sequence_number"] == sequence_number] + assert len(statuses) == 1, "Wrong number of internal repair tasks" + status = statuses[0] + + for child_ident in status["children_ids"]: + # Check if task state is correct. + child_status = get_task_status(rest_api, child_ident["task_id"]) + assert child_status["state"] == "running", "Incorrect task progress" + assert child_status["progress_completed"] * 2 <= child_status["progress_total"], "Incorrect task progress" + + resp = rest_api.send("POST", f"v2/error_injection/injection/{injection}/message") + resp.raise_for_status() + + child_status = wait_for_task(rest_api, status["id"]) + status_tree = get_task_status_recursively(rest_api, status["id"]) + for child_status in get_children(status_tree, status["id"]): + assert child_status["progress_completed"] == child_status["progress_total"], "Incorrect task progress" drain_module_tasks(rest_api, module_name) diff --git a/tools/scylla-local-file-key-generator.cc b/tools/scylla-local-file-key-generator.cc index 5f46262b124b..159e3fc99183 100644 --- a/tools/scylla-local-file-key-generator.cc +++ b/tools/scylla-local-file-key-generator.cc @@ -12,7 +12,7 @@ #include #include -#include "compound.hh" +#include "keys/compound.hh" #include "db/marshal/type_parser.hh" #include "schema/schema_builder.hh" #include "tools/utils.hh" diff --git a/tools/scylla-sstable.cc b/tools/scylla-sstable.cc index 87c9d9a63aff..2ea70890f5aa 100644 --- a/tools/scylla-sstable.cc +++ b/tools/scylla-sstable.cc @@ -1081,8 +1081,10 @@ void dump_index_operation(schema_ptr schema, reader_permit permit, const std::ve auto pkey = idx_reader->get_partition_key(); writer.StartObject(); - writer.Key("key"); - writer.DataKey(*schema, pkey); + if (pkey) { + writer.Key("key"); + writer.DataKey(*schema, *pkey); + } writer.Key("pos"); writer.Uint64(pos); writer.EndObject(); diff --git a/tools/scylla-types.cc b/tools/scylla-types.cc index d187a926087a..e1f5ca96ff61 100644 --- a/tools/scylla-types.cc +++ b/tools/scylla-types.cc @@ -9,7 +9,7 @@ #include #include -#include "compound.hh" +#include "keys/compound.hh" #include "db/marshal/type_parser.hh" #include "schema/schema_builder.hh" #include "tools/utils.hh" diff --git a/tools/toolchain/image b/tools/toolchain/image index dec0204e4239..2fef394a4b85 100644 --- a/tools/toolchain/image +++ b/tools/toolchain/image @@ -1 +1 @@ -docker.io/scylladb/scylla-toolchain:fedora-42-20250630 +docker.io/scylladb/scylla-toolchain:fedora-42-20250727 diff --git a/tools/toolchain/optimized_clang.sh b/tools/toolchain/optimized_clang.sh index 4a1656c15b21..d9ad669d30e3 100755 --- a/tools/toolchain/optimized_clang.sh +++ b/tools/toolchain/optimized_clang.sh @@ -65,7 +65,7 @@ SCYLLA_BUILD_DIR_FULLPATH="${SCYLLA_DIR}"/"${SCYLLA_BUILD_DIR}" SCYLLA_NINJA_FILE_FULLPATH="${SCYLLA_DIR}"/"${SCYLLA_NINJA_FILE}" # Which LLVM release to build in order to compile Scylla -LLVM_CLANG_TAG=20.1.7 +LLVM_CLANG_TAG=20.1.8 CLANG_SUFFIX=20 CLANG_ARCHIVE=$(cd "${SCYLLA_DIR}" && realpath -m "${CLANG_ARCHIVE}") diff --git a/utils/CMakeLists.txt b/utils/CMakeLists.txt index a53169d7d488..a92db1ca467a 100644 --- a/utils/CMakeLists.txt +++ b/utils/CMakeLists.txt @@ -63,7 +63,12 @@ target_sources(utils s3/credentials_providers/instance_profile_credentials_provider.cc s3/credentials_providers/sts_assume_role_credentials_provider.cc s3/credentials_providers/aws_credentials_provider_chain.cc - s3/utils/manip_s3.cc) + s3/utils/manip_s3.cc + azure/identity/credentials.cc + azure/identity/service_principal_credentials.cc + azure/identity/managed_identity_credentials.cc + azure/identity/azure_cli_credentials.cc + azure/identity/default_credentials.cc) target_include_directories(utils PUBLIC ${CMAKE_SOURCE_DIR} diff --git a/utils/azure/identity/azure_cli_credentials.cc b/utils/azure/identity/azure_cli_credentials.cc new file mode 100644 index 000000000000..3eee69bedaf5 --- /dev/null +++ b/utils/azure/identity/azure_cli_credentials.cc @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#include +#include +#include +#include +#include + +#include "utils/rjson.hh" +#include "utils/exceptions.hh" +#include "exceptions.hh" +#include "azure_cli_credentials.hh" + +namespace azure { + +azure_cli_credentials::azure_cli_credentials(const sstring& logctx) + : credentials(logctx) +{} + +access_token azure_cli_credentials::make_token(const rjson::value& json, const resource_type& resource_uri) { + auto token = rjson::get(json, "accessToken"); + auto expires_on = rjson::get(json, "expires_on"); + return { token, std::chrono::system_clock::from_time_t(expires_on), resource_uri }; +} + +std::vector azure_cli_credentials::make_env() { + std::vector vec; + vec.reserve(3); + + auto path = std::getenv("PATH"); + auto home = std::getenv("HOME"); + auto azure_config_dir = std::getenv("AZURE_CONFIG_DIR"); + + const auto DEFAULT_PATH = "/usr/bin:/usr/local/bin"; + if (path && *path) { + vec.emplace_back(seastar::format("PATH={}:{}", path, DEFAULT_PATH)); + } else { + vec.emplace_back(seastar::format("PATH={}", DEFAULT_PATH)); + } + if (home) { + vec.emplace_back(seastar::format("HOME={}", home)); + } + if (azure_config_dir) { + vec.emplace_back(seastar::format("AZURE_CONFIG_DIR={}", azure_config_dir)); + } + return vec; +} + +future<> azure_cli_credentials::refresh(const resource_type& resource_uri) { + try { + co_await do_refresh(resource_uri); + } catch (auth_error&) { + throw; + } catch (...) { + std::throw_with_nested(auth_error(fmt::format("{}", std::current_exception()))); + } +} + +future<> azure_cli_credentials::do_refresh(const resource_type& resource_uri) { + // This timeout is purely a safeguard for badly-behaved CLIs. + // It is not expected to be reached under normal circumstances. + const auto timeout = std::chrono::seconds(5); + + az_creds_logger.debug("[{}] Refreshing token", *this); + using namespace seastar::experimental; + constexpr char SHELL[] = "/bin/sh"; + const auto azcmd = seastar::format("az account get-access-token --resource {}", resource_uri); + spawn_parameters params = { + .argv = { SHELL, "-c", azcmd }, + .env = make_env(), + }; + auto process = co_await spawn_process(SHELL, params); + + auto cout = process.cout(); + auto cerr = process.cerr(); + sstring output; + sstring error; + + co_await with_timeout(timer<>::clock::now() + timeout, [&] -> future<> { + auto read_cout = [&] -> future<> { output = co_await util::read_entire_stream_contiguous(cout); }; + auto read_cerr = [&] -> future<> { error = co_await util::read_entire_stream_contiguous(cerr); }; + co_await when_all_succeed(read_cout, read_cerr); + }()).handle_exception([&] (std::exception_ptr ep) { + if (try_catch(ep)) { + az_creds_logger.debug("[{}] Azure CLI not responding. Killing it forcefully...", *this); + process.kill(); + } + std::rethrow_exception(ep); + }).finally(seastar::coroutine::lambda([&] -> future<> { + auto wstatus = co_await process.wait(); + auto* exited = std::get_if(&wstatus); + auto* signaled = std::get_if(&wstatus); + if (exited && exited->exit_code != EXIT_SUCCESS) { + az_creds_logger.debug("[{}] Azure CLI failed with exit status ({}): {}", *this, exited->exit_code, error); + throw auth_error(seastar::format("Azure CLI failed with exit status ({})", exited->exit_code)); + } + if (signaled) { + az_creds_logger.debug("[{}] Azure CLI was terminated by signal: {} ({})", *this, signaled->terminating_signal, strsignal(signaled->terminating_signal)); + throw auth_error(seastar::format("Azure CLI was terminated by signal: {} ({})", signaled->terminating_signal, strsignal(signaled->terminating_signal))); + } + })); + _token = make_token(rjson::parse(output), resource_uri); +} + +} \ No newline at end of file diff --git a/utils/azure/identity/azure_cli_credentials.hh b/utils/azure/identity/azure_cli_credentials.hh new file mode 100644 index 000000000000..d79acf52aeb3 --- /dev/null +++ b/utils/azure/identity/azure_cli_credentials.hh @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#pragma once + +#include "utils/rjson.hh" +#include "credentials.hh" + +namespace azure { + +class azure_cli_credentials : public credentials { + static constexpr char NAME[] = "AzureCliCredentials"; + + std::string_view get_name() const override { return NAME; }; + static std::vector make_env(); + future<> refresh(const resource_type& resource_uri) override; + future<> do_refresh(const resource_type& resource_uri); + access_token make_token(const rjson::value&, const resource_type&); +public: + azure_cli_credentials(const sstring& logctx = ""); +}; + +} + +template <> +struct fmt::formatter { + constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } + auto format(const azure::azure_cli_credentials& creds, fmt::format_context& ctxt) const { + return fmt::format_to(ctxt.out(), "{}", *dynamic_cast(&creds)); + } +}; \ No newline at end of file diff --git a/utils/azure/identity/credentials.cc b/utils/azure/identity/credentials.cc new file mode 100644 index 000000000000..09fc329b5c96 --- /dev/null +++ b/utils/azure/identity/credentials.cc @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#include + +#include "credentials.hh" + +logger az_creds_logger("azure_creds"); + +namespace azure { + +access_token::access_token(const sstring& token, const timestamp_type& expiry, const resource_type& resource_uri) + : token(token) + , expiry(expiry) + , resource_uri(resource_uri) +{} + +bool access_token::empty() const { + return token.empty(); +} + +bool access_token::expired() const { + if (empty()) { + return true; + } + return timeout_clock::now() >= this->expiry; +} + +future credentials::get_access_token(const resource_type& resource_uri) { + if (_token.expired() || _token.resource_uri != resource_uri) { + co_await refresh(resource_uri); + } + co_return _token; +} + +} diff --git a/utils/azure/identity/credentials.hh b/utils/azure/identity/credentials.hh new file mode 100644 index 000000000000..5ba01ede7feb --- /dev/null +++ b/utils/azure/identity/credentials.hh @@ -0,0 +1,71 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#pragma once + +#include + +#include + +#include "utils/log.hh" + +using namespace seastar; + +extern logger az_creds_logger; + +namespace azure { + +using timeout_clock = std::chrono::system_clock; +using timestamp_type = typename timeout_clock::time_point; +using resource_type = sstring; + +struct access_token { + sstring token; + timestamp_type expiry; + resource_type resource_uri; + + access_token() = default; + access_token(const sstring& token, const timestamp_type& expiry, const resource_type& resource_uri); + + bool empty() const; + bool expired() const; +}; + +class credentials { +protected: + access_token _token; + // An externally provided context string for logging. + // Allows disambiguating logs from multiple instances of the same credential type. + const sstring _logctx; +private: + friend struct fmt::formatter; + + virtual std::string_view get_name() const = 0; + virtual future<> refresh(const resource_type& resource_uri) = 0; +public: + explicit credentials(const sstring& logctx = "") : _logctx(logctx) {} + virtual ~credentials() = default; + + // Throws azure::auth_error on failure. + future get_access_token(const resource_type& resource_uri); +}; + +} + +template <> +struct fmt::formatter { + constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } + auto format(const azure::credentials& creds, fmt::format_context& ctxt) const { + if (creds._logctx.empty()) { + return fmt::format_to(ctxt.out(), "{}", creds.get_name()); + } else { + return fmt::format_to(ctxt.out(), "{}/{}", creds._logctx, creds.get_name()); + } + } +}; \ No newline at end of file diff --git a/utils/azure/identity/default_credentials.cc b/utils/azure/identity/default_credentials.cc new file mode 100644 index 000000000000..62a6fd33dcc7 --- /dev/null +++ b/utils/azure/identity/default_credentials.cc @@ -0,0 +1,146 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#include + +#include "exceptions.hh" +#include "default_credentials.hh" +#include "azure_cli_credentials.hh" +#include "managed_identity_credentials.hh" +#include "service_principal_credentials.hh" + +namespace azure { + +default_credentials::default_credentials(const source_set& sources, + const sstring& imds_endpoint, const sstring& truststore, + const sstring& priority_string, const sstring& logctx) + : credentials(logctx) + , _sources(sources) + , _truststore(truststore) + , _priority_string(priority_string) + , _imds_endpoint(imds_endpoint) +{} + +future<> default_credentials::refresh(const resource_type& resource_uri) { + if (!_creds) { + co_await detect(resource_uri); + } + SCYLLA_ASSERT(_creds); + _token = co_await _creds->get_access_token(resource_uri); +} + +future<> default_credentials::detect(const resource_type& resource_uri) { + if (_creds) { + co_return; + } + if (_sources.contains()) { + az_creds_logger.debug("[{}] Detecting credentials in environment", *this); + if (auto creds = co_await get_credentials_from_env(resource_uri)) { + az_creds_logger.debug("[{}] Credentials found in environment!", *this); + _creds = std::move(*creds); + co_return; + } + } + if (_sources.contains()) { + az_creds_logger.debug("[{}] Detecting credentials in CLI", *this); + if (auto creds = co_await get_credentials_from_azure_cli(resource_uri)) { + az_creds_logger.debug("[{}] Credentials found in CLI!", *this); + _creds = std::move(*creds); + co_return; + } + } + if (_sources.contains()) { + az_creds_logger.debug("[{}] Detecting credentials in IMDS", *this); + if (auto creds = co_await get_credentials_from_imds(resource_uri)) { + az_creds_logger.debug("[{}] Credentials found in IMDS!", *this); + _creds = std::move(*creds); + co_return; + } + } + throw auth_error("No credentials found in any source."); +} + +future default_credentials::get_credentials_from_env(const resource_type& resource_uri) { + auto tenant_id = std::getenv("AZURE_TENANT_ID"); + auto client_id = std::getenv("AZURE_CLIENT_ID"); + auto client_secret = std::getenv("AZURE_CLIENT_SECRET"); + auto client_certificate_path = std::getenv("AZURE_CLIENT_CERTIFICATE_PATH"); + auto authority = std::getenv("AZURE_AUTHORITY_HOST"); + auto creds_found = tenant_id || client_id || client_secret || client_certificate_path; + auto creds_complete = tenant_id && client_id && (client_secret || client_certificate_path); + if (!creds_found) { + az_creds_logger.debug("[{}] No credentials found in environment", *this); + co_return std::nullopt; + } else if (!creds_complete) { + az_creds_logger.debug("[{}] Incomplete credentials. Both 'AZURE_TENANT_ID' and 'AZURE_CLIENT_ID', and at least one of 'AZURE_CLIENT_SECRET', 'AZURE_CLIENT_CERTIFICATE_PATH' must be provided. Currently:", *this); + az_creds_logger.debug("[{}] Tenant ID is {}set", *this, tenant_id ? "" : "NOT "); + az_creds_logger.debug("[{}] Client ID is {}set", *this, client_id ? "" : "NOT "); + az_creds_logger.debug("[{}] Client Secret is {}set", *this, client_secret ? "" : "NOT "); + az_creds_logger.debug("[{}] Client Certificate is {}set", *this, client_certificate_path ? "" : "NOT "); + az_creds_logger.debug("[{}] Authority host is {}set", *this, authority ? "" : "NOT "); + co_return std::nullopt; + } + auto creds = std::make_unique( + tenant_id, + client_id, + client_secret ? client_secret : "", + client_certificate_path ? client_certificate_path : "", + authority ? authority : "", + _truststore, + _priority_string, + _logctx); + try { + co_await creds->get_access_token(resource_uri); + } catch (auth_error& e) { + az_creds_logger.debug("[{}] Failed to obtain token from environment: {}", *this, e.what()); + co_return std::nullopt; + } + co_return creds; +} + +future default_credentials::get_credentials_from_azure_cli(const resource_type& resource_uri) { + auto creds = std::make_unique(_logctx); + try { + co_await creds->get_access_token(resource_uri); + } catch (auth_error& e) { + az_creds_logger.debug("[{}] Failed to obtain token from Azure CLI: {}", *this, e.what()); + co_return std::nullopt; + } + co_return creds; +} + +/** + * @brief Get Managed Identity credentials from Azure Instance Metadata Service (IMDS). + * + * Assume the node is an Azure VM, and try to query the IMDS token endpoint. + * + * If it's not an Azure VM, the service will be unreachable and + * `get_access_token()` will throw a `timed_out_error`. + * + * If it is an Azure VM, but the node has none or multiple user-assigned + * managed identities, the request will fail with a 400 error. + * + * If it is an Azure VM and it has a system-assigned managed identity, or exactly + * one user-assigned managed identity, the request will succeed. + */ +future default_credentials::get_credentials_from_imds(const resource_type& resource_uri) { + auto creds = std::make_unique(_imds_endpoint, _logctx); + try { + co_await creds->get_access_token(resource_uri); + } catch (timed_out_error&) { + az_creds_logger.debug("[{}] Failed to connect to IMDS.", *this); + co_return std::nullopt; + } catch (auth_error& e) { + az_creds_logger.debug("[{}] Got unexpected return from IMDS: {}", *this, e.what()); + co_return std::nullopt; + } + co_return creds; +} + +} \ No newline at end of file diff --git a/utils/azure/identity/default_credentials.hh b/utils/azure/identity/default_credentials.hh new file mode 100644 index 000000000000..5b7335b34291 --- /dev/null +++ b/utils/azure/identity/default_credentials.hh @@ -0,0 +1,63 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#pragma once + +#include "enum_set.hh" +#include "credentials.hh" + +namespace azure { + +class default_credentials : public credentials { +public: + enum class source : uint8_t { + Env, + AzureCli, + Imds, + }; + using source_set = enum_set>; + static constexpr source_set all_sources = source_set::full(); + + default_credentials(const source_set& sources = all_sources, + const sstring& imds_endpoint = "", + const sstring& truststore = "", + const sstring& priority_string = "", + const sstring& logctx = ""); +private: + static constexpr char NAME[] = "DefaultCredentials"; + source_set _sources; + std::unique_ptr _creds; + // TLS options. + sstring _truststore; + sstring _priority_string; + + sstring _imds_endpoint; + + std::string_view get_name() const override { return NAME; }; + future<> refresh(const resource_type& resource_uri) override; + future<> detect(const resource_type& resource_uri); + + using credentials_opt = std::optional>; + future get_credentials_from_env(const resource_type& resource_uri); + future get_credentials_from_azure_cli(const resource_type& resource_uri); + future get_credentials_from_imds(const resource_type& resource_uri); +}; + +} + +template <> +struct fmt::formatter { + constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } + auto format(const azure::default_credentials& creds, fmt::format_context& ctxt) const { + return fmt::format_to(ctxt.out(), "{}", *dynamic_cast(&creds)); + } +}; \ No newline at end of file diff --git a/utils/azure/identity/exceptions.hh b/utils/azure/identity/exceptions.hh new file mode 100644 index 000000000000..92a89e5f87fb --- /dev/null +++ b/utils/azure/identity/exceptions.hh @@ -0,0 +1,94 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#pragma once + +#include + +#include "utils/rjson.hh" + +namespace azure { + +class auth_error : public std::exception { + std::string _msg; +public: + auth_error(std::string_view msg) : _msg(msg) {} + const char* what() const noexcept override { + return _msg.c_str(); + } +}; + +// Exception representing an error from Azure Entra or IMDS. +class creds_auth_error : public auth_error { + // The response body may contain an 'error' and 'error_description'. + // https://learn.microsoft.com/en-us/entra/identity-platform/reference-error-codes + // https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#error-handling + static constexpr char ERROR_CODE_KEY[] = "error"; + static constexpr char ERROR_DESCRIPTION_KEY[] = "error_description"; + http::reply::status_type _status; + std::string _code; +public: + creds_auth_error(const http::reply::status_type& status) + : auth_error(fmt::format("{} ({})", status, fmt::to_string(status))) + , _status(status) + , _code("") + {} + creds_auth_error(const http::reply::status_type& status, std::string_view code, std::string_view desc) + : auth_error(fmt::format("{}: {}", code, desc)) + , _status(status) + , _code(code) + {} + const http::reply::status_type& status() const { + return _status; + } + const std::string& code() const { + return _code; + } + static creds_auth_error make_error(const http::reply::status_type& status, std::string_view result) { + const auto& jres = rjson::try_parse(result); + if (!jres) { + return creds_auth_error(status); + } + const auto& code = rjson::get_opt(*jres, ERROR_CODE_KEY); + const auto& desc = rjson::get_opt(*jres, ERROR_DESCRIPTION_KEY); + if (code && desc) { + return creds_auth_error(status, *code, *desc); + } + return creds_auth_error(status); + } +}; + +// Common error codes from Microsoft Entra: +// https://learn.microsoft.com/en-us/entra/identity-platform/reference-error-codes#handling-error-codes-in-your-application +namespace entra_errors { + [[maybe_unused]] static constexpr char InvalidRequest[] = "invalid_request"; + [[maybe_unused]] static constexpr char InvalidGrant[] = "invalid_grant"; + [[maybe_unused]] static constexpr char UnauthorizedClient[] = "unauthorized_client"; + [[maybe_unused]] static constexpr char InvalidClient[] = "invalid_client"; + [[maybe_unused]] static constexpr char UnsupportedGrantType[] = "unsupported_grant_type"; + [[maybe_unused]] static constexpr char InvalidResource[] = "invalid_resource"; + [[maybe_unused]] static constexpr char InteractionRequired[] = "interaction_required"; + [[maybe_unused]] static constexpr char TemporarilyUnavailable[] = "temporarily_unavailable"; +} + +// Common error codes from IMDS: +// https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#http-response-reference +namespace imds_errors { + [[maybe_unused]] static constexpr char InvalidResource[] = "invalid_resource"; + [[maybe_unused]] static constexpr char BadRequest102[] = "bad_request_102"; + [[maybe_unused]] static constexpr char UnknownSource[] = "unknown_source"; + [[maybe_unused]] static constexpr char InvalidRequest[] = "invalid_request"; + [[maybe_unused]] static constexpr char UnauthorizedClient[] = "unauthorized_client"; + [[maybe_unused]] static constexpr char AccessDenied[] = "access_denied"; + [[maybe_unused]] static constexpr char UnsupportedResponseType[] = "unsupported_response_type"; + [[maybe_unused]] static constexpr char InvalidScope[] = "invalid_scope"; + [[maybe_unused]] static constexpr char Unknown[] = "unknown"; +} + +} \ No newline at end of file diff --git a/utils/azure/identity/managed_identity_credentials.cc b/utils/azure/identity/managed_identity_credentials.cc new file mode 100644 index 000000000000..8533b61eb592 --- /dev/null +++ b/utils/azure/identity/managed_identity_credentials.cc @@ -0,0 +1,194 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#include + +#include +#include +#include +#include + +#include "utils/rest/client.hh" +#include "exceptions.hh" +#include "managed_identity_credentials.hh" + +namespace azure { + +class mi_log_filter : public rest::nop_log_filter { +public: + string_opt filter_body(body_type type, std::string_view body) const override { + if (type == body_type::response && !body.empty()) { + auto j = rjson::parse(body); + auto val = rjson::find(j, "access_token"); + if (val) { + val->SetString(REDACTED_VALUE); + return rjson::print(j); + } + } + return std::nullopt; + } +}; + +managed_identity_credentials::managed_identity_credentials(const sstring& endpoint, const sstring& logctx) + : credentials(logctx) +{ + if (endpoint.empty()) { + return; + } + // Regex for the IMDS endpoint. + // Expected format: [http://][:port] + static const boost::regex uri_pattern(R"((?:http://)?([^/:]+)(?::(\d+))?)"); + boost::match_results match; + std::string_view endpoint_view{endpoint}; + if (boost::regex_match(endpoint_view.begin(), endpoint_view.end(), match, uri_pattern)) { + _host = match[1].str(); + if (match[2].matched) { + _port = std::stoi(match[2].str()); + } + } else { + throw std::invalid_argument(fmt::format("Invalid IMDS endpoint '{}'. Expected format: [http://][:port]", endpoint)); + } +} + +access_token managed_identity_credentials::make_token(const rjson::value& json, const resource_type& resource_uri) { + auto token = rjson::get(json, "access_token"); + auto expires_in_str = rjson::get(json, "expires_in"); + if (auto expires_in_int = std::atoi(expires_in_str.c_str())) { + return { token, timeout_clock::now() + std::chrono::seconds(expires_in_int), resource_uri }; + } + throw std::runtime_error(seastar::format("Invalid expires_in value: {}", expires_in_str)); +} + +/** + * @brief Retries for transient errors. + * + * Retries are performed for 404, 429, and 5xx errors, using an exponential backoff strategy. + * Three retries are attempted in total. + * The latencies between retries are 0, 100, and 300 milliseconds. + * + * Based on the official Azure docs, but with a reduced number of retries and backoff delay + * to prioritize responsiveness (longer transient errors should be handled by upper layers): + * https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#retry-guidance + * + * @param func A callable that returns a future<>, representing the asynchronous operation to retry. + * @return A future<> that resolves to the result of the operation if successful, or propagates the error if all retries fail. + */ +future<> managed_identity_credentials::with_retries(std::function()> func) { + constexpr int MAX_RETRIES = 3; + constexpr std::chrono::milliseconds DELTA_BACKOFF {100}; + std::chrono::milliseconds backoff; + + int retries = 0; + while (true) { + try { + co_return co_await func(); + } catch (const creds_auth_error& e) { + auto status = e.status(); + bool should_retry = + status == http::reply::status_type::not_found || + status == http::reply::status_type::too_many_requests || + http::reply::classify_status(status) == http::reply::status_class::server_error; + + if (retries >= MAX_RETRIES || !should_retry) { + throw; + } + + backoff = DELTA_BACKOFF * ((1 << retries) - 1); + az_creds_logger.debug("[{}] Token request failed with status {}. Reason: {}. Retrying in {} ms...", + *this, static_cast(status), e.what(), backoff.count()); + + retries++; + } + co_await seastar::sleep(backoff); + } +} + +/** + * @brief Check connectivity to the IMDS endpoint + * + * Azure offers no node-local indicator to determine if the node is an Azure VM. + * The only way to check is to try to connect to the IMDS endpoint. + * + * This method attempts to establish a TCP connection to the IMDS endpoint and + * assumes the service is unreachable after 3 seconds. + * + * Ideally, we should be able to do that via the rest HTTP client, but it + * offers no timeout or cancellation API, so we would have to wait for the + * system timeout (connect(2) ETIMEDOUT), which can take several minutes. + * + * @throws timed_out_error if the connection attempt times out. + */ +future<> managed_identity_credentials::check_connectivity() { + az_creds_logger.debug("[{}] Checking connectivity to IMDS endpoint", *this); + + const auto timeout = std::chrono::seconds(3); + auto sock = make_socket(); + auto addr = co_await net::dns::resolve_name(_host, net::inet_address::family::INET); + gate g; + std::exception_ptr ex; + + try { + co_await with_timeout(timer<>::clock::now() + timeout, with_gate(g, [&sock, addr, this]() { + return sock.connect(socket_address(addr, static_cast(_port))); + })); + } catch (const timed_out_error&) { + ex = std::current_exception(); + } + + sock.shutdown(); + co_await g.close(); + + if (ex) { + std::rethrow_exception(ex); + } +} + +/** + * Token request from IMDS. + * https://learn.microsoft.com/en-us/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http + */ +future<> managed_identity_credentials::refresh(const resource_type& resource_uri) { + if (!_connectivity_checked) { + co_await check_connectivity(); + _connectivity_checked = true; + } + + az_creds_logger.debug("[{}] Refreshing token", *this); + + co_await with_retries([this, &resource_uri] () -> future<> { + static const mi_log_filter filter{}; + + const auto op = httpd::operation_type::GET; + const auto path = seastar::format(IMDS_TOKEN_PATH_TEMPLATE, IMDS_API_VERSION, resource_uri); + + rest::httpclient client{_host, _port, nullptr, std::nullopt}; + client.target(path); + client.method(op); + client.add_header("Metadata", "true"); + + if (az_creds_logger.is_enabled(log_level::trace)) { + az_creds_logger.trace("[{}] Sending request: {}", *this, rest::redacted_request_type{ client.request(), filter }); + } + + auto res = co_await client.send(); + if (res.result() == http::reply::status_type::ok) { + if (az_creds_logger.is_enabled(log_level::trace)) { + az_creds_logger.trace("[{}] Got response: {}", *this, rest::redacted_result_type{ res, filter }); + } + } else { + if (az_creds_logger.is_enabled(log_level::trace)) { + az_creds_logger.trace("[{}] Got unexpected response: {}", *this, rest::redacted_result_type{ res, filter }); + } + throw creds_auth_error::make_error(res.result(), res.body()); + } + _token = make_token(rjson::parse(res.body()), resource_uri); + }); +} + +} \ No newline at end of file diff --git a/utils/azure/identity/managed_identity_credentials.hh b/utils/azure/identity/managed_identity_credentials.hh new file mode 100644 index 000000000000..7885602e66ad --- /dev/null +++ b/utils/azure/identity/managed_identity_credentials.hh @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#pragma once + +#include "utils/rjson.hh" +#include "credentials.hh" + +namespace azure { + +class managed_identity_credentials : public credentials { + static constexpr char NAME[] = "ManagedIdentityCredentials"; + static constexpr char IMDS_HOST[] = "169.254.169.254"; + static constexpr char IMDS_API_VERSION[] = "2018-02-01"; + static constexpr char IMDS_TOKEN_PATH_TEMPLATE[] = "/metadata/identity/oauth2/token?api-version={}&resource={}"; + + sstring _host{IMDS_HOST}; + unsigned _port{80}; + bool _connectivity_checked{false}; + + std::string_view get_name() const override { return NAME; }; + future<> with_retries(std::function()>); + future<> check_connectivity(); + future<> refresh(const resource_type& resource_uri) override; + access_token make_token(const rjson::value&, const resource_type&); +public: + managed_identity_credentials(const sstring& endpoint = "", const sstring& logctx = ""); +}; + +} + +template <> +struct fmt::formatter { + constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } + auto format(const azure::managed_identity_credentials& creds, fmt::format_context& ctxt) const { + return fmt::format_to(ctxt.out(), "{}", *dynamic_cast(&creds)); + } +}; \ No newline at end of file diff --git a/utils/azure/identity/service_principal_credentials.cc b/utils/azure/identity/service_principal_credentials.cc new file mode 100644 index 000000000000..5c03fa1bb1be --- /dev/null +++ b/utils/azure/identity/service_principal_credentials.cc @@ -0,0 +1,334 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#define CPP_JWT_USE_VENDORED_NLOHMANN_JSON +#include +#include + +#include +#include + +#include "db/config.hh" +#include "types/types.hh" +#include "utils/base64.hh" +#include "utils/hashers.hh" +#include "utils/rest/client.hh" +#include "exceptions.hh" +#include "service_principal_credentials.hh" + +namespace azure { + +class sp_log_filter : public rest::nop_log_filter { +public: + string_opt filter_body(body_type type, std::string_view body) const override { + if (type == body_type::request) { + for (const auto& param : {"client_secret=", "client_assertion="}) { + size_t start_pos = body.find(param); + if (start_pos == std::string_view::npos) { + continue; + } + size_t value_start_pos = start_pos + std::string(param).length(); + size_t end_pos = body.find('&', value_start_pos); + if (end_pos == std::string_view::npos) { + end_pos = body.length(); + } + return fmt::format("{}{}{}", + body.substr(0, value_start_pos), + REDACTED_VALUE, + body.substr(end_pos)); + } + } else if (type == body_type::response) { + if (!body.empty()) { + auto j = rjson::parse(body); + auto val = rjson::find(j, "access_token"); + if (val) { + val->SetString(REDACTED_VALUE); + return rjson::print(j); + } + } + } + return std::nullopt; + } +}; + +service_principal_credentials::service_principal_credentials(const sstring& tenant_id, + const sstring& client_id, const sstring& client_secret, const sstring& client_cert, + const sstring& authority, const sstring& truststore, const sstring& priority_string, + const sstring& logctx) + : credentials(logctx) + , _tenant_id(tenant_id) + , _client_id(client_id) + , _client_secret(client_secret) + , _client_cert(client_cert) + , _truststore(truststore) + , _priority_string(priority_string) +{ + if (authority.empty()) { + return; + } + // Regex for the authentication authority URL. + // Expected format: [http(s)://][:port] + static const boost::regex uri_pattern(R"((?:(https?)://)?([^/:]+)(?::(\d+))?)"); + boost::match_results match; + std::string_view authority_view{authority}; + if (boost::regex_match(authority_view.begin(), authority_view.end(), match, uri_pattern)) { + if (match[1].matched) { + _is_secured = (match[1].str() == "https"); + } + _host = match[2].str(); + if (match[3].matched) { + _port = std::stoi(match[3].str()); + } + } else { + throw std::invalid_argument(fmt::format("Invalid authentication authority URL '{}'. Expected format: [http(s)://][:port]", authority)); + } +} + +static future<::shared_ptr> make_creds(const sstring& truststore, const sstring& priority_string) { + auto creds = seastar::make_shared(); + if (!priority_string.empty()) { + creds->set_priority_string(priority_string); + } else { + creds->set_priority_string(db::config::default_tls_priority); + } + if (!truststore.empty()) { + co_await creds->set_x509_trust_file(truststore, seastar::tls::x509_crt_format::PEM); + } else { + co_await creds->set_system_trust(); + } + co_return creds; +} + +/** + * @brief Retries for transient errors. + * + * Retries are performed for 408, 429, 500, 502, 503, and 504 errors, using an exponential backoff strategy. + * Three retries are attempted in total. + * The latencies between retries are 0, 100, and 300 milliseconds. + * + * The error codes are based on the generic retry policy of the Azure C++ SDK: + * https://github.com/Azure/azure-sdk-for-cpp/blob/126452efd30860263398a152f11f337007f529f4/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp#L133 + * + * The exponential backoff strategy follows the one for the `managed_identity_credentials`. + * + * @param func The asynchronous function to execute and retry on failure. + * @return The result of the function if successful, or throws if all retries fail or a non-retriable error occurs. + */ +future service_principal_credentials::with_retries(std::function()> func) { + constexpr int MAX_RETRIES = 3; + constexpr std::chrono::milliseconds DELTA_BACKOFF {100}; + std::chrono::milliseconds backoff; + + int retries = 0; + while (true) { + try { + co_return co_await func(); + } catch (const creds_auth_error& e) { + auto status = e.status(); + bool should_retry = + status == http::reply::status_type::request_timeout || + status == http::reply::status_type::too_many_requests || + status == http::reply::status_type::internal_server_error || + status == http::reply::status_type::bad_gateway || + status == http::reply::status_type::service_unavailable || + status == http::reply::status_type::gateway_timeout; + + if (retries >= MAX_RETRIES || !should_retry) { + throw; + } + + backoff = DELTA_BACKOFF * ((1 << retries) - 1); + az_creds_logger.debug("[{}] Token request failed with status {}. Reason: {}. Retrying in {} ms...", + *this, static_cast(status), e.what(), backoff.count()); + + retries++; + } + co_await seastar::sleep(backoff); + } +} + +future service_principal_credentials::post(const sstring& body) { + static const sp_log_filter filter{}; + const auto op = httpd::operation_type::POST; + const auto path = seastar::format(AZURE_ENTRA_ID_TOKEN_PATH_TEMPLATE, _tenant_id); + const auto mime_type = MIME_TYPE; + + shared_ptr creds; + std::optional options; + + if (_is_secured) { + creds = co_await make_creds(_truststore, _priority_string); + options = { .wait_for_eof_on_shutdown = false, .server_name = _host }; + } + + rest::httpclient client{_host, _port, std::move(creds), options}; + client.target(path); + client.method(op); + client.add_header("Content-Type", mime_type); + client.content(std::move(body)); + + if (az_creds_logger.is_enabled(log_level::trace)) { + az_creds_logger.trace("[{}] Sending request: {}", *this, rest::redacted_request_type{ client.request(), filter }); + } + + auto res = co_await client.send(); + if (res.result() == http::reply::status_type::ok) { + if (az_creds_logger.is_enabled(log_level::trace)) { + az_creds_logger.trace("[{}] Got response: {}", *this, rest::redacted_result_type{ res, filter }); + } + } else { + if (az_creds_logger.is_enabled(log_level::trace)) { + az_creds_logger.trace("[{}] Got unexpected response: {}", *this, rest::redacted_result_type{ res, filter }); + } + throw creds_auth_error::make_error(res.result(), res.body()); + } + co_return res.body(); +} + +access_token service_principal_credentials::make_token(const rjson::value& json, const resource_type& resource_uri) { + auto token = rjson::get(json, "access_token"); + auto expires_in = timeout_clock::now() + std::chrono::seconds(rjson::get(json, "expires_in")); + return { token, expires_in, resource_uri }; +} + +future<> service_principal_credentials::refresh(const resource_type& resource_uri) { + az_creds_logger.debug("[{}] Refreshing token for principal: {}", *this, _client_id); + if (_client_secret != "") { + co_await refresh_with_secret(resource_uri); + } else { + co_await refresh_with_certificate(resource_uri); + } +} + +/** + * Token request with secret. + * Based on: https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-client-creds-grant-flow#first-case-access-token-request-with-a-shared-secret + */ +future<> service_principal_credentials::refresh_with_secret(const resource_type& resource_uri) { + // Scopes for the client credentials flow must contain only one resource + // identifier and only the .default scope. + auto scope = seastar::format("{}/.default", resource_uri); + sstring grant_type = "client_credentials"; + sstring body = seastar::format( + "client_id={}&scope={}&client_secret={}&grant_type={}", + _client_id, + seastar::http::internal::url_encode(scope), + seastar::http::internal::url_encode(_client_secret), + grant_type); + auto resp = co_await with_retries([&] { return post(body); }); + _token = make_token(rjson::parse(resp), resource_uri); +} + +static future read_pem_artifact(const sstring& key_file, const sstring& header) { + auto contents = co_await util::read_entire_file_contiguous(std::filesystem::path(key_file.c_str())); + std::istringstream stream(contents); + std::string line; + std::string key_data; + bool in_key = false; + while (std::getline(stream, line)) { + if (line.find(header) != std::string::npos) { + in_key = !in_key; + key_data += line + "\n"; + } else if (in_key) { + key_data += line + "\n"; + } + } + co_return key_data; +} + +/** + * Argument is expected to be a PEM encoded certificate. + * Thumbprint computed as the base64url-encoded SHA-256 hash of the certificate's DER encoding. + * https://learn.microsoft.com/en-us/entra/identity-platform/certificate-credentials#header + * https://datatracker.ietf.org/doc/html/rfc7517#section-4.9 + */ +std::string compute_thumbprint(const std::string& pem_cert) { + BIO* bio = BIO_new_mem_buf(pem_cert.data(), pem_cert.size()); + if (!bio) { + on_internal_error(az_creds_logger, "Error creating BIO object"); + } + + X509* cert = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr); + BIO_free(bio); + if (!cert) { + on_internal_error(az_creds_logger, "Error reading certificate from memory"); + } + + // Convert to DER format + unsigned char* der = nullptr; + int der_len = i2d_X509(cert, &der); + X509_free(cert); + if (der_len < 0) { + on_internal_error(az_creds_logger, "Error converting certificate to DER"); + } + + std::string thumbprint(reinterpret_cast(der), der_len); + OPENSSL_free(der); + bytes sha = [&thumbprint] { + sha256_hasher hasher; + hasher.update(thumbprint.data(), thumbprint.size()); + return hasher.finalize(); + }(); + return base64url_encode(sha); +} + +/** + * Token request with certificate. + * Based on: https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-client-creds-grant-flow#second-case-access-token-request-with-a-certificate + */ +future<> service_principal_credentials::refresh_with_certificate(const resource_type& resource_uri) { + // Scopes for the client credentials flow must contain only one resource + // identifier and only the .default scope. + auto scope = seastar::format("{}/.default", resource_uri); + sstring grant_type = "client_credentials"; + sstring client_assertion_type = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; + std::string private_key; + try { + private_key = co_await read_pem_artifact(_client_cert, PEM_STRING_PKCS8INF); + } catch (const std::exception& e) { + throw std::runtime_error(seastar::format("Failed to read private key from certificate file {}: {}", _client_cert, e.what())); + } + if (private_key.empty()) { + throw std::invalid_argument(seastar::format("Private key not found in certificate file {}", _client_cert)); + } + std::string cert = co_await read_pem_artifact(_client_cert, PEM_STRING_X509); + if (cert.empty()) { + throw std::invalid_argument(seastar::format("Certificate not found in certificate file {}", _client_cert)); + } + auto alg = "RS256"; // docs suggest PS256, but it's not supported by jwt-cpp + auto thumbprint = compute_thumbprint(cert); + + using namespace jwt::params; + jwt::jwt_object obj{algorithm(alg), secret(private_key), headers({{"x5t#S256", thumbprint }})}; + + const auto host = AZURE_ENTRA_ID_HOST; + const auto path = seastar::format(AZURE_ENTRA_ID_TOKEN_PATH_TEMPLATE, _tenant_id); + auto uri = seastar::format("https://{}{}", host, path); + using jwt_id = utils::tagged_uuid; + obj.add_claim("aud", uri) + .add_claim("exp", timeout_clock::now() + std::chrono::minutes(10)) + .add_claim("iss", _client_id) + .add_claim("jti", jwt_id::create_random_id().to_sstring()) + .add_claim("nbf", timeout_clock::now()) + .add_claim("sub", _client_id) + .add_claim("iat", timeout_clock::now()) + ; + auto sign = obj.signature(); + sstring body = seastar::format( + "client_id={}&scope={}&client_assertion_type={}&client_assertion={}&grant_type={}", + _client_id, + seastar::http::internal::url_encode(scope), + seastar::http::internal::url_encode(client_assertion_type), + sign, + grant_type); + auto resp = co_await with_retries([&] { return post(body); }); + _token = make_token(rjson::parse(resp), resource_uri); +} + +} \ No newline at end of file diff --git a/utils/azure/identity/service_principal_credentials.hh b/utils/azure/identity/service_principal_credentials.hh new file mode 100644 index 000000000000..740c9a8c797d --- /dev/null +++ b/utils/azure/identity/service_principal_credentials.hh @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2025 ScyllaDB + * + */ + +/* + * SPDX-License-Identifier: LicenseRef-ScyllaDB-Source-Available-1.0 + */ + +#pragma once + +#include "utils/rjson.hh" +#include "credentials.hh" + +namespace azure { + +class service_principal_credentials : public credentials { + static constexpr char NAME[] = "ServicePrincipalCredentials"; + static constexpr char AZURE_ENTRA_ID_HOST[] = "login.microsoftonline.com"; + static constexpr char AZURE_ENTRA_ID_TOKEN_PATH_TEMPLATE[] = "/{}/oauth2/v2.0/token"; + static constexpr char MIME_TYPE[] = "application/x-www-form-urlencoded"; + sstring _tenant_id; + sstring _client_id; + sstring _client_secret; + sstring _client_cert; + // TLS options + sstring _truststore; + sstring _priority_string; + // Endpoint + sstring _host{AZURE_ENTRA_ID_HOST}; + unsigned _port{443}; + bool _is_secured{true}; + + std::string_view get_name() const override { return NAME; }; + future post(const sstring& body); + future with_retries(std::function()>); + future<> refresh(const resource_type& resource_uri) override; + future<> refresh_with_secret(const resource_type& resource_uri); + future<> refresh_with_certificate(const resource_type& resource_uri); + access_token make_token(const rjson::value&, const resource_type&); +public: + service_principal_credentials(const sstring& tenant_id, const sstring& client_id, + const sstring& client_secret, const sstring& client_cert, + const sstring& authority, + const sstring& truststore = "", const sstring& priority_string = "", + const sstring& logctx = ""); +}; + +} + +template <> +struct fmt::formatter { + constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } + auto format(const azure::service_principal_credentials& creds, fmt::format_context& ctxt) const { + return fmt::format_to(ctxt.out(), "{}", *dynamic_cast(&creds)); + } +}; \ No newline at end of file diff --git a/utils/base64.cc b/utils/base64.cc index 11523f6c5618..1116370a11c5 100644 --- a/utils/base64.cc +++ b/utils/base64.cc @@ -132,3 +132,36 @@ bool base64_begins_with(std::string_view base, std::string_view operand) { const std::string operand_remainder = base64_decode_string(operand.substr(operand.size() - 4)); return base_remainder.starts_with(operand_remainder); } + +std::string base64url_encode(bytes_view in) { + std::string str = base64_encode(in); + for (char& c : str) { + if (c == '+') { + c = '-'; + } else if (c == '/') { + c = '_'; + } + } + str.erase(std::find(str.begin(), str.end(), '='), str.end()); + return str; +} + +bytes base64url_decode(std::string_view in) { + std::string str{in}; + size_t mod = str.size() % 4; + if (mod == 1) { + std::invalid_argument(seastar::format("Base64 encoded length is invalid: {}", str.size())); + } else if (mod == 2) { + str.append("==", 2); + } else if (mod == 3) { + str.append("=", 1); + } + for (char& c : str) { + if (c == '-') { + c = '+'; + } else if (c == '_') { + c = '/'; + } + } + return base64_decode(str); +} diff --git a/utils/base64.hh b/utils/base64.hh index a28c0833deb1..2090083cccef 100644 --- a/utils/base64.hh +++ b/utils/base64.hh @@ -12,8 +12,10 @@ #include "bytes_fwd.hh" std::string base64_encode(bytes_view); +std::string base64url_encode(bytes_view); bytes base64_decode(std::string_view); +bytes base64url_decode(std::string_view); size_t base64_decoded_len(std::string_view str); diff --git a/utils/rest/client.cc b/utils/rest/client.cc index 6fad04e90141..89cf350ab4db 100644 --- a/utils/rest/client.cc +++ b/utils/rest/client.cc @@ -145,4 +145,28 @@ fmt::formatter::format(const rest::httpclient::re return os; } +auto +fmt::formatter::format(const rest::redacted_request_type& rr, fmt::format_context& ctx) const -> decltype(ctx.out()) { + const auto& r = rr.original; + auto os = fmt::format_to(ctx.out(), "{} {} HTTP/{}{}", r._method, r._url, r._version, linesep); + for (auto& [k, v] : r._headers) { + os = fmt::format_to(os, "{}: {}{}", k, rr.filter_header(k, v).value_or(v), linesep); + } + os = fmt::format_to(os, "{}{}", linesep, rr.filter_body(r.content).value_or(r.content)); + return os; +} +auto +fmt::formatter::format(const rest::redacted_result_type& rr, fmt::format_context& ctx) const -> decltype(ctx.out()) { + const auto& r = rr.original; + auto s = r.reply.response_line(); + // remove the trailing \r\n from response_line string. we want our own linebreak, hence substr. + auto os = fmt::format_to(ctx.out(), "{}{}", std::string_view(s).substr(0, s.size()-2), linesep); + for (auto& [k, v] : r.reply._headers) { + os = fmt::format_to(os, "{}: {}{}", k, rr.filter_header(k, v).value_or(v), linesep); + } + auto redacted_body_opt = rr.filter_body(r.body()); + auto redacted_body_view = redacted_body_opt.has_value() ? *redacted_body_opt : r.body(); + os = fmt::format_to(os, "{}{}", linesep, redacted_body_view); + return os; +} diff --git a/utils/rest/client.hh b/utils/rest/client.hh index d5591e76fdd1..24e59ddf78f7 100644 --- a/utils/rest/client.hh +++ b/utils/rest/client.hh @@ -75,6 +75,45 @@ private: request_type _req; }; +// Interface for redacting sensitive data from HTTP requests and responses before logging. +class http_log_filter { +public: + static constexpr char REDACTED_VALUE[] = "[REDACTED]"; + enum class body_type { + request, + response, + }; + using string_opt = std::optional; + // Filter a request/response header. + // Returns an optional containing the filtered value. If no filtering is required, the optional is not engaged. + virtual string_opt filter_header(std::string_view name, std::string_view value) const = 0; + // Filter the request/response body. + // Returns an optional containing the filtered value. If no filtering is required, the optional is not engaged. + virtual string_opt filter_body(body_type type, const std::string_view body) const = 0; +}; + +class nop_log_filter: public http_log_filter { +public: + virtual string_opt filter_header(std::string_view name, std::string_view value) const { return std::nullopt; } + virtual string_opt filter_body(body_type type, std::string_view body) const { return std::nullopt; } +}; + +template +struct redacted { + const T& original; + const http_log_filter& filter; + + std::optional filter_header(std::string_view name, std::string_view value) const { + return filter.filter_header(name, value); + } + std::optional filter_body(std::string_view value) const { + return filter.filter_body(type, value); + } +}; + +using redacted_request_type = redacted; +using redacted_result_type = redacted; + } template <> @@ -86,3 +125,13 @@ template <> struct fmt::formatter : fmt::formatter { auto format(const rest::httpclient::result_type&, fmt::format_context& ctx) const -> decltype(ctx.out()); }; + +template <> +struct fmt::formatter : fmt::formatter { + auto format(const rest::redacted_request_type&, fmt::format_context& ctx) const -> decltype(ctx.out()); +}; + +template <> +struct fmt::formatter : fmt::formatter { + auto format(const rest::redacted_result_type&, fmt::format_context& ctx) const -> decltype(ctx.out()); +}; \ No newline at end of file diff --git a/validation.cc b/validation.cc index e90ba68bd72e..3fe89bf07b3f 100644 --- a/validation.cc +++ b/validation.cc @@ -10,7 +10,7 @@ #include "validation.hh" #include "schema/schema.hh" -#include "keys.hh" +#include "keys/keys.hh" #include "data_dictionary/data_dictionary.hh" #include "exceptions/exceptions.hh"