Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
62188a1
first try
ryanraaschCDC Dec 3, 2025
91d2b9a
add a force_keyvault argument
ryanraaschCDC Dec 4, 2025
22a1255
update doc strings
ryanraaschCDC Dec 4, 2025
3ddb317
add kv function to cloudclient
ryanraaschCDC Dec 4, 2025
b6be2e0
fix credentials
ryanraaschCDC Dec 4, 2025
dce8336
Update docs/CloudClient/authentication.md
ryanraaschCDC Dec 4, 2025
b017261
Update cfa/cloudops/_cloudclient.py
ryanraaschCDC Dec 4, 2025
a134588
updated documentation
ryanraaschCDC Dec 5, 2025
670ad88
change keys to match key vault
ryanraaschCDC Dec 8, 2025
8beba30
add more logic
ryanraaschCDC Dec 9, 2025
6aa4d49
add a delete env
ryanraaschCDC Dec 10, 2025
8fa162d
add check if env var exists before del
ryanraaschCDC Dec 10, 2025
116db74
remove extra del loop
ryanraaschCDC Dec 10, 2025
5f391a9
remove del
ryanraaschCDC Dec 10, 2025
ef356e4
Merge branch 'main' into rr-65-add-key-vault
ryanraaschCDC Dec 12, 2025
102aebf
add print statements for debugging
ryanraaschCDC Dec 16, 2025
d7e7a91
pull kv from env again
ryanraaschCDC Dec 16, 2025
97bbb29
Merge branch 'rr-65-add-key-vault' of https://github.com/CDCgov/cfa-c…
ryanraaschCDC Dec 16, 2025
11b86f3
more prints
ryanraaschCDC Dec 16, 2025
6c8d2b5
more prints and change where pulling sub client from
ryanraaschCDC Dec 16, 2025
770b9c0
more print
ryanraaschCDC Dec 17, 2025
1546de5
print exceptions
ryanraaschCDC Dec 17, 2025
3cf6ef4
fix formatting of key name
ryanraaschCDC Dec 18, 2025
e6d3a87
remove extra prints
ryanraaschCDC Jan 5, 2026
d92fbe8
remove prints
ryanraaschCDC Jan 5, 2026
6ca8b2b
save keyvault to env var
ryanraaschCDC Jan 6, 2026
7f0a13c
change default dotenv path
ryanraaschCDC Jan 6, 2026
1d15c65
add dotenv check
ryanraaschCDC Jan 6, 2026
b2963bf
remove keyvault check
ryanraaschCDC Jan 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 64 additions & 4 deletions cfa/cloudops/_cloudclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
from graphlib import CycleError, TopologicalSorter
from typing import Optional

import networkx as nx
import pandas as pd
Expand All @@ -13,6 +14,9 @@
OnAllTasksComplete,
OnTaskFailure,
)
from azure.keyvault.secrets import SecretClient

# from azure.batch.models import TaskAddParameter
from azure.mgmt.batch import models
from azure.mgmt.resource import SubscriptionClient

Expand Down Expand Up @@ -45,6 +49,7 @@ class CloudClient:
provides convenient methods for common batch operations.

Args:
keyvault (str, optional): Name of the Azure Key Vault to use for secrets.
dotenv_path (str, optional): Path to .env file containing environment variables.
If None, uses default .env file discovery. Default is None.
use_sp (bool, optional): Whether to use Service Principal authentication.
Expand Down Expand Up @@ -87,23 +92,47 @@ class CloudClient:

def __init__(
self,
keyvault: str = None,
dotenv_path: str = None,
use_sp: bool = False,
use_federated: bool = False,
force_keyvault: bool = False,
**kwargs,
):
logger.debug("Initializing CloudClient.")
if keyvault is None:
dotenv_path = dotenv_path or ".env"
if keyvault is None and force_keyvault:
logger.error(
"Keyvault information not found but force_keyvault set to True."
)
raise ValueError("Keyvault information is required but not found.")
# authenticate to get credentials
if not use_sp and not use_federated:
self.cred = EnvCredentialHandler(dotenv_path=dotenv_path, **kwargs)
self.cred = EnvCredentialHandler(
dotenv_path=dotenv_path,
keyvault=keyvault,
force_keyvault=force_keyvault,
**kwargs,
)
self.method = "env"
logger.info("Using environment-based credentials.")
logger.info("Using managed identity credentials.")
elif use_federated:
self.cred = DefaultCredentialHandler(dotenv_path=dotenv_path, **kwargs)
self.cred = DefaultCredentialHandler(
dotenv_path=dotenv_path,
keyvault=keyvault,
force_keyvault=force_keyvault,
**kwargs,
)
self.method = "default"
logger.info("Using default credentials.")
else:
self.cred = SPCredentialHandler(dotenv_path=dotenv_path, **kwargs)
self.cred = SPCredentialHandler(
dotenv_path=dotenv_path,
keyvault=keyvault,
force_keyvault=force_keyvault,
**kwargs,
)
self.method = "sp"
logger.info("Using service principal credentials.")
# get clients
Expand Down Expand Up @@ -1701,6 +1730,7 @@ def async_upload_folder(
location_in_blob="project")

Note:

The blob container must exist before uploading. Directory structure is
preserved in the container. Use filtering options to avoid uploading
unnecessary files like temporary files or build artifacts.
Expand Down Expand Up @@ -2180,3 +2210,33 @@ def run_dag(self, *args: batch_helpers.Task, job_name: str, **kwargs):
dlist.append(str(dp))
task_df.at[i, "deps"] = dlist
logger.info(f"Completed DAG run for job '{job_name}'.")

def get_kv_secret(self, secret_name: str, keyvault: str) -> Optional[str]:
"""Retrieve a secret from Azure Key Vault.

Args:
secret_name (str): The name of the secret to retrieve.
keyvault (str): The name of the Key Vault.

Returns:
Optional[str]: The value of the secret, or None if not found.
"""
if self.method == "env":
cred = self.cred.user_credential
elif self.method == "default":
cred = self.cred.user_credential
else:
cred = self.cred.client_secret_credential
try:
secret_client = SecretClient(
vault_url=f"https://{keyvault}.vault.azure.net/",
credential=cred,
)
secret = secret_client.get_secret(secret_name)
return secret.value
except Exception as e:
logger.error(
f"Failed to retrieve secret '{secret_name}' from Key Vault '{keyvault}': {e}"
)
print(f"Error retrieving secret '{secret_name}': {e}")
return None
177 changes: 169 additions & 8 deletions cfa/cloudops/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,18 +533,31 @@ class EnvCredentialHandler(CredentialHandler):
>>> handler = EnvCredentialHandler(dotenv_path="/path/to/.env")
"""

def __init__(self, dotenv_path: str = None, **kwargs) -> None:
def __init__(
self,
dotenv_path: str = ".env",
keyvault: str = None,
force_keyvault: bool = False,
**kwargs,
) -> None:
"""Initialize the EnvCredentialHandler.

Loads environment variables from .env file and populates credential attributes from them.

Args:
dotenv_path (str, optional): Path to .env file to load environment variables from.
If None, uses default .env file discovery.
keyvault (str, optional): Name of the Azure Key Vault to use for secrets.
force_keyvault (bool, optional): If True, forces loading of Key Vault secrets even if they are already set in the environment.
**kwargs: Additional keyword arguments to override specific credential attributes.
"""
logger.debug("Initializing EnvCredentialHandler.")
load_env_vars(dotenv_path=dotenv_path)
load_env_vars(
dotenv_path=dotenv_path,
keyvault_name=keyvault,
force_keyvault=force_keyvault,
)

get_conf = partial(get_config_val, config_dict=kwargs, try_env=True)

for key in self.__dataclass_fields__.keys():
Expand All @@ -556,29 +569,44 @@ def __init__(self, dotenv_path: str = None, **kwargs) -> None:
self.__setattr__("azure_batch_location", d.default_azure_batch_location)


def load_env_vars(dotenv_path=None):
def load_env_vars(
dotenv_path=None, keyvault_name: str = None, force_keyvault: bool = False
):
"""Load environment variables and Azure subscription information.

Loads variables from a .env file (if specified), retrieves Azure subscription
information using ManagedIdentityCredential, and sets default environment variables.

Args:
dotenv_path: Path to .env file to load. If None, uses default .env file discovery.
keyvault_name: Name of the Azure Key Vault to use for secrets.
force_keyvault: If True, forces loading of Key Vault secrets even if they are already set in the environment.

Example:
>>> load_env_vars() # Load from default .env
>>> load_env_vars("/path/to/.env") # Load from specific file
"""
# get ManagedIdentityCredential
mid_cred = ManagedIdentityCredential()

logger.debug("Loading environment variables.")
load_dotenv(dotenv_path=dotenv_path, override=True)
# get ManagedIdentityCredential to pull SubscriptionClient
mid_cred = ManagedIdentityCredential()

sub_c = SubscriptionClient(mid_cred)
# pull in account info and save to environment vars
account_info = list(sub_c.subscriptions.list())[0]
os.environ["AZURE_SUBSCRIPTION_ID"] = account_info.subscription_id
os.environ["AZURE_TENANT_ID"] = account_info.tenant_id
os.environ["AZURE_RESOURCE_GROUP_NAME"] = account_info.display_name

# get Key Vault secrets
if keyvault_name is not None:
get_keyvault_vars(
keyvault_name=keyvault_name,
credential=mid_cred,
force_keyvault=force_keyvault,
)

# save default values
d.set_env_vars()

Expand All @@ -590,7 +618,9 @@ def __init__(
azure_subscription_id: str = None,
azure_client_id: str = None,
azure_client_secret: str = None,
dotenv_path: str = None,
dotenv_path: str = ".env",
keyvault: str = None,
force_keyvault: bool = False,
**kwargs,
):
"""Initialize a Service Principal Credential Handler.
Expand All @@ -611,6 +641,8 @@ def __init__(
attempt to load from AZURE_CLIENT_SECRET environment variable.
dotenv_path: Path to .env file to load environment variables from.
If None, uses default .env file discovery.
keyvault: Name of the Azure Key Vault to use for secrets.
force_keyvault: If True, forces loading of Key Vault secrets even if they are already set in the environment.
**kwargs: Additional keyword arguments to override specific credential attributes.

Raises:
Expand Down Expand Up @@ -681,6 +713,18 @@ def __init__(
[x.lower() for x in mandatory_environment_variables],
goal="service principal credentials",
)
sp_cred = ClientSecretCredential(
tenant_id=self.azure_tenant_id,
client_id=self.azure_client_id,
client_secret=self.azure_client_secret,
)
# load keyvault secrets
if keyvault is not None:
get_keyvault_vars(
keyvault_name=keyvault,
credential=sp_cred,
force_keyvault=force_keyvault,
)

d.set_env_vars()

Expand All @@ -698,7 +742,9 @@ def __init__(
class DefaultCredentialHandler(CredentialHandler):
def __init__(
self,
dotenv_path: str | None = None,
dotenv_path: str | None = ".env",
keyvault: str = None,
force_keyvault: bool = False,
**kwargs,
) -> None:
"""Initialize a Default Credential Handler.
Expand All @@ -711,6 +757,8 @@ def __init__(
Args:
dotenv_path: Path to .env file to load environment variables from.
If None, uses default .env file discovery.
keyvault: Name of the Azure Key Vault to use for secrets.
force_keyvault: If True, forces loading of Key Vault secrets even if they are already set in the environment.
**kwargs: Additional keyword arguments to override specific credential attributes.

Raises:
Expand All @@ -731,7 +779,25 @@ def __init__(
"Retrieving Azure subscription information using DefaultCredential."
)
d_cred = DefaultCredential()
sub_c = SubscriptionClient(d_cred)

# load keyvault secrets
if keyvault is None:
try:
keyvault = os.environ["AZURE_KEYVAULT_NAME"]
except KeyError:
keyvault = None
if keyvault is not None:
get_keyvault_vars(
keyvault_name=keyvault,
credential=d_cred,
force_keyvault=force_keyvault,
)

try:
sub_c = SubscriptionClient(d_cred)
except Exception as e:
logger.error(f"Failed to create SubscriptionClient: {e}")
raise
sub_id = os.getenv("AZURE_SUBSCRIPTION_ID", None)
if sub_id is None:
logger.error("AZURE_SUBSCRIPTION_ID not found in environment variables.")
Expand Down Expand Up @@ -929,3 +995,98 @@ def get_compute_node_identity_reference(
ch = EnvCredentialHandler()
logger.debug("Retrieving compute_node_identity_reference from CredentialHandler.")
return ch.compute_node_identity_reference


def get_secret_client(keyvault: str, credential: object) -> SecretClient:
"""Get an Azure Key Vault SecretClient using a CredentialHandler.

Args:
keyvault: Name of the Azure Key Vault to connect to.
credential: Credential handler for connecting and authenticating to Azure resources.

Returns:
SecretClient: An authenticated SecretClient for the specified Key Vault.

Example:
>>> handler = CredentialHandler()
>>> secret_client = get_secret_client("myvault", handler)
"""
logger.debug("Creating SecretClient for Azure Key Vault.")
vault_url = f"https://{keyvault}.{d.default_azure_keyvault_endpoint_subdomain}"
secret_client = SecretClient(vault_url=vault_url, credential=credential)
logger.debug("Created SecretClient for Azure Key Vault.")
return secret_client


def load_keyvault_vars(
secret_client: SecretClient,
force_keyvault: bool = False,
):
"""Load secrets from an Azure Key Vault into environment variables.

Args:
secret_client: SecretClient for accessing the Azure Key Vault.
force_keyvault: If True, forces loading of Key Vault secrets even if they are already set in the environment.
"""
kv_keys = d.default_kv_keys

for key in kv_keys:
if force_keyvault:
logger.debug(
"Force Key Vault load enabled; loading secret regardless of existing environment variable."
)
try:
secret = secret_client.get_secret(key.replace("_", "-")).value
os.environ[key] = secret
logger.debug(
f"Loaded secret '{key}' from Key Vault into environment variable."
)
except Exception as e:
logger.warning(f"Could not load secret '{key}' from Key Vault: {e}")
print("Error loading secret: ", e)
else:
if key in os.environ:
logger.debug(
f"Environment variable '{key}' already set; skipping Key Vault load."
)
continue
else:
try:
secret = secret_client.get_secret(key.replace("_", "-")).value
os.environ[key] = secret
logger.debug(
f"Loaded secret '{key}' from Key Vault into environment variable."
)
except Exception as e:
logger.warning(f"Could not load secret '{key}' from Key Vault: {e}")
print(f"Error loading secret: {e}")


def get_keyvault_vars(
keyvault_name: str,
credential: object,
force_keyvault: bool = False,
):
"""Retrieve secrets from an Azure Key Vault and save to environment.

Args:
keyvault_name: Name of the Azure Key Vault to connect to.
credential: Credential handler for connecting and authenticating to Azure resources.
force_keyvault: If True, forces loading of Key Vault secrets even if they are already set in the environment.
"""
if keyvault_name is None:
logger.debug("No Key Vault name provided; skipping Key Vault variable loading.")
return None
else:
os.environ["AZURE_KEYVAULT_NAME"] = keyvault_name
logger.debug("Getting SecretClient for Azure Key Vault.")
try:
secret_client = get_secret_client(
keyvault=keyvault_name,
credential=credential,
)
except Exception as e:
logger.error(f"Failed to get SecretClient: {e}")
raise
logger.debug("Loading Key Vault secrets into environment variables.")
load_keyvault_vars(secret_client, force_keyvault=force_keyvault)
Loading
Loading