|
| 1 | +import functools |
1 | 2 | import logging
|
2 |
| -from typing import Any, Dict, List |
| 3 | +import textwrap |
| 4 | +from collections import defaultdict |
| 5 | +from typing import (Any, Callable, Dict, Iterable, List, Mapping, NamedTuple, |
| 6 | + Optional, Sequence, Tuple, TypeVar, Union) |
3 | 7 |
|
4 | 8 | from pydantic.dataclasses import dataclass
|
5 |
| -from pypika import Query, Schema, Table |
6 |
| -from pypika.enums import Dialects, JoinType |
7 |
| -from pypika.queries import QueryBuilder |
8 |
| -from pypika.utils import builder |
9 | 9 |
|
10 |
| -from app.services.query_builder.join_queries import JoinQuery |
| 10 | +from app.etl_exceptions import AutoETLException |
11 | 11 |
|
12 | 12 | logger = logging.getLogger(__name__)
|
13 | 13 |
|
14 | 14 |
|
15 |
| -class RedshiftQuery(Query): |
16 |
| - """ |
17 |
| - Query class for AWS Redshift |
18 |
| - """ |
| 15 | +_T = TypeVar('_T', bound='QueryValue') |
| 16 | + |
| 17 | + |
| 18 | +_Q = TypeVar('_Q', bound='BaseQuery') |
| 19 | +_QArg = Union[str, Tuple[str, ...]] |
| 20 | + |
| 21 | + |
| 22 | +class QueryValue(NamedTuple): |
| 23 | + value: str |
| 24 | + alias: str = '' |
| 25 | + on_condn: str = '' |
| 26 | + keyword: str = '' |
| 27 | + is_subquery: bool = False |
19 | 28 |
|
20 | 29 | @classmethod
|
21 |
| - def _builder(cls, **kwargs: Any) -> "RedshiftQueryBuilder": |
22 |
| - return RedshiftQueryBuilder(**kwargs) |
| 30 | + def from_arg(cls, arg: _QArg, **kwargs: Any) -> 'QueryValue': |
| 31 | + """method to set the parameter for the QueryValue |
| 32 | +
|
| 33 | + Args: |
| 34 | + arg (_QArg): _description_ |
| 35 | +
|
| 36 | + Raises: |
| 37 | + ValueError |
| 38 | +
|
| 39 | + Returns: |
| 40 | + class object 'QueryValue' |
| 41 | + """ |
| 42 | + if isinstance(arg, str): |
| 43 | + alias, value, on_condn = '', arg, '' |
| 44 | + elif len(arg) == 3 and 'JOIN' in kwargs['keyword']: |
| 45 | + alias, value, on_condn = arg |
| 46 | + elif len(arg) == 2: |
| 47 | + alias, value = arg |
| 48 | + on_condn = '' |
| 49 | + else: # pragma: no cover |
| 50 | + raise ValueError(f"invalid arg: {arg!r}") |
| 51 | + return cls(_clean_up(value), _clean_up(alias), _clean_up(on_condn), **kwargs) |
| 52 | + |
| 53 | + |
| 54 | +class _FlagList(List[_T]): |
| 55 | + flag: str = '' |
| 56 | + |
| 57 | + |
| 58 | +def _clean_up(thing: str) -> str: |
| 59 | + return textwrap.dedent(thing.rstrip()).strip() |
| 60 | + |
| 61 | + |
| 62 | +class BaseQuery: |
| 63 | + |
| 64 | + keywords = [ |
| 65 | + 'WITH', |
| 66 | + 'SELECT', |
| 67 | + 'FROM', |
| 68 | + 'JOIN', |
| 69 | + 'WHERE', |
| 70 | + 'GROUP BY', |
| 71 | + 'HAVING', |
| 72 | + 'ORDER BY', |
| 73 | + 'LIMIT', |
| 74 | + ] |
| 75 | + |
| 76 | + separators: Mapping[str, str] = dict(WHERE='AND', HAVING='AND') |
| 77 | + default_separator = ',' |
| 78 | + |
| 79 | + formats: Tuple[Mapping[str, str], ...] = ( |
| 80 | + defaultdict(lambda: '{value}'), |
| 81 | + defaultdict(lambda: '{value} AS {alias}', WITH='{alias} AS {value}'), |
| 82 | + ) |
| 83 | + |
| 84 | + subquery_keywords = {'WITH'} |
| 85 | + fake_keywords = dict(JOIN='FROM') |
| 86 | + flag_keywords = dict(SELECT={'DISTINCT', 'ALL'}) |
| 87 | + |
| 88 | + def __init__( |
| 89 | + self, |
| 90 | + data: Optional[Mapping[str, Iterable[_QArg]]] = None, |
| 91 | + separators: Optional[Mapping[str, str]] = None, |
| 92 | + ) -> None: |
| 93 | + """ |
| 94 | + """ |
| 95 | + self.data: Mapping[str, _FlagList[QueryValue]] = {} |
| 96 | + if data is None: |
| 97 | + data = dict.fromkeys(self.keywords, ()) |
| 98 | + for keyword, args in data.items(): |
| 99 | + self.data[keyword] = _FlagList() |
| 100 | + self.add(keyword, *args) |
| 101 | + |
| 102 | + if separators is not None: |
| 103 | + self.separators = separators |
| 104 | + |
| 105 | + def add(self: _Q, keyword: str, *args: _QArg) -> _Q: |
| 106 | + """method to add params to the query object |
| 107 | +
|
| 108 | + Args: |
| 109 | + self (_Q): current object of Basequery |
| 110 | + keyword (str): keyword to be added to the query object |
23 | 111 |
|
| 112 | + Raises: |
| 113 | + ValueError |
24 | 114 |
|
25 |
| -class RedshiftQueryBuilder(QueryBuilder): |
26 |
| - QUERY_CLS = RedshiftQuery |
| 115 | + Returns: |
| 116 | + _Q: Basequery |
| 117 | + """ |
| 118 | + keyword, fake_keyword = self._resolve_fakes(keyword) |
| 119 | + keyword, flag = self._resolve_flags(keyword) |
| 120 | + target = self.data[keyword] |
| 121 | + |
| 122 | + if flag: |
| 123 | + if target.flag: # pragma: no cover |
| 124 | + raise ValueError(f"{keyword} already has flag: {flag!r}") |
| 125 | + target.flag = flag |
| 126 | + |
| 127 | + kwargs: Dict[str, Any] = {} |
| 128 | + if fake_keyword: |
| 129 | + kwargs.update(keyword=fake_keyword) |
| 130 | + if keyword in self.subquery_keywords: |
| 131 | + kwargs.update(is_subquery=True) |
| 132 | + |
| 133 | + for arg in args: |
| 134 | + target.append(QueryValue.from_arg(arg, **kwargs)) |
| 135 | + |
| 136 | + return self |
| 137 | + |
| 138 | + def _resolve_fakes(self, keyword: str) -> Tuple[str, str]: |
| 139 | + for part, real in self.fake_keywords.items(): |
| 140 | + if part in keyword: |
| 141 | + return real, keyword |
| 142 | + return keyword, '' |
| 143 | + |
| 144 | + def _resolve_flags(self, keyword: str) -> Tuple[str, str]: |
| 145 | + prefix, _, flag = keyword.partition(' ') |
| 146 | + if prefix in self.flag_keywords: |
| 147 | + if flag and flag not in self.flag_keywords[prefix]: |
| 148 | + raise ValueError(f"invalid flag for {prefix}: {flag!r}") |
| 149 | + return prefix, flag |
| 150 | + return keyword, '' |
| 151 | + |
| 152 | + def __getattr__(self: _Q, name: str) -> Callable[..., _Q]: |
| 153 | + # conveniently, avoids shadowing dunder methods (e.g. __deepcopy__) |
| 154 | + if not name.isupper(): |
| 155 | + return getattr(super(), name) # type: ignore |
| 156 | + return functools.partial(self.add, name.replace('_', ' ')) |
| 157 | + |
| 158 | + def __str__(self) -> str: |
| 159 | + return ''.join(self._lines()) |
| 160 | + |
| 161 | + def _lines(self) -> Iterable[str]: |
| 162 | + for keyword, things in self.data.items(): |
| 163 | + if not things: |
| 164 | + continue |
| 165 | + |
| 166 | + if things.flag: |
| 167 | + yield f'{keyword} {things.flag}\n' |
| 168 | + else: |
| 169 | + yield f'{keyword}\n' |
27 | 170 |
|
28 |
| - def __init__(self, **kwargs: Any) -> None: |
29 |
| - super().__init__(dialect=Dialects.REDSHIFT, **kwargs) |
| 171 | + grouped: Tuple[List[QueryValue], ...] = ([], []) |
| 172 | + for thing in things: |
| 173 | + grouped[bool(thing.keyword)].append(thing) |
| 174 | + for group in grouped: |
| 175 | + yield from self._lines_keyword(keyword, group) |
30 | 176 |
|
31 |
| - @builder |
32 |
| - def join( |
33 |
| - self, item: Table, how: JoinType = JoinType.inner |
34 |
| - ) -> "JoinQuery": |
35 |
| - if isinstance(item, Table): |
36 |
| - return JoinQuery(self, item, how, label="table") |
| 177 | + def _lines_keyword(self, keyword: str, things: Sequence[QueryValue]) -> Iterable[str]: |
| 178 | + for i, thing in enumerate(things): |
| 179 | + last = i + 1 == len(things) |
37 | 180 |
|
38 |
| - raise ValueError(f"Cannot join on type {type(item)}") |
| 181 | + if thing.keyword: |
| 182 | + yield thing.keyword + '\n' |
39 | 183 |
|
40 |
| - def inner_join(self, item: Table) -> "JoinQuery": |
41 |
| - return self.join(item, JoinType.inner) |
| 184 | + _format = self.formats[bool(thing.alias)][keyword] |
| 185 | + value = thing.value |
42 | 186 |
|
43 |
| - def left_join(self, item: Table) -> "JoinQuery": |
44 |
| - return self.join(item, JoinType.left) |
| 187 | + if thing.is_subquery: |
| 188 | + value = f'(\n{textwrap.indent(text=value, prefix=" ")}\n)' |
| 189 | + |
| 190 | + yield textwrap.indent(text=_format.format(value=value, alias=thing.alias), prefix=' ') |
| 191 | + |
| 192 | + if thing.on_condn: |
| 193 | + yield '\n ON '+thing.on_condn |
| 194 | + |
| 195 | + if not last and not thing.keyword: |
| 196 | + try: |
| 197 | + yield ' ' + self.separators[keyword] |
| 198 | + except KeyError: |
| 199 | + yield self.default_separator |
| 200 | + |
| 201 | + yield '\n' |
45 | 202 |
|
46 | 203 |
|
47 | 204 | @dataclass()
|
48 | 205 | class RedshiftDialect:
|
49 | 206 | target_table_conf: List[Dict]
|
50 | 207 | joins_and_filters_conf: Dict[str, Dict]
|
| 208 | + select_sources: List[Dict] |
51 | 209 |
|
52 | 210 | def get_sql(self) -> None:
|
53 | 211 | """Method to trrigger Redshift query builder
|
54 | 212 | """
|
55 | 213 | logger.info('building Redshift query from the mappings file')
|
56 | 214 |
|
57 |
| - _query = RedshiftQuery() |
| 215 | + _query = BaseQuery() |
| 216 | + _query = self.get_select(_query, self.select_sources) |
| 217 | + _query = self.get_join(_query) |
| 218 | + print(str(_query)) |
| 219 | + |
| 220 | + def get_select(self, _query: BaseQuery, select_sources: List[Dict]) -> BaseQuery: |
| 221 | + """method to generate the select sql |
| 222 | +
|
| 223 | + Args: |
| 224 | + _query (BaseQuery) |
| 225 | + select_sources (List[Dict]): mapping file dictionary |
58 | 226 |
|
| 227 | + Returns: |
| 228 | + BaseQuery |
| 229 | + """ |
| 230 | + try: |
| 231 | + for _select in select_sources: |
| 232 | + _query = _query.SELECT( |
| 233 | + (_select['column_alias'], _select['transformation'])) |
| 234 | + |
| 235 | + return _query |
| 236 | + except Exception as excep: |
| 237 | + logger.error("Error while generating SELECT sql") |
| 238 | + raise AutoETLException( |
| 239 | + "Error while generating SELECT sql", excep.args) |
| 240 | + |
| 241 | + def get_join(self, _query: BaseQuery) -> BaseQuery: |
| 242 | + """method to generate the join sql |
| 243 | +
|
| 244 | + Args: |
| 245 | + _query (BaseQuery) |
| 246 | +
|
| 247 | + Returns: |
| 248 | + BaseQuery |
| 249 | + """ |
59 | 250 | for _index in self.joins_and_filters_conf:
|
60 | 251 |
|
61 | 252 | _map = self.joins_and_filters_conf[_index]
|
62 | 253 |
|
63 | 254 | if int(_index) == 0:
|
64 |
| - schema1, schema2 = Schema(_map['driving_table'].split( |
65 |
| - '.')[0]), Schema(_map['reference_table'].split('.')[0]) |
66 | 255 |
|
67 |
| - table1, table2 = Table(_map['driving_table'].split('.')[1], |
68 |
| - schema=schema1, alias=_map['driving_table_alias']), \ |
69 |
| - Table(_map['reference_table'].split('.')[1], |
70 |
| - schema=schema2, alias=_map['reference_table_alias']) |
| 256 | + _query = _query.FROM( |
| 257 | + (_map['driving_table_alias'], _map['driving_table'])) |
71 | 258 |
|
72 |
| - _query = _query.from_(table1).inner_join( |
73 |
| - table2).on(_map['join_condition']) |
| 259 | + if _map['reference_subquery']: |
| 260 | + ref_table = '('+_map['reference_subquery']+')' |
74 | 261 | else:
|
75 |
| - schema2 = Schema(_map['reference_table'].split( |
76 |
| - '.')[0]) |
| 262 | + ref_table = _map['reference_table'] |
77 | 263 |
|
78 |
| - table2 = Table(_map['reference_table'].split('.')[1], |
79 |
| - schema=schema2, alias=_map['driving_table_alias']) |
| 264 | + match _map['join_type']: |
| 265 | + case "left join": |
| 266 | + _query = _query.LEFT_JOIN( |
| 267 | + (_map['reference_table_alias'], ref_table, _map['join_condition'])) |
| 268 | + case "right join": |
| 269 | + _query = _query.RIGHT_JOIN( |
| 270 | + (_map['reference_table_alias'], ref_table, _map['join_condition'])) |
| 271 | + case "inner join": |
| 272 | + _query = _query.INNER_JOIN( |
| 273 | + (_map['reference_table_alias'], ref_table, _map['join_condition'])) |
| 274 | + case "full outer join": |
| 275 | + _query = _query.FULL_OUTER_JOIN( |
| 276 | + (_map['reference_table_alias'], ref_table, _map['join_condition'])) |
| 277 | + case "cross join": |
| 278 | + _query = _query.CROSS_JOIN( |
| 279 | + (_map['reference_table_alias'], ref_table, _map['join_condition'])) |
80 | 280 |
|
81 |
| - _query = _query.left_join(table2).on( |
82 |
| - _map['join_condition']) |
| 281 | + if _map['filter_condition']: |
| 282 | + _query = _query.WHERE(_map['filter_condition']) |
83 | 283 |
|
84 |
| - print(_query.select('*')) |
| 284 | + return _query |
0 commit comments