Skip to content

Commit 91afca7

Browse files
committed
prefer enums over ints
1 parent 89cf732 commit 91afca7

File tree

1 file changed

+11
-21
lines changed

1 file changed

+11
-21
lines changed

tests/test_text_multitask_embeddings.py

+11-21
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
CANONICAL_VECTOR_VALUES = {
1212
"jinaai/jina-embeddings-v3": [
1313
{
14-
"task_id": 0,
14+
"task_id": Task.RETRIEVAL_QUERY,
1515
"vectors": np.array(
1616
[
1717
[0.0623, -0.0402, 0.1706, -0.0143, 0.0617],
@@ -20,7 +20,7 @@
2020
),
2121
},
2222
{
23-
"task_id": 1,
23+
"task_id": Task.RETRIEVAL_PASSAGE,
2424
"vectors": np.array(
2525
[
2626
[0.0513, -0.0247, 0.1751, -0.0075, 0.0679],
@@ -29,7 +29,7 @@
2929
),
3030
},
3131
{
32-
"task_id": 2,
32+
"task_id": Task.SEPARATION,
3333
"vectors": np.array(
3434
[
3535
[0.094, -0.1065, 0.1305, 0.0547, 0.0556],
@@ -38,7 +38,7 @@
3838
),
3939
},
4040
{
41-
"task_id": 3,
41+
"task_id": Task.CLASSIFICATION,
4242
"vectors": np.array(
4343
[
4444
[0.0606, -0.0877, 0.1384, 0.0065, 0.0722],
@@ -47,7 +47,7 @@
4747
),
4848
},
4949
{
50-
"task_id": 4,
50+
"task_id": Task.TEXT_MATCHING,
5151
"vectors": np.array(
5252
[
5353
[0.0911, -0.0341, 0.1305, -0.026, 0.0576],
@@ -63,7 +63,7 @@
6363
def test_batch_embedding():
6464
is_ci = os.getenv("CI")
6565
docs_to_embed = docs * 10
66-
default_task = 4
66+
default_task = Task.TEXT_MATCHING
6767

6868
for model_desc in TextEmbedding.list_supported_models():
6969
if not is_ci and model_desc["size_in_GB"] > 1:
@@ -127,7 +127,7 @@ def test_single_embedding():
127127

128128
def test_single_embedding_query():
129129
is_ci = os.getenv("CI")
130-
task_id = 0
130+
task_id = Task.RETRIEVAL_QUERY
131131

132132
for model_desc in TextEmbedding.list_supported_models():
133133
if not is_ci and model_desc["size_in_GB"] > 1:
@@ -159,7 +159,7 @@ def test_single_embedding_query():
159159

160160
def test_single_embedding_passage():
161161
is_ci = os.getenv("CI")
162-
task_id = 1
162+
task_id = Task.RETRIEVAL_PASSAGE
163163

164164
for model_desc in TextEmbedding.list_supported_models():
165165
if not is_ci and model_desc["size_in_GB"] > 1:
@@ -202,19 +202,9 @@ def test_task_assignment():
202202

203203
model = TextEmbedding(model_name=model_name)
204204

205-
_ = list(model.embed(documents=docs, batch_size=1, task_id=2))
206-
assert model.model._current_task_id == Task.SEPARATION
207-
208-
_ = list(
209-
model.embed(documents=docs, batch_size=1, parallel=1, task_id=Task.CLASSIFICATION)
210-
)
211-
assert model.model._current_task_id == 3
212-
213-
_ = list(model.query_embed(query=docs))
214-
assert model.model._current_task_id == Task.RETRIEVAL_QUERY
215-
216-
_ = list(model.passage_embed(texts=docs))
217-
assert model.model._current_task_id == Task.RETRIEVAL_PASSAGE
205+
for i, task_id in enumerate(Task):
206+
_ = list(model.embed(documents=docs, batch_size=1, task_id=i))
207+
assert model.model._current_task_id == task_id
218208

219209
if is_ci:
220210
delete_model_cache(model.model._model_dir)

0 commit comments

Comments
 (0)