diff --git a/pyproject.toml b/pyproject.toml index 134f235da2..8046b9d392 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,8 @@ issues = "https://github.com/instructlab/instructlab/issues" [project.entry-points."instructlab.command.data"] "generate" = "instructlab.cli.data.generate:generate" "list" = "instructlab.cli.data.list:list_datasets" +"process" = "instructlab.cli.data.process:process" +"ingest" = "instructlab.cli.data.ingest:ingest" [project.entry-points."instructlab.command.model"] "chat" = "instructlab.model.chat:chat" @@ -90,6 +92,7 @@ optional-dependencies.cuda = { file = ["requirements/cuda.txt"] } optional-dependencies.hpu = { file = ["requirements/hpu.txt"] } optional-dependencies.mps = { file = ["requirements/mps.txt"] } optional-dependencies.rocm = { file = ["requirements/rocm.txt"] } +optional-dependencies.milvus = { file = ["requirements/milvus.txt"] } [tool.setuptools.packages.find] where = ["src"] diff --git a/requirements.txt b/requirements.txt index b6b7130687..24b2492326 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,3 +37,4 @@ wandb>=0.16.4 xdg-base-dirs>=6.0.1 psutil>=6.0.0 huggingface_hub[hf_transfer]>=0.1.8 +haystack-ai>=2.8 diff --git a/requirements/milvus.txt b/requirements/milvus.txt new file mode 100644 index 0000000000..dc5dc64098 --- /dev/null +++ b/requirements/milvus.txt @@ -0,0 +1,4 @@ +# Cannot upgrade because of https://github.com/milvus-io/milvus-haystack/issues/39 +milvus_haystack==0.0.11 +docling-core[chunking]>=2.10.0 +sentence-transformers>=3.0.0 \ No newline at end of file diff --git a/src/instructlab/cli/data/ingest.py b/src/instructlab/cli/data/ingest.py new file mode 100644 index 0000000000..7a02364053 --- /dev/null +++ b/src/instructlab/cli/data/ingest.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import logging + +# Third Party +import click + +# First Party +from instructlab import clickext +from instructlab.data.ingest_docs import ingest_docs +from instructlab.data.taxonomy_utils import lookup_processed_documents_folder +from instructlab.defaults import DEFAULTS +from instructlab.rag.rag_configuration import ( # type: ignore + document_store_configuration, + embedder_configuration, +) + +logger = logging.getLogger(__name__) + + +# TODO: fill-in help fields +@click.command() +@click.option( + "--document-store-type", + default="milvuslite", + envvar="ILAB_DOCUMENT_STORE_TYPE", + type=click.STRING, + help="The document store type, one of: `milvuslite`.", +) +@click.option( + "--document-store-uri", + default="embeddings.db", + envvar="ILAB_DOCUMENT_STORE_URI", + type=click.STRING, + help="The document store URI", +) +@click.option( + "--document-store-collection-name", + default="IlabEmbeddings", + envvar="ILAB_DOCUMENT_STORE_COLLECTION_NAME", + type=click.STRING, + help="The document store collection name", +) +@click.option( + "--model-dir", + default=lambda: DEFAULTS.MODELS_DIR, + envvar="ILAB_MODEL_DIR", + show_default="The default system model location store, located in the data directory.", + help="Base directories where models are stored.", +) +@click.option( + "--embedding-model", + "embedding_model_name", + default="sentence-transformers/all-minilm-l6-v2", + envvar="ILAB_EMBEDDING_MODEL_NAME", + type=click.STRING, + help="The embedding model name", +) +@click.option( + "--output-dir", + envvar="ILAB_OUTPUT_DIR", + help="Directory where generated datasets are stored.", +) +@click.option( + "--input", + "input_dir", + required=False, + type=click.Path( + exists=True, + dir_okay=True, + file_okay=False, + ), + envvar="ILAB_INPUT_DIR", + help="Directory where pre-processed documents are located.", +) +@click.pass_context +@clickext.display_params +def ingest( + ctx, + document_store_type, + document_store_uri, + document_store_collection_name, + model_dir, + embedding_model_name, + output_dir, + input_dir, +): + """The embedding ingestion pipeline""" + + document_store_config = document_store_configuration( + type=document_store_type, + uri=document_store_uri, + collection_name=document_store_collection_name, + ) + embedder_config = embedder_configuration( + model_dir=model_dir, + model_name=embedding_model_name, + ) + logger.info(f"VectorDB params: {vars(document_store_config)}") + logger.info(f"Embedding model: {vars(embedder_config)}") + if not embedder_config.validate_local_model_path(): + raise click.UsageError( + f"Cannot find local embedding model {embedding_model_name} in {model_dir}. Download the model before running the pipeline." + ) + + if input_dir is None: + if output_dir is None: + output_dir = ctx.obj.config.generate.output_dir + if output_dir is None: + output_dir = DEFAULTS.DATASETS_DIR + logger.info(f"Ingesting latest taxonomy changes at {output_dir}") + processed_docs_folder = lookup_processed_documents_folder(output_dir) + if processed_docs_folder is None: + click.secho( + f"Cannot find the latest processed documents folders from {output_dir}." + + " Please verify that you executed `ilab data generate` and you have updated or new knowledge" + + " documents in the current taxonomy." + ) + raise click.exceptions.Exit(1) + + logger.info(f"Latest processed docs are in {processed_docs_folder}") + input_dir = processed_docs_folder + + ingest_docs( + input_dir=input_dir, + document_store_config=document_store_config, + embedder_config=embedder_config, + ) + + return diff --git a/src/instructlab/cli/data/process.py b/src/instructlab/cli/data/process.py new file mode 100644 index 0000000000..c304823951 --- /dev/null +++ b/src/instructlab/cli/data/process.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import logging + +# Third Party +import click + +# First Party +from instructlab import clickext +from instructlab.configuration import DEFAULTS +from instructlab.data.process_docs import ( + process_docs_from_folder, + process_docs_from_taxonomy, +) + +logger = logging.getLogger(__name__) + + +# TODO: fill-in help fields +@click.command() +@click.option( + "--input", + "input_dir", + required=False, + default=None, + envvar="ILAB_PROCESS_INPUT", + help="The folder with user documents to process. In case it's missing, the knowledge taxonomy files will be processed instead.", + type=click.Path(exists=True, file_okay=False, dir_okay=True, readable=True), +) +@click.option( + "--taxonomy-path", + type=click.Path(), + envvar="ILAB_TAXONOMY_PATH", + help="Directory where taxonomy is stored and accessed from.", +) +@click.option( + "--taxonomy-base", + envvar="ILAB_TAXONOMY_BASE", + help="Branch of taxonomy used to calculate diff against.", +) +@click.option( + "--output", + "output_dir", + type=click.Path( + exists=True, + dir_okay=True, + file_okay=False, + ), + envvar="ILAB_OUTPUT_DIR", + help="Directory where processed docs are stored.", +) +@click.pass_context +@clickext.display_params +def process( + ctx, + taxonomy_path, + taxonomy_base, + input_dir, + output_dir, +): + """The document processing pipeline""" + + if input_dir is None: + if taxonomy_path is None: + taxonomy_path = ctx.obj.config.generate.taxonomy_path + if taxonomy_path is None: + taxonomy_path = DEFAULTS.TAXONOMY_DIR + if taxonomy_base is None: + taxonomy_base = ctx.obj.config.generate.taxonomy_base + if taxonomy_base is None: + taxonomy_base = DEFAULTS.TAXONOMY_BASE + + logger.info( + f"Pre-processing latest taxonomy changes at {taxonomy_path}@{taxonomy_base}" + ) + process_docs_from_taxonomy( + taxonomy_path=taxonomy_path, + taxonomy_base=taxonomy_base, + output_dir=output_dir, + ) + else: + logger.info(f"Pre-processing documents from {input_dir} to {output_dir}") + process_docs_from_folder( + input_dir=input_dir, + output_dir=output_dir, + ) diff --git a/src/instructlab/configuration.py b/src/instructlab/configuration.py index f265be636f..5579a3e410 100644 --- a/src/instructlab/configuration.py +++ b/src/instructlab/configuration.py @@ -3,7 +3,7 @@ # Standard from os import path from re import match -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union import enum import logging import os @@ -162,6 +162,9 @@ class _chat(BaseModel): default=1.0, description="Controls the randomness of the model's responses. Lower values make the output more deterministic, while higher values produce more random results.", ) + rag: Optional[Dict[str, Any]] = Field( + default_factory=dict, description="The RAG chat configuration" + ) class _serve_vllm(BaseModel): diff --git a/src/instructlab/data/ingest_docs.py b/src/instructlab/data/ingest_docs.py new file mode 100644 index 0000000000..3cb4886739 --- /dev/null +++ b/src/instructlab/data/ingest_docs.py @@ -0,0 +1,117 @@ +# Standard +from pathlib import Path +import glob +import logging +import os + +# Third Party +from haystack import Pipeline # type: ignore +from haystack.components.converters import TextFileToDocument # type: ignore +from haystack.components.embedders import ( # type: ignore + SentenceTransformersDocumentEmbedder, +) +from haystack.components.preprocessors import DocumentCleaner # type: ignore +from haystack.components.writers import DocumentWriter # type: ignore +from milvus_haystack import MilvusDocumentStore # type: ignore + +# First Party +from instructlab.rag.haystack.docling_splitter import ( # type: ignore + DoclingDocumentSplitter, +) +from instructlab.rag.rag_configuration import ( + document_store_configuration, + embedder_configuration, +) + +logger = logging.getLogger(__name__) + + +def ingest_docs( + input_dir: str, + document_store_config: document_store_configuration, + embedder_config: embedder_configuration, +): + pipeline = _create_pipeline( + document_store_config=document_store_config, + embedder_config=embedder_config, + ) + _connect_components(pipeline) + _ingest_docs(pipeline=pipeline, input_dir=input_dir) + + +def _create_pipeline( + document_store_config: document_store_configuration, + embedder_config: embedder_configuration, +) -> Pipeline: + pipeline = Pipeline() + pipeline.add_component(instance=_converter_component(), name="converter") + pipeline.add_component(instance=DocumentCleaner(), name="document_cleaner") + # TODO make the params configurable + pipeline.add_component( + instance=_splitter_component(embedder_config), + name="document_splitter", + ) + # TODO make this more generic + pipeline.add_component( + instance=SentenceTransformersDocumentEmbedder( + model=embedder_config.local_model_path() + ), + name="document_embedder", + ) + pipeline.add_component( + instance=DocumentWriter( + _document_store_component( + document_store_type=document_store_config.type, + document_store_uri=document_store_config.uri, + collection_name=document_store_config.collection_name, + ) + ), + name="document_writer", + ) + return pipeline + + +def _connect_components(pipeline): + pipeline.connect("converter", "document_cleaner") + pipeline.connect("document_cleaner", "document_splitter") + pipeline.connect("document_splitter", "document_embedder") + pipeline.connect("document_embedder", "document_writer") + + +def _ingest_docs(pipeline, input_dir): + pattern = "*json" + if Path(os.path.join(input_dir, "docling-artifacts")).exists(): + pattern = "docling-artifacts/" + pattern + + ingestion_results = pipeline.run( + {"converter": {"sources": glob.glob(os.path.join(input_dir, pattern))}} + ) + + document_store = pipeline.get_component("document_writer").document_store + logger.info(f"count_documents: {document_store.count_documents()}") + logger.info( + f"document_writer.documents_written: {ingestion_results['document_writer']['documents_written']}" + ) + + +def _document_store_component(document_store_type, document_store_uri, collection_name): + if document_store_type == "milvuslite": + document_store = MilvusDocumentStore( + connection_args={"uri": document_store_uri}, + collection_name=collection_name, + drop_old=True, + ) + return document_store + raise ValueError(f"Unmanaged document store type {document_store_type}") + + +def _converter_component(): + return TextFileToDocument() + + +def _splitter_component(embedding_model): + return DoclingDocumentSplitter( + embedding_model_id=embedding_model.local_model_path(), + content_format="json", + max_tokens=150, + ) diff --git a/src/instructlab/data/process_docs.py b/src/instructlab/data/process_docs.py new file mode 100644 index 0000000000..d64d5501e4 --- /dev/null +++ b/src/instructlab/data/process_docs.py @@ -0,0 +1,108 @@ +# Standard +from pathlib import Path +from typing import Iterable +import json +import logging +import os +import tempfile +import time + +# Third Party +from docling.datamodel.base_models import ConversionStatus # type: ignore +from docling.datamodel.document import ConversionResult # type: ignore +from docling.document_converter import DocumentConverter # type: ignore + +# First Party +from instructlab.data.taxonomy_utils import lookup_knowledge_files +from instructlab.utils import clear_directory + +logger = logging.getLogger(__name__) + + +def process_docs_from_taxonomy(taxonomy_path, taxonomy_base, output_dir): + # Taxonomy navigation strategy: + # Create temp folder that is deleted when the function returns + # Read taxonomy using read_taxonomy_leaf_nodes from instructlab-sdg package + # The above step also downloads the reference documents. + # Move all the downloaded documents under the temp folder + # Pass the list to process_docs_from_folder + with tempfile.TemporaryDirectory() as temp_dir: + logger.info(f"Temporary directory created: {temp_dir}") + knowledge_files = lookup_knowledge_files(taxonomy_path, taxonomy_base, temp_dir) + logger.info(f"Found {len(knowledge_files)} knowledge files") + logger.info(f"{knowledge_files}") + + process_docs_from_folder(temp_dir, output_dir) + + +def process_docs_from_folder(input_dir, output_dir): + """ + Process user documents from a given `input_dir` folder to the given `output_dir` folder, using docling converters. + Latest version of docling schema is used (currently, v2). + """ + logger.info(f"Processing {input_dir} to {output_dir}") + + clear_directory(Path(output_dir)) + + source_files = _load_source_files(input_dir=input_dir) + logger.info(f"Transforming source files {[p.name for p in source_files]}") + + doc_converter = DocumentConverter() + + start_time = time.time() + conv_results = doc_converter.convert_all( + source_files, + raises_on_error=False, + ) + _, _, failure_count = _export_documents(conv_results, output_dir=Path(output_dir)) + end_time = time.time() - start_time + logger.info(f"Document conversion complete in {end_time:.2f} seconds.") + + if failure_count > 0: + raise RuntimeError( + f"The example failed converting {failure_count} on {len(source_files)}." + ) + + +def _load_source_files(input_dir) -> list[Path]: + return [ + Path(os.path.join(input_dir, f)) + for f in os.listdir(input_dir) + if os.path.isfile(os.path.join(input_dir, f)) + ] + + +def _export_documents( + conv_results: Iterable[ConversionResult], + output_dir: Path, +): + output_dir.mkdir(parents=True, exist_ok=True) + + success_count = 0 + failure_count = 0 + partial_success_count = 0 + + for conv_res in conv_results: + if conv_res.status == ConversionStatus.SUCCESS: + success_count += 1 + doc_filename = conv_res.input.file.stem + + with (output_dir / f"{doc_filename}.json").open("w") as fp: + fp.write(json.dumps(conv_res.document.export_to_dict())) + elif conv_res.status == ConversionStatus.PARTIAL_SUCCESS: + logger.info( + f"Document {conv_res.input.file} was partially converted with the following errors:" + ) + for item in conv_res.errors: + print(f"\t{item.error_message}") + partial_success_count += 1 + else: + logger.info(f"Document {conv_res.input.file} failed to convert.") + failure_count += 1 + + logger.info( + f"Processed {success_count + partial_success_count + failure_count} docs, " + f"of which {failure_count} failed " + f"and {partial_success_count} were partially converted." + ) + return success_count, partial_success_count, failure_count diff --git a/src/instructlab/data/taxonomy_utils.py b/src/instructlab/data/taxonomy_utils.py new file mode 100644 index 0000000000..cc41ccdbda --- /dev/null +++ b/src/instructlab/data/taxonomy_utils.py @@ -0,0 +1,48 @@ +# Standard +from pathlib import Path +from typing import Optional +import logging + +# Third Party +from instructlab.sdg.utils.taxonomy import read_taxonomy_leaf_nodes + +logger = logging.getLogger(__name__) + + +def lookup_knowledge_files(taxonomy_path, taxonomy_base, temp_dir) -> list[Path]: + """ + Lookup updated or new knowledge files in the taxonomy repo. + Download the documents referenced in the configured datasets folder under a temporary folder. + Finally, groups all the documents at the root of this folder and returns the list of paths. + """ + yaml_rules = None + leaf_nodes = read_taxonomy_leaf_nodes( + taxonomy_path, taxonomy_base, yaml_rules, temp_dir + ) + knowledge_files: list[Path] = [] + for leaf_node in leaf_nodes.values(): + knowledge_files.extend(leaf_node[0]["filepaths"]) + + grouped_knowledge_files = [] + for knowledge_file in knowledge_files: + grouped_knowledge_files.append( + knowledge_file.rename(Path(temp_dir) / knowledge_file.name) + ) + + return knowledge_files + + +def lookup_processed_documents_folder(output_dir: str) -> Optional[Path]: + latest_folder = ( + max(Path(output_dir).iterdir(), key=lambda d: d.stat().st_mtime) + if Path(output_dir).exists() + else None + ) + logger.info(f"Latest processed folder is {latest_folder}") + + if latest_folder is not None: + docling_folder = Path.joinpath(latest_folder, "docling-artifacts") + logger.debug(f"Docling folder: {docling_folder}") + if docling_folder.exists(): + return docling_folder + return None diff --git a/src/instructlab/model/chat.py b/src/instructlab/model/chat.py index 0eed4f328e..8d948b4a2c 100644 --- a/src/instructlab/model/chat.py +++ b/src/instructlab/model/chat.py @@ -30,6 +30,8 @@ from instructlab import configuration as cfg from instructlab import log from instructlab.client_utils import HttpClientParams +from instructlab.rag.rag import RagHandler +from instructlab.rag.rag_configuration import RagConfig, no_rag_config, rag_options # Local from ..client_utils import http_client @@ -37,6 +39,7 @@ logger = logging.getLogger(__name__) + HELP_MD = """ Help / TL;DR - `/q`: **q**uit @@ -50,6 +53,7 @@ - `/N`: **n**ew session (ignoring loaded) - `/d `: **d**isplay previous response based on input, if passed 1 then previous, if 2 then second last response and so on. - `/p `: previous response in **p**lain text based on input, if passed 1 then previous, if 2 then second last response and so on. +- `/r`: toggle the status of the RAG pipeline. - `/md `: previous response in **M**ark**d**own based on input, if passed 1 then previous, if 2 then second last response and so on. - `/s filepath`: **s**ave current session to `filepath` - `/l filepath`: **l**oad `filepath` and start a new session @@ -172,6 +176,7 @@ def is_openai_server_and_serving_model( "--temperature", cls=clickext.ConfigOption, ) +@rag_options @click.pass_context @clickext.display_params def chat( @@ -191,6 +196,7 @@ def chat( model_family, serving_log_file, temperature, + **kwargs, ): """Runs a chat using the modified model""" # pylint: disable=import-outside-toplevel @@ -297,6 +303,9 @@ def chat( backend_instance.shutdown() raise click.exceptions.Exit(1) from exc + rag_config = RagConfig(ctx.obj.config.chat.rag, **kwargs) + logger.info(f"rag_config is {vars(rag_config)}") + try: chat_cli( ctx, @@ -309,6 +318,7 @@ def chat( qq=quick_question, max_tokens=max_tokens, temperature=temperature, + rag_config=rag_config, ) except ChatException as exc: click.secho(f"Executing chat failed with: {exc}", fg="red") @@ -339,6 +349,7 @@ def __init__( log_file=None, max_tokens=None, temperature=1.0, + rag_config: RagConfig = no_rag_config, ): self.client = client self.model = model @@ -348,6 +359,7 @@ def __init__( self.log_file = log_file self.max_tokens = max_tokens self.temperature = temperature + self.rag_handler = RagHandler(rag_config) self.console = Console() @@ -397,6 +409,10 @@ def model_name(self): def _right_prompt(self): return FormattedText( [ + ( + "#3f7cac bold", + "[RAG]" if self.rag_handler.is_enabled() else "", + ), # info blue for multiple ( "#3f7cac bold", f"[{'M' if self.multiline else 'S'}]", @@ -444,6 +460,7 @@ def _handle_context(self, content): Markdown("**WARNING**: No contexts loaded from the config file.") ) raise KeyboardInterrupt + cs = content.split() if len(cs) < 2: self._sys_print( @@ -594,6 +611,10 @@ def _handle_list_contexts(self, _): self._sys_print(Markdown(f"**Available contexts:**\n\n{context_list}")) raise KeyboardInterrupt + def _handle_rag(self, _): + self.rag_handler.toggle_state() + raise KeyboardInterrupt + def start_prompt( self, logger, # pylint: disable=redefined-outer-name @@ -615,6 +636,7 @@ def start_prompt( "/s": self._handle_save_session, "/l": self._handle_load_session, "/lc": self._handle_list_contexts, + "/r": self._handle_rag, } if content is None: @@ -636,6 +658,9 @@ def start_prompt( self.log_message(PROMPT_PREFIX + content + "\n\n") + if self.rag_handler.is_enabled(): + content = self.rag_handler.augment_user_query(content) + # Update message history and token counters self._update_conversation(content, "user") @@ -766,6 +791,7 @@ def chat_cli( qq, max_tokens, temperature, + rag_config: RagConfig, ): """Starts a CLI-based chat with the server""" client = OpenAI( @@ -832,6 +858,7 @@ def chat_cli( loaded=loaded, temperature=(temperature if temperature is not None else config.temperature), max_tokens=(max_tokens if max_tokens else config.max_tokens), + rag_config=rag_config, ) if not qq and session is None: diff --git a/src/instructlab/rag/component_factory.py b/src/instructlab/rag/component_factory.py new file mode 100644 index 0000000000..b1122ae692 --- /dev/null +++ b/src/instructlab/rag/component_factory.py @@ -0,0 +1,52 @@ +# Third Party +from haystack.components.embedders import ( # type: ignore + SentenceTransformersTextEmbedder, +) +from milvus_haystack import MilvusDocumentStore # type: ignore +from milvus_haystack.milvus_embedding_retriever import ( # type: ignore + MilvusEmbeddingRetriever, +) + +# First Party +from instructlab.rag.rag_configuration import ( + _retriever_configuration, + document_store_configuration, +) + + +def create_document_store( + document_store_config: document_store_configuration, drop_old: bool +): + if document_store_config.type == "milvuslite": + return MilvusDocumentStore( + connection_args={"uri": document_store_config.uri}, + collection_name=document_store_config.collection_name, + drop_old=drop_old, + ) + else: + raise ValueError(f"Unmanaged document store type {document_store_config.type}") + + +def create_retriever( + document_store_config: document_store_configuration, + retriever_config: _retriever_configuration, + document_store: MilvusDocumentStore, +): + if document_store_config.type == "milvuslite": + return MilvusEmbeddingRetriever( + document_store=document_store, + top_k=retriever_config.top_k, + ) + else: + raise ValueError(f"Unmanaged document store type {document_store_config.type}") + + +def create_embedder(retriever_config: _retriever_configuration): + if retriever_config.embedder is None: + raise ValueError( + f"Missing value for field embedder in {vars(retriever_config)}" + ) + + return SentenceTransformersTextEmbedder( + model=retriever_config.embedder.local_model_path() + ) diff --git a/src/instructlab/rag/haystack/docling_splitter.py b/src/instructlab/rag/haystack/docling_splitter.py new file mode 100644 index 0000000000..68bf29352e --- /dev/null +++ b/src/instructlab/rag/haystack/docling_splitter.py @@ -0,0 +1,102 @@ +# Standard +from typing import Any, Dict, List, cast +import logging + +# Third Party +from docling_core.transforms.chunker.hybrid_chunker import HybridChunker +from docling_core.types import DoclingDocument +from docling_core.types.legacy_doc.document import ( + ExportedCCSDocument as LegacyDoclingDocument, +) +from docling_core.utils.legacy import legacy_to_docling_document +from haystack import Document, component # type: ignore +from haystack.core.serialization import ( # type: ignore + default_from_dict, + default_to_dict, +) +from pydantic_core._pydantic_core import ValidationError + +logger = logging.getLogger(__name__) + + +@component +class DoclingDocumentSplitter: + SUPPORTED_CONTENT_FORMATS = ["json"] + + def __init__(self, embedding_model_id=None, content_format=None, max_tokens=None): + self.__chunker = HybridChunker( + tokenizer=embedding_model_id, max_tokens=max_tokens + ) + self.__embedding_model_id = embedding_model_id + + if content_format not in self.SUPPORTED_CONTENT_FORMATS: + raise ValueError( + f"Only the following input formats are currently supported: {self.SUPPORTED_CONTENT_FORMATS}." + ) + self.__content_format = content_format + + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document]): + if not isinstance(documents, list) or ( + documents and not isinstance(documents[0], Document) + ): + raise TypeError( + "DoclingDocumentSplitter expects a List of Documents as input." + ) + + split_docs = [] + for doc in documents: + if doc.content is None: + raise ValueError(f"Missing content for document ID {doc.id}.") + + chunks = self._split_with_docling(doc.meta["file_path"], doc.content) + current_split_docs = [Document(content=chunk) for chunk in chunks] + split_docs.extend(current_split_docs) + + return {"documents": split_docs} + + def _split_with_docling(self, file_path: str, text: str) -> List[str]: + if self.__content_format == "json": + try: + # We expect the JSON coming from instructlab-sdg, so in docling "legacy" schema + # See this note about the content that will not be preserved in the transformation: + # https://github.com/DS4SD/docling-core/blob/3f631f06277a2a7301c4c7a4e45792242512ce11/docling_core/utils/legacy.py#L352 + legacy_document: LegacyDoclingDocument = ( + LegacyDoclingDocument.model_validate_json(text) + ) + document = legacy_to_docling_document(legacy_document) + except ValidationError: + logger.info( + f"Document at {file_path} not in legacy docling format. Tring the updated schema instead." + ) + try: + document = DoclingDocument.model_validate_json(text) + except ValidationError as e: + logger.error( + f"Expected {file_path} to be in docling format, but schema validation failed: {e}" + ) + raise e + + else: + raise ValueError(f"Unexpected content format {self.__content_format}") + + chunk_iter = self.__chunker.chunk(dl_doc=document) + chunks = list(chunk_iter) + return [self.__chunker.serialize(chunk=chunk) for chunk in chunks] + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + """ + return default_to_dict( # type: ignore[no-any-return] + self, + embedding_model_id=self.__embedding_model_id, + content_format=self.__content_format, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DoclingDocumentSplitter": + """ + Deserializes the component from a dictionary. + """ + return cast("DoclingDocumentSplitter", default_from_dict(cls, data)) diff --git a/src/instructlab/rag/rag.py b/src/instructlab/rag/rag.py new file mode 100644 index 0000000000..ba0ec74dfd --- /dev/null +++ b/src/instructlab/rag/rag.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import logging + +# Third Party +from haystack import Pipeline # type: ignore + +# First Party +from instructlab.rag.component_factory import ( + create_document_store, + create_embedder, + create_retriever, +) +from instructlab.rag.rag_configuration import RagConfig + +logger = logging.getLogger(__name__) + + +_DEFAULT_RAG_PROMPT = """ +Given the following information, answer the question. +Context: +{context} +Question: +{user_query} +Answer: +""" + + +def rag_prompt() -> str: + return _DEFAULT_RAG_PROMPT + + +def _init_rag_chat_pipeline( + rag_config: RagConfig, +): + document_store = create_document_store(rag_config.document_store, drop_old=False) + logger.debug(f"RAG document_store created {document_store}") + + document_retriever = create_retriever( + document_store_config=rag_config.document_store, + retriever_config=rag_config.retriever, + document_store=document_store, + ) + logger.debug(f"RAG document_retriever created {document_retriever}") + + text_embedder = create_embedder(rag_config.retriever) + logger.debug(f"RAG text_embedder created {text_embedder}") + + pipeline = Pipeline() + pipeline.add_component("embedder", text_embedder) + pipeline.add_component("retriever", document_retriever) + pipeline.connect("embedder.embedding", "retriever.query_embedding") + logger.debug(f"RAG pipeline created {pipeline}") + + return pipeline + + +class RagHandler: + def __init__(self, rag_config: RagConfig): + self._rag_config = rag_config + self.rag_pipeline: Pipeline = None + self.rag_prompt = rag_prompt() + + def is_enabled(self) -> bool: + return bool(self._rag_config.enabled) + + def toggle_state(self): + self._rag_config.enabled = not self._rag_config.enabled + + def augment_user_query(self, user_query: str) -> str: + if self.rag_pipeline is None: + self.rag_pipeline = _init_rag_chat_pipeline( + rag_config=self._rag_config, + ) + + retrieval_results = self.rag_pipeline.run( + { + "embedder": {"text": user_query}, + } + ) + context = "\n".join( + [doc.content for doc in retrieval_results["retriever"]["documents"]] + ) + + logger.debug("-" * 10) + logger.debug(f"RAG context is {context}") + logger.debug("-" * 10) + + updated_user_query = self.rag_prompt.format( + context=context, user_query=user_query + ) + logger.debug(f"Updated user query is {updated_user_query}") + return updated_user_query diff --git a/src/instructlab/rag/rag_configuration.py b/src/instructlab/rag/rag_configuration.py new file mode 100644 index 0000000000..e9bfdc2719 --- /dev/null +++ b/src/instructlab/rag/rag_configuration.py @@ -0,0 +1,142 @@ +# Standard +from typing import Any, Optional +import logging +import os + +# Third Party +from pydantic import BaseModel, ConfigDict +import click + +# First Party +from instructlab.configuration import DEFAULTS + +logger = logging.getLogger(__name__) + + +def rag_options(command): + """Wrapper to apply common options.""" + command = click.option( + "--rag", + "is_rag", + is_flag=True, + envvar="ILAB_RAG", + help="To enable Retrieval-Augmented Generation", + )(command) + command = click.option( + "--document-store-type", + default="milvuslite", + envvar="ILAB_DOCUMENT_STORE_TYPE", + type=click.STRING, + help="The document store type, one of: `milvuslite`.", + )(command) + command = click.option( + "--document-store-uri", + default="embeddings.db", + envvar="ILAB_DOCUMENT_STORE_URI", + type=click.STRING, + help="The document store URI", + )(command) + command = click.option( + "--document-store-collection-name", + default="IlabEmbeddings", + envvar="ILAB_DOCUMENT_STORE_COLLECTION_NAME", + type=click.STRING, + help="The document store collection name", + )(command) + command = click.option( + "--retriever-top-k", + default=10, + envvar="ILAB_RETRIEVER_TOP_K", + type=click.INT, + )(command) + command = click.option( + "--retriever-embedder-model-dir", + default=lambda: DEFAULTS.MODELS_DIR, + envvar="ILAB_EMBEDDER_MODEL_DIR", + show_default="The default system model location store, located in the data directory.", + help="Base directories where models are stored.", + )(command) + command = click.option( + "--retriever-embedder-model-name", + default="sentence-transformers/all-minilm-l6-v2", + envvar="ILAB_EMBEDDER_MODEL_NAME", + type=click.STRING, + help="The embedding model name", + )(command) + return command + + +def init_from_flags(model: BaseModel, prefix="", **kwargs): + model_name = model.__class__.__name__ + if model_name.startswith("_"): + model_name = model_name[1:] + if model_name.endswith("_configuration"): + model_name = model_name.replace("_configuration", "") + model_name = prefix + model_name + "_" + logger.debug(f"model_name is {model_name}") + for key, value in kwargs.items(): + if value is not None: + logger.debug(f"key is {key}") + if key.startswith(model_name): + attr_name = key.replace(model_name, "") + logger.debug(f"attr_name is {attr_name}") + if hasattr(model, attr_name): + logger.debug(f"Overriding from flag {key}") + setattr(model, attr_name, value) + + prefix = model_name + for _, value in vars(model).items(): + if isinstance(value, BaseModel): + init_from_flags(model=value, prefix=prefix, **kwargs) + + +class document_store_configuration(BaseModel): + model_config = ConfigDict(extra="ignore") + type: Optional[str] = None + uri: Optional[str] = None + collection_name: Optional[str] = None + + +class embedder_configuration(BaseModel): + model_config = ConfigDict(extra="ignore", protected_namespaces=()) + model_dir: Optional[str] = None + model_name: Optional[str] = None + + def validate_local_model_path(self): + local_model_path = self.local_model_path() + return os.path.exists(local_model_path) and os.path.isdir(local_model_path) + + def local_model_path(self) -> str: + if self.model_dir is None: + click.secho(f"Missing value for field model_dir in {vars(self)}") + raise click.exceptions.Exit(1) + + if self.model_name is None: + click.secho(f"Missing value for field model_name in {vars(self)}") + raise click.exceptions.Exit(1) + + return os.path.join(self.model_dir, self.model_name) + + +class _retriever_configuration(BaseModel): + top_k: Optional[int] = None + embedder: Optional[embedder_configuration] = embedder_configuration() + + +class RagConfig: + def __init__(self, rag_config: dict[str, Any], **kwargs): + logger.debug(f"init from {rag_config}") + logger.debug(f"init from {kwargs}") + self.enabled = rag_config.get("enable") or kwargs.get("is_rag") + self.document_store = document_store_configuration( + **(rag_config.get("document_store") or {}) + ) + self.retriever = _retriever_configuration(**(rag_config.get("retriever") or {})) + + logger.debug(f"Before injecting config: {vars(self)}") + init_from_flags(model=self.document_store, **kwargs) + init_from_flags(model=self.retriever, **kwargs) + logger.debug(f"After injecting config: {vars(self)}") + + +no_rag_config: RagConfig = RagConfig({"enable": False}) diff --git a/tests/test_lab.py b/tests/test_lab.py index 9c5f9e476f..e9e92db71c 100644 --- a/tests/test_lab.py +++ b/tests/test_lab.py @@ -129,6 +129,8 @@ def has_debug_params(self) -> bool: Command(("data",), needs_config=False, should_fail=False), Command(("data", "generate")), Command(("data", "list")), + Command(("data", "process", ".")), + Command(("data", "ingest", ".")), Command(("system",), needs_config=False, should_fail=False), Command(("system", "info"), needs_config=False, should_fail=False), Command(("taxonomy",), needs_config=False, should_fail=False), @@ -261,7 +263,7 @@ def test_ilab_commands_tested(): if not command.args: continue sub = tested.setdefault(command.args[0], set()) - if len(command.args) == 2: + if len(command.args) >= 2: sub.add(command.args[1]) else: sub.add("") diff --git a/tests/testdata/default_config.yaml b/tests/testdata/default_config.yaml index 0463050c44..15ea93f7c7 100644 --- a/tests/testdata/default_config.yaml +++ b/tests/testdata/default_config.yaml @@ -16,6 +16,9 @@ chat: # Model to be used for chatting with. # Default: /cache/instructlab/models/granite-7b-lab-Q4_K_M.gguf model: /cache/instructlab/models/granite-7b-lab-Q4_K_M.gguf + # The RAG chat configuration + # Default: {} + rag: {} # Filepath of a dialog session file. # Default: None session: diff --git a/tests/testdata/leanimports.py b/tests/testdata/leanimports.py index c1f9051520..741e3bac60 100644 --- a/tests/testdata/leanimports.py +++ b/tests/testdata/leanimports.py @@ -5,7 +5,7 @@ import sys # block slow imports -for unwanted in ["deepspeed", "llama_cpp", "torch", "vllm"]: +for unwanted in ["deepspeed", "llama_cpp", "vllm"]: # importlib raises ModuleNotFound when sys.modules value is None. assert unwanted not in sys.modules sys.modules[unwanted] = None # type: ignore[assignment] diff --git a/tox.ini b/tox.ini index b0fcd5fc8e..dea55949c1 100644 --- a/tox.ini +++ b/tox.ini @@ -19,7 +19,7 @@ setenv = package = wheel wheel_build_env = pkg # equivalent to `pip install instructlab[cpu]` -extras = cpu +extras = cpu,milvus deps = -r requirements-dev.txt commands = ilab --version