diff --git a/.env.docker b/.env.docker index 59ca01df..03b510d4 100644 --- a/.env.docker +++ b/.env.docker @@ -76,6 +76,14 @@ TEST_ORACLE_USER=syncmaster TEST_ORACLE_PASSWORD=changeme TEST_ORACLE_SERVICE_NAME=XEPDB1 +TEST_CLICKHOUSE_HOST_FOR_CONFTEST=test-clickhouse +TEST_CLICKHOUSE_PORT_FOR_CONFTEST=8123 +TEST_CLICKHOUSE_HOST_FOR_WORKER=test-clickhouse +TEST_CLICKHOUSE_PORT_FOR_WORKER=8123 +TEST_CLICKHOUSE_USER=default +TEST_CLICKHOUSE_PASSWORD= +TEST_CLICKHOUSE_DB=default + TEST_HIVE_CLUSTER=test-hive TEST_HIVE_USER=syncmaster TEST_HIVE_PASSWORD=changeme diff --git a/.env.local b/.env.local index fc0e0be3..fdf80541 100644 --- a/.env.local +++ b/.env.local @@ -62,6 +62,14 @@ export TEST_ORACLE_USER=syncmaster export TEST_ORACLE_PASSWORD=changeme export TEST_ORACLE_SERVICE_NAME=XEPDB1 +export TEST_CLICKHOUSE_HOST_FOR_CONFTEST=localhost +export TEST_CLICKHOUSE_PORT_FOR_CONFTEST=8123 +export TEST_CLICKHOUSE_HOST_FOR_WORKER=test-clickhouse +export TEST_CLICKHOUSE_PORT_FOR_WORKER=8123 +export TEST_CLICKHOUSE_USER=default +export TEST_CLICKHOUSE_PASSWORD= +export TEST_CLICKHOUSE_DB=default + export TEST_HIVE_CLUSTER=test-hive export TEST_HIVE_USER=syncmaster export TEST_HIVE_PASSWORD=changeme diff --git a/.github/workflows/clickhouse-tests.yml b/.github/workflows/clickhouse-tests.yml new file mode 100644 index 00000000..33a9717f --- /dev/null +++ b/.github/workflows/clickhouse-tests.yml @@ -0,0 +1,79 @@ +name: Clickhouse Tests +on: + workflow_call: + +env: + DEFAULT_PYTHON: '3.12' + +jobs: + tests: + name: Run Clickhouse tests + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Cache jars + uses: actions/cache@v4 + with: + path: ./cached_jars + key: ${{ runner.os }}-python-${{ env.DEFAULT_PYTHON }}-test-clickhouse + restore-keys: | + ${{ runner.os }}-python-${{ env.DEFAULT_PYTHON }}-test-clickhouse + ${{ runner.os }}-python- + + - name: Build Worker Image + uses: docker/build-push-action@v6 + with: + context: . + tags: mtsrus/syncmaster-worker:${{ github.sha }} + target: test + file: docker/Dockerfile.worker + load: true + cache-from: mtsrus/syncmaster-worker:develop + + - name: Docker compose up + run: | + docker compose -f docker-compose.test.yml --profile all down -v --remove-orphans + docker compose -f docker-compose.test.yml --profile clickhouse up -d --wait --wait-timeout 200 + env: + WORKER_IMAGE_TAG: ${{ github.sha }} + + - name: Run Clickhouse Tests + run: | + docker compose -f ./docker-compose.test.yml --profile clickhouse exec -T worker coverage run -m pytest -vvv -s -m "worker and clickhouse" + + - name: Dump worker logs on failure + if: failure() + uses: jwalton/gh-docker-logs@v2 + with: + images: mtsrus/syncmaster-worker + dest: ./logs + + # This is important, as coverage is exported after receiving SIGTERM + - name: Shutdown + if: always() + run: | + docker compose -f docker-compose.test.yml --profile all down -v --remove-orphans + + - name: Upload worker logs + uses: actions/upload-artifact@v4 + if: failure() + with: + name: worker-logs-clickhouse + path: logs/* + + - name: Upload coverage results + uses: actions/upload-artifact@v4 + with: + name: coverage-clickhouse + path: reports/* + # https://github.com/actions/upload-artifact/issues/602 + include-hidden-files: true diff --git a/.github/workflows/hdfs-tests.yml b/.github/workflows/hdfs-tests.yml index 12aeb146..ce7ceab2 100644 --- a/.github/workflows/hdfs-tests.yml +++ b/.github/workflows/hdfs-tests.yml @@ -46,7 +46,6 @@ jobs: env: WORKER_IMAGE_TAG: ${{ github.sha }} - # This is important, as coverage is exported after receiving SIGTERM - name: Run HDFS Tests run: | docker compose -f ./docker-compose.test.yml --profile hdfs exec -T worker coverage run -m pytest -vvv -s -m "worker and hdfs" @@ -58,6 +57,7 @@ jobs: images: mtsrus/syncmaster-worker dest: ./logs + # This is important, as coverage is exported after receiving SIGTERM - name: Shutdown if: always() run: | diff --git a/.github/workflows/hive-tests.yml b/.github/workflows/hive-tests.yml index 804b7846..7025590a 100644 --- a/.github/workflows/hive-tests.yml +++ b/.github/workflows/hive-tests.yml @@ -46,7 +46,6 @@ jobs: env: WORKER_IMAGE_TAG: ${{ github.sha }} - # This is important, as coverage is exported after receiving SIGTERM - name: Run Hive Tests run: | docker compose -f ./docker-compose.test.yml --profile hive exec -T worker coverage run -m pytest -vvv -s -m "worker and hive" @@ -58,6 +57,7 @@ jobs: images: mtsrus/syncmaster-worker dest: ./logs + # This is important, as coverage is exported after receiving SIGTERM - name: Shutdown if: always() run: | diff --git a/.github/workflows/oracle-tests.yml b/.github/workflows/oracle-tests.yml index a634965d..8484fa57 100644 --- a/.github/workflows/oracle-tests.yml +++ b/.github/workflows/oracle-tests.yml @@ -46,7 +46,6 @@ jobs: env: WORKER_IMAGE_TAG: ${{ github.sha }} - # This is important, as coverage is exported after receiving SIGTERM - name: Run Oracle Tests run: | docker compose -f ./docker-compose.test.yml --profile oracle exec -T worker coverage run -m pytest -vvv -s -m "worker and oracle" @@ -58,6 +57,7 @@ jobs: images: mtsrus/syncmaster-worker dest: ./logs + # This is important, as coverage is exported after receiving SIGTERM - name: Shutdown if: always() run: | diff --git a/.github/workflows/scheduler-tests.yml b/.github/workflows/scheduler-tests.yml index db1e1bc7..b4d84fb7 100644 --- a/.github/workflows/scheduler-tests.yml +++ b/.github/workflows/scheduler-tests.yml @@ -46,7 +46,6 @@ jobs: env: WORKER_IMAGE_TAG: ${{ github.sha }} - # This is important, as coverage is exported after receiving SIGTERM - name: Run Scheduler Tests run: | docker compose -f ./docker-compose.test.yml --profile worker exec -T worker coverage run -m pytest -vvv -s -m "worker and scheduler_integration" @@ -58,6 +57,7 @@ jobs: images: mtsrus/syncmaster-worker dest: ./logs + # This is important, as coverage is exported after receiving SIGTERM - name: Shutdown if: always() run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ef3d9af9..74452b46 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -20,6 +20,10 @@ jobs: name: Oracle tests uses: ./.github/workflows/oracle-tests.yml + clickhouse_tests: + name: Clickhouse tests + uses: ./.github/workflows/clickhouse-tests.yml + hdfs_tests: name: HDFS tests uses: ./.github/workflows/hdfs-tests.yml @@ -44,7 +48,7 @@ jobs: name: Tests done runs-on: ubuntu-latest - needs: [oracle_tests, hive_tests, hdfs_tests, s3_tests, unit_tests] + needs: [oracle_tests, clickhouse_tests, hive_tests, hdfs_tests, s3_tests, unit_tests] steps: - name: Checkout code uses: actions/checkout@v4 diff --git a/Makefile b/Makefile index 17a0d36d..cc083eb7 100644 --- a/Makefile +++ b/Makefile @@ -65,7 +65,7 @@ broker-start: ##Broker Start broker -test: test-db test-broker ##@Test Run tests +test: test-db test-broker ##@Test Run tests ${POETRY} run pytest $(PYTEST_ARGS) test-db: test-db-start db-upgrade ##@TestDB Prepare database (in docker) @@ -78,33 +78,37 @@ test-broker: test-broker-start ##@TestBroker Prepare broker (in docker) test-broker-start: ##@TestBroker Start broker docker compose -f docker-compose.test.yml up -d --wait rabbitmq $(DOCKER_COMPOSE_ARGS) -test-unit: test-db ##@Test Run unit tests +test-unit: test-db ##@Test Run unit tests ${POETRY} run pytest ./tests/test_unit ./tests/test_database $(PYTEST_ARGS) -test-integration-hdfs: test-db ##@Test Run integration tests for HDFS +test-integration-hdfs: test-db ##@Test Run integration tests for HDFS docker compose -f docker-compose.test.yml --profile hdfs up -d --wait $(DOCKER_COMPOSE_ARGS) ${POETRY} run pytest ./tests/test_integration -m hdfs $(PYTEST_ARGS) -test-integration-hive: test-db ##@Test Run integration tests for Hive +test-integration-hive: test-db ##@Test Run integration tests for Hive docker compose -f docker-compose.test.yml --profile hive up -d --wait $(DOCKER_COMPOSE_ARGS) ${POETRY} run pytest ./tests/test_integration -m hive $(PYTEST_ARGS) -test-integration-oracle: test-db ##@Test Run integration tests for Oracle +test-integration-clickhouse: test-db ##@Test Run integration tests for Clickhouse + docker compose -f docker-compose.test.yml --profile clickhouse up -d --wait $(DOCKER_COMPOSE_ARGS) + ${POETRY} run pytest ./tests/test_integration -m clickhouse $(PYTEST_ARGS) + +test-integration-oracle: test-db ##@Test Run integration tests for Oracle docker compose -f docker-compose.test.yml --profile oracle up -d --wait $(DOCKER_COMPOSE_ARGS) ${POETRY} run pytest ./tests/test_integration -m oracle $(PYTEST_ARGS) -test-integration-s3: test-db ##@Test Run integration tests for S3 +test-integration-s3: test-db ##@Test Run integration tests for S3 docker compose -f docker-compose.test.yml --profile s3 up -d --wait $(DOCKER_COMPOSE_ARGS) ${POETRY} run pytest ./tests/test_integration -m s3 $(PYTEST_ARGS) -test-integration: test-db ##@Test Run all integration tests +test-integration: test-db ##@Test Run all integration tests docker compose -f docker-compose.test.yml --profile all up -d --wait $(DOCKER_COMPOSE_ARGS) ${POETRY} run pytest ./tests/test_integration $(PYTEST_ARGS) -test-check-fixtures: ##@Test Check declared fixtures +test-check-fixtures: ##@Test Check declared fixtures ${POETRY} run pytest --dead-fixtures $(PYTEST_ARGS) -test-cleanup: ##@Test Cleanup tests dependencies +test-cleanup: ##@Test Cleanup tests dependencies docker compose -f docker-compose.test.yml --profile all down $(ARGS) diff --git a/docker-compose.test.yml b/docker-compose.test.yml index ce1343c0..54d294af 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -95,7 +95,7 @@ services: condition: service_healthy rabbitmq: condition: service_healthy - profiles: [worker, scheduler, s3, oracle, hdfs, hive, all] + profiles: [worker, scheduler, s3, oracle, hdfs, hive, all, clickhouse] test-postgres: image: postgres @@ -109,7 +109,7 @@ services: interval: 30s timeout: 5s retries: 3 - profiles: [s3, oracle, hdfs, hive, all] + profiles: [s3, oracle, clickhouse, hdfs, hive, all] test-s3: image: bitnami/minio:latest @@ -140,6 +140,14 @@ services: APP_USER_PASSWORD: changeme profiles: [oracle, all] + test-clickhouse: + image: clickhouse/clickhouse-server + restart: unless-stopped + ports: + - 8123:8123 + - 9001:9000 + profiles: [clickhouse, all] + metastore-hive: image: postgres restart: unless-stopped diff --git a/syncmaster/dto/connections.py b/syncmaster/dto/connections.py index 5402a614..6251fa2b 100644 --- a/syncmaster/dto/connections.py +++ b/syncmaster/dto/connections.py @@ -20,6 +20,17 @@ class PostgresConnectionDTO(ConnectionDTO): type: ClassVar[str] = "postgres" +@dataclass +class ClickhouseConnectionDTO(ConnectionDTO): + host: str + port: int + user: str + password: str + database_name: str + additional_params: dict + type: ClassVar[str] = "clickhouse" + + @dataclass class OracleConnectionDTO(ConnectionDTO): host: str diff --git a/syncmaster/dto/transfers.py b/syncmaster/dto/transfers.py index dc8fe93d..85294802 100644 --- a/syncmaster/dto/transfers.py +++ b/syncmaster/dto/transfers.py @@ -51,6 +51,11 @@ class OracleTransferDTO(DBTransferDTO): type: ClassVar[str] = "oracle" +@dataclass +class ClickhouseTransferDTO(DBTransferDTO): + type: ClassVar[str] = "clickhouse" + + @dataclass class HiveTransferDTO(DBTransferDTO): type: ClassVar[str] = "hive" diff --git a/syncmaster/worker/controller.py b/syncmaster/worker/controller.py index 64f344cd..86e8619f 100644 --- a/syncmaster/worker/controller.py +++ b/syncmaster/worker/controller.py @@ -5,6 +5,7 @@ from syncmaster.db.models import Connection, Run from syncmaster.dto.connections import ( + ClickhouseConnectionDTO, HDFSConnectionDTO, HiveConnectionDTO, OracleConnectionDTO, @@ -12,6 +13,7 @@ S3ConnectionDTO, ) from syncmaster.dto.transfers import ( + ClickhouseTransferDTO, HDFSTransferDTO, HiveTransferDTO, OracleTransferDTO, @@ -20,6 +22,7 @@ ) from syncmaster.exceptions.connection import ConnectionTypeNotRecognizedError from syncmaster.worker.handlers.base import Handler +from syncmaster.worker.handlers.db.clickhouse import ClickhouseHandler from syncmaster.worker.handlers.db.hive import HiveHandler from syncmaster.worker.handlers.db.oracle import OracleHandler from syncmaster.worker.handlers.db.postgres import PostgresHandler @@ -41,6 +44,11 @@ OracleConnectionDTO, OracleTransferDTO, ), + "clickhouse": ( + ClickhouseHandler, + ClickhouseConnectionDTO, + ClickhouseTransferDTO, + ), "postgres": ( PostgresHandler, PostgresConnectionDTO, diff --git a/syncmaster/worker/handlers/db/clickhouse.py b/syncmaster/worker/handlers/db/clickhouse.py new file mode 100644 index 00000000..c033a431 --- /dev/null +++ b/syncmaster/worker/handlers/db/clickhouse.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: 2023-2024 MTS PJSC +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from onetl.connection import Clickhouse +from onetl.db import DBWriter + +from syncmaster.dto.connections import ClickhouseConnectionDTO +from syncmaster.dto.transfers import ClickhouseTransferDTO +from syncmaster.worker.handlers.db.base import DBHandler + +if TYPE_CHECKING: + from pyspark.sql import SparkSession + from pyspark.sql.dataframe import DataFrame + + +class ClickhouseHandler(DBHandler): + connection: Clickhouse + connection_dto: ClickhouseConnectionDTO + transfer_dto: ClickhouseTransferDTO + + def connect(self, spark: SparkSession): + self.connection = Clickhouse( + host=self.connection_dto.host, + port=self.connection_dto.port, + user=self.connection_dto.user, + password=self.connection_dto.password, + extra=self.connection_dto.additional_params, + spark=spark, + ).check() + + def write(self, df: DataFrame) -> None: + normalized_df = self.normalize_column_names(df) + sort_column = next( + (col for col in normalized_df.columns if col.lower().endswith("id")), + normalized_df.columns[0], # if there is no column with "id", take the first column + ) + quoted_sort_column = f'"{sort_column}"' + + writer = DBWriter( + connection=self.connection, + table=self.transfer_dto.table_name, + options=( + Clickhouse.WriteOptions(createTableOptions=f"ENGINE = MergeTree() ORDER BY {quoted_sort_column}") + if self.transfer_dto.type == "clickhouse" + else None + ), + ) + return writer.run(df=normalized_df) + + def normalize_column_names(self, df: DataFrame) -> DataFrame: + for column_name in df.columns: + df = df.withColumnRenamed(column_name, column_name.lower()) + return df diff --git a/syncmaster/worker/spark.py b/syncmaster/worker/spark.py index 9bfc46fd..0aa051c7 100644 --- a/syncmaster/worker/spark.py +++ b/syncmaster/worker/spark.py @@ -36,12 +36,15 @@ def get_worker_spark_session( def get_packages(db_type: str) -> list[str]: - from onetl.connection import Oracle, Postgres, SparkS3 + from onetl.connection import Clickhouse, Oracle, Postgres, SparkS3 if db_type == "postgres": return Postgres.get_packages() if db_type == "oracle": return Oracle.get_packages() + if db_type == "clickhouse": + # TODO: add https://github.com/MobileTeleSystems/spark-dialect-extension/ to spark jars + return Clickhouse.get_packages() if db_type == "s3": import pyspark diff --git a/tests/settings.py b/tests/settings.py index e2d98dc6..05cc64b1 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -22,6 +22,14 @@ class TestSettings(BaseSettings): TEST_ORACLE_SID: str | None = None TEST_ORACLE_SERVICE_NAME: str | None = None + TEST_CLICKHOUSE_HOST_FOR_CONFTEST: str + TEST_CLICKHOUSE_PORT_FOR_CONFTEST: int + TEST_CLICKHOUSE_HOST_FOR_WORKER: str + TEST_CLICKHOUSE_PORT_FOR_WORKER: int + TEST_CLICKHOUSE_USER: str + TEST_CLICKHOUSE_PASSWORD: str + TEST_CLICKHOUSE_DB: str + TEST_HIVE_CLUSTER: str TEST_HIVE_USER: str TEST_HIVE_PASSWORD: str diff --git a/tests/test_integration/test_run_transfer/conftest.py b/tests/test_integration/test_run_transfer/conftest.py index 5927dd40..dac14564 100644 --- a/tests/test_integration/test_run_transfer/conftest.py +++ b/tests/test_integration/test_run_transfer/conftest.py @@ -8,7 +8,7 @@ import pyspark import pytest import pytest_asyncio -from onetl.connection import Hive, Oracle, Postgres, SparkS3 +from onetl.connection import Clickhouse, Hive, Oracle, Postgres, SparkS3 from onetl.db import DBWriter from onetl.file.format import CSV, JSON, JSONLine from pyspark.sql import DataFrame, SparkSession @@ -27,6 +27,7 @@ from syncmaster.backend.settings import ServerAppSettings as Settings from syncmaster.db.models import Group from syncmaster.dto.connections import ( + ClickhouseConnectionDTO, HDFSConnectionDTO, HiveConnectionDTO, OracleConnectionDTO, @@ -72,6 +73,9 @@ def spark(settings: Settings, request: FixtureRequest) -> SparkSession: if "oracle" in markers: maven_packages.extend(Oracle.get_packages()) + if "clickhouse" in markers: + maven_packages.extend(Clickhouse.get_packages()) + if "s3" in markers: maven_packages.extend(SparkS3.get_packages(spark_version=pyspark.__version__)) excluded_packages.extend( @@ -164,6 +168,36 @@ def oracle_for_worker(test_settings: TestSettings) -> OracleConnectionDTO: ) +@pytest.fixture( + scope="session", + params=[pytest.param("clickhouse", marks=[pytest.mark.clickhouse])], +) +def clickhouse_for_conftest(test_settings: TestSettings) -> ClickhouseConnectionDTO: + return ClickhouseConnectionDTO( + host=test_settings.TEST_CLICKHOUSE_HOST_FOR_CONFTEST, + port=test_settings.TEST_CLICKHOUSE_PORT_FOR_CONFTEST, + user=test_settings.TEST_CLICKHOUSE_USER, + password=test_settings.TEST_CLICKHOUSE_PASSWORD, + database_name=test_settings.TEST_CLICKHOUSE_DB, + additional_params={}, + ) + + +@pytest.fixture( + scope="session", + params=[pytest.param("clickhouse", marks=[pytest.mark.clickhouse])], +) +def clickhouse_for_worker(test_settings: TestSettings) -> ClickhouseConnectionDTO: + return ClickhouseConnectionDTO( + host=test_settings.TEST_CLICKHOUSE_HOST_FOR_WORKER, + port=test_settings.TEST_CLICKHOUSE_PORT_FOR_WORKER, + user=test_settings.TEST_CLICKHOUSE_USER, + password=test_settings.TEST_CLICKHOUSE_PASSWORD, + database_name=test_settings.TEST_CLICKHOUSE_DB, + additional_params={}, + ) + + @pytest.fixture( scope="session", params=[pytest.param("postgres", marks=[pytest.mark.postgres])], @@ -509,6 +543,50 @@ def fill_with_data(df: DataFrame): pass +@pytest.fixture +def prepare_clickhouse( + clickhouse_for_conftest: ClickhouseConnectionDTO, + spark: SparkSession, +): + clickhouse = clickhouse_for_conftest + onetl_conn = Clickhouse( + host=clickhouse.host, + port=clickhouse.port, + user=clickhouse.user, + password=clickhouse.password, + spark=spark, + ).check() + try: + onetl_conn.execute(f"DROP TABLE {clickhouse.user}.source_table") + except Exception: + pass + try: + onetl_conn.execute(f"DROP TABLE {clickhouse.user}.target_table") + except Exception: + pass + + def fill_with_data(df: DataFrame): + logger.info("START PREPARE ORACLE") + db_writer = DBWriter( + connection=onetl_conn, + target=f"{clickhouse.user}.source_table", + options=Clickhouse.WriteOptions(createTableOptions="ENGINE = Memory"), + ) + db_writer.run(df) + logger.info("END PREPARE ORACLE") + + yield onetl_conn, fill_with_data + + try: + onetl_conn.execute(f"DROP TABLE {clickhouse.user}.source_table") + except Exception: + pass + try: + onetl_conn.execute(f"DROP TABLE {clickhouse.user}.target_table") + except Exception: + pass + + @pytest.fixture(params=[("csv", {}), ("jsonline", {}), ("json", {})]) def source_file_format(request: FixtureRequest): name, params = request.param @@ -716,6 +794,43 @@ async def oracle_connection( await session.commit() +@pytest_asyncio.fixture +async def clickhouse_connection( + clickhouse_for_worker: ClickhouseConnectionDTO, + settings: Settings, + session: AsyncSession, + group: Group, +): + clickhouse = clickhouse_for_worker + syncmaster_conn = await create_connection( + session=session, + name=secrets.token_hex(5), + data=dict( + type=clickhouse.type, + host=clickhouse.host, + port=clickhouse.port, + database_name=clickhouse.database_name, + additional_params={}, + ), + group_id=group.id, + ) + + await create_credentials( + session=session, + settings=settings, + connection_id=syncmaster_conn.id, + auth_data=dict( + type="clickhouse", + user=clickhouse.user, + password=clickhouse.password, + ), + ) + + yield syncmaster_conn + await session.delete(syncmaster_conn) + await session.commit() + + @pytest_asyncio.fixture async def hdfs_connection( hdfs: HDFSConnectionDTO, diff --git a/tests/test_integration/test_run_transfer/test_clickhouse.py b/tests/test_integration/test_run_transfer/test_clickhouse.py new file mode 100644 index 00000000..d76a5e7e --- /dev/null +++ b/tests/test_integration/test_run_transfer/test_clickhouse.py @@ -0,0 +1,286 @@ +import secrets + +import pytest +import pytest_asyncio +from httpx import AsyncClient +from onetl.connection import Clickhouse +from onetl.db import DBReader +from pyspark.sql import DataFrame +from pyspark.sql.functions import col, date_trunc +from sqlalchemy.ext.asyncio import AsyncSession + +from syncmaster.db.models import Connection, Group, Queue, Status, Transfer +from tests.mocks import MockUser +from tests.test_unit.utils import create_transfer +from tests.utils import get_run_on_end + +pytestmark = [pytest.mark.asyncio, pytest.mark.worker] + + +@pytest_asyncio.fixture +async def postgres_to_clickhouse( + session: AsyncSession, + group: Group, + queue: Queue, + clickhouse_for_conftest: Clickhouse, + clickhouse_connection: Connection, + postgres_connection: Connection, +): + result = await create_transfer( + session=session, + group_id=group.id, + name=f"postgres2clickhouse_{secrets.token_hex(5)}", + source_connection_id=postgres_connection.id, + target_connection_id=clickhouse_connection.id, + source_params={ + "type": "postgres", + "table_name": "public.source_table", + }, + target_params={ + "type": "clickhouse", + "table_name": f"{clickhouse_for_conftest.user}.target_table", + }, + queue_id=queue.id, + ) + yield result + await session.delete(result) + await session.commit() + + +@pytest_asyncio.fixture +async def clickhouse_to_postgres( + session: AsyncSession, + group: Group, + queue: Queue, + clickhouse_for_conftest: Clickhouse, + clickhouse_connection: Connection, + postgres_connection: Connection, +): + result = await create_transfer( + session=session, + group_id=group.id, + name=f"clickhouse2postgres_{secrets.token_hex(5)}", + source_connection_id=clickhouse_connection.id, + target_connection_id=postgres_connection.id, + source_params={ + "type": "clickhouse", + "table_name": f"{clickhouse_for_conftest.user}.source_table", + }, + target_params={ + "type": "postgres", + "table_name": "public.target_table", + }, + queue_id=queue.id, + ) + yield result + await session.delete(result) + await session.commit() + + +async def test_run_transfer_postgres_to_clickhouse( + client: AsyncClient, + group_owner: MockUser, + prepare_postgres, + prepare_clickhouse, + init_df: DataFrame, + postgres_to_clickhouse: Transfer, +): + # Arrange + _, fill_with_data = prepare_postgres + fill_with_data(init_df) + clickhouse, _ = prepare_clickhouse + + # Act + result = await client.post( + "v1/runs", + headers={"Authorization": f"Bearer {group_owner.token}"}, + json={"transfer_id": postgres_to_clickhouse.id}, + ) + # Assert + assert result.status_code == 200 + + run_data = await get_run_on_end( + client=client, + run_id=result.json()["id"], + token=group_owner.token, + ) + source_auth_data = run_data["transfer_dump"]["source_connection"]["auth_data"] + target_auth_data = run_data["transfer_dump"]["target_connection"]["auth_data"] + + assert run_data["status"] == Status.FINISHED.value + assert source_auth_data["user"] + assert "password" not in source_auth_data + assert target_auth_data["user"] + assert "password" not in target_auth_data + reader = DBReader( + connection=clickhouse, + table=f"{clickhouse.user}.target_table", + ) + df = reader.run() + # as spark truncates milliseconds while writing to clickhouse: https://onetl.readthedocs.io/en/latest/connection/db_connection/clickhouse/types.html#id10 + init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT"))) + for field in init_df.schema: + df = df.withColumn(field.name, df[field.name].cast(field.dataType)) + + assert df.sort("ID").collect() == init_df.sort("ID").collect() + + +async def test_run_transfer_postgres_to_clickhouse_mixed_naming( + client: AsyncClient, + group_owner: MockUser, + prepare_postgres, + prepare_clickhouse, + init_df_with_mixed_column_naming: DataFrame, + postgres_to_clickhouse: Transfer, +): + # Arrange + _, fill_with_data = prepare_postgres + fill_with_data(init_df_with_mixed_column_naming) + clickhouse, _ = prepare_clickhouse + + # Act + result = await client.post( + "v1/runs", + headers={"Authorization": f"Bearer {group_owner.token}"}, + json={"transfer_id": postgres_to_clickhouse.id}, + ) + # Assert + assert result.status_code == 200 + + run_data = await get_run_on_end( + client=client, + run_id=result.json()["id"], + token=group_owner.token, + ) + source_auth_data = run_data["transfer_dump"]["source_connection"]["auth_data"] + target_auth_data = run_data["transfer_dump"]["target_connection"]["auth_data"] + + assert run_data["status"] == Status.FINISHED.value + assert source_auth_data["user"] + assert "password" not in source_auth_data + assert target_auth_data["user"] + assert "password" not in target_auth_data + + reader = DBReader( + connection=clickhouse, + table=f"{clickhouse.user}.target_table", + ) + df = reader.run() + assert df.columns != init_df_with_mixed_column_naming.columns + assert df.columns == [column.lower() for column in init_df_with_mixed_column_naming.columns] + + # as spark truncates milliseconds while writing to clickhouse: https://onetl.readthedocs.io/en/latest/connection/db_connection/clickhouse/types.html#id10 + init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.withColumn( + "Registered At", + date_trunc("second", col("Registered At")), + ) + for field in init_df_with_mixed_column_naming.schema: + df = df.withColumn(field.name, df[field.name].cast(field.dataType)) + + assert df.collect() == init_df_with_mixed_column_naming.collect() + + +async def test_run_transfer_clickhouse_to_postgres( + client: AsyncClient, + group_owner: MockUser, + prepare_clickhouse, + prepare_postgres, + init_df: DataFrame, + clickhouse_to_postgres: Transfer, +): + # Arrange + _, fill_with_data = prepare_clickhouse + fill_with_data(init_df) + postgres, _ = prepare_postgres + + # Act + result = await client.post( + "v1/runs", + headers={"Authorization": f"Bearer {group_owner.token}"}, + json={"transfer_id": clickhouse_to_postgres.id}, + ) + # Assert + assert result.status_code == 200 + + run_data = await get_run_on_end( + client=client, + run_id=result.json()["id"], + token=group_owner.token, + ) + source_auth_data = run_data["transfer_dump"]["source_connection"]["auth_data"] + target_auth_data = run_data["transfer_dump"]["target_connection"]["auth_data"] + + assert run_data["status"] == Status.FINISHED.value + assert source_auth_data["user"] + assert "password" not in source_auth_data + assert target_auth_data["user"] + assert "password" not in target_auth_data + + reader = DBReader( + connection=postgres, + table="public.target_table", + ) + df = reader.run() + + # as spark truncates milliseconds while writing to clickhouse: https://onetl.readthedocs.io/en/latest/connection/db_connection/clickhouse/types.html#id10 + init_df = init_df.withColumn("REGISTERED_AT", date_trunc("second", col("REGISTERED_AT"))) + for field in init_df.schema: + df = df.withColumn(field.name, df[field.name].cast(field.dataType)) + + assert df.sort("ID").collect() == init_df.sort("ID").collect() + + +async def test_run_transfer_clickhouse_to_postgres_mixed_naming( + client: AsyncClient, + group_owner: MockUser, + prepare_clickhouse, + prepare_postgres, + init_df_with_mixed_column_naming: DataFrame, + clickhouse_to_postgres: Transfer, +): + # Arrange + _, fill_with_data = prepare_clickhouse + fill_with_data(init_df_with_mixed_column_naming) + postgres, _ = prepare_postgres + + # Act + result = await client.post( + "v1/runs", + headers={"Authorization": f"Bearer {group_owner.token}"}, + json={"transfer_id": clickhouse_to_postgres.id}, + ) + # Assert + assert result.status_code == 200 + + run_data = await get_run_on_end( + client=client, + run_id=result.json()["id"], + token=group_owner.token, + ) + source_auth_data = run_data["transfer_dump"]["source_connection"]["auth_data"] + target_auth_data = run_data["transfer_dump"]["target_connection"]["auth_data"] + + assert run_data["status"] == Status.FINISHED.value + assert source_auth_data["user"] + assert "password" not in source_auth_data + assert target_auth_data["user"] + assert "password" not in target_auth_data + + reader = DBReader( + connection=postgres, + table="public.target_table", + ) + df = reader.run() + + assert df.columns != init_df_with_mixed_column_naming.columns + assert df.columns == [column.lower() for column in init_df_with_mixed_column_naming.columns] + + # as spark truncates milliseconds while writing to clickhouse: https://onetl.readthedocs.io/en/latest/connection/db_connection/clickhouse/types.html#id10 + init_df_with_mixed_column_naming = init_df_with_mixed_column_naming.withColumn( + "Registered At", + date_trunc("second", col("Registered At")), + ) + for field in init_df_with_mixed_column_naming.schema: + df = df.withColumn(field.name, df[field.name].cast(field.dataType)) + + assert df.collect() == init_df_with_mixed_column_naming.collect()