From d45e20e74dec61f0f0f233ddaca348849f06adbd Mon Sep 17 00:00:00 2001 From: lingzhq Date: Tue, 22 Jul 2025 15:39:35 +0800 Subject: [PATCH 1/2] Add an op GroupDiversityFilter --- data_juicer/config/config_all.yaml | 13 +- data_juicer/ops/filter/__init__.py | 4 +- .../ops/filter/group_diversity_filter.py | 140 ++++++++++++++++++ data_juicer/utils/constant.py | 2 + .../ops/filter/test_group_diversity_filter.py | 76 ++++++++++ 5 files changed, 232 insertions(+), 3 deletions(-) create mode 100644 data_juicer/ops/filter/group_diversity_filter.py create mode 100644 tests/ops/filter/test_group_diversity_filter.py diff --git a/data_juicer/config/config_all.yaml b/data_juicer/config/config_all.yaml index 13ddb849e3a..9e3d52109fc 100644 --- a/data_juicer/config/config_all.yaml +++ b/data_juicer/config/config_all.yaml @@ -723,8 +723,17 @@ process: use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese words_aug_group_sizes: [2] # the group size of words to augment words_aug_join_char: "" # the join char between words to augment - - general_field_filter: # Filter to keep samples based on a general field filter condition. - filter_condition: "" # The filter condition as a string. It can include logical operators (and/or) and chain comparisons. + - general_field_filter: # Filter to keep samples based on a general field filter condition. + filter_condition: "" # The filter condition as a string. It can include logical operators (and/or) and chain comparisons. + - group_diversity_filter: # filter samples based on their semantic diversity within a group. + api_or_hf_model: "text-embedding-v3" # API or huggingface embedding model name. + is_hf_model: false # indicates if the model is from HuggingFace. + api_endpoint: "/embeddings" # embedding URL endpoint for the API. + response_path: "data.0.embedding" # path to extract content from the API response. + ebd_dim: 512 # the embedding's dimension via API. + min_score: 0.0 # the min score of filter range + max_score: 1.0 # the max score of filter range + norm_ratio: 0.5 # ratio to normalize the score. - image_aesthetics_filter: # filter samples according to the aesthetics score of images. hf_scorer_model: shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE # Huggingface model name for the aesthetics predictor min_score: 0.3 # the min aesthetics score of filter range diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index 2825ed1c01a..43b8ae6330d 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -6,6 +6,7 @@ from .character_repetition_filter import CharacterRepetitionFilter from .flagged_words_filter import FlaggedWordFilter from .general_field_filter import GeneralFieldFilter +from .group_diversity_filter import GroupDiversityFilter from .image_aesthetics_filter import ImageAestheticsFilter from .image_aspect_ratio_filter import ImageAspectRatioFilter from .image_face_count_filter import ImageFaceCountFilter @@ -65,6 +66,8 @@ "AverageLineLengthFilter", "CharacterRepetitionFilter", "FlaggedWordFilter", + "GeneralFieldFilter", + "GroupDiversityFilter", "ImageAestheticsFilter", "ImageAspectRatioFilter", "ImageFaceCountFilter", @@ -113,7 +116,6 @@ "VideoWatermarkFilter", "WordRepetitionFilter", "WordsNumFilter", - "GeneralFieldFilter", ] NON_STATS_FILTERS = [ diff --git a/data_juicer/ops/filter/group_diversity_filter.py b/data_juicer/ops/filter/group_diversity_filter.py new file mode 100644 index 00000000000..0fd81f67b85 --- /dev/null +++ b/data_juicer/ops/filter/group_diversity_filter.py @@ -0,0 +1,140 @@ +import sys +from typing import Dict, List + +import numpy as np +from jsonargparse.typing import NonNegativeFloat, PositiveInt +from tqdm import tqdm + +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Filter + +# Lazy load torch to improve startup time +torch = LazyLoader("torch") + + +@OPERATORS.register_module("group_diversity_filter") +class GroupDiversityFilter(Filter): + """ + Filter samples based on their semantic diversity within a group. + """ + + _accelerator = "cuda" + _batched_op = True + + def __init__( + self, + api_or_hf_model: str = "text-embedding-v3", + is_hf_model: bool = False, + api_endpoint: str = "/embeddings", + response_path: str = "data.0.embedding", + model_params: Dict = {}, + ebd_dim: PositiveInt = 512, + min_score: NonNegativeFloat = 0.0, + max_score: NonNegativeFloat = 1.0, + norm_ratio: NonNegativeFloat = 0.5, + *args, + **kwargs, + ): + """ + Initialization method. + + :param api_or_hf_model: API or huggingface embedding model name. + :param is_hf_model: Indicates if the model is from HuggingFace. + :param api_endpoint: Embedding URL endpoint for the API. + :param response_path: Path to extract content from the API response. + Defaults to 'data.0.embedding' for embedding model. + :param model_params: Parameters for initializing the API model. + :param ebd_dim: The embedding's dimension via API. + :param min_score: Minimum score for filtering. + :param max_score: Maximum score for filtering. + :param norm_ratio: Ratio to normalize the score. + :param args: extra args + :param kwargs: extra args + """ + kwargs.setdefault("mem_required", "20GB") + super().__init__(*args, **kwargs) + + self.min_score = min_score + self.max_score = max_score + self.norm_ratio = norm_ratio + self.is_hf_model = is_hf_model + self.ebd_dim = ebd_dim + + if self.is_hf_model: + self.model_key = prepare_model(model_type="embedding", model_path=api_or_hf_model, **model_params) + else: + self.model_key = prepare_model( + model_type="api", + model=api_or_hf_model, + endpoint=api_endpoint, + response_path=response_path, + **model_params, + ) + + def _embed_texts(self, texts: List[str], rank: int) -> np.ndarray: + # Embed a list of texts using the initialized model + embeddings = [] + model = get_model(self.model_key, rank, self.use_cuda()) + + for text in tqdm(texts, desc="Embedding texts", leave=False): + try: + if self.is_hf_model: + embedding = model.encode(text) + else: + embedding = model(text, dimensions=self.ebd_dim, encoding_format="float") + embeddings.append(np.array(embedding, dtype=np.float32)) + except Exception as e: + dim = model.get_sentence_embedding_dimension() if self.is_hf_model else self.ebd_dim + embeddings.append(np.zeros(dim, dtype=np.float32)) + print(f"Failed to embed text: '{text}'. Error: {e}. Using zero vector.", file=sys.stderr) + + return np.array(embeddings) + + def compute_stats_batched(self, samples: Dict, rank: int = 0) -> Dict: + stats_list = samples[Fields.stats] + if stats_list and StatsKeys.text_ebd_diversity_score in stats_list[0]: + return samples + + texts_to_embed = samples[self.text_key] + if not texts_to_embed: + for stat in stats_list: + stat[StatsKeys.text_ebd_diversity] = 0.0 + stat[StatsKeys.text_ebd_diversity_score] = 0.0 + return samples + + embeddings_array = self._embed_texts(texts_to_embed, rank=rank) + + avg_embedding = np.mean(embeddings_array, axis=0) + + cos_sims = ( + torch.nn.functional.cosine_similarity( + torch.from_numpy(embeddings_array), torch.from_numpy(avg_embedding).unsqueeze(0), dim=1 + ) + .cpu() + .numpy() + .tolist() + ) + + min_sim, max_sim = min(cos_sims), max(cos_sims) + range_sim = max_sim - min_sim + + normalized_scores = [] + if range_sim < 1e-8: + normalized_scores = [0.0] * len(cos_sims) + else: + for sim in cos_sims: + normalized_sim = self.norm_ratio * (max_sim - sim) / range_sim + normalized_scores.append(normalized_sim) + + for i, stat in enumerate(stats_list): + stat[StatsKeys.text_ebd_diversity] = cos_sims[i] + stat[StatsKeys.text_ebd_diversity_score] = normalized_scores[i] + + return samples + + def process_batched(self, samples: Dict) -> List[bool]: + stats_list = samples[Fields.stats] + return [self.min_score <= stat[StatsKeys.text_ebd_diversity_score] <= self.max_score for stat in stats_list] diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 73fd3c93e37..bf84d4e794b 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -267,6 +267,8 @@ class StatsKeysConstant(object): llm_perplexity = "llm_perplexity" llm_task_relevance = "llm_task_relevance" llm_task_relevance_record = "llm_task_relevance_record" + text_ebd_diversity = "text_ebd_diversity" + text_ebd_diversity_score = "text_ebd_diversity_score" # === image === aspect_ratios = "aspect_ratios" diff --git a/tests/ops/filter/test_group_diversity_filter.py b/tests/ops/filter/test_group_diversity_filter.py new file mode 100644 index 00000000000..94e520ebe4d --- /dev/null +++ b/tests/ops/filter/test_group_diversity_filter.py @@ -0,0 +1,76 @@ +import os +import unittest +from loguru import logger +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.filter.group_diversity_filter import GroupDiversityFilter +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, FROM_FORK + +@unittest.skipIf(FROM_FORK, "Skipping the test because running from a fork repo") +class GroupDiversityFilterTest(DataJuicerTestCaseBase): + # before running this test, set below environment variables: + # export OPENAI_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1/ + # export OPENAI_API_KEY=your_dashscope_key + api_model_name = 'text-embedding-v3' + api_ebd_dim = 512 + + # For local Hugging Face model test + hf_model_path = 'iic/gte_Qwen2-1.5B-instruct' + def setUp(self): + self.ds_list = [{ + 'text': "A cute cat is playing in the garden." + }, { + 'text': "A lovely dog is running on the grass." + }, { + 'text': "A beautiful bird is singing on the tree." + }, { + 'text': "Quantum computing is a complex field of physics." # The outlier + }] + self.dataset = Dataset.from_list(self.ds_list) + + def test_api_based_diversity_logic(self): + if not os.getenv('OPENAI_API_KEY'): + self.skipTest("OPENAI_API_KEY environment variable is not set. " + "Skipping API-based integration test.") + + logger.info(f"Running diversity test with API model: {self.api_model_name}") + op = GroupDiversityFilter( + api_or_hf_model=self.api_model_name, + is_hf_model=False, + ebd_dim=self.api_ebd_dim + ) + self._run_and_assert_diversity(op) + + def test_hf_based_diversity_logic(self): + logger.info(f"Running diversity test with HF model: {self.hf_model_path}") + op = GroupDiversityFilter( + api_or_hf_model=self.hf_model_path, + is_hf_model=True, + ) + self._run_and_assert_diversity(op) + + def _run_and_assert_diversity(self, op: GroupDiversityFilter): + dataset = self.dataset.add_column(name=Fields.stats, column=[{}] * len(self.dataset)) + dataset = dataset.map(op.compute_stats_batched, + with_rank=True, + batched=True, + batch_size=len(self.dataset)) + + stats_list = dataset.to_list() + for sample in stats_list: + logger.info(f"Text: '{sample['text']}', " + f"Score: {sample[Fields.stats].get(StatsKeys.text_ebd_diversity_score, 'N/A')}") + + scores = [d[Fields.stats][StatsKeys.text_ebd_diversity_score] for d in stats_list] + + outlier_score = scores[-1] + other_scores = scores[:-1] + + self.assertTrue(all(outlier_score > score for score in other_scores), + "The outlier sample did not receive the highest diversity score.") + + logger.info("Test passed: The outlier sample correctly received the highest diversity score.") + + +if __name__ == '__main__': + unittest.main() From b195fb95c0e9e2f9f640b952bbdcb0977435c02c Mon Sep 17 00:00:00 2001 From: lingzhq Date: Mon, 9 Mar 2026 20:54:20 +0800 Subject: [PATCH 2/2] [Fix] Rebase and refactor group_diversity_filter --- .../ops/filter/group_diversity_filter.py | 39 +++++++------- data_juicer/utils/constant.py | 3 +- docs/Operators.md | 3 +- .../filter/group_diversity_filter.md | 51 +++++++++++++++++++ .../ops/filter/test_group_diversity_filter.py | 10 ++-- 5 files changed, 81 insertions(+), 25 deletions(-) create mode 100644 docs/operators/filter/group_diversity_filter.md diff --git a/data_juicer/ops/filter/group_diversity_filter.py b/data_juicer/ops/filter/group_diversity_filter.py index 0fd81f67b85..f054dac3572 100644 --- a/data_juicer/ops/filter/group_diversity_filter.py +++ b/data_juicer/ops/filter/group_diversity_filter.py @@ -95,19 +95,26 @@ def _embed_texts(self, texts: List[str], rank: int) -> np.ndarray: def compute_stats_batched(self, samples: Dict, rank: int = 0) -> Dict: stats_list = samples[Fields.stats] - if stats_list and StatsKeys.text_ebd_diversity_score in stats_list[0]: + if stats_list and StatsKeys.llm_embd_diversity in stats_list[0]: return samples texts_to_embed = samples[self.text_key] if not texts_to_embed: for stat in stats_list: - stat[StatsKeys.text_ebd_diversity] = 0.0 - stat[StatsKeys.text_ebd_diversity_score] = 0.0 + stat[StatsKeys.llm_embd_diversity] = 0.0 return samples embeddings_array = self._embed_texts(texts_to_embed, rank=rank) - avg_embedding = np.mean(embeddings_array, axis=0) + valid_mask = ~np.all(embeddings_array == 0, axis=1) + valid_embeddings = embeddings_array[valid_mask] + + if len(valid_embeddings) == 0: + for stat in stats_list: + stat[StatsKeys.llm_embd_diversity] = 0.0 + return samples + + avg_embedding = np.mean(valid_embeddings, axis=0) cos_sims = ( torch.nn.functional.cosine_similarity( @@ -118,23 +125,21 @@ def compute_stats_batched(self, samples: Dict, rank: int = 0) -> Dict: .tolist() ) + for i, stat in enumerate(stats_list): + stat[StatsKeys.llm_embd_diversity] = cos_sims[i] + + return samples + + def process_batched(self, samples: Dict) -> List[bool]: + stats_list = samples[Fields.stats] + cos_sims = [stat[StatsKeys.llm_embd_diversity] for stat in stats_list] + min_sim, max_sim = min(cos_sims), max(cos_sims) range_sim = max_sim - min_sim - normalized_scores = [] if range_sim < 1e-8: normalized_scores = [0.0] * len(cos_sims) else: - for sim in cos_sims: - normalized_sim = self.norm_ratio * (max_sim - sim) / range_sim - normalized_scores.append(normalized_sim) - - for i, stat in enumerate(stats_list): - stat[StatsKeys.text_ebd_diversity] = cos_sims[i] - stat[StatsKeys.text_ebd_diversity_score] = normalized_scores[i] + normalized_scores = [self.norm_ratio * (max_sim - sim) / range_sim for sim in cos_sims] - return samples - - def process_batched(self, samples: Dict) -> List[bool]: - stats_list = samples[Fields.stats] - return [self.min_score <= stat[StatsKeys.text_ebd_diversity_score] <= self.max_score for stat in stats_list] + return [self.min_score <= score <= self.max_score for score in normalized_scores] diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index bf84d4e794b..1b808810cd4 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -267,8 +267,7 @@ class StatsKeysConstant(object): llm_perplexity = "llm_perplexity" llm_task_relevance = "llm_task_relevance" llm_task_relevance_record = "llm_task_relevance_record" - text_ebd_diversity = "text_ebd_diversity" - text_ebd_diversity_score = "text_ebd_diversity_score" + llm_embd_diversity = "llm_embd_diversity" # === image === aspect_ratios = "aspect_ratios" diff --git a/docs/Operators.md b/docs/Operators.md index 0c2bb9262d8..7576ff2c901 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -43,7 +43,7 @@ Data-Juicer 中的算子分为以下 8 种类型。 |------|:------:|-------------| | [aggregator](#aggregator) | 4 | Aggregate for batched samples, such as summary or conclusion. 对批量样本进行汇总,如得出总结或结论。 | | [deduplicator](#deduplicator) | 10 | Detects and removes duplicate samples. 识别、删除重复样本。 | -| [filter](#filter) | 56 | Filters out low-quality samples. 过滤低质量样本。 | +| [filter](#filter) | 57 | Filters out low-quality samples. 过滤低质量样本。 | | [formatter](#formatter) | 8 | Discovers, loads, and canonicalizes source data. 发现、加载、规范化原始数据。 | | [grouper](#grouper) | 3 | Group samples to batched samples. 将样本分组,每一组组成一个批量样本。 | | [mapper](#mapper) | 104 | Edits and transforms samples. 对数据样本进行编辑和转换。 | @@ -106,6 +106,7 @@ All the specific operators are listed below, each featured with several capabili | character_repetition_filter | 🔤Text 💻CPU 🟢Stable | Filter to keep samples with character-level n-gram repetition ratio within a specific range. 过滤器将具有字符级n-gram重复比的样本保持在特定范围内。 | [info](operators/filter/character_repetition_filter.md) | - | | flagged_words_filter | 🔤Text 💻CPU 🟢Stable | Filter to keep samples with flagged-word ratio in a specified range. 过滤器将标记词比率的样本保留在指定范围内。 | [info](operators/filter/flagged_words_filter.md) | - | | general_field_filter | 💻CPU 🟡Beta | Filter to keep samples based on a general field filter condition. 根据常规字段筛选条件保留样本。 | [info](operators/filter/general_field_filter.md) | - | +| group_diversity_filter | 🔤Text 🚀GPU 🔗API 🟡Beta | Filter samples based on their semantic diversity within a group. 基于样本在组内的语义多样性来过滤样本。 | [info](operators/filter/group_diversity_filter.md) | - | | image_aesthetics_filter | 🏞Image 🚀GPU 🧩HF 🟢Stable | Filter to keep samples with aesthetics scores within a specific range. 过滤以保持美学分数在特定范围内的样品。 | [info](operators/filter/image_aesthetics_filter.md) | - | | image_aspect_ratio_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with image aspect ratio within a specific range. 过滤器,以保持样本的图像纵横比在特定范围内。 | [info](operators/filter/image_aspect_ratio_filter.md) | - | | image_face_count_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with the number of faces within a specific range. 过滤以保持样本的面数在特定范围内。 | [info](operators/filter/image_face_count_filter.md) | - | diff --git a/docs/operators/filter/group_diversity_filter.md b/docs/operators/filter/group_diversity_filter.md new file mode 100644 index 00000000000..af6e5a1bd45 --- /dev/null +++ b/docs/operators/filter/group_diversity_filter.md @@ -0,0 +1,51 @@ +# group_diversity_filter + +Filter to keep samples based on their semantic diversity within a group. + +This operator computes the semantic diversity of each sample relative to the group by embedding all input texts into vectors using either a HuggingFace model or an API-based embedding model. It calculates the cosine similarity between each sample's embedding and the mean embedding of the entire group. A lower cosine similarity indicates that the sample is more semantically distinct from the group average, i.e., higher diversity. The raw cosine similarity is stored as `llm_embd_diversity` in the sample's stats. During filtering, the similarities are normalized within the group using `norm_ratio` to produce a diversity score, and only samples whose score falls within [`min_score`, `max_score`] are kept. Note that `max_score` should not exceed `norm_ratio`, as `norm_ratio` defines the upper bound of the diversity score. + + +用于根据样本在组内语义多样性进行过滤的算子。 + +该算子通过HuggingFace模型或API嵌入模型将所有输入文本转换为向量,计算每个样本的嵌入向量与组内平均嵌入向量之间的余弦相似度。余弦相似度越低,说明该样本与组平均语义差异越大,即多样性越高。原始余弦相似度存储在样本统计信息的`llm_embd_diversity`字段中。过滤时,通过`norm_ratio`对组内相似度进行归一化得到多样性得分,只保留得分在[`min_score`, `max_score`]范围内的样本。注意`max_score`不应超过`norm_ratio`,因为`norm_ratio`定义了多样性得分的上限。 + +Type 算子类型: **filter** + +Tags 标签: gpu, hf, text + +## 🔧 Parameter Configuration 参数配置 +| name 参数名 | type 类型 | default 默认值 | desc 说明 | +|--------|------|--------|------| +| `api_or_hf_model` | `str` | `'text-embedding-v3'` | API or HuggingFace embedding model name or local path. | +| `is_hf_model` | `bool` | `False` | Whether the model is a HuggingFace local model. If False, uses API mode. | +| `api_endpoint` | `str` | `'/embeddings'` | Embedding URL endpoint for the API. | +| `response_path` | `str` | `'data.0.embedding'` | Path to extract embedding from the API response. Defaults to 'data.0.embedding' for embedding model. | +| `model_params` | `dict` | `{}` | Extra parameters for initializing the model. | +| `ebd_dim` | `` | `512` | Embedding dimension, only effective in API mode (`is_hf_model=False`). | +| `min_score` | `` | `0.0` | Minimum diversity score to keep samples. | +| `max_score` | `` | `1.0` | Maximum diversity score to keep samples. Should not exceed `norm_ratio`. | +| `norm_ratio` | `` | `0.5` | Normalization ratio controlling the upper bound of diversity score. The valid range of score is `[0, norm_ratio]`. | +| `args` | | `''` | extra args | +| `kwargs` | | `''` | extra args | + +## 📊 Effect demonstration 效果演示 +### test_group_diversity_filter +```python +GroupDiversityFilter(api_or_hf_model='iic/gte_Qwen2-1.5B-instruct', is_hf_model=True, min_score=0.3, max_score=0.5, norm_ratio=0.5) +``` + +#### 📥 input data 输入数据 +
Sample 1: text
The little cat is playing with a ball in the garden.
Sample 2: text
小猫正在花园里开心地玩着球。
Sample 3: text
A kitten is chasing a colorful ball on the green grass.
Sample 4: text
花园里有一只可爱的小猫在追着小球跑。
Sample 5: text
Quantum entanglement is a fundamental concept in quantum mechanics.
Sample 6: text
量子纠缠是量子力学中描述粒子间关联性的核心理论。
+ +#### 📤 output data 输出数据 +
Sample 1: text
Quantum entanglement is a fundamental concept in quantum mechanics.
Sample 2: text
量子纠缠是量子力学中描述粒子间关联性的核心理论。
+ +#### ✨ explanation 解释 +The operator filters the input data to keep only those samples with a diversity score between 0.3 and 0.5. It uses a local HuggingFace embedding model (`gte_Qwen2-1.5B`) to compute the cosine similarity between each sample's embedding and the group mean embedding, stored as `llm_embd_diversity`. The 4 cat-and-ball samples are semantically similar to each other, resulting in high cosine similarities (0.864~0.908) close to the group mean, and thus low diversity scores (0.000~0.102) that fall below `min_score=0.3`. The 2 quantum mechanics samples are semantically distant from the group mean, with lower cosine similarities (0.693~0.765) and higher diversity scores (0.333~0.500), so they are kept. + +算子过滤输入数据,只保留多样性得分在0.3到0.5之间的样本。它使用本地HuggingFace嵌入模型(`gte_Qwen2-1.5B`)计算每个样本的嵌入向量与组平均嵌入向量之间的余弦相似度,存储为`llm_embd_diversity`。4条猫玩球的样本语义相近,余弦相似度较高(0.864~0.908),接近组平均,多样性得分较低(0.000~0.102),低于`min_score=0.3`因此被过滤。2条量子力学样本与组平均语义差异较大,余弦相似度较低(0.693~0.765),多样性得分较高(0.333~0.500),因此被保留。 + +## 🔗 related links 相关链接 +- [source code 源代码](../../../data_juicer/ops/filter/group_diversity_filter.py) +- [unit test 单元测试](../../../tests/ops/filter/test_group_diversity_filter.py) +- [Return operator list 返回算子列表](../../Operators.md) diff --git a/tests/ops/filter/test_group_diversity_filter.py b/tests/ops/filter/test_group_diversity_filter.py index 94e520ebe4d..9a7e4c67830 100644 --- a/tests/ops/filter/test_group_diversity_filter.py +++ b/tests/ops/filter/test_group_diversity_filter.py @@ -59,17 +59,17 @@ def _run_and_assert_diversity(self, op: GroupDiversityFilter): stats_list = dataset.to_list() for sample in stats_list: logger.info(f"Text: '{sample['text']}', " - f"Score: {sample[Fields.stats].get(StatsKeys.text_ebd_diversity_score, 'N/A')}") + f"Score: {sample[Fields.stats].get(StatsKeys.llm_embd_diversity, 'N/A')}") - scores = [d[Fields.stats][StatsKeys.text_ebd_diversity_score] for d in stats_list] + scores = [d[Fields.stats][StatsKeys.llm_embd_diversity] for d in stats_list] outlier_score = scores[-1] other_scores = scores[:-1] - self.assertTrue(all(outlier_score > score for score in other_scores), - "The outlier sample did not receive the highest diversity score.") + self.assertTrue(all(outlier_score < score for score in other_scores), + "The outlier sample did not receive the lowest diversity value.") - logger.info("Test passed: The outlier sample correctly received the highest diversity score.") + logger.info("Test passed: The outlier sample correctly received the lowest diversity value.") if __name__ == '__main__':