Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/instructlab/sdg/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
# Standard
from abc import ABC
from typing import Any, Dict, Union
import logging
import os.path

# Third Party
import yaml

# Local
from .logger_config import setup_logger

logger = setup_logger(__name__)
logger = logging.getLogger(__name__)


# This is part of the public API.
Expand Down
4 changes: 2 additions & 2 deletions src/instructlab/sdg/datamixing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Standard
from typing import Optional
import json
import logging
import os.path
import random
import uuid
Expand All @@ -10,11 +11,10 @@
import yaml

# First Party
from instructlab.sdg.logger_config import setup_logger
from instructlab.sdg.utils import GenerateException, pandas

ALLOWED_COLS = ["id", "messages", "metadata"]
logger = setup_logger(__name__)
logger = logging.getLogger(__name__)


def _adjust_train_sample_size(ds: Dataset, num_samples: int):
Expand Down
6 changes: 2 additions & 4 deletions src/instructlab/sdg/eval_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Standard
from importlib import resources
from typing import Any
import logging
import re

# Third Party
Expand All @@ -10,10 +11,7 @@
# First Party
from instructlab.sdg.pipeline import EVAL_PIPELINES_PKG, Pipeline

# Local
from .logger_config import setup_logger

logger = setup_logger(__name__)
logger = logging.getLogger(__name__)


def _extract_options(text: str) -> list[Any]:
Expand Down
4 changes: 2 additions & 2 deletions src/instructlab/sdg/filterblock.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
import logging
import operator

# Third Party
from datasets import Dataset

# Local
from .block import Block
from .logger_config import setup_logger

logger = setup_logger(__name__)
logger = logging.getLogger(__name__)


# This is part of the public API.
Expand Down
24 changes: 12 additions & 12 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from typing import Optional
import json
import logging
import os
import time

Expand Down Expand Up @@ -35,6 +36,8 @@
read_taxonomy_leaf_nodes,
)

logger = logging.getLogger(__name__)

_SYS_PROMPT = "I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant."


Expand Down Expand Up @@ -73,9 +76,7 @@ def _convert_to_messages(sample):
return sample


def _gen_train_data(
logger, machine_instruction_data, output_file_train, output_file_messages
):
def _gen_train_data(machine_instruction_data, output_file_train, output_file_messages):
"""
Generate training data in the legacy system/user/assistant format
used in train_*.jsonl as well as the legacy messages format used
Expand Down Expand Up @@ -257,9 +258,9 @@ def _mixer_init(ctx, output_dir, date_suffix):

# This is part of the public API, and used by instructlab.
# TODO - parameter removal needs to be done in sync with a CLI change.
# pylint: disable=unused-argument
# to be removed: logger, prompt_file_path, rouge_threshold, tls_*
def generate_data(
logger,
logger: logging.Logger = logger, # pylint: disable=redefined-outer-name
api_base: Optional[str] = None,
api_key: Optional[str] = None,
model_family: Optional[str] = None,
Expand All @@ -270,9 +271,9 @@ def generate_data(
taxonomy_base: Optional[str] = None,
output_dir: Optional[str] = None,
# TODO - not used and should be removed from the CLI
prompt_file_path: Optional[str] = None,
prompt_file_path: Optional[str] = None, # pylint: disable=unused-argument
# TODO - probably should be removed
rouge_threshold: Optional[float] = None,
rouge_threshold: Optional[float] = None, # pylint: disable=unused-argument
console_output=True,
yaml_rules: Optional[str] = None,
chunk_word_count=None,
Expand Down Expand Up @@ -382,9 +383,9 @@ def generate_data(
else:
sdg = sdg_freeform_skill

logger.debug("Samples: %s" % samples)
logger.debug("Samples: %s", samples)
ds = Dataset.from_list(samples)
logger.debug("Dataset: %s" % ds)
logger.debug("Dataset: %s", ds)
new_generated_data = sdg.generate(ds)
if len(new_generated_data) == 0:
raise EmptyDatasetError(
Expand All @@ -395,8 +396,8 @@ def generate_data(
if generated_data is None
else generated_data + [new_generated_data]
)
logger.info("Generated %d samples" % len(generated_data))
logger.debug("Generated data: %s" % generated_data)
logger.info("Generated %d samples", len(generated_data))
logger.debug("Generated data: %s", generated_data)

if is_knowledge:
# generate mmlubench data for the current leaf node
Expand All @@ -414,7 +415,6 @@ def generate_data(
generated_data = []

_gen_train_data(
logger,
generated_data,
os.path.join(output_dir, output_file_train),
os.path.join(output_dir, output_file_messages),
Expand Down
6 changes: 4 additions & 2 deletions src/instructlab/sdg/importblock.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
import logging

# Third Party
from datasets import Dataset

# Local
from .block import Block
from .logger_config import setup_logger

logger = setup_logger(__name__)
logger = logging.getLogger(__name__)


# This is part of the public API.
Expand Down
4 changes: 2 additions & 2 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Standard
from collections import ChainMap
from typing import Any, Dict
import logging
import re

# Third Party
Expand All @@ -10,9 +11,8 @@

# Local
from .block import Block
from .logger_config import setup_logger

logger = setup_logger(__name__)
logger = logging.getLogger(__name__)

MODEL_FAMILY_MIXTRAL = "mixtral"
MODEL_FAMILY_MERLINITE = "merlinite"
Expand Down
18 changes: 0 additions & 18 deletions src/instructlab/sdg/logger_config.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from importlib import resources
from typing import Iterable, Optional
import logging
import math
import os.path

Expand All @@ -18,9 +19,8 @@
# Local
from . import filterblock, importblock, llmblock, utilblocks
from .block import Block
from .logger_config import setup_logger

logger = setup_logger(__name__)
logger = logging.getLogger(__name__)


# This is part of the public API.
Expand Down
6 changes: 4 additions & 2 deletions src/instructlab/sdg/utilblocks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
import logging

# Third Party
from datasets import Dataset

Expand All @@ -7,9 +10,8 @@

# Local
from .block import Block
from .logger_config import setup_logger

logger = setup_logger(__name__)
logger = logging.getLogger(__name__)


# This is part of the public API.
Expand Down