Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add disable ssl verification to client instantiation #1260

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions integration/conftest.py
Original file line number Diff line number Diff line change
@@ -47,13 +47,14 @@ def __call__(
vector_index_config: Optional[_VectorIndexConfigCreate] = None,
description: Optional[str] = None,
reranker_config: Optional[_RerankerConfigCreate] = None,
) -> Collection[Any, Any]:
return_client: bool = False
) -> Union[Collection[Any, Any], Tuple[Collection[Any, Any], weaviate.WeaviateClient]]:
"""Typing for fixture."""
...


@pytest.fixture
def collection_factory(request: SubRequest) -> Generator[CollectionFactory, None, None]:
def collection_factory(request: SubRequest) -> Generator[Union[Collection[Any, Any], Tuple[Collection[Any, Any], weaviate.WeaviateClient]], None, None]:
name_fixture: Optional[str] = None
client_fixture: Optional[weaviate.WeaviateClient] = None

@@ -75,6 +76,7 @@ def _factory(
vector_index_config: Optional[_VectorIndexConfigCreate] = None,
description: Optional[str] = None,
reranker_config: Optional[_RerankerConfigCreate] = None,
return_client: bool = False
) -> Collection[Any, Any]:
nonlocal client_fixture, name_fixture
name_fixture = _sanitize_collection_name(request.node.name) + name
@@ -101,7 +103,10 @@ def _factory(
vector_index_config=vector_index_config,
reranker_config=reranker_config,
)
return collection
if return_client:
return collection, client_fixture
else:
return collection

try:
yield _factory
72 changes: 71 additions & 1 deletion integration/test_collection_config.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@
GenerativeSearches,
Rerankers,
_RerankerConfigCreate,
Tokenization
)
from weaviate.collections.classes.tenants import Tenant

@@ -589,7 +590,76 @@ def test_collection_config_get_shards_multi_tenancy(collection_factory: Collecti
assert "tenant1" in [shard.name for shard in shards]
assert "tenant2" in [shard.name for shard in shards]


def test_collection_config_create_from_dict(collection_factory: CollectionFactory) -> None:
collection, client = collection_factory(
inverted_index_config=Configure.inverted_index(bm25_b=0.8, bm25_k1=1.3),
multi_tenancy_config=Configure.multi_tenancy(enabled=True),
generative_config=Configure.Generative.openai(model="gpt-4"),
vectorizer_config=Configure.Vectorizer.text2vec_openai(
model="text-embedding-3-small",
base_url="http://weaviate.io",
vectorize_collection_name=False,
dimensions=512
),
vector_index_config=Configure.VectorIndex.flat(
vector_cache_max_objects=234,
quantizer=Configure.VectorIndex.Quantizer.bq(rescore_limit=456),
),
description="Some description",
reranker_config=Configure.Reranker.cohere(model="rerank-english-v2.0"),
properties=[
Property(name="field_tokenization", data_type=DataType.TEXT, tokenization=Tokenization.FIELD),
Property(name="field_description", data_type=DataType.TEXT,
tokenization=Tokenization.FIELD, description="field desc"),
Property(name="field_index_filterable", data_type=DataType.TEXT,
index_filterable=False),
Property(name="field_skip_vectorization", data_type=DataType.TEXT,
skip_vectorization=True),
Property(name="text", data_type=DataType.TEXT),
Property(name="texts", data_type=DataType.TEXT_ARRAY),
Property(name="number", data_type=DataType.NUMBER),
Property(name="numbers", data_type=DataType.NUMBER_ARRAY),
Property(name="int", data_type=DataType.INT),
Property(name="ints", data_type=DataType.INT_ARRAY),
Property(name="date", data_type=DataType.DATE),
Property(name="dates", data_type=DataType.DATE_ARRAY),
Property(name="boolean", data_type=DataType.BOOL),
Property(name="booleans", data_type=DataType.BOOL_ARRAY),
Property(name="geo", data_type=DataType.GEO_COORDINATES),
Property(name="phone", data_type=DataType.PHONE_NUMBER),
Property(name="vectorize_property_name", data_type=DataType.TEXT,
vectorize_property_name=False),
Property(name="field_index_searchable", data_type=DataType.TEXT,
index_searchable=False),
# TODO: this will fail
# Property(
# name="name",
# data_type=DataType.OBJECT,
# nested_properties=[
# Property(name="first", data_type=DataType.TEXT),
# Property(name="last", data_type=DataType.TEXT),
# ],
# ),
],
return_client=True
)
old_dict = collection.config.get().to_dict()
new_dict = old_dict
new_collection_name = collection.name + "_FROM_DICT"
client.collections.delete(new_collection_name)
new_dict["class"] = new_collection_name
new_collection = client.collections.create_from_dict(new_dict)
new_collection_dict = new_collection.config.get().to_dict()
# make the same name for collections
new_collection_dict["class"] = collection.name
old_dict["class"] = collection.name
# check if both dict are the same
#print("old", old_dict)
#print("new", new_collection_dict)
assert new_collection_dict == old_dict
# remove the created collection
client.collections.delete(new_collection_name)

def test_config_vector_index_flat_and_quantizer_bq(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
vector_index_config=Configure.VectorIndex.flat(
4 changes: 3 additions & 1 deletion weaviate/client.py
Original file line number Diff line number Diff line change
@@ -157,6 +157,7 @@ def __init__(
additional_headers: Optional[dict] = None,
additional_config: Optional[AdditionalConfig] = None,
skip_init_checks: bool = False,
disable_ssl_verification: bool = False,
) -> None:
"""Initialise a WeaviateClient class instance to use when interacting with Weaviate.

@@ -191,6 +192,7 @@ def __init__(
config = additional_config or AdditionalConfig()

self.__skip_init_checks = skip_init_checks
self.__disable_ssl_verification = disable_ssl_verification

self._connection = ConnectionV4( # pyright: ignore reportIncompatibleVariableOverride
connection_params=connection_params,
@@ -284,7 +286,7 @@ def connect(self) -> None:
"""
if self._connection.is_connected():
return
self._connection.connect(self.__skip_init_checks)
self._connection.connect(self.__skip_init_checks, self.__disable_ssl_verification)

def is_connected(self) -> bool:
"""Check if the client is connected to Weaviate.
1 change: 1 addition & 0 deletions weaviate/collections/batch/base.py
Original file line number Diff line number Diff line change
@@ -285,6 +285,7 @@ def _shutdown(self) -> None:

def __batch_send(self) -> None:
loop = self.__start_new_event_loop()
# TODO: figure a way to pass disable_verification_process to aopen
future = asyncio.run_coroutine_threadsafe(self.__connection.aopen(), loop)
future.result() # Wait for self._connection.aopen() to finish
refresh_time: float = 0.01
4 changes: 2 additions & 2 deletions weaviate/collections/classes/config.py
Original file line number Diff line number Diff line change
@@ -1016,8 +1016,8 @@ def to_dict(self) -> Dict[str, Any]:
out = super().to_dict()
out["dataType"] = [self.data_type.value]
out["indexFilterable"] = self.index_filterable
out["indexVector"] = self.index_searchable
out["tokenizer"] = self.tokenization.value if self.tokenization else None
out["indexSearchable"] = self.index_searchable
out["tokenization"] = self.tokenization.value if self.tokenization else None

module_config: Dict[str, Any] = {}
if self.vectorizer is not None:
14 changes: 13 additions & 1 deletion weaviate/connect/helpers.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@ def connect_to_weaviate_cloud(
headers: Optional[Dict[str, str]] = None,
additional_config: Optional[AdditionalConfig] = None,
skip_init_checks: bool = False,
disable_ssl_verification: bool = False,
) -> WeaviateClient:
"""
Connect to a Weaviate Cloud (WCD) instance.
@@ -81,6 +82,7 @@ def connect_to_weaviate_cloud(
additional_headers=headers,
additional_config=additional_config,
skip_init_checks=skip_init_checks,
disable_ssl_verification=disable_ssl_verification,
)
return __connect(client)

@@ -91,6 +93,7 @@ def connect_to_wcs(
headers: Optional[Dict[str, str]] = None,
additional_config: Optional[AdditionalConfig] = None,
skip_init_checks: bool = False,
disable_ssl_verification: bool = False,
) -> WeaviateClient:
"""
Connect to a Weaviate Cloud (WCD) instance.
@@ -137,7 +140,12 @@ def connect_to_wcs(
>>> # The connection is automatically closed when the context is exited.
"""
return connect_to_weaviate_cloud(
cluster_url, auth_credentials, headers, additional_config, skip_init_checks
cluster_url,
auth_credentials,
headers,
additional_config,
skip_init_checks,
disable_ssl_verification,
)


@@ -148,6 +156,7 @@ def connect_to_local(
headers: Optional[Dict[str, str]] = None,
additional_config: Optional[AdditionalConfig] = None,
skip_init_checks: bool = False,
disable_ssl_verification: bool = False,
auth_credentials: Optional[AuthCredentials] = None,
) -> WeaviateClient:
"""
@@ -208,6 +217,7 @@ def connect_to_local(
additional_headers=headers,
additional_config=additional_config,
skip_init_checks=skip_init_checks,
disable_ssl_verification=disable_ssl_verification,
auth_client_secret=auth_credentials,
)
return __connect(client)
@@ -310,6 +320,7 @@ def connect_to_custom(
additional_config: Optional[AdditionalConfig] = None,
auth_credentials: Optional[AuthCredentials] = None,
skip_init_checks: bool = False,
disable_ssl_verification: bool = False,
) -> WeaviateClient:
"""
Connect to a Weaviate instance with custom connection parameters.
@@ -388,6 +399,7 @@ def connect_to_custom(
additional_headers=headers,
additional_config=additional_config,
skip_init_checks=skip_init_checks,
disable_ssl_verification=disable_ssl_verification,
)
return __connect(client)

43 changes: 24 additions & 19 deletions weaviate/connect/v4.py
Original file line number Diff line number Diff line change
@@ -132,10 +132,10 @@ def __init__(
if auth_client_secret is not None and isinstance(auth_client_secret, AuthApiKey):
self._headers["authorization"] = "Bearer " + auth_client_secret.api_key

def connect(self, skip_init_checks: bool) -> None:
def connect(self, skip_init_checks: bool, disable_ssl_verification: bool) -> None:
if self.embedded_db is not None:
self.embedded_db.start()
self._create_clients(self._auth, skip_init_checks)
self._create_clients(self._auth, skip_init_checks, disable_ssl_verification)
self.__connected = True
if self.embedded_db is not None:
try:
@@ -214,46 +214,51 @@ def __make_mounts(
if key != "grpc"
}

def __make_sync_client(self) -> Client:
def __make_sync_client(self, disable_ssl_verification: bool) -> Client:
return Client(
headers=self._headers,
timeout=Timeout(
None, connect=self.timeout_config.query, read=self.timeout_config.insert
),
mounts=self.__make_mounts("sync"),
verify=not disable_ssl_verification,
)

def __make_async_client(self) -> AsyncClient:
def __make_async_client(self, disable_ssl_verification: bool) -> AsyncClient:
return AsyncClient(
headers=self._headers,
timeout=Timeout(
None, connect=self.timeout_config.query, read=self.timeout_config.insert
),
mounts=self.__make_mounts("async"),
verify=not disable_ssl_verification,
)

def __make_clients(self) -> None:
self._client = self.__make_sync_client()
def __make_clients(self, disable_ssl_verification: bool) -> None:
self._client = self.__make_sync_client(disable_ssl_verification)

def _create_clients(
self, auth_client_secret: Optional[AuthCredentials], skip_init_checks: bool
self,
auth_client_secret: Optional[AuthCredentials],
skip_init_checks: bool,
disable_ssl_verification: bool,
) -> None:
# API keys are separate from OIDC and do not need any config from weaviate
if auth_client_secret is not None and isinstance(auth_client_secret, AuthApiKey):
self.__make_clients()
self.__make_clients(disable_ssl_verification)
return

if "authorization" in self._headers and auth_client_secret is None:
self.__make_clients()
self.__make_clients(disable_ssl_verification)
return

# no need to check OIDC if no auth is provided and users dont want any checks at initialization time
if skip_init_checks and auth_client_secret is None:
self.__make_clients()
self.__make_clients(disable_ssl_verification)
return

oidc_url = self.url + self._api_version_path + "/.well-known/openid-configuration"
with self.__make_sync_client() as client:
with self.__make_sync_client(disable_ssl_verification=disable_ssl_verification) as client:
try:
response = client.get(oidc_url)
except Exception as e:
@@ -269,7 +274,7 @@ def _create_clients(
resp = response.json()
except Exception:
_Warnings.auth_cannot_parse_oidc_config(oidc_url)
self.__make_clients()
self.__make_clients(disable_ssl_verification=disable_ssl_verification)
return

if auth_client_secret is not None:
@@ -309,9 +314,9 @@ def _create_clients(
raise AuthenticationFailedError(msg)
elif response.status_code == 404 and auth_client_secret is not None:
_Warnings.auth_with_anon_weaviate()
self.__make_clients()
self.__make_clients(disable_ssl_verification)
else:
self.__make_clients()
self.__make_clients(disable_ssl_verification)

def get_current_bearer_token(self) -> str:
if not self.is_connected():
@@ -376,9 +381,9 @@ def periodic_refresh_token(refresh_time: int, _auth: Optional[_Auth[OAuth2Client
)
demon.start()

async def aopen(self) -> None:
async def aopen(self, disable_ssl_verification: bool = False) -> None:
if self._aclient is None:
self._aclient = await self.__make_async_client().__aenter__()
self._aclient = await self.__make_async_client(disable_ssl_verification).__aenter__()
if self._grpc_stub_async is None:
self._grpc_channel_async = self._connection_params._grpc_channel(
async_channel=True, proxies=self._proxies
@@ -453,7 +458,7 @@ def __send(
except RuntimeError as e:
raise WeaviateClosedClientError() from e
except ConnectError as conn_err:
raise WeaviateConnectionError(error_msg) from conn_err
raise WeaviateConnectionError(f"{conn_err} {error_msg}")

def delete(
self,
@@ -707,8 +712,8 @@ def _ping_grpc(self) -> None:
f"v{self.server_version}", self._connection_params._grpc_address
) from e

def connect(self, skip_init_checks: bool) -> None:
super().connect(skip_init_checks)
def connect(self, skip_init_checks: bool, disable_ssl_verification: bool) -> None:
super().connect(skip_init_checks, disable_ssl_verification)
# create GRPC channel. If Weaviate does not support GRPC then error now.
self._grpc_channel = self._connection_params._grpc_channel(
async_channel=False, proxies=self._proxies
Loading