-
Notifications
You must be signed in to change notification settings - Fork 57
/
Copy pathbase.py
274 lines (214 loc) · 7.9 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
import abc
import asyncio
import enum
import functools
from collections import defaultdict
import typing
from typing import (
Optional,
)
from .._grpc.grpcwrapper import ydb_query
from .._grpc.grpcwrapper.ydb_query_public_types import (
BaseQueryTxMode,
)
from ..connection import _RpcState as RpcState
from .. import convert
from .. import issues
from .. import _utilities
from .. import _apis
from ydb._topic_common.common import CallFromSyncToAsync, _get_shared_event_loop
from ydb._grpc.grpcwrapper.common_utils import to_thread
if typing.TYPE_CHECKING:
from .transaction import BaseQueryTxContext
from .session import BaseQuerySession
class QuerySyntax(enum.IntEnum):
UNSPECIFIED = 0
YQL_V1 = 1
PG = 2
class QueryExecMode(enum.IntEnum):
UNSPECIFIED = 0
PARSE = 10
VALIDATE = 20
EXPLAIN = 30
EXECUTE = 50
class QueryStatsMode(enum.IntEnum):
UNSPECIFIED = 0
NONE = 10
BASIC = 20
FULL = 30
PROFILE = 40
class SyncResponseContextIterator(_utilities.SyncResponseIterator):
def __enter__(self) -> "SyncResponseContextIterator":
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# To close stream on YDB it is necessary to scroll through it to the end
for _ in self:
pass
class QueryClientSettings:
def __init__(self):
self._native_datetime_in_result_sets = True
self._native_date_in_result_sets = True
self._native_json_in_result_sets = True
self._native_interval_in_result_sets = True
self._native_timestamp_in_result_sets = True
def with_native_timestamp_in_result_sets(self, enabled: bool) -> "QueryClientSettings":
self._native_timestamp_in_result_sets = enabled
return self
def with_native_interval_in_result_sets(self, enabled: bool) -> "QueryClientSettings":
self._native_interval_in_result_sets = enabled
return self
def with_native_json_in_result_sets(self, enabled: bool) -> "QueryClientSettings":
self._native_json_in_result_sets = enabled
return self
def with_native_date_in_result_sets(self, enabled: bool) -> "QueryClientSettings":
self._native_date_in_result_sets = enabled
return self
def with_native_datetime_in_result_sets(self, enabled: bool) -> "QueryClientSettings":
self._native_datetime_in_result_sets = enabled
return self
class IQuerySessionState(abc.ABC):
def __init__(self, settings: Optional[QueryClientSettings] = None):
pass
@abc.abstractmethod
def reset(self) -> None:
pass
@property
@abc.abstractmethod
def session_id(self) -> Optional[str]:
pass
@abc.abstractmethod
def set_session_id(self, session_id: str) -> "IQuerySessionState":
pass
@property
@abc.abstractmethod
def node_id(self) -> Optional[int]:
pass
@abc.abstractmethod
def set_node_id(self, node_id: int) -> "IQuerySessionState":
pass
@property
@abc.abstractmethod
def attached(self) -> bool:
pass
@abc.abstractmethod
def set_attached(self, attached: bool) -> "IQuerySessionState":
pass
def create_execute_query_request(
query: str,
session_id: str,
tx_id: Optional[str],
commit_tx: Optional[bool],
tx_mode: Optional[BaseQueryTxMode],
syntax: Optional[QuerySyntax],
exec_mode: Optional[QueryExecMode],
stats_mode: Optional[QueryStatsMode],
parameters: Optional[dict],
concurrent_result_sets: Optional[bool],
) -> ydb_query.ExecuteQueryRequest:
syntax = QuerySyntax.YQL_V1 if not syntax else syntax
exec_mode = QueryExecMode.EXECUTE if not exec_mode else exec_mode
stats_mode = QueryStatsMode.NONE if stats_mode is None else stats_mode
tx_control = None
if not tx_id and not tx_mode:
tx_control = None
elif tx_id:
tx_control = ydb_query.TransactionControl(
tx_id=tx_id,
commit_tx=commit_tx,
begin_tx=None,
)
else:
tx_control = ydb_query.TransactionControl(
begin_tx=ydb_query.TransactionSettings(
tx_mode=tx_mode,
),
commit_tx=commit_tx,
tx_id=None,
)
return ydb_query.ExecuteQueryRequest(
session_id=session_id,
query_content=ydb_query.QueryContent.from_public(
query=query,
syntax=syntax,
),
tx_control=tx_control,
exec_mode=exec_mode,
parameters=parameters,
concurrent_result_sets=concurrent_result_sets,
stats_mode=stats_mode,
)
def bad_session_handler(func):
@functools.wraps(func)
def decorator(rpc_state, response_pb, session_state: IQuerySessionState, *args, **kwargs):
try:
return func(rpc_state, response_pb, session_state, *args, **kwargs)
except issues.BadSession:
session_state.reset()
raise
return decorator
@bad_session_handler
def wrap_execute_query_response(
rpc_state: RpcState,
response_pb: _apis.ydb_query.ExecuteQueryResponsePart,
session_state: IQuerySessionState,
tx: Optional["BaseQueryTxContext"] = None,
session: Optional["BaseQuerySession"] = None,
commit_tx: Optional[bool] = False,
settings: Optional[QueryClientSettings] = None,
) -> convert.ResultSet:
issues._process_response(response_pb)
if tx and commit_tx:
tx._move_to_commited()
elif tx and response_pb.tx_meta and not tx.tx_id:
tx._move_to_beginned(response_pb.tx_meta.id)
if response_pb.HasField("exec_stats"):
if tx is not None:
tx._last_query_stats = response_pb.exec_stats
if session is not None:
session._last_query_stats = response_pb.exec_stats
if response_pb.HasField("result_set"):
return convert.ResultSet.from_message(response_pb.result_set, settings)
return None
class TxEvent(enum.Enum):
BEFORE_COMMIT = "BEFORE_COMMIT"
AFTER_COMMIT = "AFTER_COMMIT"
BEFORE_ROLLBACK = "BEFORE_ROLLBACK"
AFTER_ROLLBACK = "AFTER_ROLLBACK"
class CallbackHandlerMode(enum.Enum):
SYNC = "SYNC"
ASYNC = "ASYNC"
def _get_sync_callback(method: typing.Callable, loop: Optional[asyncio.AbstractEventLoop]):
if asyncio.iscoroutinefunction(method):
if loop is None:
loop = _get_shared_event_loop()
def async_to_sync_callback(*args, **kwargs):
caller = CallFromSyncToAsync(loop)
return caller.safe_call_with_result(method(*args, **kwargs), 10)
return async_to_sync_callback
return method
def _get_async_callback(method: typing.Callable):
if asyncio.iscoroutinefunction(method):
return method
async def sync_to_async_callback(*args, **kwargs):
return await to_thread(method, *args, **kwargs, executor=None)
return sync_to_async_callback
class CallbackHandler:
def _init_callback_handler(self, mode: CallbackHandlerMode) -> None:
self._callbacks = defaultdict(list)
self._callback_mode = mode
def _execute_callbacks_sync(self, event_name: str, *args, **kwargs) -> None:
for callback in self._callbacks[event_name]:
callback(self, *args, **kwargs)
async def _execute_callbacks_async(self, event_name: str, *args, **kwargs) -> None:
tasks = [asyncio.create_task(callback(self, *args, **kwargs)) for callback in self._callbacks[event_name]]
if not tasks:
return
await asyncio.gather(*tasks)
def _prepare_callback(
self, callback: typing.Callable, loop: Optional[asyncio.AbstractEventLoop]
) -> typing.Callable:
if self._callback_mode == CallbackHandlerMode.SYNC:
return _get_sync_callback(callback, loop)
return _get_async_callback(callback)
def _add_callback(self, event_name: str, callback: typing.Callable, loop: Optional[asyncio.AbstractEventLoop]):
self._callbacks[event_name].append(self._prepare_callback(callback, loop))