Skip to content

Commit e98b0aa

Browse files
committed
Refactor
1 parent 9f21009 commit e98b0aa

File tree

2 files changed

+67
-51
lines changed

2 files changed

+67
-51
lines changed

weaviate/collections/batch/grpc_batch_objects.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import struct
33
import time
44
import uuid as uuid_package
5-
from typing import Any, Dict, List, Optional, Union, cast
5+
from typing import Any, Dict, List, Optional, Tuple, Union, cast
66

77
from grpc.aio import AioRpcError # type: ignore
88
from google.protobuf.struct_pb2 import Struct
@@ -24,17 +24,7 @@
2424
WeaviateInvalidInputError,
2525
)
2626
from weaviate.proto.v1 import batch_pb2, base_pb2
27-
from weaviate.util import _datetime_to_string, _get_vector_v4
28-
29-
30-
def _pack_named_vectors(vectors: Dict[str, List[float]]) -> List[base_pb2.Vectors]:
31-
return [
32-
base_pb2.Vectors(
33-
name=name,
34-
vector_bytes=struct.pack("{}f".format(len(vector)), *vector),
35-
)
36-
for name, vector in vectors.items()
37-
]
27+
from weaviate.util import _datetime_to_string
3828

3929

4030
class _BatchGRPC(_BaseGRPC):
@@ -47,11 +37,10 @@ class _BatchGRPC(_BaseGRPC):
4737
def __init__(self, connection: ConnectionV4, consistency_level: Optional[ConsistencyLevel]):
4838
super().__init__(connection, consistency_level)
4939

50-
def __grpc_objects(self, objects: List[_BatchObject]) -> List[batch_pb2.BatchObject]:
51-
def pack_vector(vector: Any) -> bytes:
52-
vector_list = _get_vector_v4(vector)
53-
return struct.pack("{}f".format(len(vector_list)), *vector_list)
54-
40+
def __grpc_objects(
41+
self,
42+
objects: List[Tuple[_BatchObject, Optional[bytes], Optional[List[base_pb2.Vectors]]]],
43+
) -> List[batch_pb2.BatchObject]:
5544
return [
5645
batch_pb2.BatchObject(
5746
collection=obj.collection,
@@ -65,22 +54,16 @@ def pack_vector(vector: Any) -> bytes:
6554
else None
6655
),
6756
tenant=obj.tenant,
68-
vector_bytes=(
69-
pack_vector(obj.vector)
70-
if obj.vector is not None and not isinstance(obj.vector, dict)
71-
else None
72-
),
73-
vectors=(
74-
_pack_named_vectors(obj.vector)
75-
if obj.vector is not None and isinstance(obj.vector, dict)
76-
else None
77-
),
57+
vector_bytes=vector_bytes,
58+
vectors=vectors,
7859
)
79-
for obj in objects
60+
for obj, vector_bytes, vectors in objects
8061
]
8162

8263
async def objects(
83-
self, objects: List[_BatchObject], timeout: Union[int, float]
64+
self,
65+
objects: List[Tuple[_BatchObject, Optional[bytes], Optional[List[base_pb2.Vectors]]]],
66+
timeout: Union[int, float]
8467
) -> BatchObjectReturn:
8568
"""Insert multiple objects into Weaviate through the gRPC API.
8669
@@ -114,7 +97,7 @@ async def objects(
11497
return_errors: Dict[int, ErrorObject] = {}
11598

11699
for idx, weav_obj in enumerate(weaviate_objs):
117-
obj = objects[idx]
100+
obj = objects[idx][0]
118101
if idx in errors:
119102
error = ErrorObject(errors[idx], obj, original_uuid=obj.uuid)
120103
return_errors[obj.index] = error

weaviate/collections/data/data.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import datetime
3+
import struct
34
import uuid as uuid_package
45
from typing import (
56
Dict,
@@ -47,6 +48,7 @@
4748
from weaviate.connect import ConnectionV4
4849
from weaviate.connect.v4 import _ExpectedStatusCodes
4950
from weaviate.logger import logger
51+
from weaviate.proto.v1 import base_pb2
5052
from weaviate.types import BEACON, UUID, VECTORS
5153
from weaviate.util import _datetime_to_string, _get_vector_v4
5254
from weaviate.validator import _validate_input, _ValidateArgument
@@ -57,6 +59,21 @@
5759
from weaviate.exceptions import WeaviateInvalidInputError
5860

5961

62+
def _pack_named_vectors(vectors: Dict[str, List[float]]) -> List[base_pb2.Vectors]:
63+
return [
64+
base_pb2.Vectors(
65+
name=name,
66+
vector_bytes=struct.pack("{}f".format(len(vector)), *vector),
67+
)
68+
for name, vector in vectors.items()
69+
]
70+
71+
72+
def _pack_vector(vector: Any) -> bytes:
73+
vector_list = _get_vector_v4(vector)
74+
return struct.pack("{}f".format(len(vector_list)), *vector_list)
75+
76+
6077
class _DataBase:
6178
def __init__(
6279
self,
@@ -281,6 +298,42 @@ def with_data_model(self, data_model: Type[TProperties]) -> "_DataCollectionAsyn
281298
data_model,
282299
)
283300

301+
def __validate_vector(
302+
self,
303+
idx: int,
304+
obj: Union[Properties, DataObject[Properties, Optional[ReferenceInputs]]]
305+
) -> Tuple[_BatchObject, Optional[bytes], Optional[List[base_pb2.Vectors]]]:
306+
if isinstance(obj, DataObject):
307+
vector_bytes = (
308+
_pack_vector(obj.vector)
309+
if obj.vector is not None and not isinstance(obj.vector, dict)
310+
else None
311+
)
312+
vectors = (
313+
_pack_named_vectors(obj.vector)
314+
if obj.vector is not None and isinstance(obj.vector, dict)
315+
else None
316+
)
317+
return _BatchObject(
318+
collection=self.name,
319+
vector=obj.vector,
320+
uuid=str(obj.uuid if obj.uuid is not None else uuid_package.uuid4()),
321+
properties=cast(dict, obj.properties),
322+
tenant=self._tenant,
323+
references=obj.references,
324+
index=idx,
325+
), vector_bytes, vectors
326+
327+
return _BatchObject(
328+
collection=self.name,
329+
vector=None,
330+
uuid=str(uuid_package.uuid4()),
331+
properties=cast(dict, obj),
332+
tenant=self._tenant,
333+
references=None,
334+
index=idx,
335+
), None, None
336+
284337
def __parse_vector(self, obj: Dict[str, Any], vector: VECTORS) -> Dict[str, Any]:
285338
if isinstance(vector, dict):
286339
obj["vectors"] = {key: _get_vector_v4(val) for key, val in vector.items()}
@@ -360,27 +413,7 @@ async def insert_many(
360413
If every object in the batch fails to be inserted. The exception message contains details about the failure.
361414
"""
362415
objs = [
363-
(
364-
_BatchObject(
365-
collection=self.name,
366-
vector=obj.vector,
367-
uuid=str(obj.uuid if obj.uuid is not None else uuid_package.uuid4()),
368-
properties=cast(dict, obj.properties),
369-
tenant=self._tenant,
370-
references=obj.references,
371-
index=idx,
372-
)
373-
if isinstance(obj, DataObject)
374-
else _BatchObject(
375-
collection=self.name,
376-
vector=None,
377-
uuid=str(uuid_package.uuid4()),
378-
properties=cast(dict, obj),
379-
tenant=self._tenant,
380-
references=None,
381-
index=idx,
382-
)
383-
)
416+
self.__validate_vector(idx, obj)
384417
for idx, obj in enumerate(objects)
385418
]
386419
res = await self._batch_grpc.objects(objs, timeout=self._connection.timeout_config.insert)

0 commit comments

Comments
 (0)