Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/unit-test-partial.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ jobs:
- name: Run unittest ray
working-directory: dj-${{ github.run_id }}/.github/workflows/docker
run: |
docker compose exec -e OPENAI_BASE_URL=${{ secrets.OPENAI_BASE_URL }} -e OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} ray-head bash -c 'python tests/run.py --tag ray --mode partial'
if [ "${{ github.event.pull_request.head.repo.full_name }}" != "${{ github.repository }}" ]; then
docker compose exec -e OPENAI_BASE_URL=${{ secrets.OPENAI_BASE_URL }} -e OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} ray-head bash -c 'python tests/run.py --tag ray --mode partial --from-fork True'
else
docker compose exec -e OPENAI_BASE_URL=${{ secrets.OPENAI_BASE_URL }} -e OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }} ray-head bash -c 'python tests/run.py --tag ray --mode partial'
fi
docker compose exec ray-head bash -c 'coverage combine'
docker compose exec ray-head bash -c 'mv .coverage .coverage.ray'

Expand Down
57 changes: 7 additions & 50 deletions data_juicer/config/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
import copy
import importlib.util
import json
import os
import shutil
Expand Down Expand Up @@ -45,57 +44,15 @@ def timing_context(description):
logger.debug(f"{description} took {elapsed_time:.2f} seconds")


def _generate_module_name(abs_path):
"""Generate a module name based on the absolute path of the file."""
return os.path.splitext(os.path.basename(abs_path))[0]
def load_custom_operators(paths):
"""Dynamically load custom operator modules or packages in the specified path.

This is a re-export from ``data_juicer.utils.custom_op`` kept here for
backward compatibility.
"""
from data_juicer.utils.custom_op import load_custom_operators as _impl

def load_custom_operators(paths):
"""Dynamically load custom operator modules or packages in the specified path."""
for path in paths:
abs_path = os.path.abspath(path)
if os.path.isfile(abs_path):
module_name = _generate_module_name(abs_path)
if module_name in sys.modules:
existing_path = sys.modules[module_name].__file__
raise RuntimeError(
f"Module '{module_name}' already loaded from '{existing_path}'. "
f"Conflict detected while loading '{abs_path}'."
)
try:
spec = importlib.util.spec_from_file_location(module_name, abs_path)
if spec is None:
raise RuntimeError(f"Failed to create spec for '{abs_path}'")
module = importlib.util.module_from_spec(spec)
# register the module first to avoid recursive import issues
sys.modules[module_name] = module
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f"Error loading '{abs_path}' as '{module_name}': {e}")

elif os.path.isdir(abs_path):
if not os.path.isfile(os.path.join(abs_path, "__init__.py")):
raise ValueError(f"Package directory '{abs_path}' must contain __init__.py")
package_name = os.path.basename(abs_path)
parent_dir = os.path.dirname(abs_path)
if package_name in sys.modules:
existing_path = sys.modules[package_name].__path__[0]
raise RuntimeError(
f"Package '{package_name}' already loaded from '{existing_path}'. "
f"Conflict detected while loading '{abs_path}'."
)
original_sys_path = sys.path.copy()
try:
sys.path.insert(0, parent_dir)
importlib.import_module(package_name)
# record the loading path of the package (for subsequent conflict detection)
sys.modules[package_name].__loaded_from__ = abs_path
except Exception as e:
raise RuntimeError(f"Error loading package '{abs_path}': {e}")
finally:
sys.path = original_sys_path
else:
raise ValueError(f"Path '{abs_path}' is neither a file nor a directory")
_impl(paths)


def build_base_parser() -> ArgumentParser:
Expand Down
13 changes: 13 additions & 0 deletions data_juicer/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ def timing_context(description):

# yapf: disable
with timing_context('Importing operator modules'):
# 1. Built-in operators (registered via @OPERATORS.register_module decorators
# that fire as each sub-package is imported)
# 2. Persistent custom operators (loaded from ~/.data_juicer/custom_op.json;
# no-op when the registry file does not exist)
from . import aggregator, deduplicator, filter, grouper, mapper, pipeline, selector
from .base_op import (
ATTRIBUTION_FILTERS,
Expand All @@ -38,6 +42,15 @@ def timing_context(description):
op_requirements_to_op_env_spec,
)

from data_juicer.utils.custom_op import load_persistent_custom_ops as _load_persistent # isort: skip # noqa: E501
try:
_load_persistent()
except Exception as _exc:
from loguru import logger as _logger
_logger.warning(f"Failed to load persistent custom ops: {_exc}")
del _logger, _exc
del _load_persistent

__all__ = [
'load_ops',
'Filter',
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/tools/DJ_mcp_granular_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def resolve_signature_annotations(func, sig: inspect.Signature) -> inspect.Signa
try:
module = sys.modules.get(func.__module__, None) if hasattr(func, "__module__") else None
globalns = module.__dict__ if module else {}
hints = get_type_hints(func, globalns=globalns, localns=globalns)
hints = get_type_hints(func, globalns=globalns)
except Exception:
hints = {}

Expand Down Expand Up @@ -65,7 +65,7 @@ def create_operator_function(op, mcp):
param_docstring = op["param_desc"]

# Create new function signature with dataset_path as first parameter
# Consider adding other common parameters later, such as export_psth
# Consider adding other common parameters later, such as export_path
fixed_params = [
inspect.Parameter("dataset_path", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str),
inspect.Parameter(
Expand Down
219 changes: 196 additions & 23 deletions data_juicer/tools/op_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,23 +159,46 @@ class OPRecord:

def __init__(self, name: str, op_cls: type, op_type: Optional[str] = None):
self.name = name
self.type = op_type or op_cls.__module__.split(".")[2].lower()

# --- module path:
# handling for custom ops ---
if op_type:
self.type = op_type
else:
module_parts = op_cls.__module__.split(".")
if len(module_parts) >= 3:
self.type = module_parts[2].lower()
else:
self.type = self._search_mro_for_type(op_cls)
if self.type not in op_type_list:
self.type = self._search_mro_for_type(op_cls)

self.desc = op_cls.__doc__ or ""
self.tags = analyze_tag_from_cls(op_cls, name)
self.sig = inspect.signature(op_cls.__init__)
self.init_func = op_cls.__init__
self.param_desc = extract_param_docstring(op_cls.__init__.__doc__ or "")
self.param_desc_map = self._parse_param_desc()
self.source_path = str(get_source_path(op_cls))
self.test_path = None

test_path = f"tests/ops/{self.type}/test_{self.name}.py"
if not (PROJECT_ROOT / test_path).exists():
test_path = find_test_by_searching_content(PROJECT_ROOT / "tests", op_cls.__name__ + "Test") or test_path

self.test_path = str(test_path)
# --- source path: handling for custom ops ---
try:
self.source_path = str(get_source_path(op_cls))
except (ValueError, TypeError, OSError):
try:
self.source_path = str(Path(inspect.getfile(op_cls)))
except (TypeError, OSError):
self.source_path = "unknown"

# --- test path: handling for custom ops ---
try:
test_path = f"tests/ops/{self.type}/test_{self.name}.py"
if not (PROJECT_ROOT / test_path).exists():
test_path = (
find_test_by_searching_content(PROJECT_ROOT / "tests", op_cls.__name__ + "Test") or test_path
)
self.test_path = str(test_path)
except Exception:
self.test_path = None

def __getitem__(self, item):
try:
Expand Down Expand Up @@ -443,26 +466,176 @@ def records_map(self):
return self.all_ops


def main(query, tags, op_type):
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------


def _build_parser():
import argparse

parser = argparse.ArgumentParser(
prog="python -m data_juicer.tools.op_search",
description="Data-Juicer Operator Search & Query Tool",
)
sub = parser.add_subparsers(dest="command", help="Available commands")

# --- list ---
sub.add_parser(
"list",
help="List all operators (built-in + custom)",
)

# --- info ---
p_info = sub.add_parser(
"info",
help="Show detailed information about an operator",
)
p_info.add_argument("name", help="Operator name")

# --- search ---
p_search = sub.add_parser(
"search",
help="Search operators by keyword, regex, or tags",
)
p_search.add_argument(
"query",
nargs="?",
default=None,
help="Search query (natural language or regex pattern)",
)
p_search.add_argument(
"--mode",
choices=["bm25", "regex"],
default="bm25",
help="Search mode (default: bm25)",
)
p_search.add_argument(
"--tags",
nargs="+",
default=None,
help="Filter by tags (e.g., gpu, cpu, text, image)",
)
p_search.add_argument(
"--type",
dest="op_type",
default=None,
help="Filter by operator type (e.g., mapper, filter)",
)
p_search.add_argument(
"--top-k",
type=int,
default=10,
help="Maximum number of results (default: 10)",
)

return parser


def _cmd_list(args) -> int:
"""List all operators (built-in + custom)."""
from data_juicer.utils.custom_op import list_registered

custom_info = list_registered()
custom_names = set(custom_info.get("custom_operators", {}).keys())
all_names = sorted(OPERATORS.modules.keys())

print(f"Total operators: {len(all_names)}")
print(f" Built-in: {len(all_names) - len(custom_names)}")
print(f" Custom: {len(custom_names)}")
print()
for name in all_names:
marker = " [custom]" if name in custom_names else ""
print(f" {name}{marker}")
return 0


def _cmd_info(args) -> int:
"""Show detailed information about an operator."""
import sys

op_cls = OPERATORS.modules.get(args.name)
if op_cls is None:
print(f"Operator '{args.name}' not found.", file=sys.stderr)
return 1

record = OPRecord(name=args.name, op_cls=op_cls)
info = record.to_dict()

print(f"Name: {info['name']}")
print(f"Type: {info['type']}")
print(f"Tags: {', '.join(info['tags']) if info['tags'] else 'none'}")
print(f"Source: {info['source_path']}")
print(f"Test: {info['test_path'] or 'none'}")
print(f"Signature: {info['sig']}")
print()
if info["desc"]:
print("Description:")
print(f" {info['desc'].strip()}")
print()
if info["param_desc_map"]:
print("Parameters:")
for pname, pdesc in info["param_desc_map"].items():
print(f" {pname}: {pdesc}")

return 0


def _cmd_search(args) -> int:
"""Search operators by keyword, regex, or tags."""
searcher = OPSearcher(include_formatter=True)

results = searcher.search_by_bm25(query=query, tags=tags, op_type=op_type)
query = args.query
tags = args.tags
op_type = args.op_type

if args.mode == "regex" and query:
results = searcher.search_by_regex(query=query, tags=tags, op_type=op_type)
elif query:
results = searcher.search_by_bm25(query=query, tags=tags, op_type=op_type, top_k=args.top_k)
else:
results = searcher.search(tags=tags, op_type=op_type)

print(f"\nFound {len(results)} operators:")
print(f"Found {len(results)} operator(s):")
for op in results:
print(f"\n[{op['type'].upper()}] {op['name']}")
print(f"Tags: {', '.join(op['tags'])}")
print(f"Description: {op['desc']}")
print(f"Parameters: {op['param_desc']}")
print(f"Parameter Descriptions: {op['param_desc_map']}")
print(f"Signature: {op['sig']}")
print("-" * 50)
print(f"\n [{op['type'].upper()}] {op['name']}")
print(f" Tags: {', '.join(op['tags'])}")
desc = (op.get("desc") or "").strip()
if desc:
first_line = desc.split("\n")[0].strip()
if len(first_line) > 80:
first_line = first_line[:77] + "..."
print(f" Desc: {first_line}")

print(searcher.records_map["nlpaug_en_mapper"]["source_path"])
print(searcher.records_map["nlpaug_en_mapper"].test_path)
return 0


_COMMAND_MAP = {
"list": _cmd_list,
"info": _cmd_info,
"search": _cmd_search,
}


def main(argv=None) -> int:
"""CLI entry point for operator search & query."""

parser = _build_parser()
args = parser.parse_args(argv)

if not args.command:
parser.print_help()
return 1

handler = _COMMAND_MAP.get(args.command)
if handler is None:
parser.print_help()
return 1

return handler(args)


if __name__ == "__main__":
tags = []
op_type = "formatter"
main(query="json", tags=tags, op_type=op_type)
import sys

sys.exit(main())
Loading
Loading