Skip to content

Commit

Permalink
more logs
Browse files Browse the repository at this point in the history
  • Loading branch information
Divyanshu-Patel committed Oct 4, 2024
1 parent 22a956d commit f979b6d
Showing 1 changed file with 50 additions and 33 deletions.
83 changes: 50 additions & 33 deletions soda/redshift/soda/data_sources/redshift_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from soda.execution.data_source import DataSource

logger = logging.getLogger(__name__)

logger.setLevel(logging.DEBUG)

class RedshiftDataSource(DataSource):
TYPE = "redshift"
Expand All @@ -23,13 +23,12 @@ def __init__(self, logs: Logs, data_source_name: str, data_source_properties: di
self.username = data_source_properties.get("username")
self.password = data_source_properties.get("password")

self.logs.debug("Data Source Properties: %s", data_source_properties)
self.logs.debug("Initializing Redshift DataSource with properties: %s", data_source_properties)
self.logs.debug("Using external_id: %s", data_source_properties.get("external_id"))
self.logs.debug("Using role_arn: %s", data_source_properties.get("role_arn"))
self.logs.debug("Using access_key_id: %s", data_source_properties.get("access_key_id"))
self.logs.debug("Using secret_access_key: %s", data_source_properties.get("secret_access_key"))

if not self.username or not self.password:
self.logs.debug("Username or password not provided. Attempting to resolve credentials.")
aws_credentials = AwsCredentials(
access_key_id=data_source_properties.get("access_key_id"),
secret_access_key=data_source_properties.get("secret_access_key"),
Expand All @@ -39,45 +38,63 @@ def __init__(self, logs: Logs, data_source_name: str, data_source_properties: di
profile_name=data_source_properties.get("profile_name"),
external_id=data_source_properties.get("external_id")
)
self.username, self.password = self.__get_cluster_credentials(aws_credentials)
try:
self.username, self.password = self.__get_cluster_credentials(aws_credentials)
self.logs.debug("Successfully retrieved cluster credentials: username=%s", self.username)
except Exception as e:
self.logs.error("Failed to resolve cluster credentials: %s", str(e))
raise

def connect(self):
options = f"-c search_path={self.schema}" if self.schema else None

self.connection = psycopg2.connect(
user=self.username,
password=self.password,
host=self.host,
port=self.port,
connect_timeout=self.connect_timeout,
database=self.database,
options=options,
)
try:
self.connection = psycopg2.connect(
user=self.username,
password=self.password,
host=self.host,
port=self.port,
connect_timeout=self.connect_timeout,
database=self.database,
options=options,
)
self.logs.debug("Successfully connected to Redshift database at %s:%s", self.host, self.port)
except Exception as e:
self.logs.error("Failed to connect to Redshift database: %s", str(e))
raise

def __get_cluster_credentials(self, aws_credentials: AwsCredentials):
resolved_aws_credentials = aws_credentials.resolve_role(
role_session_name="soda_redshift_get_cluster_credentials"
)
try:
resolved_aws_credentials = aws_credentials.resolve_role(
role_session_name="soda_redshift_get_cluster_credentials"
)

self.logs.debug(f"Resolved AWS Credentials: {resolved_aws_credentials}")
self.logs.debug(f"Region Name after resolve_role: {resolved_aws_credentials.region_name}")
self.logs.debug("Resolved AWS Credentials: %s", resolved_aws_credentials)
self.logs.debug("Region Name after resolve_role: %s", resolved_aws_credentials.region_name)

client = boto3.client(
"redshift",
region_name=resolved_aws_credentials.region_name,
aws_access_key_id=resolved_aws_credentials.access_key_id,
aws_secret_access_key=resolved_aws_credentials.secret_access_key,
aws_session_token=resolved_aws_credentials.session_token,
)
client = boto3.client(
"redshift",
region_name=resolved_aws_credentials.region_name,
aws_access_key_id=resolved_aws_credentials.access_key_id,
aws_secret_access_key=resolved_aws_credentials.secret_access_key,
aws_session_token=resolved_aws_credentials.session_token,
)

cluster_name = self.host.split(".")[0]
username = self.username
db_name = self.database
cluster_creds = client.get_cluster_credentials(
DbUser=username, DbName=db_name, ClusterIdentifier=cluster_name, AutoCreate=False, DurationSeconds=3600
)
cluster_name = self.host.split(".")[0]
username = self.username
db_name = self.database

return cluster_creds["DbUser"], cluster_creds["DbPassword"]
self.logs.debug("Requesting cluster credentials for cluster: %s, db: %s", cluster_name, db_name)
cluster_creds = client.get_cluster_credentials(
DbUser=username, DbName=db_name, ClusterIdentifier=cluster_name, AutoCreate=False, DurationSeconds=3600
)

self.logs.debug("Cluster credentials retrieved successfully.")
return cluster_creds["DbUser"], cluster_creds["DbPassword"]

except Exception as e:
self.logs.error("Failed to get cluster credentials: %s", str(e))
raise

def sql_get_table_names_with_count(
self, include_tables: Optional[List[str]] = None, exclude_tables: Optional[List[str]] = None
Expand Down

0 comments on commit f979b6d

Please sign in to comment.