This repository was archived by the owner on Nov 30, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathwarehouse_fixture.py
122 lines (110 loc) · 5.23 KB
/
warehouse_fixture.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright 2020 Soda
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import random
import re
import socket
import string
from typing import Optional, List
import yaml
from sodasql.scan.db import sql_update, sql_updates
from sodasql.scan.dialect import Dialect
from sodasql.scan.dialect_parser import DialectParser
from sodasql.scan.warehouse import Warehouse
from sodasql.scan.warehouse_yml import WarehouseYml
class WarehouseFixture:
@classmethod
def create(cls, target: str):
from tests.common.sql_test_case import TARGET_SNOWFLAKE, TARGET_POSTGRES, TARGET_REDSHIFT, TARGET_ATHENA, \
TARGET_BIGQUERY, TARGET_HIVE, TARGET_MYSQL, TARGET_SPARK, TARGET_SQLSERVER, TARGET_TRINO, TARGET_DENODO
if target == TARGET_POSTGRES:
from tests.warehouses.postgres_fixture import PostgresFixture
pf = PostgresFixture(target)
return pf
elif target == TARGET_SNOWFLAKE:
from tests.warehouses.snowflake_fixture import SnowflakeFixture
sff = SnowflakeFixture(target)
return sff
elif target == TARGET_REDSHIFT:
from tests.warehouses.redshift_fixture import RedshiftFixture
rsf = RedshiftFixture(target)
return rsf
elif target == TARGET_ATHENA:
from tests.warehouses.athena_fixture import AthenaFixture
af = AthenaFixture(target)
return af
elif target == TARGET_BIGQUERY:
from tests.warehouses.bigquery_fixture import BigQueryFixture
return BigQueryFixture(target)
elif target == TARGET_HIVE:
from tests.warehouses.hive_fixture import HiveFixture
hf = HiveFixture(target)
return hf
elif target == TARGET_MYSQL:
from tests.warehouses.mysql_fixture import MySQLFixture
msf = MySQLFixture(target)
return msf
elif target == TARGET_SQLSERVER:
from tests.warehouses.sqlserver_fixture import SQLServerFixture
msf = SQLServerFixture(target)
return msf
elif target == TARGET_SPARK:
from tests.warehouses.spark_fixture import SparkFixture
return SparkFixture(target)
elif target == TARGET_TRINO:
from tests.warehouses.trino_fixture import TrinoFixture
return TrinoFixture(target)
elif target == TARGET_DENODO:
from tests.warehouses.denodo_fixture import DenodoFixture
return DenodoFixture(target)
raise RuntimeError(f'Invalid target {target}')
def __init__(self, target: str) -> None:
super().__init__()
self.target: str = target
self.dialect = self.create_dialect(self.target)
self.warehouse_yml = WarehouseYml(dialect=self.dialect, name='test_warehouse')
self.warehouse: Optional[Warehouse] = Warehouse(self.warehouse_yml)
self.database: Optional[str] = None
self.create_database()
def create_dialect(cls, target: str) -> Dialect:
tests_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
test_warehouse_cfg_path = f'{tests_dir}/warehouses/{target}_cfg.yml'
with open(test_warehouse_cfg_path, encoding='utf-8') as f:
warehouse_configuration_dict = yaml.load(f, Loader=yaml.SafeLoader)
dialect_parser = DialectParser(warehouse_configuration_dict)
dialect_parser.assert_no_warnings_or_errors()
return dialect_parser.dialect
def create_database(self):
self.database = self.create_unique_database_name()
self.warehouse.dialect.database = self.database
sql_updates(self.warehouse.connection, [
f'CREATE DATABASE IF NOT EXISTS {self.database}',
f'USE DATABASE {self.database}'])
self.warehouse.connection.commit()
def drop_database(self):
sql_update(self.warehouse.connection,
f'DROP DATABASE IF EXISTS {self.database} CASCADE')
self.warehouse.connection.commit()
def sql_create_table(self, columns: List[str], table_name: str):
columns_sql = ", ".join(columns)
return f"CREATE TABLE " \
f"{self.warehouse.dialect.qualify_writable_table_name(table_name)} ( \n " \
f"{columns_sql} );"
@classmethod
def create_unique_database_name(cls):
prefix: str = 'soda_test'
normalized_hostname = re.sub(r"(?i)[^a-zA-Z0-9]", "_", socket.gethostname()).lower()
random_suffix = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(10))
return f"{prefix}_{normalized_hostname}_{random_suffix}"
def tear_down(self):
logging.debug('Rolling back transaction on warehouse connection')
self.warehouse.connection.rollback()