diff --git a/.github/workflows/reset-test-account.yml b/.github/workflows/reset-test-account.yml new file mode 100644 index 00000000..692d186e --- /dev/null +++ b/.github/workflows/reset-test-account.yml @@ -0,0 +1,63 @@ +name: Reset test account + +on: + workflow_dispatch: + workflow_call: + secrets: + TEST_SNOWFLAKE_ACCOUNT: + required: true + TEST_SNOWFLAKE_USER: + required: true + TEST_SNOWFLAKE_PASSWORD: + required: true + VAR_STORAGE_BASE_URL: + required: true + VAR_STORAGE_ROLE_ARN: + required: true + VAR_STORAGE_AWS_EXTERNAL_ID: + required: true + STATIC_USER_RSA_PUBLIC_KEY: + required: true + STATIC_USER_MFA_PASSWORD: + required: true + +jobs: + reset-test-account: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - environment: snowflake-gcp-standard + - environment: snowflake-aws-standard + - environment: snowflake-aws-enterprise + environment: ${{ matrix.environment }} + steps: + - name: actions/checkout + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.9' + - name: Create a virtual environment + run: | + python -m venv .venv + - name: Install dependencies + run: | + source ./.venv/bin/activate + python -m pip install --upgrade pip + make install-dev + - name: Reset test account + run: | + source ./.venv/bin/activate + python tools/reset_test_account.py + env: + SNOWFLAKE_ACCOUNT: ${{ secrets.TEST_SNOWFLAKE_ACCOUNT }} + SNOWFLAKE_USER: ${{ secrets.TEST_SNOWFLAKE_USER }} + SNOWFLAKE_PASSWORD: ${{ secrets.TEST_SNOWFLAKE_PASSWORD }} + SNOWFLAKE_ROLE: ACCOUNTADMIN + TITAN_VAR_STORAGE_BASE_URL: ${{ secrets.VAR_STORAGE_BASE_URL }} + TITAN_VAR_STORAGE_ROLE_ARN: ${{ secrets.VAR_STORAGE_ROLE_ARN }} + TITAN_VAR_STORAGE_AWS_EXTERNAL_ID: ${{ secrets.VAR_STORAGE_AWS_EXTERNAL_ID }} + TITAN_VAR_STATIC_USER_RSA_PUBLIC_KEY: ${{ secrets.STATIC_USER_RSA_PUBLIC_KEY }} + TITAN_VAR_STATIC_USER_MFA_PASSWORD: ${{ secrets.STATIC_USER_MFA_PASSWORD }} \ No newline at end of file diff --git a/tests/test_parse.py b/tests/test_parse.py new file mode 100644 index 00000000..f49e640d --- /dev/null +++ b/tests/test_parse.py @@ -0,0 +1,48 @@ +from titan.parse import parse_region + + +def test_parse_region(): + assert parse_region("AWS_US_WEST_2") == {"cloud": "AWS", "cloud_region": "US_WEST_2"} + assert parse_region("PUBLIC.AWS_US_WEST_2") == { + "region_group": "PUBLIC", + "cloud": "AWS", + "cloud_region": "US_WEST_2", + } + assert parse_region("AZURE_WESTUS2") == {"cloud": "AZURE", "cloud_region": "WESTUS2"} + assert parse_region("GCP_EUROPE_WEST4") == {"cloud": "GCP", "cloud_region": "EUROPE_WEST4"} + + assert parse_region("AWS_US_GOV_WEST_1_FHPLUS") == {"cloud": "AWS", "cloud_region": "US_GOV_WEST_1_FHPLUS"} + assert parse_region("AWS_US_GOV_WEST_1_DOD") == {"cloud": "AWS", "cloud_region": "US_GOV_WEST_1_DOD"} + assert parse_region("AWS_AP_SOUTHEAST_1") == {"cloud": "AWS", "cloud_region": "AP_SOUTHEAST_1"} + assert parse_region("AWS_EU_CENTRAL_1") == {"cloud": "AWS", "cloud_region": "EU_CENTRAL_1"} + + assert parse_region("AZURE_CANADACENTRAL") == {"cloud": "AZURE", "cloud_region": "CANADACENTRAL"} + assert parse_region("AZURE_NORTHEUROPE") == {"cloud": "AZURE", "cloud_region": "NORTHEUROPE"} + assert parse_region("AZURE_SWITZERLANDNORTH") == {"cloud": "AZURE", "cloud_region": "SWITZERLANDNORTH"} + assert parse_region("AZURE_USGOVVIRGINIA") == {"cloud": "AZURE", "cloud_region": "USGOVVIRGINIA"} + + assert parse_region("GCP_US_CENTRAL1") == {"cloud": "GCP", "cloud_region": "US_CENTRAL1"} + assert parse_region("GCP_EUROPE_WEST2") == {"cloud": "GCP", "cloud_region": "EUROPE_WEST2"} + assert parse_region("GCP_EUROPE_WEST3") == {"cloud": "GCP", "cloud_region": "EUROPE_WEST3"} + + assert parse_region("PUBLIC.AWS_EU_CENTRAL_1") == { + "region_group": "PUBLIC", + "cloud": "AWS", + "cloud_region": "EU_CENTRAL_1", + } + assert parse_region("PUBLIC.AZURE_WESTEUROPE") == { + "region_group": "PUBLIC", + "cloud": "AZURE", + "cloud_region": "WESTEUROPE", + } + assert parse_region("PUBLIC.GCP_US_CENTRAL1") == { + "region_group": "PUBLIC", + "cloud": "GCP", + "cloud_region": "US_CENTRAL1", + } + + assert parse_region("SOME_VALUE.GCP_US_CENTRAL1") == { + "region_group": "SOME_VALUE", + "cloud": "GCP", + "cloud_region": "US_CENTRAL1", + } diff --git a/titan/blueprint_config.py b/titan/blueprint_config.py index ef851e94..e20e39c8 100644 --- a/titan/blueprint_config.py +++ b/titan/blueprint_config.py @@ -45,6 +45,9 @@ def __post_init__(self): if not isinstance(self.run_mode, RunMode): raise ValueError(f"Invalid run_mode: {self.run_mode}") + if not isinstance(self.vars, dict): + raise ValueError(f"vars must be a dictionary, got: {self.vars=}") + if self.scope is not None and not isinstance(self.scope, BlueprintScope): raise ValueError(f"Invalid scope: {self.scope}") diff --git a/titan/data_provider.py b/titan/data_provider.py index 0e8760d0..da0cf817 100644 --- a/titan/data_provider.py +++ b/titan/data_provider.py @@ -30,6 +30,7 @@ _parse_dynamic_table_text, parse_view_ddl, parse_collection_string, + parse_region, ) from .privs import GrantedPrivilege from .resource_name import ResourceName, attribute_is_resource_name, resource_name_from_snowflake_metadata @@ -44,6 +45,8 @@ class SessionContext(TypedDict): account_locator: str account: str available_roles: list[ResourceName] + cloud: str + cloud_region: str database: str role: ResourceName schemas: list[str] @@ -571,18 +574,22 @@ def fetch_session(session: SnowflakeConnection) -> SessionContext: CURRENT_SCHEMAS() as schemas, CURRENT_WAREHOUSE() as warehouse, CURRENT_VERSION() as version, + CURRENT_REGION() as region, SYSTEM$BOOTSTRAP_DATA_REQUEST('ACCOUNT') as account_data """, )[0] account_data = json.loads(session_obj["ACCOUNT_DATA"]) available_roles = [ResourceName(role) for role in json.loads(session_obj["AVAILABLE_ROLES"])] + region = parse_region(session_obj["REGION"]) return { "account_edition": AccountEdition(account_data["accountInfo"]["serviceLevelName"]), "account_locator": session_obj["ACCOUNT_LOCATOR"], "account": session_obj["ACCOUNT"], "available_roles": available_roles, + "cloud": region["cloud"], + "cloud_region": region["cloud_region"], "database": session_obj["DATABASE"], "role": ResourceName(session_obj["ROLE"]), "schemas": json.loads(session_obj["SCHEMAS"]), diff --git a/titan/enums.py b/titan/enums.py index 8778b30f..cb299df1 100644 --- a/titan/enums.py +++ b/titan/enums.py @@ -128,10 +128,13 @@ class AccountEdition(ParseableEnum): @classmethod def synonyms(cls): """Override to provide a dictionary of synonyms for the enum values""" - return { - "BUSINESS-CRITICAL" : "BUSINESS_CRITICAL" - } + return {"BUSINESS-CRITICAL": "BUSINESS_CRITICAL"} + +class AccountCloud(ParseableEnum): + AWS = "AWS" + GCP = "GCP" + AZURE = "AZURE" class DataType(ParseableEnum): diff --git a/titan/parse.py b/titan/parse.py index 220edf4d..300d62ff 100644 --- a/titan/parse.py +++ b/titan/parse.py @@ -719,3 +719,30 @@ def parse_collection_string(collection: str): def format_collection_string(collection: dict): return f"{collection['in_name']}.<{collection['on_type']}>" + + +def parse_region(region_str: str) -> dict[str, str]: + """Parse a Snowflake region identifier into its components. + + Examples: + AWS_US_WEST_2 -> {'cloud': 'AWS', 'cloud_region': 'US_WEST_2'} + PUBLIC.AWS_US_WEST_2 -> {'region_group': 'PUBLIC', 'cloud': 'AWS', 'cloud_region': 'US_WEST_2'} + AZURE_WESTUS2 -> {'cloud': 'AZURE', 'cloud_region': 'WESTUS2'} + GCP_EUROPE_WEST4 -> {'cloud': 'GCP', 'cloud_region': 'EUROPE_WEST4'} + """ + import re + + pattern = r"^(?:([A-Z_]+)\.)?([A-Z]+)_(.+?)$" + match = re.match(pattern, region_str) + + if not match: + raise ValueError(f"Invalid region format: {region_str}") + + region_group, cloud, cloud_region = match.groups() + + result = {"cloud": cloud, "cloud_region": cloud_region} + + if region_group: + result["region_group"] = region_group + + return result diff --git a/tools/__reset_test_account.py b/tools/__reset_test_account.py new file mode 100644 index 00000000..87bbcff5 --- /dev/null +++ b/tools/__reset_test_account.py @@ -0,0 +1,138 @@ +import os + +import snowflake.connector +import yaml +from dotenv import dotenv_values + +from titan import resources as res +from titan.blueprint import Blueprint, print_plan +from titan.data_provider import fetch_session +from titan.enums import AccountEdition +from titan.gitops import collect_blueprint_config + + +def read_config(file) -> dict: + config_path = os.path.join(os.path.dirname(__file__), file) + with open(config_path, "r") as f: + config = yaml.safe_load(f) + return config + + +def merge_configs(config1: dict, config2: dict) -> dict: + merged = config1.copy() + for key, value in config2.items(): + if key in merged: + if isinstance(merged[key], list): + merged[key] = merged[key] + value + elif merged[key] is None: + merged[key] = value + else: + merged[key] = value + return merged + + +def configure_test_account(conn, cloud: str): + session_ctx = fetch_session(conn) + config = read_config("test_account.yml") + vars = dotenv_values("env/.vars.test_account") + print(vars) + + if session_ctx["account_edition"] == AccountEdition.ENTERPRISE: + config = merge_configs(config, read_config("test_account_enterprise.yml")) + + if cloud == "aws": + config = merge_configs(config, read_config("test_account_aws.yml")) + # elif cloud == "gcp": + # config = merge_configs(config, read_config("test_account_gcp.yml")) + + blueprint_config = collect_blueprint_config(config, {"vars": vars}) + + bp = Blueprint.from_config(blueprint_config) + plan = bp.plan(conn) + print_plan(plan) + bp.apply(conn, plan) + + +def configure_aws_heavy(conn): + bp = Blueprint( + name="reset-test-account", + run_mode="CREATE-OR-UPDATE", + ) + + roles = [res.Role(name=f"ROLE_{i}") for i in range(50)] + databases = [] + for i in range(10): + database = res.Database(name=f"DATABASE_{i}") + bp.add(database) + databases.append(database) + for role in roles: + # bp.add(res.Grant(priv="USAGE", to=role, on=database)) + pass + for j in range(10): + schema = res.Schema(name=f"SCHEMA_{j}", database=database) + bp.add(schema) + for role in roles: + # bp.add(res.Grant(priv="USAGE", to=role, on=schema)) + pass + for k in range(5): + table = res.Table( + name=f"TABLE_{k}", columns=[{"name": "ID", "data_type": "NUMBER(38,0)"}], schema=schema + ) + bp.add(table) + for role in roles: + # bp.add(res.Grant(priv="SELECT", to=role, on=table)) + pass + + bp.add(roles) + + staged_count = len(bp._staged) + + plan = bp.plan(conn) + print_plan(plan[:10]) + print("Changes in plan:", len(plan)) + print("Staged resources:", staged_count) + bp.apply(conn, plan) + + bp = Blueprint( + name="reset-test-account", + run_mode="CREATE-OR-UPDATE", + ) + for database in databases: + for role in roles: + bp.add(res.Grant(priv="USAGE", to=role.name, on_database=database.name)) + bp.add(res.GrantOnAll(priv="USAGE", to=role.name, on_all_schemas_in_database=database.name)) + bp.add(res.GrantOnAll(priv="SELECT", to=role.name, on_all_tables_in_database=database.name)) + bp.apply(conn) + + +def get_connection(env_vars): + return snowflake.connector.connect( + account=env_vars["SNOWFLAKE_ACCOUNT"], + user=env_vars["SNOWFLAKE_USER"], + password=env_vars["SNOWFLAKE_PASSWORD"], + role=env_vars["SNOWFLAKE_ROLE"], + # warehouse=env_vars["SNOWFLAKE_WAREHOUSE"], + ) + + +def configure_test_accounts(): + + for account in ["aws.standard", "aws.enterprise", "gcp.standard"]: + print(">>>>>>>>>>>>>>>>", account) + cloud = account.split(".")[0] + env_vars = dotenv_values(f"env/.env.{account}") + conn = get_connection(env_vars) + try: + configure_test_account(conn, cloud) + # except Exception as e: + # print(f"Error configuring {account}: {e}") + finally: + conn.close() + + # now = time.time() + # configure_aws_heavy(get_connection(dotenv_values(".env.aws.heavy"))) + # print(f"done in {time.time() - now:.2f}s") + + +if __name__ == "__main__": + configure_test_accounts() diff --git a/tools/reset_test_account.py b/tools/reset_test_account.py index 87bbcff5..444a044c 100644 --- a/tools/reset_test_account.py +++ b/tools/reset_test_account.py @@ -1,51 +1,54 @@ import os +import pathlib import snowflake.connector -import yaml -from dotenv import dotenv_values -from titan import resources as res from titan.blueprint import Blueprint, print_plan from titan.data_provider import fetch_session -from titan.enums import AccountEdition -from titan.gitops import collect_blueprint_config +from titan.enums import AccountEdition, AccountCloud +from titan.gitops import collect_blueprint_config, collect_vars_from_environment, merge_configs, read_config +SCRIPT_DIR = pathlib.Path(__file__).parent.resolve() -def read_config(file) -> dict: - config_path = os.path.join(os.path.dirname(__file__), file) - with open(config_path, "r") as f: - config = yaml.safe_load(f) - return config + +def get_connection(): + return snowflake.connector.connect( + account=os.environ["SNOWFLAKE_ACCOUNT"], + user=os.environ["SNOWFLAKE_USER"], + password=os.environ["SNOWFLAKE_PASSWORD"], + role=os.environ["SNOWFLAKE_ROLE"], + ) -def merge_configs(config1: dict, config2: dict) -> dict: - merged = config1.copy() - for key, value in config2.items(): - if key in merged: - if isinstance(merged[key], list): - merged[key] = merged[key] + value - elif merged[key] is None: - merged[key] = value - else: - merged[key] = value - return merged +def read_test_account_config(config_path: str): + return read_config(f"{SCRIPT_DIR}/test_account_configs/{config_path}") -def configure_test_account(conn, cloud: str): +def reset_test_account(): + conn = get_connection() session_ctx = fetch_session(conn) - config = read_config("test_account.yml") - vars = dotenv_values("env/.vars.test_account") - print(vars) + config = read_test_account_config("base.yml") + # titan_vars = collect_vars_from_environment() + from dotenv import dotenv_values + + titan_vars = dotenv_values("env/.vars.test_account") + print("\n".join([f"{k}={v}" for k, v in titan_vars.items()])) if session_ctx["account_edition"] == AccountEdition.ENTERPRISE: - config = merge_configs(config, read_config("test_account_enterprise.yml")) + config = merge_configs(config, read_test_account_config("enterprise.yml")) + elif session_ctx["account_edition"] == AccountEdition.BUSINESS_CRITICAL: + config = merge_configs(config, read_test_account_config("business_critical.yml")) - if cloud == "aws": - config = merge_configs(config, read_config("test_account_aws.yml")) - # elif cloud == "gcp": - # config = merge_configs(config, read_config("test_account_gcp.yml")) + if session_ctx["cloud"] == AccountCloud.AWS: + config = merge_configs(config, read_test_account_config("aws.yml")) + elif session_ctx["cloud"] == AccountCloud.GCP: + config = merge_configs(config, read_test_account_config("gcp.yml")) + elif session_ctx["cloud"] == AccountCloud.AZURE: + config = merge_configs(config, read_test_account_config("azure.yml")) + else: + raise ValueError(f"Unknown cloud: {session_ctx['cloud']}") - blueprint_config = collect_blueprint_config(config, {"vars": vars}) + blueprint_config = collect_blueprint_config(config, {"vars": titan_vars}) bp = Blueprint.from_config(blueprint_config) plan = bp.plan(conn) @@ -53,86 +56,5 @@ def configure_test_account(conn, cloud: str): bp.apply(conn, plan) -def configure_aws_heavy(conn): - bp = Blueprint( - name="reset-test-account", - run_mode="CREATE-OR-UPDATE", - ) - - roles = [res.Role(name=f"ROLE_{i}") for i in range(50)] - databases = [] - for i in range(10): - database = res.Database(name=f"DATABASE_{i}") - bp.add(database) - databases.append(database) - for role in roles: - # bp.add(res.Grant(priv="USAGE", to=role, on=database)) - pass - for j in range(10): - schema = res.Schema(name=f"SCHEMA_{j}", database=database) - bp.add(schema) - for role in roles: - # bp.add(res.Grant(priv="USAGE", to=role, on=schema)) - pass - for k in range(5): - table = res.Table( - name=f"TABLE_{k}", columns=[{"name": "ID", "data_type": "NUMBER(38,0)"}], schema=schema - ) - bp.add(table) - for role in roles: - # bp.add(res.Grant(priv="SELECT", to=role, on=table)) - pass - - bp.add(roles) - - staged_count = len(bp._staged) - - plan = bp.plan(conn) - print_plan(plan[:10]) - print("Changes in plan:", len(plan)) - print("Staged resources:", staged_count) - bp.apply(conn, plan) - - bp = Blueprint( - name="reset-test-account", - run_mode="CREATE-OR-UPDATE", - ) - for database in databases: - for role in roles: - bp.add(res.Grant(priv="USAGE", to=role.name, on_database=database.name)) - bp.add(res.GrantOnAll(priv="USAGE", to=role.name, on_all_schemas_in_database=database.name)) - bp.add(res.GrantOnAll(priv="SELECT", to=role.name, on_all_tables_in_database=database.name)) - bp.apply(conn) - - -def get_connection(env_vars): - return snowflake.connector.connect( - account=env_vars["SNOWFLAKE_ACCOUNT"], - user=env_vars["SNOWFLAKE_USER"], - password=env_vars["SNOWFLAKE_PASSWORD"], - role=env_vars["SNOWFLAKE_ROLE"], - # warehouse=env_vars["SNOWFLAKE_WAREHOUSE"], - ) - - -def configure_test_accounts(): - - for account in ["aws.standard", "aws.enterprise", "gcp.standard"]: - print(">>>>>>>>>>>>>>>>", account) - cloud = account.split(".")[0] - env_vars = dotenv_values(f"env/.env.{account}") - conn = get_connection(env_vars) - try: - configure_test_account(conn, cloud) - # except Exception as e: - # print(f"Error configuring {account}: {e}") - finally: - conn.close() - - # now = time.time() - # configure_aws_heavy(get_connection(dotenv_values(".env.aws.heavy"))) - # print(f"done in {time.time() - now:.2f}s") - - if __name__ == "__main__": - configure_test_accounts() + reset_test_account() diff --git a/tools/test_account_aws.yml b/tools/test_account_configs/aws.yml similarity index 83% rename from tools/test_account_aws.yml rename to tools/test_account_configs/aws.yml index f3d88ab6..7ab926ae 100644 --- a/tools/test_account_aws.yml +++ b/tools/test_account_configs/aws.yml @@ -1,5 +1,3 @@ -name: reset-test-account-aws -run_mode: SYNC allowlist: - "compute pool" diff --git a/tools/test_account_configs/azure.yml b/tools/test_account_configs/azure.yml new file mode 100644 index 00000000..e69de29b diff --git a/tools/test_account.yml b/tools/test_account_configs/base.yml similarity index 98% rename from tools/test_account.yml rename to tools/test_account_configs/base.yml index cc8968f9..59bdb5ad 100644 --- a/tools/test_account.yml +++ b/tools/test_account_configs/base.yml @@ -51,7 +51,7 @@ users: type: SERVICE - name: STATIC_USER_KEYPAIR type: SERVICE - rsa_public_key: "{{ var.rsa_public_key }}" + rsa_public_key: "{{ var.static_user_rsa_public_key }}" - name: STATIC_USER_MFA type: PERSON password: "{{ var.static_user_mfa_password }}" @@ -256,10 +256,6 @@ warehouses: security_integrations: - # - name: SNOWSERVICES_INGRESS_OAUTH - # type: OAUTH - # oauth_client: snowservices_ingress - # enabled: true - name: STATIC_SECURITY_INTEGRATION type: api_authentication auth_type: OAUTH2 diff --git a/tools/test_account_configs/business_critical.yml b/tools/test_account_configs/business_critical.yml new file mode 100644 index 00000000..e69de29b diff --git a/tools/test_account_enterprise.yml b/tools/test_account_configs/enterprise.yml similarity index 93% rename from tools/test_account_enterprise.yml rename to tools/test_account_configs/enterprise.yml index 30ae7e96..1215f055 100644 --- a/tools/test_account_enterprise.yml +++ b/tools/test_account_configs/enterprise.yml @@ -1,5 +1,3 @@ -name: reset-test-account -run_mode: SYNC allowlist: - "tag" - "tag reference" diff --git a/tools/test_account_configs/gcp.yml b/tools/test_account_configs/gcp.yml new file mode 100644 index 00000000..e69de29b