diff --git a/src/questdb/ingress.pyx b/src/questdb/ingress.pyx index 51f263db..423ace4c 100644 --- a/src/questdb/ingress.pyx +++ b/src/questdb/ingress.pyx @@ -40,6 +40,8 @@ __all__ = [ 'TimestampMicros', 'TimestampNanos', 'TlsCa', + 'SenderPool', + 'AsyncTransaction', ] # For prototypes: https://github.com/cython/cython/tree/master/Cython/Includes @@ -82,6 +84,11 @@ import pathlib import sys import os +import concurrent.futures +import asyncio +from queue import Queue, Full, Empty +import threading + # This value is automatically updated by the `bump2version` tool. # If you need to update it, also update the search definition in @@ -2531,3 +2538,243 @@ cdef class Sender: self._close() free(self._last_flush_ms) + +class AsyncTransaction: + """ + A :class:`buffer ` restricted to a single table, + ensuring it can be flushed transactionally. + + Use in conjunction with :class:`SenderPool` to send data to QuestDB + asynchronously. + """ + def __init__(self, pool: SenderPool, buffer: Buffer, table_name: str): + self._pool = pool # TODO: weakref + self._table_name = table_name + self._buffer = buffer + if self._buffer is None: + raise ValueError('buffer cannot be None') + if len(self._buffer) > 0: + raise ValueError('buffer must be cleared') + self._entered = False + + def dataframe( + self, + df, + *, + symbols: str | bool | List[int] | List[str] = 'auto', + at: ServerTimestamp | int | str | TimestampNanos | datetime): + if self._buffer is None: + raise ValueError('buffer has already been flushed, obtain a new one from the pool') + self._buffer.dataframe(df, table_name=self._table_name, symbols=symbols, at=at) + return self + + def row( + self, + *, + symbols: Dict[str, str] | None = None, + columns: Dict[str, bool | int | float | str | TimestampMicros | datetime] | None = None, + at: TimestampNanos | datetime | ServerTimestamp) -> 'AsyncTransaction': + if self._buffer is None: + raise ValueError('buffer has already been flushed, obtain a new one from the pool') + self._buffer.row(self._table_name, symbols=symbols, columns=columns, at=at) + return self + + def __str__(self) -> str: + return str(self._buffer) + + def __len__(self) -> int: + return len(self._buffer) + + def commit_fut(self) -> concurrent.future.Future[None]: + pool = self._pool + self._pool = None + buffer = self._buffer + self._buffer = None + if pool is None: + raise ValueError('transaction has already been committed') + return pool._thread_pool.submit(pool._flush, buffer) + + # async - despite `async lacking in signature` + def commit(self) -> asyncio.Future[None]: + return asyncio.wrap_future( + self.commit_fut(), + loop=asyncio.get_event_loop()) + + def rollback(self): + pool = self._pool + self._pool = None + buffer = self._buffer + self._buffer = None + if pool is not None: + pool._add_buffer_to_free_list(buffer) + + async def __aenter__(self): + if self._entered: + raise IngressError( + IngressErrorCode.InvalidApiCall, + 'transaction already entered') + self._entered = True + return self + + def __aexit__(self, exc_type, exc_val, _exc_tb): + if not self._entered: + raise IngressError( + IngressErrorCode.InvalidApiCall, + 'transaction not entered') + if exc_type is None: + return self.commit() + else: + self.rollback() + loop = asyncio.get_event_loop() + future = loop.create_future() + future.set_exception(exc_val) + return future + + + +class SenderPool: + """ + A pool of Senders that can be used asynchronously to send data to QuestDB. + + .. code-block:: python + + import pandas as pd + from questdb.ingress.pool import SenderPool, TimestampNanos + + with SenderPool('http::addr=localhost:9000;') as pool: + txn1 = pool.transaction('my_table') + txn1.row(columns={'a': 1, 'b': 2}, at=TimestampNanos.now()) + txn1.row(columns={'a': 3, 'b': 4}, at=TimestampNanos.now()) + + df = pd.DataFrame({ + 'timestamp': pd.to_datetime([ + '2021-01-01T00:00:00', '2021-01-01T00:00:01']), + 'a': [1, 3], + 'b': [2, 4]}) + txn2 = pool.transaction('another_table') + txn2.dataframe(df, timestamp='timestamp') + + # Send the buffers asynchronously in parallel + f1 = txn1.commit() + f2 = txn2.commit() + + # Wait for both to complete, raising any exceptions on error + try: + await f1 + await f2 + except IngressError as e: + ... + + If you don't have an async context, use `txn.commit_fut()` to get a + `concurrent.futures.Future` instead of an `asyncio.Future`. + + Alternatively, the transaction itself can be an async context manager: + + .. code-block:: python + + with SenderPool('http::addr=localhost:9000;') as pool: + async with pool.transaction('my_table') as txn: + txn.row(columns={'a': 1, 'b': 2}, at=TimestampNanos.now()) + txn.row(columns={'a': 3, 'b': 4}, at=TimestampNanos.now()) + """ + def __init__( + self, + conf: str, + max_workers: Optional[int] = None, + max_free_buffers: Optional[int] = None): + """ + Create a pool of Senders that can be used asynchronously to send data to QuestDB. + + :param conf: the configuration string for each Sender in the pool + :param max_workers: the maximum number of workers in the pool, if None defaults to min(32, os.cpu_count() + 4) + :param max_free_buffers: the maximum number of buffers to keep in the pool for reuse, if None defaults to 2 * max_workers + """ + self._conf = conf + if max_workers is None: + # Same logic as for ThreadPoolExecutor + self._max_workers = min(32, (os.cpu_count() or 1) + 4) + else: + self._max_workers = int(max_workers) + if self._max_workers < 1: + raise ValueError( + 'SenderPool requires at least one worker') + if max_free_buffers is None: + self._max_free_buffers = 2 * self._max_workers + else: + self._max_free_buffers = int(max_free_buffers) + if self._max_free_buffers < 0: + raise ValueError( + 'SenderPool max_free_buffers can\'t be negative') + + if not conf.startswith("http"): + raise IngressError( + IngressErrorCode.ConfigError, + 'SenderPool only supports "http" and "https" protocols') + self._thread_pool = None + self._buffer_provisioner_sender = None + self._buffer_free_list = None + self._executor_thread_local = None + + def create(self): + """ + Create the pool of Senders. + """ + self._thread_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=self._max_workers) + self._buffer_provisioner_sender = Sender.from_conf(self._conf) + try: + self._buffer_provisioner_sender.establish() + except: + self._buffer_provisioner_sender.close() + self._buffer_provisioner_sender = None + raise + self._buffer_free_list = Queue(self._max_free_buffers) + self._executor_thread_local = threading.local() + + def __enter__(self): + self.create() + return self + + def close(self): + if self._thread_pool is not None: + self._thread_pool.shutdown() + self._thread_pool = None + if self._buffer_provisioner_sender is not None: + self._buffer_provisioner_sender.close() + self._buffer_provisioner_sender = None + self._buffer_free_list = None + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + self.close() + + def transaction(self, table_name: str): + # TODO: Work out the thread safety details of this method. + try: + buf = self._buffer_free_list.get_nowait() + except Empty: + buf = self._buffer_provisioner_sender.new_buffer() + return AsyncTransaction(self, buf, table_name) + + def _add_buffer_to_free_list(self, buffer): + if buffer is None: + return + buffer.clear() + free_list = self._buffer_free_list + if free_list is None: + return + try: + free_list.put_nowait(buffer) + except Full: + pass # drop the buffer, too many in free list + + def _flush(self, buffer): + try: + sender = self._executor_thread_local.sender + except AttributeError: + sender = Sender.from_conf(self._conf) + sender.establish() # will be closed by __del__ + self._executor_thread_local.sender = sender + try: + sender.flush(buffer, clear=False, transactional=True) + finally: + self._add_buffer_to_free_list(buffer) \ No newline at end of file diff --git a/test/test.py b/test/test.py index a5ee8706..ba54e124 100755 --- a/test/test.py +++ b/test/test.py @@ -1157,6 +1157,134 @@ class TestSenderEnv(TestBases.TestSender): builder = Builder.ENV +class TestSenderPool(unittest.IsolatedAsyncioTestCase): + def test_future(self): + with HttpServer() as server: + with qi.SenderPool(f'http::addr=localhost:{server.port};') as pool: + txn1 = pool.transaction('tbl1') + txn2 = pool.transaction('tbl2') + self.assertIsNot(txn1, txn2) + self.assertIsInstance(txn1, qi.AsyncTransaction) + self.assertIsInstance(txn2, qi.AsyncTransaction) + txn1.row(symbols={'sym1': 'val1'}, at=qi.ServerTimestamp) + txn2.row(symbols={'sym2': 'val2'}, at=qi.ServerTimestamp) + fut1 = txn1.commit_fut() + fut2 = txn2.commit_fut() + fut1.result() + fut2.result() + + def test_buffer_free_list(self): + with HttpServer() as server: + with qi.SenderPool( + f'http::addr=localhost:{server.port};', + max_workers=4, + max_free_buffers=8) as pool: + futures = [] + for _ in range(100): + txn = pool.transaction('tbl1') + txn.row(columns={'a': 1.5}, at=qi.TimestampNanos.now()) + futures.append(txn.commit_fut()) + time.sleep(0.001) + for fut in futures: + fut.result() + + def test_future_error(self): + with HttpServer() as server: + server.responses.append((0, 403, 'text/plain', b'Forbidden')) + server.responses.append((0, 403, 'text/plain', b'Forbidden')) + with qi.SenderPool(f'http::addr=localhost:{server.port};') as pool: + txn1 = pool.transaction('tbl1') + txn2 = pool.transaction('tbl2') + txn1.row(symbols={'sym1': 'val1'}, at=qi.ServerTimestamp) + txn2.row(symbols={'sym2': 'val2'}, at=qi.ServerTimestamp) + fut1 = txn1.commit_fut() + fut2 = txn2.commit_fut() + with self.assertRaisesRegex(qi.IngressError, 'Forbidden'): + fut1.result() + with self.assertRaisesRegex(qi.IngressError, 'Forbidden'): + fut2.result() + + async def test_async(self): + with HttpServer() as server: + with qi.SenderPool(f'http::addr=localhost:{server.port};') as pool: + txn1 = pool.transaction('tbl1') + txn2 = pool.transaction('tbl2') + self.assertIsNot(txn1, txn2) + self.assertIsInstance(txn1, qi.AsyncTransaction) + self.assertIsInstance(txn2, qi.AsyncTransaction) + txn1.row(symbols={'sym1': 'val1'}, at=qi.ServerTimestamp) + txn2.row(symbols={'sym2': 'val2'}, at=qi.ServerTimestamp) + fut1 = txn1.commit() + fut2 = txn2.commit() + await fut1 + await fut2 + + async def test_async_error(self): + with HttpServer() as server: + server.responses.append((0, 403, 'text/plain', b'Forbidden')) + server.responses.append((0, 403, 'text/plain', b'Forbidden')) + with qi.SenderPool(f'http::addr=localhost:{server.port};') as pool: + txn1 = pool.transaction('tbl1') + txn2 = pool.transaction('tbl2') + txn1.row(symbols={'sym1': 'val1'}, at=qi.ServerTimestamp) + txn2.row(symbols={'sym2': 'val2'}, at=qi.ServerTimestamp) + fut1 = txn1.commit() + fut2 = txn2.commit() + with self.assertRaisesRegex(qi.IngressError, 'Forbidden'): + await fut1 + with self.assertRaisesRegex(qi.IngressError, 'Forbidden'): + await fut2 + + async def test_async_txn(self): + with HttpServer() as server: + with qi.SenderPool(f'http::addr=localhost:{server.port};') as pool: + async with pool.transaction('tbl1') as txn: + txn.row(symbols={'sym1': 'val1'}, at=qi.ServerTimestamp) + buf1 = str(txn) + async with pool.transaction('tbl2') as txn: + txn.row(symbols={'sym2': 'val2'}, at=qi.ServerTimestamp) + buf2 = str(txn) + self.assertEqual(buf1, 'tbl1,sym1=val1\n') + self.assertEqual(buf2, 'tbl2,sym2=val2\n') + self.assertEqual(len(server.requests), 2) + self.assertEqual(server.requests[0], buf1.encode('utf-8')) + self.assertEqual(server.requests[1], buf2.encode('utf-8')) + + async def test_async_txn_error(self): + with HttpServer() as server: + server.responses.append((0, 403, 'text/plain', b'Forbidden1')) + server.responses.append((0, 403, 'text/plain', b'Forbidden2')) + with qi.SenderPool(f'http::addr=localhost:{server.port};') as pool: + with self.assertRaisesRegex(qi.IngressError, 'Forbidden1'): + async with pool.transaction('tbl1') as txn: + txn.row(symbols={'sym1': 'val1'}, at=qi.ServerTimestamp) + with self.assertRaisesRegex(qi.IngressError, 'Forbidden2'): + async with pool.transaction('tbl2') as txn: + txn.row(symbols={'sym2': 'val2'}, at=qi.ServerTimestamp) + + async def test_bad_reentrant_txn(self): + with HttpServer() as server: + with qi.SenderPool(f'http::addr=localhost:{server.port};') as pool: + async with pool.transaction('tbl1') as txn: + txn.row(symbols={'sym1': 'val1'}, at=qi.ServerTimestamp) + with self.assertRaisesRegex(qi.IngressError, 'transaction already entered'): + async with txn: + pass + + def test_not_http(self): + with self.assertRaisesRegex( + qi.IngressError, + 'SenderPool only supports "http" and "https" protocols'): + with qi.SenderPool('tcp::addr=localhost:1;') as _: + pass + + def test_no_addr(self): + with self.assertRaisesRegex( + qi.IngressError, + 'Missing "addr" parameter in config string'): + with qi.SenderPool('http::') as _: + pass + if __name__ == '__main__': if os.environ.get('TEST_QUESTDB_PROFILE') == '1': import cProfile