|
11 | 11 | CANONICAL_VECTOR_VALUES = {
|
12 | 12 | "jinaai/jina-embeddings-v3": [
|
13 | 13 | {
|
14 |
| - "task_id": 0, |
| 14 | + "task_id": Task.RETRIEVAL_QUERY, |
15 | 15 | "vectors": np.array(
|
16 | 16 | [
|
17 | 17 | [0.0623, -0.0402, 0.1706, -0.0143, 0.0617],
|
|
20 | 20 | ),
|
21 | 21 | },
|
22 | 22 | {
|
23 |
| - "task_id": 1, |
| 23 | + "task_id": Task.RETRIEVAL_PASSAGE, |
24 | 24 | "vectors": np.array(
|
25 | 25 | [
|
26 | 26 | [0.0513, -0.0247, 0.1751, -0.0075, 0.0679],
|
|
29 | 29 | ),
|
30 | 30 | },
|
31 | 31 | {
|
32 |
| - "task_id": 2, |
| 32 | + "task_id": Task.SEPARATION, |
33 | 33 | "vectors": np.array(
|
34 | 34 | [
|
35 | 35 | [0.094, -0.1065, 0.1305, 0.0547, 0.0556],
|
|
38 | 38 | ),
|
39 | 39 | },
|
40 | 40 | {
|
41 |
| - "task_id": 3, |
| 41 | + "task_id": Task.CLASSIFICATION, |
42 | 42 | "vectors": np.array(
|
43 | 43 | [
|
44 | 44 | [0.0606, -0.0877, 0.1384, 0.0065, 0.0722],
|
|
47 | 47 | ),
|
48 | 48 | },
|
49 | 49 | {
|
50 |
| - "task_id": 4, |
| 50 | + "task_id": Task.TEXT_MATCHING, |
51 | 51 | "vectors": np.array(
|
52 | 52 | [
|
53 | 53 | [0.0911, -0.0341, 0.1305, -0.026, 0.0576],
|
|
63 | 63 | def test_batch_embedding():
|
64 | 64 | is_ci = os.getenv("CI")
|
65 | 65 | docs_to_embed = docs * 10
|
66 |
| - default_task = 4 |
| 66 | + default_task = Task.TEXT_MATCHING |
67 | 67 |
|
68 | 68 | for model_desc in TextEmbedding.list_supported_models():
|
69 | 69 | if not is_ci and model_desc["size_in_GB"] > 1:
|
@@ -127,7 +127,7 @@ def test_single_embedding():
|
127 | 127 |
|
128 | 128 | def test_single_embedding_query():
|
129 | 129 | is_ci = os.getenv("CI")
|
130 |
| - task_id = 0 |
| 130 | + task_id = Task.RETRIEVAL_QUERY |
131 | 131 |
|
132 | 132 | for model_desc in TextEmbedding.list_supported_models():
|
133 | 133 | if not is_ci and model_desc["size_in_GB"] > 1:
|
@@ -159,7 +159,7 @@ def test_single_embedding_query():
|
159 | 159 |
|
160 | 160 | def test_single_embedding_passage():
|
161 | 161 | is_ci = os.getenv("CI")
|
162 |
| - task_id = 1 |
| 162 | + task_id = Task.RETRIEVAL_PASSAGE |
163 | 163 |
|
164 | 164 | for model_desc in TextEmbedding.list_supported_models():
|
165 | 165 | if not is_ci and model_desc["size_in_GB"] > 1:
|
@@ -202,19 +202,9 @@ def test_task_assignment():
|
202 | 202 |
|
203 | 203 | model = TextEmbedding(model_name=model_name)
|
204 | 204 |
|
205 |
| - _ = list(model.embed(documents=docs, batch_size=1, task_id=2)) |
206 |
| - assert model.model._current_task_id == Task.SEPARATION |
207 |
| - |
208 |
| - _ = list( |
209 |
| - model.embed(documents=docs, batch_size=1, parallel=1, task_id=Task.CLASSIFICATION) |
210 |
| - ) |
211 |
| - assert model.model._current_task_id == 3 |
212 |
| - |
213 |
| - _ = list(model.query_embed(query=docs)) |
214 |
| - assert model.model._current_task_id == Task.RETRIEVAL_QUERY |
215 |
| - |
216 |
| - _ = list(model.passage_embed(texts=docs)) |
217 |
| - assert model.model._current_task_id == Task.RETRIEVAL_PASSAGE |
| 205 | + for i, task_id in enumerate(Task): |
| 206 | + _ = list(model.embed(documents=docs, batch_size=1, task_id=i)) |
| 207 | + assert model.model._current_task_id == task_id |
218 | 208 |
|
219 | 209 | if is_ci:
|
220 | 210 | delete_model_cache(model.model._model_dir)
|
|
0 commit comments