Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions data_juicer/config/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -723,8 +723,17 @@ process:
use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese
words_aug_group_sizes: [2] # the group size of words to augment
words_aug_join_char: "" # the join char between words to augment
- general_field_filter: # Filter to keep samples based on a general field filter condition.
filter_condition: "" # The filter condition as a string. It can include logical operators (and/or) and chain comparisons.
- general_field_filter: # Filter to keep samples based on a general field filter condition.
filter_condition: "" # The filter condition as a string. It can include logical operators (and/or) and chain comparisons.
- group_diversity_filter: # filter samples based on their semantic diversity within a group.
api_or_hf_model: "text-embedding-v3" # API or huggingface embedding model name.
is_hf_model: false # indicates if the model is from HuggingFace.
api_endpoint: "/embeddings" # embedding URL endpoint for the API.
response_path: "data.0.embedding" # path to extract content from the API response.
ebd_dim: 512 # the embedding's dimension via API.
min_score: 0.0 # the min score of filter range
max_score: 1.0 # the max score of filter range
norm_ratio: 0.5 # ratio to normalize the score.
- image_aesthetics_filter: # filter samples according to the aesthetics score of images.
hf_scorer_model: shunk031/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE # Huggingface model name for the aesthetics predictor
min_score: 0.3 # the min aesthetics score of filter range
Expand Down
4 changes: 3 additions & 1 deletion data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .character_repetition_filter import CharacterRepetitionFilter
from .flagged_words_filter import FlaggedWordFilter
from .general_field_filter import GeneralFieldFilter
from .group_diversity_filter import GroupDiversityFilter
from .image_aesthetics_filter import ImageAestheticsFilter
from .image_aspect_ratio_filter import ImageAspectRatioFilter
from .image_face_count_filter import ImageFaceCountFilter
Expand Down Expand Up @@ -65,6 +66,8 @@
"AverageLineLengthFilter",
"CharacterRepetitionFilter",
"FlaggedWordFilter",
"GeneralFieldFilter",
"GroupDiversityFilter",
"ImageAestheticsFilter",
"ImageAspectRatioFilter",
"ImageFaceCountFilter",
Expand Down Expand Up @@ -113,7 +116,6 @@
"VideoWatermarkFilter",
"WordRepetitionFilter",
"WordsNumFilter",
"GeneralFieldFilter",
]

NON_STATS_FILTERS = [
Expand Down
145 changes: 145 additions & 0 deletions data_juicer/ops/filter/group_diversity_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import sys
from typing import Dict, List

import numpy as np
from jsonargparse.typing import NonNegativeFloat, PositiveInt
from tqdm import tqdm

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, Filter

# Lazy load torch to improve startup time
torch = LazyLoader("torch")


@OPERATORS.register_module("group_diversity_filter")
class GroupDiversityFilter(Filter):
"""
Filter samples based on their semantic diversity within a group.
"""

_accelerator = "cuda"
_batched_op = True

def __init__(
self,
api_or_hf_model: str = "text-embedding-v3",
is_hf_model: bool = False,
api_endpoint: str = "/embeddings",
response_path: str = "data.0.embedding",
model_params: Dict = {},
ebd_dim: PositiveInt = 512,
min_score: NonNegativeFloat = 0.0,
max_score: NonNegativeFloat = 1.0,
norm_ratio: NonNegativeFloat = 0.5,
*args,
**kwargs,
):
"""
Initialization method.

:param api_or_hf_model: API or huggingface embedding model name.
:param is_hf_model: Indicates if the model is from HuggingFace.
:param api_endpoint: Embedding URL endpoint for the API.
:param response_path: Path to extract content from the API response.
Defaults to 'data.0.embedding' for embedding model.
:param model_params: Parameters for initializing the API model.
:param ebd_dim: The embedding's dimension via API.
:param min_score: Minimum score for filtering.
:param max_score: Maximum score for filtering.
:param norm_ratio: Ratio to normalize the score.
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault("mem_required", "20GB")
super().__init__(*args, **kwargs)

self.min_score = min_score
self.max_score = max_score
self.norm_ratio = norm_ratio
self.is_hf_model = is_hf_model
self.ebd_dim = ebd_dim

if self.is_hf_model:
self.model_key = prepare_model(model_type="embedding", model_path=api_or_hf_model, **model_params)
else:
self.model_key = prepare_model(
model_type="api",
model=api_or_hf_model,
endpoint=api_endpoint,
response_path=response_path,
**model_params,
)

def _embed_texts(self, texts: List[str], rank: int) -> np.ndarray:
# Embed a list of texts using the initialized model
embeddings = []
model = get_model(self.model_key, rank, self.use_cuda())

for text in tqdm(texts, desc="Embedding texts", leave=False):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This loop processes texts one by one, which is inefficient for a batched operator. Most embedding models, including Hugging Face sentence-transformers, are optimized for batch processing. Given that this operator processes the entire dataset in a single batch (num_proc=1), this loop can become a significant performance bottleneck.

Consider refactoring this to process texts in batches. For Hugging Face models, you can pass the entire list of texts to model.encode() outside the loop. For API models, check if batching is supported by the underlying API wrapper.

if self.is_hf_model:
    try:
        # Use batch encoding for efficiency with Hugging Face models
        embeddings = model.encode(texts, show_progress_bar=False)
        return np.array(embeddings, dtype=np.float32)
    except Exception as e:
        logger.error(f"Failed to embed texts in batch. Error: {e}. Using zero vectors for all.")
        dim = model.get_sentence_embedding_dimension()
        return np.zeros((len(texts), dim), dtype=np.float32)

try:
if self.is_hf_model:
embedding = model.encode(text)
else:
embedding = model(text, dimensions=self.ebd_dim, encoding_format="float")
embeddings.append(np.array(embedding, dtype=np.float32))
except Exception as e:
dim = model.get_sentence_embedding_dimension() if self.is_hf_model else self.ebd_dim
embeddings.append(np.zeros(dim, dtype=np.float32))
print(f"Failed to embed text: '{text}'. Error: {e}. Using zero vector.", file=sys.stderr)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code uses print(..., file=sys.stderr) for logging errors. The tests for this OP use loguru.logger. For consistency with the rest of the project, it's better to use loguru.logger.error() here.

logger.error(f"Failed to embed text: '{text}'. Error: {e}. Using zero vector.")


return np.array(embeddings)

def compute_stats_batched(self, samples: Dict, rank: int = 0) -> Dict:
stats_list = samples[Fields.stats]
if stats_list and StatsKeys.llm_embd_diversity in stats_list[0]:
return samples

texts_to_embed = samples[self.text_key]
if not texts_to_embed:
for stat in stats_list:
stat[StatsKeys.llm_embd_diversity] = 0.0
return samples

embeddings_array = self._embed_texts(texts_to_embed, rank=rank)

valid_mask = ~np.all(embeddings_array == 0, axis=1)
valid_embeddings = embeddings_array[valid_mask]

if len(valid_embeddings) == 0:
for stat in stats_list:
stat[StatsKeys.llm_embd_diversity] = 0.0
return samples

avg_embedding = np.mean(valid_embeddings, axis=0)

cos_sims = (
torch.nn.functional.cosine_similarity(
torch.from_numpy(embeddings_array), torch.from_numpy(avg_embedding).unsqueeze(0), dim=1
)
.cpu()
.numpy()
.tolist()
)

for i, stat in enumerate(stats_list):
stat[StatsKeys.llm_embd_diversity] = cos_sims[i]

return samples

def process_batched(self, samples: Dict) -> List[bool]:
stats_list = samples[Fields.stats]
cos_sims = [stat[StatsKeys.llm_embd_diversity] for stat in stats_list]

min_sim, max_sim = min(cos_sims), max(cos_sims)
range_sim = max_sim - min_sim

if range_sim < 1e-8:
normalized_scores = [0.0] * len(cos_sims)
else:
normalized_scores = [self.norm_ratio * (max_sim - sim) / range_sim for sim in cos_sims]

return [self.min_score <= score <= self.max_score for score in normalized_scores]
1 change: 1 addition & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ class StatsKeysConstant(object):
llm_perplexity = "llm_perplexity"
llm_task_relevance = "llm_task_relevance"
llm_task_relevance_record = "llm_task_relevance_record"
llm_embd_diversity = "llm_embd_diversity"

# === image ===
aspect_ratios = "aspect_ratios"
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Data-Juicer 中的算子分为以下 8 种类型。
|------|:------:|-------------|
| [aggregator](#aggregator) | 4 | Aggregate for batched samples, such as summary or conclusion. 对批量样本进行汇总,如得出总结或结论。 |
| [deduplicator](#deduplicator) | 10 | Detects and removes duplicate samples. 识别、删除重复样本。 |
| [filter](#filter) | 56 | Filters out low-quality samples. 过滤低质量样本。 |
| [filter](#filter) | 57 | Filters out low-quality samples. 过滤低质量样本。 |
| [formatter](#formatter) | 8 | Discovers, loads, and canonicalizes source data. 发现、加载、规范化原始数据。 |
| [grouper](#grouper) | 3 | Group samples to batched samples. 将样本分组,每一组组成一个批量样本。 |
| [mapper](#mapper) | 104 | Edits and transforms samples. 对数据样本进行编辑和转换。 |
Expand Down Expand Up @@ -106,6 +106,7 @@ All the specific operators are listed below, each featured with several capabili
| character_repetition_filter | 🔤Text 💻CPU 🟢Stable | Filter to keep samples with character-level n-gram repetition ratio within a specific range. 过滤器将具有字符级n-gram重复比的样本保持在特定范围内。 | [info](operators/filter/character_repetition_filter.md) | - |
| flagged_words_filter | 🔤Text 💻CPU 🟢Stable | Filter to keep samples with flagged-word ratio in a specified range. 过滤器将标记词比率的样本保留在指定范围内。 | [info](operators/filter/flagged_words_filter.md) | - |
| general_field_filter | 💻CPU 🟡Beta | Filter to keep samples based on a general field filter condition. 根据常规字段筛选条件保留样本。 | [info](operators/filter/general_field_filter.md) | - |
| group_diversity_filter | 🔤Text 🚀GPU 🔗API 🟡Beta | Filter samples based on their semantic diversity within a group. 基于样本在组内的语义多样性来过滤样本。 | [info](operators/filter/group_diversity_filter.md) | - |
| image_aesthetics_filter | 🏞Image 🚀GPU 🧩HF 🟢Stable | Filter to keep samples with aesthetics scores within a specific range. 过滤以保持美学分数在特定范围内的样品。 | [info](operators/filter/image_aesthetics_filter.md) | - |
| image_aspect_ratio_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with image aspect ratio within a specific range. 过滤器,以保持样本的图像纵横比在特定范围内。 | [info](operators/filter/image_aspect_ratio_filter.md) | - |
| image_face_count_filter | 🏞Image 💻CPU 🟢Stable | Filter to keep samples with the number of faces within a specific range. 过滤以保持样本的面数在特定范围内。 | [info](operators/filter/image_face_count_filter.md) | - |
Expand Down
Loading