Skip to content

Commit 95a5450

Browse files
modified code base
1 parent 7cbd318 commit 95a5450

File tree

8 files changed

+285
-53
lines changed

8 files changed

+285
-53
lines changed

.pylintrc

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ confidence=
5959
#
6060
# Kubeflow disables string-interpolation because we are starting to use f
6161
# style strings
62-
disable=bad-indentation,unspecified-encoding,missing-class-docstring,missing-module-docstring,no-name-in-module,dangerous-default-value,broad-except,import-outside-toplevel,bare-except
62+
disable=bad-indentation,unspecified-encoding,missing-class-docstring,missing-module-docstring,no-name-in-module,dangerous-default-value,broad-except,import-outside-toplevel,bare-except,invalid-name
6363

6464

6565
[REPORTS]

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ openpyxl = "^3.1.2"
1616
pypika = "^0.48.9"
1717

1818

19+
[tool.poetry.group.dev.dependencies]
20+
ipykernel = "^6.23.1"
21+
1922
[build-system]
2023
requires = ["poetry-core"]
2124
build-backend = "poetry.core.masonry.api"

src/app/logger.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
def setup_logging(log_dir: str):
77
"""Load logging configuration"""
88

9-
log_file_name = log_dir + '/' + 'minimal-app-' + datetime.now().strftime("%Y-%m-%d") + '.log'
9+
log_file_name = log_dir + '/' + 'minimal-app-' + \
10+
datetime.now().strftime("%Y-%m-%d") + '.log'
1011

1112
loging_config = {
1213
'version': 1,
1314
'disable_existing_loggers': False,
1415
'loggers': {
1516
'root': {
16-
'level': 'INFO',
17+
'level': 'DEBUG',
1718
'handlers': ['debug_console_handler', 'info_rotating_file_handler'],
1819
},
1920
'src': {
+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .dialect import RedshiftDialect, RedshiftQuery
1+
from .dialect import RedshiftDialect
22
from .query_builder import QueryBuilder
+243-43
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,284 @@
1+
import functools
12
import logging
2-
from typing import Any, Dict, List
3+
import textwrap
4+
from collections import defaultdict
5+
from typing import (Any, Callable, Dict, Iterable, List, Mapping, NamedTuple,
6+
Optional, Sequence, Tuple, TypeVar, Union)
37

48
from pydantic.dataclasses import dataclass
5-
from pypika import Query, Schema, Table
6-
from pypika.enums import Dialects, JoinType
7-
from pypika.queries import QueryBuilder
8-
from pypika.utils import builder
99

10-
from app.services.query_builder.join_queries import JoinQuery
10+
from app.etl_exceptions import AutoETLException
1111

1212
logger = logging.getLogger(__name__)
1313

1414

15-
class RedshiftQuery(Query):
16-
"""
17-
Query class for AWS Redshift
18-
"""
15+
_T = TypeVar('_T', bound='QueryValue')
16+
17+
18+
_Q = TypeVar('_Q', bound='BaseQuery')
19+
_QArg = Union[str, Tuple[str, ...]]
20+
21+
22+
class QueryValue(NamedTuple):
23+
value: str
24+
alias: str = ''
25+
on_condn: str = ''
26+
keyword: str = ''
27+
is_subquery: bool = False
1928

2029
@classmethod
21-
def _builder(cls, **kwargs: Any) -> "RedshiftQueryBuilder":
22-
return RedshiftQueryBuilder(**kwargs)
30+
def from_arg(cls, arg: _QArg, **kwargs: Any) -> 'QueryValue':
31+
"""method to set the parameter for the QueryValue
32+
33+
Args:
34+
arg (_QArg): _description_
35+
36+
Raises:
37+
ValueError
38+
39+
Returns:
40+
class object 'QueryValue'
41+
"""
42+
if isinstance(arg, str):
43+
alias, value, on_condn = '', arg, ''
44+
elif len(arg) == 3 and 'JOIN' in kwargs['keyword']:
45+
alias, value, on_condn = arg
46+
elif len(arg) == 2:
47+
alias, value = arg
48+
on_condn = ''
49+
else: # pragma: no cover
50+
raise ValueError(f"invalid arg: {arg!r}")
51+
return cls(_clean_up(value), _clean_up(alias), _clean_up(on_condn), **kwargs)
52+
53+
54+
class _FlagList(List[_T]):
55+
flag: str = ''
56+
57+
58+
def _clean_up(thing: str) -> str:
59+
return textwrap.dedent(thing.rstrip()).strip()
60+
61+
62+
class BaseQuery:
63+
64+
keywords = [
65+
'WITH',
66+
'SELECT',
67+
'FROM',
68+
'JOIN',
69+
'WHERE',
70+
'GROUP BY',
71+
'HAVING',
72+
'ORDER BY',
73+
'LIMIT',
74+
]
75+
76+
separators: Mapping[str, str] = dict(WHERE='AND', HAVING='AND')
77+
default_separator = ','
78+
79+
formats: Tuple[Mapping[str, str], ...] = (
80+
defaultdict(lambda: '{value}'),
81+
defaultdict(lambda: '{value} AS {alias}', WITH='{alias} AS {value}'),
82+
)
83+
84+
subquery_keywords = {'WITH'}
85+
fake_keywords = dict(JOIN='FROM')
86+
flag_keywords = dict(SELECT={'DISTINCT', 'ALL'})
87+
88+
def __init__(
89+
self,
90+
data: Optional[Mapping[str, Iterable[_QArg]]] = None,
91+
separators: Optional[Mapping[str, str]] = None,
92+
) -> None:
93+
"""
94+
"""
95+
self.data: Mapping[str, _FlagList[QueryValue]] = {}
96+
if data is None:
97+
data = dict.fromkeys(self.keywords, ())
98+
for keyword, args in data.items():
99+
self.data[keyword] = _FlagList()
100+
self.add(keyword, *args)
101+
102+
if separators is not None:
103+
self.separators = separators
104+
105+
def add(self: _Q, keyword: str, *args: _QArg) -> _Q:
106+
"""method to add params to the query object
107+
108+
Args:
109+
self (_Q): current object of Basequery
110+
keyword (str): keyword to be added to the query object
23111
112+
Raises:
113+
ValueError
24114
25-
class RedshiftQueryBuilder(QueryBuilder):
26-
QUERY_CLS = RedshiftQuery
115+
Returns:
116+
_Q: Basequery
117+
"""
118+
keyword, fake_keyword = self._resolve_fakes(keyword)
119+
keyword, flag = self._resolve_flags(keyword)
120+
target = self.data[keyword]
121+
122+
if flag:
123+
if target.flag: # pragma: no cover
124+
raise ValueError(f"{keyword} already has flag: {flag!r}")
125+
target.flag = flag
126+
127+
kwargs: Dict[str, Any] = {}
128+
if fake_keyword:
129+
kwargs.update(keyword=fake_keyword)
130+
if keyword in self.subquery_keywords:
131+
kwargs.update(is_subquery=True)
132+
133+
for arg in args:
134+
target.append(QueryValue.from_arg(arg, **kwargs))
135+
136+
return self
137+
138+
def _resolve_fakes(self, keyword: str) -> Tuple[str, str]:
139+
for part, real in self.fake_keywords.items():
140+
if part in keyword:
141+
return real, keyword
142+
return keyword, ''
143+
144+
def _resolve_flags(self, keyword: str) -> Tuple[str, str]:
145+
prefix, _, flag = keyword.partition(' ')
146+
if prefix in self.flag_keywords:
147+
if flag and flag not in self.flag_keywords[prefix]:
148+
raise ValueError(f"invalid flag for {prefix}: {flag!r}")
149+
return prefix, flag
150+
return keyword, ''
151+
152+
def __getattr__(self: _Q, name: str) -> Callable[..., _Q]:
153+
# conveniently, avoids shadowing dunder methods (e.g. __deepcopy__)
154+
if not name.isupper():
155+
return getattr(super(), name) # type: ignore
156+
return functools.partial(self.add, name.replace('_', ' '))
157+
158+
def __str__(self) -> str:
159+
return ''.join(self._lines())
160+
161+
def _lines(self) -> Iterable[str]:
162+
for keyword, things in self.data.items():
163+
if not things:
164+
continue
165+
166+
if things.flag:
167+
yield f'{keyword} {things.flag}\n'
168+
else:
169+
yield f'{keyword}\n'
27170

28-
def __init__(self, **kwargs: Any) -> None:
29-
super().__init__(dialect=Dialects.REDSHIFT, **kwargs)
171+
grouped: Tuple[List[QueryValue], ...] = ([], [])
172+
for thing in things:
173+
grouped[bool(thing.keyword)].append(thing)
174+
for group in grouped:
175+
yield from self._lines_keyword(keyword, group)
30176

31-
@builder
32-
def join(
33-
self, item: Table, how: JoinType = JoinType.inner
34-
) -> "JoinQuery":
35-
if isinstance(item, Table):
36-
return JoinQuery(self, item, how, label="table")
177+
def _lines_keyword(self, keyword: str, things: Sequence[QueryValue]) -> Iterable[str]:
178+
for i, thing in enumerate(things):
179+
last = i + 1 == len(things)
37180

38-
raise ValueError(f"Cannot join on type {type(item)}")
181+
if thing.keyword:
182+
yield thing.keyword + '\n'
39183

40-
def inner_join(self, item: Table) -> "JoinQuery":
41-
return self.join(item, JoinType.inner)
184+
_format = self.formats[bool(thing.alias)][keyword]
185+
value = thing.value
42186

43-
def left_join(self, item: Table) -> "JoinQuery":
44-
return self.join(item, JoinType.left)
187+
if thing.is_subquery:
188+
value = f'(\n{textwrap.indent(text=value, prefix=" ")}\n)'
189+
190+
yield textwrap.indent(text=_format.format(value=value, alias=thing.alias), prefix=' ')
191+
192+
if thing.on_condn:
193+
yield '\n ON '+thing.on_condn
194+
195+
if not last and not thing.keyword:
196+
try:
197+
yield ' ' + self.separators[keyword]
198+
except KeyError:
199+
yield self.default_separator
200+
201+
yield '\n'
45202

46203

47204
@dataclass()
48205
class RedshiftDialect:
49206
target_table_conf: List[Dict]
50207
joins_and_filters_conf: Dict[str, Dict]
208+
select_sources: List[Dict]
51209

52210
def get_sql(self) -> None:
53211
"""Method to trrigger Redshift query builder
54212
"""
55213
logger.info('building Redshift query from the mappings file')
56214

57-
_query = RedshiftQuery()
215+
_query = BaseQuery()
216+
_query = self.get_select(_query, self.select_sources)
217+
_query = self.get_join(_query)
218+
print(str(_query))
219+
220+
def get_select(self, _query: BaseQuery, select_sources: List[Dict]) -> BaseQuery:
221+
"""method to generate the select sql
222+
223+
Args:
224+
_query (BaseQuery)
225+
select_sources (List[Dict]): mapping file dictionary
58226
227+
Returns:
228+
BaseQuery
229+
"""
230+
try:
231+
for _select in select_sources:
232+
_query = _query.SELECT(
233+
(_select['column_alias'], _select['transformation']))
234+
235+
return _query
236+
except Exception as excep:
237+
logger.error("Error while generating SELECT sql")
238+
raise AutoETLException(
239+
"Error while generating SELECT sql", excep.args)
240+
241+
def get_join(self, _query: BaseQuery) -> BaseQuery:
242+
"""method to generate the join sql
243+
244+
Args:
245+
_query (BaseQuery)
246+
247+
Returns:
248+
BaseQuery
249+
"""
59250
for _index in self.joins_and_filters_conf:
60251

61252
_map = self.joins_and_filters_conf[_index]
62253

63254
if int(_index) == 0:
64-
schema1, schema2 = Schema(_map['driving_table'].split(
65-
'.')[0]), Schema(_map['reference_table'].split('.')[0])
66255

67-
table1, table2 = Table(_map['driving_table'].split('.')[1],
68-
schema=schema1, alias=_map['driving_table_alias']), \
69-
Table(_map['reference_table'].split('.')[1],
70-
schema=schema2, alias=_map['reference_table_alias'])
256+
_query = _query.FROM(
257+
(_map['driving_table_alias'], _map['driving_table']))
71258

72-
_query = _query.from_(table1).inner_join(
73-
table2).on(_map['join_condition'])
259+
if _map['reference_subquery']:
260+
ref_table = '('+_map['reference_subquery']+')'
74261
else:
75-
schema2 = Schema(_map['reference_table'].split(
76-
'.')[0])
262+
ref_table = _map['reference_table']
77263

78-
table2 = Table(_map['reference_table'].split('.')[1],
79-
schema=schema2, alias=_map['driving_table_alias'])
264+
match _map['join_type']:
265+
case "left join":
266+
_query = _query.LEFT_JOIN(
267+
(_map['reference_table_alias'], ref_table, _map['join_condition']))
268+
case "right join":
269+
_query = _query.RIGHT_JOIN(
270+
(_map['reference_table_alias'], ref_table, _map['join_condition']))
271+
case "inner join":
272+
_query = _query.INNER_JOIN(
273+
(_map['reference_table_alias'], ref_table, _map['join_condition']))
274+
case "full outer join":
275+
_query = _query.FULL_OUTER_JOIN(
276+
(_map['reference_table_alias'], ref_table, _map['join_condition']))
277+
case "cross join":
278+
_query = _query.CROSS_JOIN(
279+
(_map['reference_table_alias'], ref_table, _map['join_condition']))
80280

81-
_query = _query.left_join(table2).on(
82-
_map['join_condition'])
281+
if _map['filter_condition']:
282+
_query = _query.WHERE(_map['filter_condition'])
83283

84-
print(_query.select('*'))
284+
return _query

src/app/services/query_builder/query_builder.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import dataclasses
22
import logging
33
import os
4+
import sys
45

56
from pydantic.dataclasses import dataclass
67

@@ -47,11 +48,13 @@ def run(self) -> None:
4748
target_table_json = excel_to_json(meta_xls, 'target_table', 'records')
4849
joins_and_filters = excel_to_json(
4950
meta_xls, 'joins_and_filters', 'index')
51+
select_sources = excel_to_json(meta_xls, 'select_sources', 'records')
5052

5153
validate_joins_mapping(joins_and_filters)
5254
match _config['target']:
5355
case 'redshift':
54-
RedshiftDialect(target_table_json, joins_and_filters).get_sql()
56+
RedshiftDialect(target_table_json,
57+
joins_and_filters, select_sources).get_sql()
5558

5659
case _:
5760
logger.error(

0 commit comments

Comments
 (0)