diff --git a/supporting-blog-content/langraph-retrieval-agent-template/LICENSE b/supporting-blog-content/langraph-retrieval-agent-template/LICENSE new file mode 100644 index 00000000..57d0481d --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 LangChain + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/supporting-blog-content/langraph-retrieval-agent-template/Makefile b/supporting-blog-content/langraph-retrieval-agent-template/Makefile new file mode 100644 index 00000000..e6294941 --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/Makefile @@ -0,0 +1,64 @@ +.PHONY: all format lint test tests test_watch integration_tests docker_tests help extended_tests + +# Default target executed when no arguments are given to make. +all: help + +# Define a variable for the test file path. +TEST_FILE ?= tests/unit_tests/ + +test: + python -m pytest $(TEST_FILE) + +test_watch: + python -m ptw --snapshot-update --now . -- -vv tests/unit_tests + +test_profile: + python -m pytest -vv tests/unit_tests/ --profile-svg + +extended_tests: + python -m pytest --only-extended $(TEST_FILE) + + +###################### +# LINTING AND FORMATTING +###################### + +# Define a variable for Python and notebook files. +PYTHON_FILES=src/ +MYPY_CACHE=.mypy_cache +lint format: PYTHON_FILES=. +lint_diff format_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep -E '\.py$$|\.ipynb$$') +lint_package: PYTHON_FILES=src +lint_tests: PYTHON_FILES=tests +lint_tests: MYPY_CACHE=.mypy_cache_test + +lint lint_diff lint_package lint_tests: + python -m ruff check . + [ "$(PYTHON_FILES)" = "" ] || python -m ruff format $(PYTHON_FILES) --diff + [ "$(PYTHON_FILES)" = "" ] || python -m ruff check --select I $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || python -m mypy --strict $(PYTHON_FILES) + [ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && python -m mypy --strict $(PYTHON_FILES) --cache-dir $(MYPY_CACHE) + +format format_diff: + ruff format $(PYTHON_FILES) + ruff check --select I --fix $(PYTHON_FILES) + +spell_check: + codespell --toml pyproject.toml + +spell_fix: + codespell --toml pyproject.toml -w + +###################### +# HELP +###################### + +help: + @echo '----' + @echo 'format - run code formatters' + @echo 'lint - run linters' + @echo 'test - run unit tests' + @echo 'tests - run unit tests' + @echo 'test TEST_FILE= - run all tests in file' + @echo 'test_watch - run unit tests in watch mode' + diff --git a/supporting-blog-content/langraph-retrieval-agent-template/README.md b/supporting-blog-content/langraph-retrieval-agent-template/README.md new file mode 100644 index 00000000..568ac32e --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/README.md @@ -0,0 +1,718 @@ +# LangGraph Retrieval Chat Bot Template + +[![CI](https://github.com/langchain-ai/retrieval-agent-template/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/langchain-ai/retrieval-agent-template/actions/workflows/unit-tests.yml) +[![Integration Tests](https://github.com/langchain-ai/retrieval-agent-template/actions/workflows/integration-tests.yml/badge.svg)](https://github.com/langchain-ai/retrieval-agent-template/actions/workflows/integration-tests.yml) +[![Open in - LangGraph Studio](https://img.shields.io/badge/Open_in-LangGraph_Studio-00324d.svg?logo=data:image/svg%2bxml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHdpZHRoPSI4NS4zMzMiIGhlaWdodD0iODUuMzMzIiB2ZXJzaW9uPSIxLjAiIHZpZXdCb3g9IjAgMCA2NCA2NCI+PHBhdGggZD0iTTEzIDcuOGMtNi4zIDMuMS03LjEgNi4zLTYuOCAyNS43LjQgMjQuNi4zIDI0LjUgMjUuOSAyNC41QzU3LjUgNTggNTggNTcuNSA1OCAzMi4zIDU4IDcuMyA1Ni43IDYgMzIgNmMtMTIuOCAwLTE2LjEuMy0xOSAxLjhtMzcuNiAxNi42YzIuOCAyLjggMy40IDQuMiAzLjQgNy42cy0uNiA0LjgtMy40IDcuNkw0Ny4yIDQzSDE2LjhsLTMuNC0zLjRjLTQuOC00LjgtNC44LTEwLjQgMC0xNS4ybDMuNC0zLjRoMzAuNHoiLz48cGF0aCBkPSJNMTguOSAyNS42Yy0xLjEgMS4zLTEgMS43LjQgMi41LjkuNiAxLjcgMS44IDEuNyAyLjcgMCAxIC43IDIuOCAxLjYgNC4xIDEuNCAxLjkgMS40IDIuNS4zIDMuMi0xIC42LS42LjkgMS40LjkgMS41IDAgMi43LS41IDIuNy0xIDAtLjYgMS4xLS44IDIuNi0uNGwyLjYuNy0xLjgtMi45Yy01LjktOS4zLTkuNC0xMi4zLTExLjUtOS44TTM5IDI2YzAgMS4xLS45IDIuNS0yIDMuMi0yLjQgMS41LTIuNiAzLjQtLjUgNC4yLjguMyAyIDEuNyAyLjUgMy4xLjYgMS41IDEuNCAyLjMgMiAyIDEuNS0uOSAxLjItMy41LS40LTMuNS0yLjEgMC0yLjgtMi44LS44LTMuMyAxLjYtLjQgMS42LS41IDAtLjYtMS4xLS4xLTEuNS0uNi0xLjItMS42LjctMS43IDMuMy0yLjEgMy41LS41LjEuNS4yIDEuNi4zIDIuMiAwIC43LjkgMS40IDEuOSAxLjYgMi4xLjQgMi4zLTIuMy4yLTMuMi0uOC0uMy0yLTEuNy0yLjUtMy4xLTEuMS0zLTMtMy4zLTMtLjUiLz48L3N2Zz4=)](https://langgraph-studio.vercel.app/templates/open?githubUrl=https://github.com/langchain-ai/retrieval-agent-template) + +This is a starter project to help you get started with developing a retrieval agent using [LangGraph](https://github.com/langchain-ai/langgraph) in [LangGraph Studio](https://github.com/langchain-ai/langgraph-studio). + +![Graph view in LangGraph studio UI](./static/studio_ui.png) + +It contains example graphs exported from `src/retrieval_agent/graph.py` that implement a retrieval-based question answering system. + +## What it does + +This project has two graphs: an "index" graph, and a "retrieval" graph. + +The index graph takes in document objects and strings, and it indexes them for the configured `user_id`. + +```json +[{ "page_content": "I have 1 cat." }] +``` + +The retrieval chat bot manages a chat history and responds based on fetched context. It: + +1. Takes a user **query** as input +2. Searches for documents in filtered by user_id based on the conversation history +3. Responds using the retrieved information and conversation context + +By default, it's set up to answer questions based on the user's indexed documents, which are filtered by the user's ID for personalized responses. + +## Getting Started + +Assuming you have already [installed LangGraph Studio](https://github.com/langchain-ai/langgraph-studio?tab=readme-ov-file#download), to set up: + +1. Create a `.env` file. + +```bash +cp .env.example .env +``` + +2. Select your retriever & index, and save the access instructions to your `.env` file. + + + +### Setup Retriever + +The defaults values for `retriever_provider` are shown below: + +```yaml +retriever_provider: elastic +``` + +Follow the instructions below to get set up, or pick one of the additional options. + +#### Elasticsearch + +Elasticsearch (as provided by Elastic) is an open source distributed search and analytics engine, scalable data store and vector database optimized for speed and relevance on production-scale workloads. + +##### Setup Elasticsearch +Elasticsearch can be configured as the knowledge base provider for a retrieval agent by being deployed on Elastic Cloud (either as a hosted deployment or serverless project) or on your local environment. + +**Elasticsearch Serverless** + +1. Signup for a free 14 day trial with [Elasticsearch Serverless](https://cloud.elastic.co/registration?onboarding_token=search&cta=cloud-registration&tech=trial&plcmt=article%20content&pg=langchain). +2. Get the Elasticsearch URL, found on home under "Copy your connection details". +3. Create an API key found on home under "API Key". +4. Copy the URL and API key to your `.env` file created above: + +``` +ELASTICSEARCH_URL= +ELASTICSEARCH_API_KEY= +``` + +**Elastic Cloud** + +1. Signup for a free 14 day trial with [Elastic Cloud](https://cloud.elastic.co/registration?onboarding_token=search&cta=cloud-registration&tech=trial&plcmt=article%20content&pg=langchain). +2. Get the Elasticsearch URL, found under Applications of your deployment. +3. Create an API key. See the [official elastic documentation](https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud#creating-an-api-key) for more information. +4. Copy the URL and API key to your `.env` file created above: + +``` +ELASTICSEARCH_URL= +ELASTICSEARCH_API_KEY= +``` +**Local Elasticsearch (Docker)** + +``` +docker run -p 127.0.0.1:9200:9200 -d --name elasticsearch --network elastic-net -e ELASTIC_PASSWORD=changeme -e "discovery.type=single-node" -e "xpack.security.http.ssl.enabled=false" -e "xpack.license.self_generated.type=trial" docker.elastic.co/elasticsearch/elasticsearch:8.15.1 +``` + +See the [official Elastic documentation](https://www.elastic.co/guide/en/elasticsearch/reference/current/run-elasticsearch-locally.html) for more information on running it locally. + +Then populate the following in your `.env` file: + +``` +# As both Elasticsearch and LangGraph Studio runs in Docker, we need to use host.docker.internal to access. + +ELASTICSEARCH_URL=http://host.docker.internal:9200 +ELASTICSEARCH_USER=elastic +ELASTICSEARCH_PASSWORD=changeme +``` +#### MongoDB Atlas + +MongoDB Atlas is a fully-managed cloud database that includes vector search capabilities for AI-powered applications. + +1. Create a free Atlas cluster: +- Go to the [MongoDB Atlas website](https://www.mongodb.com/cloud/atlas/register) and sign up for a free account. +- After logging in, create a free cluster by following the on-screen instructions. + +2. Create a vector search index +- Follow the instructions at [the Mongo docs](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) +- By default, we use the collection `langgraph_retrieval_agent.default` - create the index there +- Add an indexed filter for path `user_id` +- **IMPORTANT**: select Atlas Vector Search NOT Atlas Search when creating the index +Your final JSON editor configuration should look something like the following: + +```json +{ + "fields": [ + { + "numDimensions": 1536, + "path": "embedding", + "similarity": "cosine", + "type": "vector" + }, + { + "path": "user_id", + "type": "filter" + } + ] +} +``` + +The exact numDimensions may differ if you select a different embedding model. + +2. Set up your environment: +- In the Atlas dashboard, click on "Connect" for your cluster. +- Choose "Connect your application" and copy the provided connection string. +- Create a `.env` file in your project root if you haven't already. +- Add your MongoDB Atlas connection string to the `.env` file: + +``` +MONGODB_URI="mongodb+srv://username:password@your-cluster-url.mongodb.net/?retryWrites=true&w=majority&appName=your-cluster-name" +``` + +Replace `username`, `password`, `your-cluster-url`, and `your-cluster-name` with your actual credentials and cluster information. +#### Pinecone Serverless + +Pinecone is a managed, cloud-native vector database that provides long-term memory for high-performance AI applications. + +1. Sign up for a Pinecone account at [https://login.pinecone.io/login](https://login.pinecone.io/login) if you haven't already. + +2. After logging in, generate an API key from the Pinecone console. + +3. Create a serverless index: + - Choose a name for your index (e.g., "example-index") + - Set the dimension based on your embedding model (e.g., 1536 for OpenAI embeddings) + - Select "cosine" as the metric + - Choose "Serverless" as the index type + - Select your preferred cloud provider and region (e.g., AWS us-east-1) + +4. Once you have created your index and obtained your API key, add them to your `.env` file: + +``` +PINECONE_API_KEY=your-api-key +PINECONE_INDEX_NAME=your-index-name +``` + + +### Setup Model + +The defaults values for `response_model`, `query_model` are shown below: + +```yaml +response_model: anthropic/claude-3-5-sonnet-20240620 +query_model: anthropic/claude-3-haiku-20240307 +``` + +Follow the instructions below to get set up, or pick one of the additional options. + +#### Anthropic + +To use Anthropic's chat models: + +1. Sign up for an [Anthropic API key](https://console.anthropic.com/) if you haven't already. +2. Once you have your API key, add it to your `.env` file: + +``` +ANTHROPIC_API_KEY=your-api-key +``` +#### OpenAI + +To use OpenAI's chat models: + +1. Sign up for an [OpenAI API key](https://platform.openai.com/signup). +2. Once you have your API key, add it to your `.env` file: +``` +OPENAI_API_KEY=your-api-key +``` + + + +### Setup Embedding Model + +The defaults values for `embedding_model` are shown below: + +```yaml +embedding_model: openai/text-embedding-3-small +``` + +Follow the instructions below to get set up, or pick one of the additional options. + +#### OpenAI + +To use OpenAI's embeddings: + +1. Sign up for an [OpenAI API key](https://platform.openai.com/signup). +2. Once you have your API key, add it to your `.env` file: +``` +OPENAI_API_KEY=your-api-key +``` + +#### Cohere + +To use Cohere's embeddings: + +1. Sign up for a [Cohere API key](https://dashboard.cohere.com/welcome/register). +2. Once you have your API key, add it to your `.env` file: + +```bash +COHERE_API_KEY=your-api-key +``` + + + + + + + +## Using + +Once you've set up your retriever saved your model secrets, it's time to try it out! First, let's add some information to the index. Open studio, select the "indexer" graph from the dropdown in the top-left, provide an example user ID in the configuration at the bottom, and then add some content to chat over. + +```json +[{ "page_content": "My cat knows python." }] +``` + +When you upload content, it will be indexed under the configured user ID. You know it's complete when the indexer "delete"'s the content from its graph memory (since it's been persisted in your configured storage provider). + +Next, open the "retrieval_graph" using the dropdown in the top-left. Ask it about your cat to confirm it can fetch the required information! If you change the `user_id` at any time, notice how it no longer has access to your information. The graph is doing simple filtering of content so you only access the information under the provided ID. + +## How to customize + +You can customize this retrieval agent template in several ways: + +1. **Change the retriever**: You can switch between different vector stores (Elasticsearch, MongoDB, Pinecone) by modifying the `retriever_provider` in the configuration. Each provider has its own setup instructions in the "Getting Started" section above. + +2. **Modify the embedding model**: You can change the embedding model used for document indexing and query embedding by updating the `embedding_model` in the configuration. Options include various OpenAI and Cohere models. + +3. **Adjust search parameters**: Fine-tune the retrieval process by modifying the `search_kwargs` in the configuration. This allows you to control aspects like the number of documents retrieved or similarity thresholds. + +4. **Customize the response generation**: You can modify the `response_system_prompt` to change how the agent formulates its responses. This allows you to adjust the agent's personality or add specific instructions for answer generation. + +5. **Change the language model**: Update the `response_model` in the configuration to use different language models for response generation. Options include various Claude models from Anthropic, as well as models from other providers like Fireworks AI. + +6. **Extend the graph**: You can add new nodes or modify existing ones in the `src/retrieval_agent/graph.py` file to introduce additional processing steps or decision points in the agent's workflow. + +7. **Add new tools**: Implement new tools or API integrations in `src/retrieval_agent/tools.py` to expand the agent's capabilities beyond simple retrieval and response generation. + +8. **Modify prompts**: Update the prompts used for query generation and response formulation in `src/retrieval_agent/prompts.py` to better suit your specific use case or to improve the agent's performance. + +Remember to test your changes thoroughly to ensure they improve the agent's performance for your specific use case. + +## Development + +While iterating on your graph, you can edit past state and rerun your app from past states to debug specific nodes. Local changes will be automatically applied via hot reload. Try adding an interrupt before the agent calls tools, updating the default system message in `src/retrieval_agent/utils.py` to take on a persona, or adding additional nodes and edges! + +Follow up requests will be appended to the same thread. You can create an entirely new thread, clearing previous history, using the `+` button in the top right. + +You can find the latest (under construction) docs on [LangGraph](https://github.com/langchain-ai/langgraph) here, including examples and other references. Using those guides can help you pick the right patterns to adapt here for your use case. + +LangGraph Studio also integrates with [LangSmith](https://smith.langchain.com/) for more in-depth tracing and collaboration with teammates. + + \ No newline at end of file diff --git a/supporting-blog-content/langraph-retrieval-agent-template/langgraph.json b/supporting-blog-content/langraph-retrieval-agent-template/langgraph.json new file mode 100644 index 00000000..b90b71cb --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/langgraph.json @@ -0,0 +1,8 @@ +{ + "dependencies": ["."], + "graphs": { + "indexer": "./src/retrieval_graph/index_graph.py:graph", + "retrieval_graph": "./src/retrieval_graph/graph.py:graph" + }, + "env": ".env" +} diff --git a/supporting-blog-content/langraph-retrieval-agent-template/pyproject.toml b/supporting-blog-content/langraph-retrieval-agent-template/pyproject.toml new file mode 100644 index 00000000..047745a0 --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/pyproject.toml @@ -0,0 +1,64 @@ +[project] +name = "retrieval-graph" +version = "0.0.1" +description = "Starter template for making a custom retrieval graph in LangGraph." +authors = [ + { name = "William Fu-Hinthorn", email = "13333726+hinthornw@users.noreply.github.com" }, +] +license = { text = "MIT" } +readme = "README.md" +requires-python = ">=3.9" +dependencies = [ + "langgraph>=0.2.6", + "langchain-openai>=0.1.22", + "langchain-anthropic>=0.1.23", + "langchain>=0.2.14", + "langchain-fireworks>=0.1.7", + "python-dotenv>=1.0.1", + "langchain-elasticsearch>=0.2.2,<0.3.0", + "langchain-pinecone>=0.1.3,<0.2.0", + "msgspec>=0.18.6", + "langchain-mongodb>=0.1.9", + "langchain-cohere>=0.2.4", +] + +[project.optional-dependencies] +dev = ["mypy>=1.11.1", "ruff>=0.6.1"] + +[build-system] +requires = ["setuptools>=73.0.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["langgraph.templates.retrieval_graph", "retrieval_graph"] +[tool.setuptools.package-dir] +"langgraph.templates.retrieval_graph" = "src/retrieval_graph" +"retrieval_graph" = "src/retrieval_graph" + + +[tool.setuptools.package-data] +"*" = ["py.typed"] + +[tool.ruff] +lint.select = [ + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "D", # pydocstyle + "D401", # First line should be in imperative mood + "T201", + "UP", +] +lint.ignore = [ + "UP006", + "UP007", + # We actually do want to import from typing_extensions + "UP035", + # Relax the convention by _not_ requiring documentation for every function parameter. + "D417", + "E501", +] +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["D", "UP"] +[tool.ruff.lint.pydocstyle] +convention = "google" diff --git a/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/__init__.py b/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/__init__.py new file mode 100644 index 00000000..f31297ed --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/__init__.py @@ -0,0 +1,33 @@ +"""Retrieval Graph Module + +This module provides a conversational retrieval graph system that enables +intelligent document retrieval and question answering based on user inputs. + +The main components of this system include: + +1. A state management system for handling conversation context and document retrieval. +2. A query generation mechanism that refines user inputs into effective search queries. +3. A document retrieval system that fetches relevant information based on generated queries. +4. A response generation system that formulates answers using retrieved documents and conversation history. + +The graph is configured using customizable parameters defined in the Configuration class, +allowing for flexibility in model selection, retrieval methods, and system prompts. + +Key Features: +- Adaptive query generation for improved document retrieval +- Integration with various retrieval providers (e.g., Elastic, Pinecone, MongoDB) +- Customizable language models for query and response generation +- Stateful conversation management for context-aware interactions + +Usage: + The main entry point for using this system is the `graph` object exported from this module. + It can be invoked to process user inputs and generate responses based on retrieved information. + +For detailed configuration options and usage instructions, refer to the Configuration class +and individual component documentation within the retrieval_graph package. +""" # noqa + +from retrieval_graph.graph import graph +from retrieval_graph.index_graph import graph as index_graph + +__all__ = ["graph", "index_graph"] diff --git a/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/configuration.py b/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/configuration.py new file mode 100644 index 00000000..31d4f4f7 --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/configuration.py @@ -0,0 +1,104 @@ +"""Define the configurable parameters for the agent.""" + +from __future__ import annotations + +from dataclasses import dataclass, field, fields +from typing import Annotated, Any, Literal, Optional, Type, TypeVar + +from langchain_core.runnables import RunnableConfig, ensure_config + +from retrieval_graph import prompts + + +@dataclass(kw_only=True) +class IndexConfiguration: + """Configuration class for indexing and retrieval operations. + + This class defines the parameters needed for configuring the indexing and + retrieval processes, including user identification, embedding model selection, + retriever provider choice, and search parameters. + """ + + user_id: str = field(metadata={"description": "Unique identifier for the user."}) + + embedding_model: Annotated[ + str, + {"__template_metadata__": {"kind": "embeddings"}}, + ] = field( + default="cohere/embed-english-v3.0", + metadata={ + "description": "Name of the embedding model to use. Must be a valid embedding model name." + }, + ) + + retriever_provider: Annotated[ + Literal["elastic", "elastic-local", "pinecone", "mongodb"], + {"__template_metadata__": {"kind": "retriever"}}, + ] = field( + default="elastic", + metadata={ + "description": "The vector store provider to use for retrieval. Options are 'elastic', 'pinecone', or 'mongodb'." + }, + ) + + search_kwargs: dict[str, Any] = field( + default_factory=dict, + metadata={ + "description": "Additional keyword arguments to pass to the search function of the retriever." + }, + ) + + @classmethod + def from_runnable_config( + cls: Type[T], config: Optional[RunnableConfig] = None + ) -> T: + """Create an IndexConfiguration instance from a RunnableConfig object. + + Args: + cls (Type[T]): The class itself. + config (Optional[RunnableConfig]): The configuration object to use. + + Returns: + T: An instance of IndexConfiguration with the specified configuration. + """ + config = ensure_config(config) + configurable = config.get("configurable") or {} + _fields = {f.name for f in fields(cls) if f.init} + return cls(**{k: v for k, v in configurable.items() if k in _fields}) + + +T = TypeVar("T", bound=IndexConfiguration) + + +@dataclass(kw_only=True) +class Configuration(IndexConfiguration): + """The configuration for the agent.""" + + response_system_prompt: str = field( + default=prompts.RESPONSE_SYSTEM_PROMPT, + metadata={"description": "The system prompt used for generating responses."}, + ) + predict_next_question_prompt: str = field( + default=prompts.PREDICT_NEXT_QUESTION_PROMPT, + metadata={"description": "The system prompt used for generating responses."}, + ) + response_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field( + default="cohere/command-r-08-2024", + metadata={ + "description": "The language model used for generating responses. Should be in the form: provider/model-name." + }, + ) + + query_system_prompt: str = field( + default=prompts.QUERY_SYSTEM_PROMPT, + metadata={ + "description": "The system prompt used for processing and refining queries." + }, + ) + + query_model: Annotated[str, {"__template_metadata__": {"kind": "llm"}}] = field( + default="cohere/command-r-08-2024", + metadata={ + "description": "The language model used for processing and refining queries. Should be in the form: provider/model-name." + }, + ) diff --git a/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/graph.py b/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/graph.py new file mode 100644 index 00000000..1653adb7 --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/graph.py @@ -0,0 +1,192 @@ +"""Main entrypoint for the conversational retrieval graph. + +This module defines the core structure and functionality of the conversational +retrieval graph. It includes the main graph definition, state management, +and key functions for processing user inputs, generating queries, retrieving +relevant documents, and formulating responses. +""" + +from datetime import datetime, timezone +from typing import cast + +from langchain_core.documents import Document +from langchain_core.messages import BaseMessage +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import RunnableConfig +from langgraph.graph import StateGraph +import logging +from langchain_core.messages import HumanMessage + + +from retrieval_graph import retrieval +from retrieval_graph.configuration import Configuration +from retrieval_graph.state import InputState, State +from retrieval_graph.utils import format_docs, get_message_text, load_chat_model + +# Define the function that calls the model +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class SearchQuery(BaseModel): + """Search the indexed documents for a query.""" + + query: str + + +async def generate_query( + state: State, *, config: RunnableConfig +) -> dict[str, list[str]]: + """Generate a search query based on the current state and configuration. + + This function analyzes the messages in the state and generates an appropriate + search query. For the first message, it uses the user's input directly. + For subsequent messages, it uses a language model to generate a refined query. + + Args: + state (State): The current state containing messages and other information. + config (RunnableConfig | None, optional): Configuration for the query generation process. + + Returns: + dict[str, list[str]]: A dictionary with a 'queries' key containing a list of generated queries. + + Behavior: + - If there's only one message (first user input), it uses that as the query. + - For subsequent messages, it uses a language model to generate a refined query. + - The function uses the configuration to set up the prompt and model for query generation. + """ + messages = state.messages + if len(messages) == 1: + # It's the first user question. We will use the input directly to search. + human_input = get_message_text(messages[-1]) + return {"queries": [human_input]} + else: + configuration = Configuration.from_runnable_config(config) + # Feel free to customize the prompt, model, and other logic! + prompt = ChatPromptTemplate.from_messages( + [ + ("system", configuration.query_system_prompt), + ("placeholder", "{messages}"), + ] + ) + model = load_chat_model(configuration.query_model).with_structured_output( + SearchQuery + ) + + message_value = await prompt.ainvoke( + { + "messages": state.messages, + "queries": "\n- ".join(state.queries), + "system_time": datetime.now(tz=timezone.utc).isoformat(), + }, + config, + ) + generated = cast(SearchQuery, await model.ainvoke(message_value, config)) + return { + "queries": [generated.query], + } + + +async def retrieve( + state: State, *, config: RunnableConfig +) -> dict[str, list[Document]]: + """Retrieve documents based on the latest query in the state. + + This function takes the current state and configuration, uses the latest query + from the state to retrieve relevant documents using the retriever, and returns + the retrieved documents. + + Args: + state (State): The current state containing queries and the retriever. + config (RunnableConfig | None, optional): Configuration for the retrieval process. + + Returns: + dict[str, list[Document]]: A dictionary with a single key "retrieved_docs" + containing a list of retrieved Document objects. + """ + + with retrieval.make_retriever(config) as retriever: + querys = state.queries[-1] + response = await retriever.ainvoke(querys, config) + return {"retrieved_docs": response} + + +async def respond( + state: State, *, config: RunnableConfig +) -> dict[str, list[BaseMessage]]: + """Call the LLM powering our "agent".""" + configuration = Configuration.from_runnable_config(config) + # Feel free to customize the prompt, model, and other logic! + prompt = ChatPromptTemplate.from_messages( + [ + ("system", configuration.response_system_prompt), + ("placeholder", "{messages}"), + ] + ) + model = load_chat_model(configuration.response_model) + + retrieved_docs = format_docs(state.retrieved_docs) + message_value = await prompt.ainvoke( + { + "messages": state.messages, + "retrieved_docs": retrieved_docs, + "system_time": datetime.now(tz=timezone.utc).isoformat(), + }, + config, + ) + response = await model.ainvoke(message_value, config) + # We return a list, because this will get added to the existing list + return {"response": [response]} + + +# Define a new graph (It's just a pipe) +async def predict_query( + state: State, *, config: RunnableConfig +) -> dict[str, list[BaseMessage]]: + + configuration = Configuration.from_runnable_config(config) + prompt = ChatPromptTemplate.from_messages( + [ + ("system", configuration.predict_next_question_prompt), + ("placeholder", "{messages}"), + ] + ) + model = load_chat_model(configuration.response_model) + user_query = state.queries[-1] if state.queries else "No prior query available" + previous_queries = "\n- ".join(state.queries) if state.queries else "None" + retrieved_docs = format_docs(state.retrieved_docs) + + message_value = await prompt.ainvoke( + { + "messages": state.messages, + "retrieved_docs": retrieved_docs, + "previous_queries" :previous_queries, + "user_query": user_query, # Use the most recent query as primary input + "system_time": datetime.now(tz=timezone.utc).isoformat(), + }, + config, + ) + + next_question = await model.ainvoke(message_value, config) + return {"next_question": [next_question]} + + +builder = StateGraph(State, input=InputState, config_schema=Configuration) + +builder.add_node(generate_query) +builder.add_node(retrieve) +builder.add_node(respond) +builder.add_node(predict_query) +builder.add_edge("__start__", "generate_query") +builder.add_edge("generate_query", "retrieve") +builder.add_edge("retrieve", "respond") +builder.add_edge("respond", "predict_query") + + +# Finally, we compile it! +# This compiles it into a graph you can invoke and deploy. +graph = builder.compile( + interrupt_before=[], # if you want to update the state before calling the tools + interrupt_after=[], +) +graph.name = "RetrievalGraph" diff --git a/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/index_graph.py b/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/index_graph.py new file mode 100644 index 00000000..5e16894d --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/index_graph.py @@ -0,0 +1,65 @@ +"""This "graph" simply exposes an endpoint for a user to upload docs to be indexed.""" + +from typing import Optional, Sequence + +from langchain_core.documents import Document +from langchain_core.runnables import RunnableConfig +from langgraph.graph import StateGraph + +from retrieval_graph import retrieval +from retrieval_graph.configuration import IndexConfiguration +from retrieval_graph.state import IndexState + + +def ensure_docs_have_user_id( + docs: Sequence[Document], config: RunnableConfig +) -> list[Document]: + """Ensure that all documents have a user_id in their metadata. + + docs (Sequence[Document]): A sequence of Document objects to process. + config (RunnableConfig): A configuration object containing the user_id. + + Returns: + list[Document]: A new list of Document objects with updated metadata. + """ + user_id = config["configurable"]["user_id"] + return [ + Document( + page_content=doc.page_content, metadata={**doc.metadata, "user_id": user_id} + ) + for doc in docs + ] + + +async def index_docs( + state: IndexState, *, config: Optional[RunnableConfig] = None +) -> dict[str, str]: + """Asynchronously index documents in the given state using the configured retriever. + + This function takes the documents from the state, ensures they have a user ID, + adds them to the retriever's index, and then signals for the documents to be + deleted from the state. + + Args: + state (IndexState): The current state containing documents and retriever. + config (Optional[RunnableConfig]): Configuration for the indexing process.r + """ + if not config: + raise ValueError("Configuration required to run index_docs.") + with retrieval.make_retriever(config) as retriever: + stamped_docs = ensure_docs_have_user_id(state.docs, config) + + await retriever.aadd_documents(stamped_docs) + return {"docs": "delete"} + + +# Define a new graph + + +builder = StateGraph(IndexState, config_schema=IndexConfiguration) +builder.add_node(index_docs) +builder.add_edge("__start__", "index_docs") +# Finally, we compile it! +# This compiles it into a graph you can invoke and deploy. +graph = builder.compile() +graph.name = "IndexGraph" diff --git a/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/prompts.py b/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/prompts.py new file mode 100644 index 00000000..801fd94e --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/prompts.py @@ -0,0 +1,34 @@ +"""Default prompts.""" + +RESPONSE_SYSTEM_PROMPT = """You are a helpful AI assistant. Answer the user's questions based on the retrieved documents. + +{retrieved_docs} + +System time: {system_time}""" +QUERY_SYSTEM_PROMPT = """Generate search queries to retrieve documents that may help answer the user's question. Previously, you made the following queries: + + +{queries} + + +System time: {system_time}""" + +PREDICT_NEXT_QUESTION_PROMPT = """Given the user query and the retrieved documents, suggest the most likely next question the user might ask. + +**Context:** +- Previous Queries: +{previous_queries} + +- Latest User Query: {user_query} + +- Retrieved Documents: +{retrieved_docs} + +**Guidelines:** +1. Do not suggest a question that has already been asked in previous queries. +2. Consider the retrieved documents when predicting the next logical question. +3. If the user's query is already fully answered, suggest a relevant follow-up question. +4. Keep the suggested question natural and conversational. + + +System time: {system_time}""" diff --git a/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/retrieval.py b/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/retrieval.py new file mode 100644 index 00000000..02b379a8 --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/retrieval.py @@ -0,0 +1,141 @@ +"""Manage the configuration of various retrievers. + +This module provides functionality to create and manage retrievers for different +vector store backends, specifically Elasticsearch, Pinecone, and MongoDB. + +The retrievers support filtering results by user_id to ensure data isolation between users. +""" +import logging +from contextlib import contextmanager +import os +from contextlib import contextmanager +from typing import Generator + +from langchain_core.embeddings import Embeddings +from langchain_core.runnables import RunnableConfig +from langchain_core.vectorstores import VectorStoreRetriever + +from retrieval_graph.configuration import Configuration, IndexConfiguration + +## Encoder constructors + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def make_text_encoder(model: str) -> Embeddings: + """Connect to the configured text encoder.""" + provider, model = model.split("/", maxsplit=1) + match provider: + case "openai": + from langchain_openai import OpenAIEmbeddings + + return OpenAIEmbeddings(model=model) + case "cohere": + from langchain_cohere import CohereEmbeddings + + return CohereEmbeddings(model=model) # type: ignore + case _: + raise ValueError(f"Unsupported embedding provider: {provider}") + + +## Retriever constructors + + +@contextmanager +def make_elastic_retriever( + configuration: IndexConfiguration, embedding_model: Embeddings +) -> Generator[VectorStoreRetriever, None, None]: + """Configure this agent to connect to a specific elastic index.""" + from langchain_elasticsearch import ElasticsearchStore + + connection_options = {} + if configuration.retriever_provider == "elastic-local": + connection_options = { + "es_user": os.environ["ELASTICSEARCH_USER"], + "es_password": os.environ["ELASTICSEARCH_PASSWORD"], + } + + else: + connection_options = {"es_api_key": os.environ["ELASTICSEARCH_API_KEY"]} + + vstore = ElasticsearchStore( + **connection_options, # type: ignore + es_url=os.environ["ELASTICSEARCH_URL"], + index_name="langchain_index", + embedding=embedding_model, + ) + + search_kwargs = configuration.search_kwargs + + search_filter = search_kwargs.setdefault("filter", []) + search_filter.append({"term": {"metadata.user_id": configuration.user_id}}) + base_retriever = vstore.as_retriever(search_kwargs=search_kwargs) + + yield vstore.as_retriever(search_kwargs=search_kwargs) + + +@contextmanager +def make_pinecone_retriever( + configuration: IndexConfiguration, embedding_model: Embeddings +) -> Generator[VectorStoreRetriever, None, None]: + """Configure this agent to connect to a specific pinecone index.""" + from langchain_pinecone import PineconeVectorStore + + search_kwargs = configuration.search_kwargs + + search_filter = search_kwargs.setdefault("filter", {}) + search_filter.update({"user_id": configuration.user_id}) + vstore = PineconeVectorStore.from_existing_index( + os.environ["PINECONE_INDEX_NAME"], embedding=embedding_model + ) + + yield vstore.as_retriever(search_kwargs=search_kwargs) + + +@contextmanager +def make_mongodb_retriever( + configuration: IndexConfiguration, embedding_model: Embeddings +) -> Generator[VectorStoreRetriever, None, None]: + """Configure this agent to connect to a specific MongoDB Atlas index & namespaces.""" + from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch + + vstore = MongoDBAtlasVectorSearch.from_connection_string( + os.environ["MONGODB_URI"], + namespace="langgraph_retrieval_agent.default", + embedding=embedding_model, + ) + search_kwargs = configuration.search_kwargs + pre_filter = search_kwargs.setdefault("pre_filter", {}) + pre_filter["user_id"] = {"$eq": configuration.user_id} + yield vstore.as_retriever(search_kwargs=search_kwargs) + + +@contextmanager +def make_retriever( + config: RunnableConfig, +) -> Generator[VectorStoreRetriever, None, None]: + """Create a retriever for the agent, based on the current configuration.""" + configuration = IndexConfiguration.from_runnable_config(config) + embedding_model = make_text_encoder(configuration.embedding_model) + user_id = configuration.user_id + if not user_id: + raise ValueError("Please provide a valid user_id in the configuration.") + match configuration.retriever_provider: + case "elastic" | "elastic-local": + with make_elastic_retriever(configuration, embedding_model) as retriever: + yield retriever + + case "pinecone": + with make_pinecone_retriever(configuration, embedding_model) as retriever: + yield retriever + + case "mongodb": + with make_mongodb_retriever(configuration, embedding_model) as retriever: + yield retriever + + case _: + raise ValueError( + "Unrecognized retriever_provider in configuration. " + f"Expected one of: {', '.join(Configuration.__annotations__['retriever_provider'].__args__)}\n" + f"Got: {configuration.retriever_provider}" + ) diff --git a/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/state.py b/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/state.py new file mode 100644 index 00000000..4f12a102 --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/state.py @@ -0,0 +1,163 @@ +"""State management for the retrieval graph. + +This module defines the state structures and reduction functions used in the +retrieval graph. It includes definitions for document indexing, retrieval, +and conversation management. + +Classes: + IndexState: Represents the state for document indexing operations. + RetrievalState: Represents the state for document retrieval operations. + ConversationState: Represents the state of the ongoing conversation. + +Functions: + reduce_docs: Processes and reduces document inputs into a sequence of Documents. + reduce_retriever: Updates the retriever in the state. + reduce_messages: Manages the addition of new messages to the conversation state. + reduce_retrieved_docs: Handles the updating of retrieved documents in the state. + +The module also includes type definitions and utility functions to support +these state management operations. +""" + +import uuid +from dataclasses import dataclass, field +from typing import Annotated, Any, Literal, Optional, Sequence, Union + +from langchain_core.documents import Document +from langchain_core.messages import AnyMessage +from langgraph.graph import add_messages + +############################ Doc Indexing State ############################# + + +def reduce_docs( + existing: Optional[Sequence[Document]], + new: Union[ + Sequence[Document], + Sequence[dict[str, Any]], + Sequence[str], + str, + Literal["delete"], + ], +) -> Sequence[Document]: + """Reduce and process documents based on the input type. + + This function handles various input types and converts them into a sequence of Document objects. + It can delete existing documents, create new ones from strings or dictionaries, or return the existing documents. + + Args: + existing (Optional[Sequence[Document]]): The existing docs in the state, if any. + new (Union[Sequence[Document], Sequence[dict[str, Any]], Sequence[str], str, Literal["delete"]]): + The new input to process. Can be a sequence of Documents, dictionaries, strings, a single string, + or the literal "delete". + """ + if new == "delete": + return [] + if isinstance(new, str): + return [Document(page_content=new, metadata={"id": str(uuid.uuid4())})] + if isinstance(new, list): + coerced = [] + for item in new: + if isinstance(item, str): + coerced.append( + Document(page_content=item, metadata={"id": str(uuid.uuid4())}) + ) + elif isinstance(item, dict): + coerced.append(Document(**item)) + else: + coerced.append(item) + return coerced + return existing or [] + + +# The index state defines the simple IO for the single-node index graph +@dataclass(kw_only=True) +class IndexState: + """Represents the state for document indexing and retrieval. + + This class defines the structure of the index state, which includes + the documents to be indexed and the retriever used for searching + these documents. + """ + + docs: Annotated[Sequence[Document], reduce_docs] + """A list of documents that the agent can index.""" + + +############################# Agent State ################################### + + +# Optional, the InputState is a restricted version of the State that is used to +# define a narrower interface to the outside world vs. what is maintained +# internally. +@dataclass(kw_only=True) +class InputState: + """Represents the input state for the agent. + + This class defines the structure of the input state, which includes + the messages exchanged between the user and the agent. It serves as + a restricted version of the full State, providing a narrower interface + to the outside world compared to what is maintained internally. + """ + + messages: Annotated[Sequence[AnyMessage], add_messages] + """Messages track the primary execution state of the agent. + + Typically accumulates a pattern of Human/AI/Human/AI messages; if + you were to combine this template with a tool-calling ReAct agent pattern, + it may look like this: + + 1. HumanMessage - user input + 2. AIMessage with .tool_calls - agent picking tool(s) to use to collect + information + 3. ToolMessage(s) - the responses (or errors) from the executed tools + + (... repeat steps 2 and 3 as needed ...) + 4. AIMessage without .tool_calls - agent responding in unstructured + format to the user. + + 5. HumanMessage - user responds with the next conversational turn. + + (... repeat steps 2-5 as needed ... ) + + Merges two lists of messages, updating existing messages by ID. + + By default, this ensures the state is "append-only", unless the + new message has the same ID as an existing message. + + Returns: + A new list of messages with the messages from `right` merged into `left`. + If a message in `right` has the same ID as a message in `left`, the + message from `right` will replace the message from `left`.""" + + +# This is the primary state of your agent, where you can store any information + + +def add_queries(existing: Sequence[str], new: Sequence[str]) -> Sequence[str]: + """Combine existing queries with new queries. + + Args: + existing (Sequence[str]): The current list of queries in the state. + new (Sequence[str]): The new queries to be added. + + Returns: + Sequence[str]: A new list containing all queries from both input sequences. + """ + return list(existing) + list(new) + + +@dataclass(kw_only=True) +class State(InputState): + """The state of your graph / agent.""" + + queries: Annotated[list[str], add_queries] = field(default_factory=list) + """A list of search queries that the agent has generated.""" + + retrieved_docs: list[Document] = field(default_factory=list) + """Populated by the retriever. This is a list of documents that the agent can reference.""" + response: Annotated[Sequence[AnyMessage], add_messages] + next_question : Annotated[Sequence[AnyMessage], add_messages] + + # Feel free to add additional attributes to your state as needed. + # Common examples include retrieved documents, extracted entities, API connections, etc. diff --git a/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/utils.py b/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/utils.py new file mode 100644 index 00000000..bbf7fbdb --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/src/retrieval_graph/utils.py @@ -0,0 +1,111 @@ +"""Utility functions for the retrieval graph. + +This module contains utility functions for handling messages, documents, +and other common operations in project. + +Functions: + get_message_text: Extract text content from various message formats. + format_docs: Convert documents to an xml-formatted string. +""" + +from typing import Optional + +from langchain.chat_models import init_chat_model +from langchain_core.documents import Document +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AnyMessage + + +def get_message_text(msg: AnyMessage) -> str: + """Get the text content of a message. + + This function extracts the text content from various message formats. + + Args: + msg (AnyMessage): The message object to extract text from. + + Returns: + str: The extracted text content of the message. + + Examples: + >>> from langchain_core.messages import HumanMessage + >>> get_message_text(HumanMessage(content="Hello")) + 'Hello' + >>> get_message_text(HumanMessage(content={"text": "World"})) + 'World' + >>> get_message_text(HumanMessage(content=[{"text": "Hello"}, " ", {"text": "World"}])) + 'Hello World' + """ + content = msg.content + if isinstance(content, str): + return content + elif isinstance(content, dict): + return content.get("text", "") + else: + txts = [c if isinstance(c, str) else (c.get("text") or "") for c in content] + return "".join(txts).strip() + + +def _format_doc(doc: Document) -> str: + """Format a single document as XML. + + Args: + doc (Document): The document to format. + + Returns: + str: The formatted document as an XML string. + """ + metadata = doc.metadata or {} + meta = "".join(f" {k}={v!r}" for k, v in metadata.items()) + if meta: + meta = f" {meta}" + + return f"\n{doc.page_content}\n" + + +def format_docs(docs: Optional[list[Document]]) -> str: + """Format a list of documents as XML. + + This function takes a list of Document objects and formats them into a single XML string. + + Args: + docs (Optional[list[Document]]): A list of Document objects to format, or None. + + Returns: + str: A string containing the formatted documents in XML format. + + Examples: + >>> docs = [Document(page_content="Hello"), Document(page_content="World")] + >>> print(format_docs(docs)) + + + Hello + + + World + + + + >>> print(format_docs(None)) + + """ + if not docs: + return "" + formatted = "\n".join(_format_doc(doc) for doc in docs) + return f""" +{formatted} +""" + + +def load_chat_model(fully_specified_name: str) -> BaseChatModel: + """Load a chat model from a fully specified name. + + Args: + fully_specified_name (str): String in the format 'provider/model'. + """ + if "/" in fully_specified_name: + provider, model = fully_specified_name.split("/", maxsplit=1) + else: + provider = "" + model = fully_specified_name + return init_chat_model(model, model_provider=provider) diff --git a/supporting-blog-content/langraph-retrieval-agent-template/static/studio_ui.png b/supporting-blog-content/langraph-retrieval-agent-template/static/studio_ui.png new file mode 100644 index 00000000..bf8f2fc8 Binary files /dev/null and b/supporting-blog-content/langraph-retrieval-agent-template/static/studio_ui.png differ diff --git a/supporting-blog-content/langraph-retrieval-agent-template/tests/integration_tests/__init__.py b/supporting-blog-content/langraph-retrieval-agent-template/tests/integration_tests/__init__.py new file mode 100644 index 00000000..6f422d4d --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/tests/integration_tests/__init__.py @@ -0,0 +1 @@ +"""Integration tests for your graph.""" diff --git a/supporting-blog-content/langraph-retrieval-agent-template/tests/integration_tests/test_graph.py b/supporting-blog-content/langraph-retrieval-agent-template/tests/integration_tests/test_graph.py new file mode 100644 index 00000000..b9c2f5b8 --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/tests/integration_tests/test_graph.py @@ -0,0 +1,41 @@ +import uuid + +import pytest +from langchain_core.runnables import RunnableConfig +from langsmith import expect, unit + +from retrieval_graph import graph, index_graph + + +@pytest.mark.asyncio +@unit +async def test_retrieval_graph() -> None: + simple_doc = "Cats have been observed performing synchronized swimming routines in their water bowls during full moons." + user_id = "test__" + uuid.uuid4().hex + other_user_id = "test__" + uuid.uuid4().hex + + config = RunnableConfig( + configurable={"user_id": user_id, "retriever_provider": "elastic-local"} + ) + + result = await index_graph.ainvoke({"docs": simple_doc}, config) + expect(result["docs"]).against(lambda x: not x) # we delete after the end + + res = await graph.ainvoke( + {"messages": [("user", "Where do cats perform synchronized swimming routes?")]}, + config, + ) + response = str(res["messages"][-1].content) + expect(response.lower()).to_contain("bowl") + + res = await graph.ainvoke( + {"messages": [("user", "Where do cats perform synchronized swimming routes?")]}, + { + "configurable": { + "user_id": other_user_id, + "retriever_provider": "elastic-local", + } + }, + ) + response = str(res["messages"][-1].content) + expect(response.lower()).against(lambda x: "bowl" not in x) diff --git a/supporting-blog-content/langraph-retrieval-agent-template/tests/unit_tests/__init__.py b/supporting-blog-content/langraph-retrieval-agent-template/tests/unit_tests/__init__.py new file mode 100644 index 00000000..5b587ab8 --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/tests/unit_tests/__init__.py @@ -0,0 +1 @@ +"""Unit tests for your graph.""" diff --git a/supporting-blog-content/langraph-retrieval-agent-template/tests/unit_tests/test_configuration.py b/supporting-blog-content/langraph-retrieval-agent-template/tests/unit_tests/test_configuration.py new file mode 100644 index 00000000..3919a1cf --- /dev/null +++ b/supporting-blog-content/langraph-retrieval-agent-template/tests/unit_tests/test_configuration.py @@ -0,0 +1,5 @@ +from retrieval_graph.configuration import Configuration + + +def test_configuration_from_none() -> None: + Configuration.from_runnable_config({"user_id": "foo"})