|
| 1 | +//! Interactions with the AZURE KEY VAULT SDK |
| 2 | +
|
| 3 | +use std::sync::Arc; |
| 4 | + |
| 5 | +use azure_identity::DefaultAzureCredential; |
| 6 | +use azure_security_keyvault_keys::{ |
| 7 | + KeyClient, |
| 8 | + models::{KeyOperationsParameters, JsonWebKeyEncryptionAlgorithm}, |
| 9 | +}; |
| 10 | +use crate::{consts, metrics}; |
| 11 | +use base64::Engine; |
| 12 | + |
| 13 | +use std::time::Instant; |
| 14 | +use common_utils::errors::CustomResult; |
| 15 | +use error_stack::{report, ResultExt}; |
| 16 | +use router_env::logger; |
| 17 | + |
| 18 | + |
| 19 | +/// Configuration parameters required for constructing a [`AzureKeyVaultClient`]. |
| 20 | +#[derive(Clone, Debug, Default, serde::Deserialize)] |
| 21 | +#[serde(default)] |
| 22 | +pub struct AzureKeyVaultConfig { |
| 23 | + /// key name of Azure Key vault used to encrypt or decrypt data |
| 24 | + pub key_name: String, |
| 25 | + /// The Azure vault url of the Key vault. |
| 26 | + pub vault_url: String, |
| 27 | + /// version of the key name |
| 28 | + pub version: String, |
| 29 | +} |
| 30 | + |
| 31 | +impl AzureKeyVaultConfig { |
| 32 | + /// Verifies that the [`AzureKeyVaultClient`] configuration is usable. |
| 33 | + pub fn validate(&self) -> Result<(), &'static str> { |
| 34 | + use common_utils::{ext_traits::ConfigExt, fp_utils::when}; |
| 35 | + |
| 36 | + when(self.key_name.is_default_or_empty(), || { |
| 37 | + Err("Azure Key Vault key name must not be empty") |
| 38 | + })?; |
| 39 | + |
| 40 | + when(self.vault_url.is_default_or_empty(), || { |
| 41 | + Err("Azure Key Vault url must not be empty") |
| 42 | + }) |
| 43 | + } |
| 44 | +} |
| 45 | + |
| 46 | +/// Client for AZURE KEY VAULT operations. |
| 47 | +#[derive(Clone)] |
| 48 | +pub struct AzureKeyVaultClient { |
| 49 | + inner_client: Arc<KeyClient>, |
| 50 | + key_name: String, |
| 51 | + version: String, |
| 52 | +} |
| 53 | + |
| 54 | +impl AzureKeyVaultClient { |
| 55 | + /// Constructs a new Azure Key Vault client. |
| 56 | + pub async fn new(config: &AzureKeyVaultConfig) -> Result<Self, AzureKeyVaultError> { |
| 57 | + let credential = DefaultAzureCredential::new() |
| 58 | + .map_err(|_| AzureKeyVaultError::AzureKeyVaultClientInitializationFailed)?; |
| 59 | + |
| 60 | + Ok(Self { |
| 61 | + inner_client: Arc::new( |
| 62 | + KeyClient::new( |
| 63 | + &config.vault_url, |
| 64 | + credential.clone(), |
| 65 | + None |
| 66 | + ) |
| 67 | + .map_err( |
| 68 | + |_| AzureKeyVaultError::AzureKeyVaultClientInitializationFailed)? |
| 69 | + ), |
| 70 | + key_name: config.key_name.clone(), |
| 71 | + version: config.version.clone(), |
| 72 | + }) |
| 73 | + } |
| 74 | + /// Decrypts the provided base64-encoded encrypted data using the AZURE KEY VAULT SDK. We assume that |
| 75 | + /// the SDK has the values required to interact with the AZURE KEY VAULT APIs (`AZURE_TENANT_ID`, |
| 76 | + /// `AZURE_CLIENT_ID` and `AZURE_CLIENT_SECRET`) either set in environment variables, or that the |
| 77 | + /// SDK is running in a machine that is able to assume an Azure AD role. |
| 78 | + pub async fn decrypt(&self, data: impl AsRef<[u8]>) -> CustomResult<String, AzureKeyVaultError> { |
| 79 | + let start = Instant::now(); |
| 80 | + |
| 81 | + let data = consts::BASE64_ENGINE |
| 82 | + .decode(data) |
| 83 | + .change_context(AzureKeyVaultError::Base64DecodingFailed)?; |
| 84 | + |
| 85 | + let decrypt_params = KeyOperationsParameters { |
| 86 | + algorithm: Some(JsonWebKeyEncryptionAlgorithm::RsaOaep), |
| 87 | + value: Some(data), |
| 88 | + ..Default::default() |
| 89 | + }; |
| 90 | + let decrypted_output = self.inner_client |
| 91 | + .decrypt(&self.key_name, &self.version , decrypt_params.clone().try_into().unwrap(), None) |
| 92 | + .await |
| 93 | + .inspect_err(|error| { |
| 94 | + logger::error!(azure_key_vault_error=?error, "Failed to Azure Key Vault decrypt data"); |
| 95 | + metrics::AZURE_KEY_VAULT_DECRYPTION_FAILURES.add(1, &[]); |
| 96 | + }) |
| 97 | + .change_context(AzureKeyVaultError::DecryptionFailed)? |
| 98 | + .into_body() |
| 99 | + .await |
| 100 | + .inspect_err(|error| { |
| 101 | + logger::error!(azure_key_vault_error=?error, "Failed to Azure Key Vault decrypt data"); |
| 102 | + metrics::AZURE_KEY_VAULT_DECRYPTION_FAILURES.add(1, &[]); |
| 103 | + }) |
| 104 | + .change_context(AzureKeyVaultError::DecryptionFailed)?; |
| 105 | + |
| 106 | + let output = decrypted_output |
| 107 | + .result |
| 108 | + .ok_or(report!(AzureKeyVaultError::MissingPlaintextDecryptionOutput)) |
| 109 | + .and_then(|bytes| |
| 110 | + String::from_utf8(bytes) |
| 111 | + .change_context(AzureKeyVaultError::Utf8DecodingFailed) |
| 112 | + )?; |
| 113 | + |
| 114 | + let time_taken = start.elapsed(); |
| 115 | + metrics::AZURE_KEY_VAULT_DECRYPT_TIME.record(time_taken.as_secs_f64(), &[]); |
| 116 | + |
| 117 | + Ok(output) |
| 118 | + } |
| 119 | + |
| 120 | + /// Encrypts the provided String using the AZURE KEY VAULT SDK and returns base64-encoded encrypted data. |
| 121 | + /// We assume that the SDK has the values required to interact with the AZURE KEY VAULT APIs (`AZURE_TENANT_ID`, |
| 122 | + /// `AZURE_CLIENT_ID` and `AZURE_CLIENT_SECRET`) either set in environment variables, or that the |
| 123 | + /// SDK is running in a machine that is able to assume an Azure AD role. |
| 124 | + pub async fn encrypt(&self, data: impl AsRef<[u8]>) -> CustomResult<String, AzureKeyVaultError> { |
| 125 | + let start = Instant::now(); |
| 126 | + |
| 127 | + let encrypt_params = KeyOperationsParameters { |
| 128 | + algorithm: Some(JsonWebKeyEncryptionAlgorithm::RsaOaep), |
| 129 | + value: Some(data.as_ref().to_vec()), |
| 130 | + ..Default::default() |
| 131 | + }; |
| 132 | + |
| 133 | + let encrypted_output = self |
| 134 | + .inner_client |
| 135 | + .encrypt(&self.key_name, &self.version, encrypt_params.clone().try_into().unwrap(), None) |
| 136 | + .await |
| 137 | + .inspect_err(|error| { |
| 138 | + logger::error!(azure_key_vault_error=?error, "Failed to Azure Key Vault decrypt data"); |
| 139 | + metrics::AZURE_KEY_VAULT_ENCRYPTION_FAILURES.add(1, &[]); |
| 140 | + }) |
| 141 | + .change_context(AzureKeyVaultError::EncryptionFailed)? |
| 142 | + .into_body() |
| 143 | + .await |
| 144 | + .inspect_err(|error| { |
| 145 | + logger::error!(azure_key_vault_error=?error, "Failed to Azure Key Vault decrypt data"); |
| 146 | + metrics::AZURE_KEY_VAULT_ENCRYPTION_FAILURES.add(1, &[]); |
| 147 | + }) |
| 148 | + .change_context(AzureKeyVaultError::EncryptionFailed)?; |
| 149 | + |
| 150 | + let output = encrypted_output |
| 151 | + .result |
| 152 | + .ok_or(AzureKeyVaultError::MissingCiphertextEncryptionOutput) |
| 153 | + .map(|bytes| consts::BASE64_ENGINE.encode(bytes))?; |
| 154 | + |
| 155 | + let time_taken = start.elapsed(); |
| 156 | + metrics::AZURE_KEY_VAULT_ENCRYPT_TIME.record(time_taken.as_secs_f64(), &[]); |
| 157 | + |
| 158 | + Ok(output) |
| 159 | + } |
| 160 | + |
| 161 | + |
| 162 | +} |
| 163 | + |
| 164 | + |
| 165 | +/// Errors that could occur during AZURE KEY VAULT operations. |
| 166 | +#[derive(Debug, thiserror::Error)] |
| 167 | +pub enum AzureKeyVaultError { |
| 168 | + /// An error occurred when base64 encoding input data. |
| 169 | + #[error("Failed to base64 encode input data")] |
| 170 | + Base64EncodingFailed, |
| 171 | + |
| 172 | + /// An error occurred when base64 decoding input data. |
| 173 | + #[error("Failed to base64 decode input data")] |
| 174 | + Base64DecodingFailed, |
| 175 | + |
| 176 | + /// An error occurred when AZURE KEY VAULT decrypting input data. |
| 177 | + #[error("Failed to Azure Key Vault decrypt input data")] |
| 178 | + DecryptionFailed, |
| 179 | + |
| 180 | + /// An error occurred when AZURE KEY VAULT encrypting input data. |
| 181 | + #[error("Failed to Azure Key Vault encrypt input data")] |
| 182 | + EncryptionFailed, |
| 183 | + |
| 184 | + /// The AZURE KEY VAULT decrypted output does not include a plaintext output. |
| 185 | + #[error("Missing plaintext AZURE KEY VAULT decryption output")] |
| 186 | + MissingPlaintextDecryptionOutput, |
| 187 | + |
| 188 | + /// The AZURE KEY VAULT encrypted output does not include a ciphertext output. |
| 189 | + #[error("Missing ciphertext AZURE KEY VAULT encryption output")] |
| 190 | + MissingCiphertextEncryptionOutput, |
| 191 | + |
| 192 | + /// An error occurred UTF-8 decoding AZURE KEY VAULT decrypted output. |
| 193 | + #[error("Failed to UTF-8 decode decryption output")] |
| 194 | + Utf8DecodingFailed, |
| 195 | + |
| 196 | + /// The AZURE KEY VAULT client has not been initialized. |
| 197 | + #[error("The AZURE KEY VAULT client has not been initialized")] |
| 198 | + AzureKeyVaultClientInitializationFailed, |
| 199 | +} |
| 200 | + |
| 201 | + |
| 202 | +#[cfg(test)] |
| 203 | +mod tests { |
| 204 | + #![allow(clippy::expect_used, clippy::print_stdout)] |
| 205 | + #[tokio::test] |
| 206 | + async fn check_azure_key_vault_encryption() { |
| 207 | + std::env::set_var("AZURE_CLIENT_ID", "YOUR-CLIENT-ID"); |
| 208 | + std::env::set_var("AZURE_TENANT_ID", "YOUR-TENANT-ID"); |
| 209 | + std::env::set_var("AZURE_CLIENT_SECRET", "YOUR-CLIENT-SECRET"); |
| 210 | + use super::*; |
| 211 | + let config = AzureKeyVaultConfig { |
| 212 | + key_name: "YOUR AZURE KEY VAULT KEY NAME".to_string(), |
| 213 | + vault_url: "YOUR AZURE KEY VAULT URL".to_string(), |
| 214 | + version: "".to_string(), |
| 215 | + }; |
| 216 | + |
| 217 | + let data = "hello".to_string(); |
| 218 | + let binding = data.as_bytes(); |
| 219 | + let encrypted_fingerprint = AzureKeyVaultClient::new(&config) |
| 220 | + .await |
| 221 | + .expect("azure key vault client initialization failed") |
| 222 | + .encrypt(binding) |
| 223 | + .await |
| 224 | + .expect("azure key vault encryption failed"); |
| 225 | + |
| 226 | + println!("{}", encrypted_fingerprint); |
| 227 | + } |
| 228 | + |
| 229 | + #[tokio::test] |
| 230 | + async fn check_azure_key_vault_decrypt() { |
| 231 | + std::env::set_var("AZURE_CLIENT_ID", "YOUR-CLIENT-ID"); |
| 232 | + std::env::set_var("AZURE_TENANT_ID", "YOUR-TENANT-ID"); |
| 233 | + std::env::set_var("AZURE_CLIENT_SECRET", "YOUR-CLIENT-SECRET"); |
| 234 | + use super::*; |
| 235 | + let config = AzureKeyVaultConfig { |
| 236 | + key_name: "YOUR AZURE KEY VAULT KEY NAME".to_string(), |
| 237 | + vault_url: "YOUR AZURE KEY VAULT URL".to_string(), |
| 238 | + version: "".to_string(), |
| 239 | + }; |
| 240 | + |
| 241 | + // Should decrypt to hello |
| 242 | + let data = "AZURE KEY VAULT ENCRYPTED CIPHER".to_string(); |
| 243 | + let binding = data.as_bytes(); |
| 244 | + let decrypted_fingerprint = AzureKeyVaultClient::new(&config) |
| 245 | + .await |
| 246 | + .expect("azure key vault client initialization failed") |
| 247 | + .encrypt(binding) |
| 248 | + .await |
| 249 | + .expect("azure key vault decryption failed"); |
| 250 | + |
| 251 | + println!("{}", decrypted_fingerprint); |
| 252 | + } |
| 253 | +} |
0 commit comments