Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/instructlab/sdg/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
}


def get_model_family(forced, model_path):
forced = MODEL_FAMILY_MAPPINGS.get(forced, forced)
if forced and forced.lower() not in MODEL_FAMILIES:
raise GenerateException("Unknown model family: %s" % forced)
def get_model_family(model_family, model_path):
model_family_retrieved = MODEL_FAMILY_MAPPINGS.get(model_family, model_family)
if model_family_retrieved and model_family_retrieved.lower() not in MODEL_FAMILIES:
raise GenerateException("Unknown model family: %s" % model_family_retrieved)

# Try to guess the model family based on the model's filename
guess = re.match(r"^\w*", os.path.basename(model_path)).group(0).lower()
Expand Down
17 changes: 13 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# First Party
from instructlab.sdg.pipeline import PipelineContext

# Local
from .taxonomy import MockTaxonomy


def get_ctx(**kwargs) -> PipelineContext:
kwargs.setdefault("client", mock.MagicMock())
Expand Down Expand Up @@ -39,10 +42,16 @@ def single_threaded_ctx() -> PipelineContext:


@pytest.fixture
def threaded_ctx() -> PipelineContext:
return get_threaded_ctx()
def sample_dataset():
return Dataset.from_list([{"foo": i} for i in range(10)])


@pytest.fixture
def sample_dataset():
return Dataset.from_list([{"foo": i} for i in range(10)])
def taxonomy_dir(tmp_path):
with MockTaxonomy(tmp_path) as taxonomy:
yield taxonomy


@pytest.fixture
def threaded_ctx() -> PipelineContext:
return get_threaded_ctx()
68 changes: 68 additions & 0 deletions tests/taxonomy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
from pathlib import Path
from typing import List
import shutil

# Third Party
import git


class MockTaxonomy:
INIT_COMMIT_FILE = "README.md"

def __init__(self, path: Path) -> None:
self.root = path
self._repo = git.Repo.init(path, initial_branch="main")
with open(path / self.INIT_COMMIT_FILE, "wb"):
pass
self._repo.index.add([self.INIT_COMMIT_FILE])
self._repo.index.commit("Initial commit")

@property
def untracked_files(self) -> List[str]:
"""List untracked files in the repository"""
return self._repo.untracked_files

def create_untracked(self, rel_path: str, contents: str) -> Path:
"""Create a new untracked file in the repository.

Args:
rel_path (str): Relative path (from repository root) to the file.
contents (str): String to be written to the file.
Returns:
file_path: The path to the created file.
"""
taxonomy_path = Path(rel_path)
assert not taxonomy_path.is_absolute()
file_path = self.root.joinpath(taxonomy_path)
file_path.parent.mkdir(exist_ok=True, parents=True)
file_path.write_text(contents, encoding="utf-8")
return file_path

def add_tracked(self, rel_path, contents: str) -> Path:
"""Add a new tracked file to the repository (and commits it).

Args:
rel_path (str): Relative path (from repository root) to the file.
contents (str): String to be written to the file.
Returns:
file_path: The path to the added file.
"""
file_path = self.create_untracked(rel_path, contents)
self._repo.index.add([rel_path])
self._repo.index.commit("new commit")
return file_path

def teardown(self) -> None:
"""Recursively remove the temporary repository and all of its
subdirectories and files.
"""
shutil.rmtree(self.root)

def __enter__(self):
return self

def __exit__(self, *args):
self.teardown()
48 changes: 48 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0

# Third Party
import pytest

# First Party
from instructlab.sdg.utils import GenerateException, models


class TestModels:
"""Test model family in instructlab.sdg.utils.models."""

def test_granite_model_family(self):
assert (
models.get_model_family("granite", "./models/granite-7b-lab-Q4_K_M.gguf")
== "merlinite"
)

def test_merlinite_model_family(self):
assert (
models.get_model_family(
"merlinite", "./models/merlinite-7b-lab-Q4_K_M.gguf"
)
== "merlinite"
)

def test_mixtral_model_family(self):
assert (
models.get_model_family(
"mixtral", "./models/mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf"
)
== "mixtral"
)

def test_default_model_family(self):
assert (
models.get_model_family(
"mixtral", "./models/foo-8x7b-instruct-v0.1.Q4_K_M.gguf"
)
== "merlinite"
)

def test_unknown_model_family(self):
with pytest.raises(GenerateException) as exc:
models.get_model_family(
"foobar", "./models/mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf"
)
assert "Unknown model family: foobar" in str(exc.value)
2 changes: 1 addition & 1 deletion tests/test_sample_populator_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from instructlab.sdg.utilblocks import SamplePopulatorBlock


class TestFilterByValueBlock(unittest.TestCase):
class TestSamplePopulatorBlock(unittest.TestCase):
def setUp(self):
self.ctx = MagicMock()
self.ctx.dataset_num_procs = 1
Expand Down
75 changes: 75 additions & 0 deletions tests/test_taxonomy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
import os
import pathlib

# Third Party
import pytest

# First Party
from instructlab.sdg.utils import taxonomy

TEST_VALID_COMPOSITIONAL_SKILL_YAML = """created_by: rafael-vasquez
version: 1
seed_examples:
- answer: "Sure thing!"
context: "This is a valid YAML."
question: "Can you help me debug this failing unit test?"
- answer: "answer2"
context: "context2"
question: "question2"
- answer: "answer3"
context: "context3"
question: "question3"
- answer: "answer4"
context: "context4"
question: "question4"
- answer: "answer5"
context: "context5"
question: "question5"
task_description: 'This is a task'
"""

TEST_SEED_EXAMPLE = "Can you help me debug this failing unit test?"

TEST_TAXONOMY_BASE = "main"

TEST_CUSTOM_YAML_RULES = b"""extends: relaxed

rules:
line-length:
max: 180
"""


class TestTaxonomy:
"""Test taxonomy in instructlab.sdg.utils.taxonomy."""

@pytest.fixture(autouse=True)
def _init_taxonomy(self, taxonomy_dir):
self.taxonomy = taxonomy_dir

def test_read_taxonomy_leaf_nodes(self):
tracked_file = "compositional_skills/tracked/qna.yaml"
untracked_file = "compositional_skills/new/qna.yaml"
self.taxonomy.add_tracked(tracked_file, TEST_VALID_COMPOSITIONAL_SKILL_YAML)
self.taxonomy.create_untracked(
untracked_file, TEST_VALID_COMPOSITIONAL_SKILL_YAML
)

leaf_node = taxonomy.read_taxonomy_leaf_nodes(
self.taxonomy.root, TEST_TAXONOMY_BASE, TEST_CUSTOM_YAML_RULES
)
leaf_node_key = str(pathlib.Path(untracked_file).parent).replace(
os.path.sep, "->"
)
assert leaf_node_key in leaf_node

leaf_node_entries = leaf_node.get(leaf_node_key)
seed_example_exists = False
if any(
entry["instruction"] == TEST_SEED_EXAMPLE for entry in leaf_node_entries
):
seed_example_exists = True
assert seed_example_exists is True