Skip to content

Commit

Permalink
[CHORE] test account github workflow (#146)
Browse files Browse the repository at this point in the history
* test account github workflow

---------

Co-authored-by: TJ Murphy <[email protected]>
  • Loading branch information
teej and teej authored Nov 7, 2024
1 parent 69ec316 commit c3263a4
Show file tree
Hide file tree
Showing 14 changed files with 328 additions and 125 deletions.
63 changes: 63 additions & 0 deletions .github/workflows/reset-test-account.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
name: Reset test account

on:
workflow_dispatch:
workflow_call:
secrets:
TEST_SNOWFLAKE_ACCOUNT:
required: true
TEST_SNOWFLAKE_USER:
required: true
TEST_SNOWFLAKE_PASSWORD:
required: true
VAR_STORAGE_BASE_URL:
required: true
VAR_STORAGE_ROLE_ARN:
required: true
VAR_STORAGE_AWS_EXTERNAL_ID:
required: true
STATIC_USER_RSA_PUBLIC_KEY:
required: true
STATIC_USER_MFA_PASSWORD:
required: true

jobs:
reset-test-account:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
- environment: snowflake-gcp-standard
- environment: snowflake-aws-standard
- environment: snowflake-aws-enterprise
environment: ${{ matrix.environment }}
steps:
- name: actions/checkout
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.9'
- name: Create a virtual environment
run: |
python -m venv .venv
- name: Install dependencies
run: |
source ./.venv/bin/activate
python -m pip install --upgrade pip
make install-dev
- name: Reset test account
run: |
source ./.venv/bin/activate
python tools/reset_test_account.py
env:
SNOWFLAKE_ACCOUNT: ${{ secrets.TEST_SNOWFLAKE_ACCOUNT }}
SNOWFLAKE_USER: ${{ secrets.TEST_SNOWFLAKE_USER }}
SNOWFLAKE_PASSWORD: ${{ secrets.TEST_SNOWFLAKE_PASSWORD }}
SNOWFLAKE_ROLE: ACCOUNTADMIN
TITAN_VAR_STORAGE_BASE_URL: ${{ secrets.VAR_STORAGE_BASE_URL }}
TITAN_VAR_STORAGE_ROLE_ARN: ${{ secrets.VAR_STORAGE_ROLE_ARN }}
TITAN_VAR_STORAGE_AWS_EXTERNAL_ID: ${{ secrets.VAR_STORAGE_AWS_EXTERNAL_ID }}
TITAN_VAR_STATIC_USER_RSA_PUBLIC_KEY: ${{ secrets.STATIC_USER_RSA_PUBLIC_KEY }}
TITAN_VAR_STATIC_USER_MFA_PASSWORD: ${{ secrets.STATIC_USER_MFA_PASSWORD }}
48 changes: 48 additions & 0 deletions tests/test_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from titan.parse import parse_region


def test_parse_region():
assert parse_region("AWS_US_WEST_2") == {"cloud": "AWS", "cloud_region": "US_WEST_2"}
assert parse_region("PUBLIC.AWS_US_WEST_2") == {
"region_group": "PUBLIC",
"cloud": "AWS",
"cloud_region": "US_WEST_2",
}
assert parse_region("AZURE_WESTUS2") == {"cloud": "AZURE", "cloud_region": "WESTUS2"}
assert parse_region("GCP_EUROPE_WEST4") == {"cloud": "GCP", "cloud_region": "EUROPE_WEST4"}

assert parse_region("AWS_US_GOV_WEST_1_FHPLUS") == {"cloud": "AWS", "cloud_region": "US_GOV_WEST_1_FHPLUS"}
assert parse_region("AWS_US_GOV_WEST_1_DOD") == {"cloud": "AWS", "cloud_region": "US_GOV_WEST_1_DOD"}
assert parse_region("AWS_AP_SOUTHEAST_1") == {"cloud": "AWS", "cloud_region": "AP_SOUTHEAST_1"}
assert parse_region("AWS_EU_CENTRAL_1") == {"cloud": "AWS", "cloud_region": "EU_CENTRAL_1"}

assert parse_region("AZURE_CANADACENTRAL") == {"cloud": "AZURE", "cloud_region": "CANADACENTRAL"}
assert parse_region("AZURE_NORTHEUROPE") == {"cloud": "AZURE", "cloud_region": "NORTHEUROPE"}
assert parse_region("AZURE_SWITZERLANDNORTH") == {"cloud": "AZURE", "cloud_region": "SWITZERLANDNORTH"}
assert parse_region("AZURE_USGOVVIRGINIA") == {"cloud": "AZURE", "cloud_region": "USGOVVIRGINIA"}

assert parse_region("GCP_US_CENTRAL1") == {"cloud": "GCP", "cloud_region": "US_CENTRAL1"}
assert parse_region("GCP_EUROPE_WEST2") == {"cloud": "GCP", "cloud_region": "EUROPE_WEST2"}
assert parse_region("GCP_EUROPE_WEST3") == {"cloud": "GCP", "cloud_region": "EUROPE_WEST3"}

assert parse_region("PUBLIC.AWS_EU_CENTRAL_1") == {
"region_group": "PUBLIC",
"cloud": "AWS",
"cloud_region": "EU_CENTRAL_1",
}
assert parse_region("PUBLIC.AZURE_WESTEUROPE") == {
"region_group": "PUBLIC",
"cloud": "AZURE",
"cloud_region": "WESTEUROPE",
}
assert parse_region("PUBLIC.GCP_US_CENTRAL1") == {
"region_group": "PUBLIC",
"cloud": "GCP",
"cloud_region": "US_CENTRAL1",
}

assert parse_region("SOME_VALUE.GCP_US_CENTRAL1") == {
"region_group": "SOME_VALUE",
"cloud": "GCP",
"cloud_region": "US_CENTRAL1",
}
3 changes: 3 additions & 0 deletions titan/blueprint_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def __post_init__(self):
if not isinstance(self.run_mode, RunMode):
raise ValueError(f"Invalid run_mode: {self.run_mode}")

if not isinstance(self.vars, dict):
raise ValueError(f"vars must be a dictionary, got: {self.vars=}")

if self.scope is not None and not isinstance(self.scope, BlueprintScope):
raise ValueError(f"Invalid scope: {self.scope}")

Expand Down
7 changes: 7 additions & 0 deletions titan/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
_parse_dynamic_table_text,
parse_view_ddl,
parse_collection_string,
parse_region,
)
from .privs import GrantedPrivilege
from .resource_name import ResourceName, attribute_is_resource_name, resource_name_from_snowflake_metadata
Expand All @@ -44,6 +45,8 @@ class SessionContext(TypedDict):
account_locator: str
account: str
available_roles: list[ResourceName]
cloud: str
cloud_region: str
database: str
role: ResourceName
schemas: list[str]
Expand Down Expand Up @@ -571,18 +574,22 @@ def fetch_session(session: SnowflakeConnection) -> SessionContext:
CURRENT_SCHEMAS() as schemas,
CURRENT_WAREHOUSE() as warehouse,
CURRENT_VERSION() as version,
CURRENT_REGION() as region,
SYSTEM$BOOTSTRAP_DATA_REQUEST('ACCOUNT') as account_data
""",
)[0]

account_data = json.loads(session_obj["ACCOUNT_DATA"])
available_roles = [ResourceName(role) for role in json.loads(session_obj["AVAILABLE_ROLES"])]
region = parse_region(session_obj["REGION"])

return {
"account_edition": AccountEdition(account_data["accountInfo"]["serviceLevelName"]),
"account_locator": session_obj["ACCOUNT_LOCATOR"],
"account": session_obj["ACCOUNT"],
"available_roles": available_roles,
"cloud": region["cloud"],
"cloud_region": region["cloud_region"],
"database": session_obj["DATABASE"],
"role": ResourceName(session_obj["ROLE"]),
"schemas": json.loads(session_obj["SCHEMAS"]),
Expand Down
9 changes: 6 additions & 3 deletions titan/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,13 @@ class AccountEdition(ParseableEnum):
@classmethod
def synonyms(cls):
"""Override to provide a dictionary of synonyms for the enum values"""
return {
"BUSINESS-CRITICAL" : "BUSINESS_CRITICAL"
}
return {"BUSINESS-CRITICAL": "BUSINESS_CRITICAL"}


class AccountCloud(ParseableEnum):
AWS = "AWS"
GCP = "GCP"
AZURE = "AZURE"


class DataType(ParseableEnum):
Expand Down
27 changes: 27 additions & 0 deletions titan/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,3 +719,30 @@ def parse_collection_string(collection: str):

def format_collection_string(collection: dict):
return f"{collection['in_name']}.<{collection['on_type']}>"


def parse_region(region_str: str) -> dict[str, str]:
"""Parse a Snowflake region identifier into its components.
Examples:
AWS_US_WEST_2 -> {'cloud': 'AWS', 'cloud_region': 'US_WEST_2'}
PUBLIC.AWS_US_WEST_2 -> {'region_group': 'PUBLIC', 'cloud': 'AWS', 'cloud_region': 'US_WEST_2'}
AZURE_WESTUS2 -> {'cloud': 'AZURE', 'cloud_region': 'WESTUS2'}
GCP_EUROPE_WEST4 -> {'cloud': 'GCP', 'cloud_region': 'EUROPE_WEST4'}
"""
import re

pattern = r"^(?:([A-Z_]+)\.)?([A-Z]+)_(.+?)$"
match = re.match(pattern, region_str)

if not match:
raise ValueError(f"Invalid region format: {region_str}")

region_group, cloud, cloud_region = match.groups()

result = {"cloud": cloud, "cloud_region": cloud_region}

if region_group:
result["region_group"] = region_group

return result
138 changes: 138 additions & 0 deletions tools/__reset_test_account.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import os

import snowflake.connector
import yaml
from dotenv import dotenv_values

from titan import resources as res
from titan.blueprint import Blueprint, print_plan
from titan.data_provider import fetch_session
from titan.enums import AccountEdition
from titan.gitops import collect_blueprint_config


def read_config(file) -> dict:
config_path = os.path.join(os.path.dirname(__file__), file)
with open(config_path, "r") as f:
config = yaml.safe_load(f)
return config


def merge_configs(config1: dict, config2: dict) -> dict:
merged = config1.copy()
for key, value in config2.items():
if key in merged:
if isinstance(merged[key], list):
merged[key] = merged[key] + value
elif merged[key] is None:
merged[key] = value
else:
merged[key] = value
return merged


def configure_test_account(conn, cloud: str):
session_ctx = fetch_session(conn)
config = read_config("test_account.yml")
vars = dotenv_values("env/.vars.test_account")
print(vars)

if session_ctx["account_edition"] == AccountEdition.ENTERPRISE:
config = merge_configs(config, read_config("test_account_enterprise.yml"))

if cloud == "aws":
config = merge_configs(config, read_config("test_account_aws.yml"))
# elif cloud == "gcp":
# config = merge_configs(config, read_config("test_account_gcp.yml"))

blueprint_config = collect_blueprint_config(config, {"vars": vars})

bp = Blueprint.from_config(blueprint_config)
plan = bp.plan(conn)
print_plan(plan)
bp.apply(conn, plan)


def configure_aws_heavy(conn):
bp = Blueprint(
name="reset-test-account",
run_mode="CREATE-OR-UPDATE",
)

roles = [res.Role(name=f"ROLE_{i}") for i in range(50)]
databases = []
for i in range(10):
database = res.Database(name=f"DATABASE_{i}")
bp.add(database)
databases.append(database)
for role in roles:
# bp.add(res.Grant(priv="USAGE", to=role, on=database))
pass
for j in range(10):
schema = res.Schema(name=f"SCHEMA_{j}", database=database)
bp.add(schema)
for role in roles:
# bp.add(res.Grant(priv="USAGE", to=role, on=schema))
pass
for k in range(5):
table = res.Table(
name=f"TABLE_{k}", columns=[{"name": "ID", "data_type": "NUMBER(38,0)"}], schema=schema
)
bp.add(table)
for role in roles:
# bp.add(res.Grant(priv="SELECT", to=role, on=table))
pass

bp.add(roles)

staged_count = len(bp._staged)

plan = bp.plan(conn)
print_plan(plan[:10])
print("Changes in plan:", len(plan))
print("Staged resources:", staged_count)
bp.apply(conn, plan)

bp = Blueprint(
name="reset-test-account",
run_mode="CREATE-OR-UPDATE",
)
for database in databases:
for role in roles:
bp.add(res.Grant(priv="USAGE", to=role.name, on_database=database.name))
bp.add(res.GrantOnAll(priv="USAGE", to=role.name, on_all_schemas_in_database=database.name))
bp.add(res.GrantOnAll(priv="SELECT", to=role.name, on_all_tables_in_database=database.name))
bp.apply(conn)


def get_connection(env_vars):
return snowflake.connector.connect(
account=env_vars["SNOWFLAKE_ACCOUNT"],
user=env_vars["SNOWFLAKE_USER"],
password=env_vars["SNOWFLAKE_PASSWORD"],
role=env_vars["SNOWFLAKE_ROLE"],
# warehouse=env_vars["SNOWFLAKE_WAREHOUSE"],
)


def configure_test_accounts():

for account in ["aws.standard", "aws.enterprise", "gcp.standard"]:
print(">>>>>>>>>>>>>>>>", account)
cloud = account.split(".")[0]
env_vars = dotenv_values(f"env/.env.{account}")
conn = get_connection(env_vars)
try:
configure_test_account(conn, cloud)
# except Exception as e:
# print(f"Error configuring {account}: {e}")
finally:
conn.close()

# now = time.time()
# configure_aws_heavy(get_connection(dotenv_values(".env.aws.heavy")))
# print(f"done in {time.time() - now:.2f}s")


if __name__ == "__main__":
configure_test_accounts()
Loading

0 comments on commit c3263a4

Please sign in to comment.