Skip to content

Commit

Permalink
Merge pull request #13 from atlanhq/soda-rex
Browse files Browse the repository at this point in the history
PES-3645 Soda <> Redshift IAM Role profiling
  • Loading branch information
bichitra95 authored Oct 15, 2024
2 parents 560117c + 653d537 commit a0203f9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion soda/core/soda/common/aws_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def assume_role(self, role_session_name: str):
aws_session_token=self.session_token,
)

assumed_role_object = self.sts_client.assume_role(RoleArn=self.role_arn, RoleSessionName=role_session_name)
assumed_role_object = self.sts_client.assume_role(RoleArn=self.role_arn, ExternalId=self.external_id, RoleSessionName=role_session_name)
credentials_dict = assumed_role_object["Credentials"]
return AwsCredentials(
region_name=self.region_name,
Expand Down
10 changes: 7 additions & 3 deletions soda/redshift/soda/data_sources/redshift_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def __init__(self, logs: Logs, data_source_name: str, data_source_properties: di
self.connect_timeout = data_source_properties.get("connection_timeout_sec")
self.username = data_source_properties.get("username")
self.password = data_source_properties.get("password")
self.dbuser = data_source_properties.get("dbuser")
self.dbname = data_source_properties.get("dbname")
self.cluster_id = data_source_properties.get("cluster_id")

if not self.username or not self.password:
aws_credentials = AwsCredentials(
Expand All @@ -31,6 +34,7 @@ def __init__(self, logs: Logs, data_source_name: str, data_source_properties: di
session_token=data_source_properties.get("session_token"),
region_name=data_source_properties.get("region", "eu-west-1"),
profile_name=data_source_properties.get("profile_name"),
external_id=data_source_properties.get("external_id"),
)
self.username, self.password = self.__get_cluster_credentials(aws_credentials)

Expand Down Expand Up @@ -60,9 +64,9 @@ def __get_cluster_credentials(self, aws_credentials: AwsCredentials):
aws_session_token=resolved_aws_credentials.session_token,
)

cluster_name = self.host.split(".")[0]
username = self.username
db_name = self.database
cluster_name = self.cluster_id if self.cluster_id else self.host.split(".")[0]
username = self.dbuser if self.dbuser else self.username
db_name = self.dbname if self.dbname else self.database
cluster_creds = client.get_cluster_credentials(
DbUser=username, DbName=db_name, ClusterIdentifier=cluster_name, AutoCreate=False, DurationSeconds=3600
)
Expand Down

0 comments on commit a0203f9

Please sign in to comment.