diff --git a/README.rst b/README.rst index a2a893a5..6f7477ca 100644 --- a/README.rst +++ b/README.rst @@ -36,6 +36,7 @@ List of currently supported connections: * Clickhouse * Postgres * Oracle +* MSSQL * HDFS * S3 diff --git a/docs/changelog/next_release/125.feature.rst b/docs/changelog/next_release/125.feature.rst new file mode 100644 index 00000000..0aed705b --- /dev/null +++ b/docs/changelog/next_release/125.feature.rst @@ -0,0 +1 @@ +Add MSSQL API schema \ No newline at end of file diff --git a/syncmaster/backend/api/v1/connections.py b/syncmaster/backend/api/v1/connections.py index 47296f4d..678656a1 100644 --- a/syncmaster/backend/api/v1/connections.py +++ b/syncmaster/backend/api/v1/connections.py @@ -21,6 +21,7 @@ CLICKHOUSE_TYPE, HDFS_TYPE, HIVE_TYPE, + MSSQL_TYPE, ORACLE_TYPE, POSTGRES_TYPE, S3_TYPE, @@ -38,7 +39,7 @@ router = APIRouter(tags=["Connections"], responses=get_error_responses()) -CONNECTION_TYPES = ORACLE_TYPE, POSTGRES_TYPE, CLICKHOUSE_TYPE, HIVE_TYPE, S3_TYPE, HDFS_TYPE +CONNECTION_TYPES = ORACLE_TYPE, POSTGRES_TYPE, CLICKHOUSE_TYPE, HIVE_TYPE, MSSQL_TYPE, S3_TYPE, HDFS_TYPE @router.get("/connections") diff --git a/syncmaster/schemas/v1/connection_types.py b/syncmaster/schemas/v1/connection_types.py index f21774d9..924493a3 100644 --- a/syncmaster/schemas/v1/connection_types.py +++ b/syncmaster/schemas/v1/connection_types.py @@ -7,6 +7,7 @@ ORACLE_TYPE = Literal["oracle"] POSTGRES_TYPE = Literal["postgres"] CLICKHOUSE_TYPE = Literal["clickhouse"] +MSSQL_TYPE = Literal["mssql"] S3_TYPE = Literal["s3"] HDFS_TYPE = Literal["hdfs"] @@ -16,5 +17,6 @@ class ConnectionType(str, Enum): HIVE = "hive" ORACLE = "oracle" CLICKHOUSE = "clickhouse" + MSSQL = "mssql" S3 = "s3" HDFS = "hdfs" diff --git a/syncmaster/schemas/v1/connections/connection.py b/syncmaster/schemas/v1/connections/connection.py index b9b86aaa..ff597940 100644 --- a/syncmaster/schemas/v1/connections/connection.py +++ b/syncmaster/schemas/v1/connections/connection.py @@ -27,6 +27,14 @@ UpdateHiveAuthSchema, UpdateHiveConnectionSchema, ) +from syncmaster.schemas.v1.connections.mssql import ( + CreateMSSQLAuthSchema, + CreateMSSQLConnectionSchema, + ReadMSSQLAuthSchema, + ReadMSSQLConnectionSchema, + UpdateMSSQLAuthSchema, + UpdateMSSQLConnectionSchema, +) from syncmaster.schemas.v1.connections.oracle import ( CreateOracleAuthSchema, CreateOracleConnectionSchema, @@ -60,6 +68,7 @@ | ReadOracleConnectionSchema | ReadPostgresConnectionSchema | ReadClickhouseConnectionSchema + | ReadMSSQLConnectionSchema | S3ReadConnectionSchema ) CreateConnectionDataSchema = ( @@ -67,6 +76,7 @@ | CreateOracleConnectionSchema | CreatePostgresConnectionSchema | CreateClickhouseConnectionSchema + | CreateMSSQLConnectionSchema | HDFSCreateConnectionSchema | S3CreateConnectionSchema ) @@ -77,12 +87,14 @@ | UpdateOracleConnectionSchema | UpdatePostgresConnectionSchema | UpdateClickhouseConnectionSchema + | UpdateMSSQLConnectionSchema ) ReadConnectionAuthDataSchema = ( ReadHiveAuthSchema | ReadOracleAuthSchema | ReadPostgresAuthSchema | ReadClickhouseAuthSchema + | ReadMSSQLAuthSchema | S3ReadAuthSchema | HDFSReadAuthSchema ) @@ -91,6 +103,7 @@ | CreateOracleAuthSchema | CreatePostgresAuthSchema | CreateClickhouseAuthSchema + | CreateMSSQLAuthSchema | S3CreateAuthSchema | HDFSCreateAuthSchema ) @@ -99,6 +112,7 @@ | UpdateOracleAuthSchema | UpdatePostgresAuthSchema | UpdateClickhouseAuthSchema + | UpdateMSSQLAuthSchema | S3UpdateAuthSchema | HDFSUpdateAuthSchema ) diff --git a/syncmaster/schemas/v1/connections/mssql.py b/syncmaster/schemas/v1/connections/mssql.py new file mode 100644 index 00000000..b560bd00 --- /dev/null +++ b/syncmaster/schemas/v1/connections/mssql.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: 2023-2024 MTS PJSC +# SPDX-License-Identifier: Apache-2.0 +from pydantic import BaseModel, Field, SecretStr + +from syncmaster.schemas.v1.connection_types import MSSQL_TYPE + + +class MSSQLBaseSchema(BaseModel): + type: MSSQL_TYPE + + class Config: + from_attributes = True + + +class ReadMSSQLConnectionSchema(MSSQLBaseSchema): + host: str + port: int + database: str + additional_params: dict = Field(default_factory=dict) + + +class ReadMSSQLAuthSchema(MSSQLBaseSchema): + user: str + + +class UpdateMSSQLConnectionSchema(MSSQLBaseSchema): + host: str | None = None + port: int | None = None + database: str | None = None + additional_params: dict | None = Field(default_factory=dict) + + +class UpdateMSSQLAuthSchema(MSSQLBaseSchema): + user: str | None = None # noqa: F722 + password: SecretStr | None = None + + +class CreateMSSQLConnectionSchema(MSSQLBaseSchema): + host: str + port: int + database: str + additional_params: dict = Field(default_factory=dict) + + +class CreateMSSQLAuthSchema(MSSQLBaseSchema): + user: str + password: SecretStr diff --git a/syncmaster/schemas/v1/transfers/__init__.py b/syncmaster/schemas/v1/transfers/__init__.py index d6c7d662..b998e3dd 100644 --- a/syncmaster/schemas/v1/transfers/__init__.py +++ b/syncmaster/schemas/v1/transfers/__init__.py @@ -9,6 +9,7 @@ from syncmaster.schemas.v1.transfers.db import ( ClickhouseReadTransferSourceAndTarget, HiveReadTransferSourceAndTarget, + MSSQLReadTransferSourceAndTarget, OracleReadTransferSourceAndTarget, PostgresReadTransferSourceAndTarget, ) @@ -33,6 +34,7 @@ | HiveReadTransferSourceAndTarget | OracleReadTransferSourceAndTarget | ClickhouseReadTransferSourceAndTarget + | MSSQLReadTransferSourceAndTarget | S3ReadTransferSource ) @@ -42,6 +44,7 @@ | HiveReadTransferSourceAndTarget | OracleReadTransferSourceAndTarget | ClickhouseReadTransferSourceAndTarget + | MSSQLReadTransferSourceAndTarget | S3ReadTransferTarget ) @@ -51,6 +54,7 @@ | HiveReadTransferSourceAndTarget | OracleReadTransferSourceAndTarget | ClickhouseReadTransferSourceAndTarget + | MSSQLReadTransferSourceAndTarget | S3CreateTransferSource ) @@ -60,6 +64,7 @@ | HiveReadTransferSourceAndTarget | OracleReadTransferSourceAndTarget | ClickhouseReadTransferSourceAndTarget + | MSSQLReadTransferSourceAndTarget | S3CreateTransferTarget ) @@ -69,6 +74,7 @@ | HiveReadTransferSourceAndTarget | OracleReadTransferSourceAndTarget | ClickhouseReadTransferSourceAndTarget + | MSSQLReadTransferSourceAndTarget | S3CreateTransferSource | None ) @@ -79,6 +85,7 @@ | HiveReadTransferSourceAndTarget | OracleReadTransferSourceAndTarget | ClickhouseReadTransferSourceAndTarget + | MSSQLReadTransferSourceAndTarget | S3CreateTransferTarget | None ) diff --git a/syncmaster/schemas/v1/transfers/db.py b/syncmaster/schemas/v1/transfers/db.py index 5618cb71..751bf57a 100644 --- a/syncmaster/schemas/v1/transfers/db.py +++ b/syncmaster/schemas/v1/transfers/db.py @@ -7,6 +7,7 @@ from syncmaster.schemas.v1.connection_types import ( CLICKHOUSE_TYPE, HIVE_TYPE, + MSSQL_TYPE, ORACLE_TYPE, POSTGRES_TYPE, ) @@ -30,3 +31,7 @@ class PostgresReadTransferSourceAndTarget(ReadDBTransfer): class ClickhouseReadTransferSourceAndTarget(ReadDBTransfer): type: CLICKHOUSE_TYPE + + +class MSSQLReadTransferSourceAndTarget(ReadDBTransfer): + type: MSSQL_TYPE diff --git a/tests/test_unit/test_connections/connection_fixtures/group_connections_fixture.py b/tests/test_unit/test_connections/connection_fixtures/group_connections_fixture.py index ce188e10..b698ffa1 100644 --- a/tests/test_unit/test_connections/connection_fixtures/group_connections_fixture.py +++ b/tests/test_unit/test_connections/connection_fixtures/group_connections_fixture.py @@ -1,3 +1,5 @@ +from collections.abc import AsyncGenerator + import pytest_asyncio from sqlalchemy.ext.asyncio import AsyncSession @@ -7,7 +9,10 @@ @pytest_asyncio.fixture -async def group_connections(group_connection: MockConnection, session: AsyncSession) -> list[MockConnection]: +async def group_connections( + group_connection: MockConnection, + session: AsyncSession, +) -> AsyncGenerator[list[MockConnection], None]: connection = group_connection.connection # start with the connection from group_connection fixture @@ -17,13 +22,36 @@ async def group_connections(group_connection: MockConnection, session: AsyncSess # since group_connection already created a connection, we start from index 1 for conn_type in connection_types[1:]: - new_data = { # TODO: create different dicts + new_data = { **connection.data, "type": conn_type.value, - "cluster": "cluster", - "bucket": "bucket", } + if conn_type in [ConnectionType.HDFS, ConnectionType.HIVE]: + new_data.update( + { + "cluster": "cluster", + }, + ) + elif conn_type == ConnectionType.S3: + new_data.update( + { + "bucket": "bucket", + }, + ) + elif conn_type == ConnectionType.POSTGRES: + new_data.update( + { + "database_name": "database", + }, + ) + elif conn_type in [ConnectionType.ORACLE, ConnectionType.CLICKHOUSE, ConnectionType.MSSQL]: + new_data.update( + { + "database": "database", + }, + ) + new_connection = Connection( group_id=connection.group_id, name=f"{connection.name}_{conn_type.value}", diff --git a/tests/test_unit/test_connections/test_create_connection.py b/tests/test_unit/test_connections/test_create_connection.py index e4439faf..e0a43ec8 100644 --- a/tests/test_unit/test_connections/test_create_connection.py +++ b/tests/test_unit/test_connections/test_create_connection.py @@ -282,7 +282,7 @@ async def test_check_fields_validation_on_create_connection( "context": { "discriminator": "'type'", "tag": "POSTGRESQL", - "expected_tags": "'hive', 'oracle', 'postgres', 'clickhouse', 'hdfs', 's3'", + "expected_tags": "'hive', 'oracle', 'postgres', 'clickhouse', 'mssql', 'hdfs', 's3'", }, "input": { "type": "POSTGRESQL", @@ -292,7 +292,7 @@ async def test_check_fields_validation_on_create_connection( "database_name": "postgres", }, "location": ["body", "connection_data"], - "message": "Input tag 'POSTGRESQL' found using 'type' does not match any of the expected tags: 'hive', 'oracle', 'postgres', 'clickhouse', 'hdfs', 's3'", + "message": "Input tag 'POSTGRESQL' found using 'type' does not match any of the expected tags: 'hive', 'oracle', 'postgres', 'clickhouse', 'mssql', 'hdfs', 's3'", "code": "union_tag_invalid", }, ], diff --git a/tests/test_unit/test_connections/test_db_connection/test_create_mssql_connection.py b/tests/test_unit/test_connections/test_db_connection/test_create_mssql_connection.py new file mode 100644 index 00000000..62d4f3e1 --- /dev/null +++ b/tests/test_unit/test_connections/test_db_connection/test_create_mssql_connection.py @@ -0,0 +1,80 @@ +import pytest +from httpx import AsyncClient +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from syncmaster.db.models import AuthData, Connection +from syncmaster.db.repositories.utils import decrypt_auth_data +from syncmaster.settings import Settings +from tests.mocks import MockGroup, UserTestRoles + +pytestmark = [pytest.mark.asyncio, pytest.mark.backend, pytest.mark.mssql] + + +async def test_developer_plus_can_create_mssql_connection( + client: AsyncClient, + group: MockGroup, + session: AsyncSession, + settings: Settings, + role_developer_plus: UserTestRoles, +): + # Arrange + user = group.get_member_of_role(role_developer_plus) + + # Act + result = await client.post( + "v1/connections", + headers={"Authorization": f"Bearer {user.token}"}, + json={ + "group_id": group.id, + "name": "New connection", + "description": "", + "connection_data": { + "type": "mssql", + "host": "127.0.0.1", + "port": 1433, + "database": "database_name", + }, + "auth_data": { + "type": "mssql", + "user": "user", + "password": "secret", + }, + }, + ) + connection = ( + await session.scalars( + select(Connection).filter_by( + name="New connection", + ), + ) + ).first() + + creds = ( + await session.scalars( + select(AuthData).filter_by( + connection_id=connection.id, + ), + ) + ).one() + + # Assert + decrypted = decrypt_auth_data(creds.value, settings=settings) + assert result.status_code == 200 + assert result.json() == { + "id": connection.id, + "name": connection.name, + "description": connection.description, + "group_id": connection.group_id, + "connection_data": { + "type": connection.data["type"], + "host": connection.data["host"], + "port": connection.data["port"], + "database": connection.data["database"], + "additional_params": connection.data["additional_params"], + }, + "auth_data": { + "type": decrypted["type"], + "user": decrypted["user"], + }, + } diff --git a/tests/test_unit/test_connections/test_db_connection/test_update_mssql_connection.py b/tests/test_unit/test_connections/test_db_connection/test_update_mssql_connection.py new file mode 100644 index 00000000..f2942ad8 --- /dev/null +++ b/tests/test_unit/test_connections/test_db_connection/test_update_mssql_connection.py @@ -0,0 +1,72 @@ +import pytest +from httpx import AsyncClient + +from tests.mocks import MockConnection, UserTestRoles + +pytestmark = [pytest.mark.asyncio, pytest.mark.backend, pytest.mark.mssql] + + +@pytest.mark.parametrize( + "create_connection_data,create_connection_auth_data", + [ + ( + { + "type": "mssql", + "host": "127.0.0.1", + "port": 1433, + "database": "name", + }, + { + "type": "mssql", + "user": "user", + "password": "secret", + }, + ), + ], + indirect=True, +) +async def test_developer_plus_can_update_mssql_connection( + client: AsyncClient, + group_connection: MockConnection, + role_developer_plus: UserTestRoles, +): + # Arrange + user = group_connection.owner_group.get_member_of_role(role_developer_plus) + group_connection.connection.group.id + + # Act + result = await client.patch( + f"v1/connections/{group_connection.id}", + headers={"Authorization": f"Bearer {user.token}"}, + json={ + "connection_data": { + "type": "mssql", + "host": "127.0.1.1", + "database": "new_name", + }, + "auth_data": { + "type": "mssql", + "user": "new_user", + }, + }, + ) + + # Assert + assert result.status_code == 200 + assert result.json() == { + "id": group_connection.id, + "name": group_connection.name, + "description": group_connection.description, + "group_id": group_connection.group_id, + "connection_data": { + "type": group_connection.data["type"], + "host": "127.0.1.1", + "port": group_connection.data["port"], + "database": "new_name", + "additional_params": {}, + }, + "auth_data": { + "type": group_connection.credentials.value["type"], + "user": "new_user", + }, + } diff --git a/tests/test_unit/test_connections/test_read_connections.py b/tests/test_unit/test_connections/test_read_connections.py index 455a223a..3cb6e124 100644 --- a/tests/test_unit/test_connections/test_read_connections.py +++ b/tests/test_unit/test_connections/test_read_connections.py @@ -307,10 +307,10 @@ async def test_search_connections_with_nonexistent_query( @pytest.mark.parametrize( "filter_params, expected_total", [ - ({}, 6), # No filters applied, expecting all connections + ({}, 7), # No filters applied, expecting all connections ({"type": ["oracle"]}, 1), ({"type": ["postgres", "hive"]}, 2), - ({"type": ["postgres", "hive", "oracle", "clickhouse", "hdfs", "s3"]}, 6), + ({"type": ["postgres", "hive", "oracle", "clickhouse", "mssql", "hdfs", "s3"]}, 7), ], ids=[ "no_filters", diff --git a/tests/test_unit/test_transfers/test_create_transfer.py b/tests/test_unit/test_transfers/test_create_transfer.py index ae393f12..67a3bb5d 100644 --- a/tests/test_unit/test_transfers/test_create_transfer.py +++ b/tests/test_unit/test_transfers/test_create_transfer.py @@ -370,12 +370,12 @@ async def test_superuser_can_create_transfer( "location": ["body", "source_params"], "message": ( "Input tag 'new some connection type' found using 'type' " - "does not match any of the expected tags: 'postgres', 'hdfs', 'hive', 'oracle', 'clickhouse', 's3'" + "does not match any of the expected tags: 'postgres', 'hdfs', 'hive', 'oracle', 'clickhouse', 'mssql', 's3'" ), "code": "union_tag_invalid", "context": { "discriminator": "'type'", - "expected_tags": "'postgres', 'hdfs', 'hive', 'oracle', 'clickhouse', 's3'", + "expected_tags": "'postgres', 'hdfs', 'hive', 'oracle', 'clickhouse', 'mssql', 's3'", "tag": "new some connection type", }, "input": { diff --git a/tests/test_unit/test_transfers/transfer_fixtures/transfers_fixture.py b/tests/test_unit/test_transfers/transfer_fixtures/transfers_fixture.py index 17063024..f02d1998 100644 --- a/tests/test_unit/test_transfers/transfer_fixtures/transfers_fixture.py +++ b/tests/test_unit/test_transfers/transfer_fixtures/transfers_fixture.py @@ -81,6 +81,7 @@ async def group_transfers( ConnectionType.POSTGRES, ConnectionType.ORACLE, ConnectionType.CLICKHOUSE, + ConnectionType.MSSQL, ]: source_params["table_name"] = "source_table" target_params["table_name"] = "target_table"