diff --git a/mock_tests/conftest.py b/mock_tests/conftest.py index 236e500ba..c07dfae7b 100644 --- a/mock_tests/conftest.py +++ b/mock_tests/conftest.py @@ -12,6 +12,7 @@ from werkzeug.wrappers import Request, Response import weaviate +from mock_tests.mock_data import mock_class from weaviate.connect.base import ConnectionParams, ProtocolParams from weaviate.proto.v1 import ( batch_pb2, @@ -21,8 +22,6 @@ weaviate_pb2_grpc, ) -from mock_tests.mock_data import mock_class - MOCK_IP = "127.0.0.1" MOCK_PORT = 23536 MOCK_PORT_GRPC = 23537 @@ -105,18 +104,17 @@ def slow_post(request: Request) -> Response: yield weaviate_no_auth_mock +# Implement the health check service +class MockHealthServicer(HealthServicer): + def Check(self, request: HealthCheckRequest, context: ServicerContext) -> HealthCheckResponse: + return HealthCheckResponse(status=HealthCheckResponse.SERVING) + + @pytest.fixture(scope="function") def start_grpc_server() -> Generator[grpc.Server, None, None]: # Create a gRPC server server: grpc.Server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - # Implement the health check service - class MockHealthServicer(HealthServicer): - def Check( - self, request: HealthCheckRequest, context: ServicerContext - ) -> HealthCheckResponse: - return HealthCheckResponse(status=HealthCheckResponse.SERVING) - # Add the health check service to the server add_HealthServicer_to_server(MockHealthServicer(), server) diff --git a/mock_tests/test_ssl.py b/mock_tests/test_ssl.py new file mode 100644 index 000000000..fb0f356c5 --- /dev/null +++ b/mock_tests/test_ssl.py @@ -0,0 +1,125 @@ +import json +import ssl +from concurrent import futures +from typing import Iterable + +import grpc +import pytest +import trustme +from grpc_health.v1.health_pb2_grpc import add_HealthServicer_to_server +from pytest_httpserver import HTTPServer +from werkzeug.wrappers import Response + +import weaviate +from mock_tests.conftest import MockHealthServicer, MOCK_IP, MOCK_PORT_GRPC + +SERVER = "127.0.0.1" +MOCK_PORT_GRPC_SSL = 23538 +PORT = 23539 + + +@pytest.fixture(scope="session") +def httpserver_ssl_context(): + ca = trustme.CA() + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_cert = ca.issue_cert(SERVER) + server_cert.configure_cert(server_context) + + return server_context + + +@pytest.fixture(scope="session") +def make_httpserver(httpserver_ssl_context) -> Iterable[HTTPServer]: + server = HTTPServer(host=SERVER, port=PORT, ssl_context=httpserver_ssl_context) + server.start() + yield server + server.clear() + if server.is_running(): + server.stop() + + +@pytest.fixture(scope="module") +def start_grpc_server_ssl() -> grpc.Server: + # Create a gRPC server + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + + # Add the health check service to the server + add_HealthServicer_to_server(MockHealthServicer(), server) + + # Create server credentials using the SSL context + ca = trustme.CA() + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_cert = ca.issue_cert(SERVER) + server_cert.configure_cert(server_context) + server_credentials = grpc.ssl_server_credentials( + [(server_cert.private_key_pem.bytes(), server_cert.cert_chain_pems[0].bytes())] + ) + + # Listen on a specific port with SSL + server.add_secure_port(f"[::]:{MOCK_PORT_GRPC_SSL}", server_credentials) + server.start() + + yield server + + # Teardown - stop the server + server.stop(0) + + +def test_disable_ssl_verification( + make_httpserver: HTTPServer, start_grpc_server_ssl: grpc.Server, start_grpc_server: grpc.Server +): + make_httpserver.expect_request("/v1/.well-known/ready").respond_with_json({}) + make_httpserver.expect_request("/v1/meta").respond_with_json({"version": "1.24"}) + make_httpserver.expect_request("/v1/nodes").respond_with_json({"nodes": [{"gitHash": "ABC"}]}) + make_httpserver.expect_request("/v1/.well-known/openid-configuration").respond_with_response( + Response(json.dumps({}), status=404) + ) + + assert make_httpserver.port == PORT + assert make_httpserver.host == SERVER + + # test http connection with ssl + with pytest.raises(weaviate.exceptions.WeaviateConnectionError): + weaviate.connect_to_custom( + http_port=PORT, + http_host=SERVER, + grpc_port=MOCK_PORT_GRPC, + http_secure=True, + grpc_host=MOCK_IP, + grpc_secure=False, + ) + + # test grpc connection with ssl + with pytest.raises(weaviate.exceptions.WeaviateConnectionError): + weaviate.connect_to_custom( + http_port=PORT, + http_host=SERVER, + grpc_port=MOCK_PORT_GRPC_SSL, + http_secure=True, + grpc_host=SERVER, + grpc_secure=True, + ) + + # test http connection with ssl and verify disabled + weaviate.connect_to_custom( + http_port=PORT, + http_host=SERVER, + grpc_port=MOCK_PORT_GRPC, + http_secure=True, + grpc_host=MOCK_IP, + grpc_secure=False, + additional_config=weaviate.config.AdditionalConfig(disable_ssl_verification=True), + ) + + # test grpc connection with ssl and verify disabled + weaviate.connect_to_custom( + http_port=PORT, + http_host=SERVER, + grpc_port=MOCK_PORT_GRPC_SSL, + http_secure=True, + grpc_host=SERVER, + grpc_secure=True, + additional_config=weaviate.config.AdditionalConfig(disable_ssl_verification=True), + ) + + make_httpserver.check_assertions() diff --git a/requirements-devel.txt b/requirements-devel.txt index aae7aa9bd..ab24e1222 100644 --- a/requirements-devel.txt +++ b/requirements-devel.txt @@ -27,6 +27,7 @@ pytest-xdist==3.6.1 werkzeug==3.0.3 pytest-httpserver==1.0.12 py-spy==0.3.14 +trustme>=1.1.0 numpy>=1.24.4,<3.0.0 pandas>=2.0.3,<3.0.0 diff --git a/test/collection/conftest.py b/test/collection/conftest.py index 4915da9ec..a87c7d36c 100644 --- a/test/collection/conftest.py +++ b/test/collection/conftest.py @@ -1,4 +1,5 @@ import pytest + from weaviate.config import ConnectionConfig from weaviate.connect import ConnectionV4, ConnectionParams @@ -11,6 +12,7 @@ def connection() -> ConnectionV4: (10, 60), None, True, + False, None, ConnectionConfig(), None, diff --git a/weaviate/client.py b/weaviate/client.py index 39269e3df..8761c61b9 100644 --- a/weaviate/client.py +++ b/weaviate/client.py @@ -8,22 +8,20 @@ from httpx import HTTPError as HttpxError from requests.exceptions import ConnectionError as RequestsConnectionError +from weaviate import syncify from weaviate.backup.backup import _BackupAsync from weaviate.backup.sync import _Backup - - -from weaviate import syncify +from weaviate.event_loop import _EventLoopSingleton, _EventLoop from .auth import AuthCredentials from .backup import Backup from .batch import Batch from .classification import Classification - from .client_base import _WeaviateClientBase from .cluster import Cluster -from .collections.collections.async_ import _CollectionsAsync -from .collections.collections.sync import _Collections from .collections.batch.client import _BatchClientWrapper from .collections.cluster import _Cluster, _ClusterAsync +from .collections.collections.async_ import _CollectionsAsync +from .collections.collections.sync import _Collections from .config import AdditionalConfig, Config from .connect import Connection from .connect.base import ( @@ -40,7 +38,6 @@ ) from .gql import Query from .schema import Schema -from weaviate.event_loop import _EventLoopSingleton, _EventLoop from .types import NUMBER from .util import _get_valid_timeout_config, _type_request_response from .warnings import _Warnings diff --git a/weaviate/client_base.py b/weaviate/client_base.py index 2a3dbfbaa..1edea3b43 100644 --- a/weaviate/client_base.py +++ b/weaviate/client_base.py @@ -5,11 +5,8 @@ import asyncio from typing import Optional, Tuple, Union, Dict, Any - from weaviate.collections.classes.internal import _GQLEntryReturnType, _RawGQLReturn - from weaviate.integrations import _Integrations - from .auth import AuthCredentials from .config import AdditionalConfig from .connect import ConnectionV4 @@ -83,6 +80,7 @@ def __init__( proxies=config.proxies, trust_env=config.trust_env, loop=self._loop, + disable_ssl_verification=config.disable_ssl_verification, ) self.integrations = _Integrations(self._connection) diff --git a/weaviate/config.py b/weaviate/config.py index 8f9a55f95..0e6f9fe06 100644 --- a/weaviate/config.py +++ b/weaviate/config.py @@ -76,6 +76,7 @@ class AdditionalConfig(BaseModel): proxies: Union[str, Proxies, None] = Field(default=None) timeout_: Union[Tuple[int, int], Timeout] = Field(default_factory=Timeout, alias="timeout") trust_env: bool = Field(default=False) + disable_ssl_verification: bool = Field(default=False) @property def timeout(self) -> Timeout: diff --git a/weaviate/connect/base.py b/weaviate/connect/base.py index 7d84824a5..a067416ea 100644 --- a/weaviate/connect/base.py +++ b/weaviate/connect/base.py @@ -1,5 +1,7 @@ import datetime import os +import socket +import ssl import time from abc import ABC, abstractmethod from typing import Dict, Tuple, TypeVar, Union, cast @@ -8,14 +10,13 @@ import grpc # type: ignore from grpc import ssl_channel_credentials from grpc.aio import Channel # type: ignore - -# from grpclib.client import Channel - from pydantic import BaseModel, field_validator, model_validator from weaviate.config import Proxies from weaviate.types import NUMBER +# from grpclib.client import Channel + JSONPayload = Union[dict, list] TIMEOUT_TYPE_RETURN = Tuple[NUMBER, NUMBER] @@ -111,15 +112,43 @@ def _grpc_address(self) -> Tuple[str, int]: def _grpc_target(self) -> str: return f"{self.grpc.host}:{self.grpc.port}" - def _grpc_channel(self, proxies: Dict[str, str]) -> Channel: + def _grpc_channel(self, proxies: Dict[str, str], enable_ssl_verification: bool) -> Channel: if (p := proxies.get("grpc")) is not None: options: list = [*GRPC_DEFAULT_OPTIONS, ("grpc.http_proxy", p)] else: options = GRPC_DEFAULT_OPTIONS if self.grpc.secure: + if enable_ssl_verification: + creds = ssl_channel_credentials() + else: + import logging + + logging.basicConfig(level=logging.DEBUG) + + # download certificate from server. This is super hacky, but the grpc library does NOT offer a way to + # disable certificate verification. There are probably a number of edge cases that this does not cover. + context = ssl.create_default_context() + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + targets = self.grpc.host.replace("http://", "", 1).split(":") + + with socket.create_connection((targets[0], self.grpc.port)) as sock: + with context.wrap_socket( + sock, server_hostname=self._grpc_target + ) as secure_sock: + cert_binary = secure_sock.getpeercert(binary_form=True) + if cert_binary is None: + raise ValueError( + "Failed to retrieve the server certificate to bypass ssl verification." + ) + + cert = ssl.DER_cert_to_PEM_cert(cert_binary) + + creds = ssl_channel_credentials(root_certificates=cert.encode()) + return grpc.aio.secure_channel( target=self._grpc_target, - credentials=ssl_channel_credentials(), + credentials=creds, options=options, ) else: diff --git a/weaviate/connect/v4.py b/weaviate/connect/v4.py index 9724986a1..5714251a4 100644 --- a/weaviate/connect/v4.py +++ b/weaviate/connect/v4.py @@ -98,6 +98,7 @@ def __init__( timeout_config: TimeoutConfig, proxies: Union[str, Proxies, None], trust_env: bool, + disable_ssl_verification: bool, additional_headers: Optional[Dict[str, Any]], connection_config: ConnectionConfig, loop: asyncio.AbstractEventLoop, # required for background token refresh @@ -115,6 +116,7 @@ def __init__( self.timeout_config = timeout_config self.__connection_config = connection_config self.__trust_env = trust_env + self.__enable_ssl_verification = not disable_ssl_verification self._weaviate_version = _ServerVersion.from_string("") self.__connected = False self.__loop = loop @@ -211,6 +213,7 @@ def __make_mounts(self) -> Dict[str, AsyncHTTPTransport]: proxy=Proxy(url=proxy), retries=self.__connection_config.session_pool_max_retries, trust_env=self.__trust_env, + verify=self.__enable_ssl_verification, ) for key, proxy in self._proxies.items() if key != "grpc" @@ -221,6 +224,7 @@ def __make_async_client(self) -> AsyncClient: headers=self._headers, mounts=self.__make_mounts(), trust_env=self.__trust_env, + verify=self.__enable_ssl_verification, ) def __make_clients(self) -> None: @@ -229,7 +233,9 @@ def __make_clients(self) -> None: async def _open_connections( self, auth_client_secret: Optional[AuthCredentials], skip_init_checks: bool ) -> None: - self._grpc_channel = self._connection_params._grpc_channel(proxies=self._proxies) + self._grpc_channel = self._connection_params._grpc_channel( + proxies=self._proxies, enable_ssl_verification=self.__enable_ssl_verification + ) assert self._grpc_channel is not None self._grpc_stub = weaviate_pb2_grpc.WeaviateStub(self._grpc_channel)