From eb5250fe2a42a18f737a9085c03ce5ac225b0b18 Mon Sep 17 00:00:00 2001 From: Connor Shorten Date: Mon, 27 Jan 2025 11:22:29 -0500 Subject: [PATCH] Create memgraph-pydantic-simple.ipynb --- notebooks/memgraph-pydantic-simple.ipynb | 511 +++++++++++++++++++++++ 1 file changed, 511 insertions(+) create mode 100644 notebooks/memgraph-pydantic-simple.ipynb diff --git a/notebooks/memgraph-pydantic-simple.ipynb b/notebooks/memgraph-pydantic-simple.ipynb new file mode 100644 index 0000000..ede536b --- /dev/null +++ b/notebooks/memgraph-pydantic-simple.ipynb @@ -0,0 +1,511 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Pydantic is the new Cypher\n", + "\n", + "This notebook will illustrate the challenge of text-to-Cypher approaches with LLMs,\n", + "\n", + "> and why we believe **structured outputs with Pydantic** is the best way to query databases with LLMs." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Memgraph Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "John\n", + "{'labels_added': 1, 'labels_removed': 0, 'nodes_created': 1, 'nodes_deleted': 0, 'properties_set': 0, 'relationships_created': 0, 'relationships_deleted': 0}\n", + "John\n", + "MATCH (u:User {name: $name}) RETURN u.name AS name\n" + ] + } + ], + "source": [ + "from neo4j import GraphDatabase\n", + " \n", + "# Define correct URI and AUTH arguments (no AUTH by default)\n", + "URI = \"bolt://localhost:7687\"\n", + "AUTH = (\"\", \"\")\n", + " \n", + "with GraphDatabase.driver(URI, auth=AUTH) as client:\n", + " # Check the connection\n", + " client.verify_connectivity()\n", + " \n", + " # Create a user in the database\n", + " records, summary, keys = client.execute_query(\n", + " \"CREATE (u:User {name: $name, password: $password}) RETURN u.name AS name;\",\n", + " name=\"John\",\n", + " password=\"pass\",\n", + " database_=\"memgraph\",\n", + " )\n", + " \n", + " # Get the result\n", + " for record in records:\n", + " print(record[\"name\"])\n", + " \n", + " # Print the query counters\n", + " print(summary.counters)\n", + " \n", + " # Find a user John in the database\n", + " records, summary, keys = client.execute_query(\n", + " \"MATCH (u:User {name: $name}) RETURN u.name AS name\",\n", + " name=\"John\",\n", + " database_=\"memgraph\",\n", + " )\n", + " \n", + " # Get the result\n", + " for record in records:\n", + " print(record[\"name\"])\n", + " \n", + " # Print the query\n", + " print(summary.query)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Verifying relationships:\n", + "Alice Chen -WORKS_AT-> Memgraph\n", + "Alice Chen -EXPERT_IN-> Transactions\n", + "Alice Chen -EXPERT_IN-> Storage Engines\n", + "Bob Smith -WORKS_AT-> Neo4j\n", + "Bob Smith -EXPERT_IN-> Distributed Systems\n", + "Carol Kumar -WORKS_AT-> TigerGraph\n", + "David Garcia -WORKS_AT-> AgensGraph\n", + "Elena Wilson -WORKS_AT-> Neo4j\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/41/8dp_379x15d8zz4ppsjthdw40000gn/T/ipykernel_24994/2828091927.py:10: DeprecationWarning: Using a driver after it has been closed is deprecated. Future versions of the driver will raise an error.\n", + " records, summary, keys = client.execute_query(\n", + "/var/folders/41/8dp_379x15d8zz4ppsjthdw40000gn/T/ipykernel_24994/2828091927.py:26: DeprecationWarning: Using a driver after it has been closed is deprecated. Future versions of the driver will raise an error.\n", + " records, summary, keys = client.execute_query(\n", + "/var/folders/41/8dp_379x15d8zz4ppsjthdw40000gn/T/ipykernel_24994/2828091927.py:41: DeprecationWarning: Using a driver after it has been closed is deprecated. Future versions of the driver will raise an error.\n", + " records, summary, keys = client.execute_query(\n", + "/var/folders/41/8dp_379x15d8zz4ppsjthdw40000gn/T/ipykernel_24994/2828091927.py:57: DeprecationWarning: Using a driver after it has been closed is deprecated. Future versions of the driver will raise an error.\n", + " records, summary, keys = client.execute_query(\n", + "/var/folders/41/8dp_379x15d8zz4ppsjthdw40000gn/T/ipykernel_24994/2828091927.py:75: DeprecationWarning: Using a driver after it has been closed is deprecated. Future versions of the driver will raise an error.\n", + " records, summary, keys = client.execute_query(\n", + "/var/folders/41/8dp_379x15d8zz4ppsjthdw40000gn/T/ipykernel_24994/2828091927.py:94: DeprecationWarning: Using a driver after it has been closed is deprecated. Future versions of the driver will raise an error.\n", + " records, summary, keys = client.execute_query(\n", + "/var/folders/41/8dp_379x15d8zz4ppsjthdw40000gn/T/ipykernel_24994/2828091927.py:106: DeprecationWarning: Using a driver after it has been closed is deprecated. Future versions of the driver will raise an error.\n", + " records, summary, keys = client.execute_query(\n" + ] + } + ], + "source": [ + "# Import and setup\n", + "\n", + "# Create Companies\n", + "for company_data in [\n", + " (\"Memgraph\", 2016),\n", + " (\"Neo4j\", 2007),\n", + " (\"AgensGraph\", 2016),\n", + " (\"TigerGraph\", 2012)\n", + "]:\n", + " records, summary, keys = client.execute_query(\n", + " \"CREATE (c:Company {name: $name, founded: $founded})\",\n", + " name=company_data[0],\n", + " founded=company_data[1],\n", + " database_=\"memgraph\"\n", + " )\n", + "\n", + "# Create Topics\n", + "for topic_data in [\n", + " (\"Transactions\", \"Database Core\"),\n", + " (\"Backup Systems\", \"Operations\"),\n", + " (\"Query Languages\", \"User Interface\"),\n", + " (\"Graph Algorithms\", \"Analytics\"),\n", + " (\"Storage Engines\", \"Infrastructure\"),\n", + " (\"Distributed Systems\", \"Infrastructure\")\n", + "]:\n", + " records, summary, keys = client.execute_query(\n", + " \"CREATE (t:Topic {name: $name, field: $field})\",\n", + " name=topic_data[0],\n", + " field=topic_data[1],\n", + " database_=\"memgraph\"\n", + " )\n", + "\n", + "# Create People\n", + "for person_data in [\n", + " (\"Alice Chen\", \"Database Engineer\", 8),\n", + " (\"Bob Smith\", \"Systems Architect\", 12),\n", + " (\"Carol Kumar\", \"Research Engineer\", 5),\n", + " (\"David Garcia\", \"Technical Writer\", 6),\n", + " (\"Elena Wilson\", \"Performance Engineer\", 9)\n", + "]:\n", + " records, summary, keys = client.execute_query(\n", + " \"CREATE (p:Person {name: $name, role: $role, yearsExp: $exp})\",\n", + " name=person_data[0],\n", + " role=person_data[1],\n", + " exp=person_data[2],\n", + " database_=\"memgraph\"\n", + " )\n", + "\n", + "# Create Employment Relationships\n", + "for employment_data in [\n", + " (\"Alice Chen\", \"Memgraph\", 2019),\n", + " (\"Bob Smith\", \"Neo4j\", 2015),\n", + " (\"Carol Kumar\", \"TigerGraph\", 2020),\n", + " (\"David Garcia\", \"AgensGraph\", 2018),\n", + " (\"Elena Wilson\", \"Neo4j\", 2017)\n", + "]:\n", + " records, summary, keys = client.execute_query(\n", + " \"\"\"\n", + " MATCH (p:Person), (c:Company)\n", + " WHERE p.name = $person_name AND c.name = $company_name\n", + " CREATE (p)-[:WORKS_AT {since: $since}]->(c)\n", + " \"\"\",\n", + " person_name=employment_data[0],\n", + " company_name=employment_data[1],\n", + " since=employment_data[2],\n", + " database_=\"memgraph\"\n", + " )\n", + "\n", + "# Create Expertise Relationships\n", + "for expertise_data in [\n", + " (\"Alice Chen\", \"Transactions\", \"Advanced\", 5),\n", + " (\"Alice Chen\", \"Storage Engines\", \"Intermediate\", 3),\n", + " (\"Bob Smith\", \"Distributed Systems\", \"Advanced\", 8)\n", + "]:\n", + " records, summary, keys = client.execute_query(\n", + " \"\"\"\n", + " MATCH (p:Person), (t:Topic)\n", + " WHERE p.name = $person_name AND t.name = $topic_name\n", + " CREATE (p)-[:EXPERT_IN {level: $level, years: $years}]->(t)\n", + " \"\"\",\n", + " person_name=expertise_data[0],\n", + " topic_name=expertise_data[1],\n", + " level=expertise_data[2],\n", + " years=expertise_data[3],\n", + " database_=\"memgraph\"\n", + " )\n", + "\n", + "# Create Company Focus Areas\n", + "for focus_data in [\n", + " (\"Memgraph\", \"Graph Algorithms\"),\n", + " (\"Neo4j\", \"Transactions\"),\n", + " (\"TigerGraph\", \"Graph Algorithms\")\n", + "]:\n", + " records, summary, keys = client.execute_query(\n", + " \"\"\"\n", + " MATCH (c:Company), (t:Topic)\n", + " WHERE c.name = $company_name AND t.name = $topic_name\n", + " CREATE (c)-[:FOCUSES_ON {priority: 'High'}]->(t)\n", + " \"\"\",\n", + " company_name=focus_data[0],\n", + " topic_name=focus_data[1],\n", + " database_=\"memgraph\"\n", + " )\n", + "\n", + "# Verify the data\n", + "records, summary, keys = client.execute_query(\n", + " \"\"\"\n", + " RETURN p.name as person, type(r) as relationship, n.name as target\n", + " \"\"\",\n", + " database_=\"memgraph\"\n", + ")\n", + "\n", + "print(\"\\nVerifying relationships:\")\n", + "for record in records:\n", + " print(f\"{record['person']} -{record['relationship']}-> {record['target']}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Neo4j Employees and their expertise:\n", + "Bob Smith (Systems Architect) - Distributed Systems (Advanced)\n", + "Elena Wilson (Performance Engineer) - No expertise listed\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/41/8dp_379x15d8zz4ppsjthdw40000gn/T/ipykernel_24994/4176022777.py:1: DeprecationWarning: Using a driver after it has been closed is deprecated. Future versions of the driver will raise an error.\n", + " records, summary, keys = client.execute_query(\n" + ] + } + ], + "source": [ + "records, summary, keys = client.execute_query(\n", + " \"\"\"\n", + " MATCH (p:Person)-[:WORKS_AT]->(c:Company {name: 'Neo4j'})\n", + " OPTIONAL MATCH (p)-[e:EXPERT_IN]->(t:Topic)\n", + " RETURN p.name as person, p.role as role, t.name as expertise, e.level as level\n", + " \"\"\",\n", + " database_=\"memgraph\"\n", + ")\n", + "\n", + "print(\"Neo4j Employees and their expertise:\")\n", + "for record in records:\n", + " expertise = f\"{record['expertise']} ({record['level']})\" if record['expertise'] else \"No expertise listed\"\n", + " print(f\"{record['person']} ({record['role']}) - {expertise}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# DSPy setup" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import dspy\n", + "import os\n", + "lm = dspy.LM('openai/gpt-4o', api_key=os.getenv(\"OPENAI_API_KEY\"))\n", + "dspy.configure(lm=lm)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['Intelligent, automated, scalable.']" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "lm(\"What is the future of database systems with AI? In 3 words.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Pydantic is the new Cypher" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from pydantic import BaseModel\n", + "from typing import Optional, List, Tuple, Union\n", + "\n", + "# The Cypher data model you provided earlier:\n", + "class CypherFilter(BaseModel):\n", + " \"\"\"\n", + " Represents a single condition in the WHERE clause.\n", + " \"\"\"\n", + " field: str\n", + " operator: str\n", + " value: Union[str, int, float]\n", + "\n", + "class CypherQuery(BaseModel):\n", + " \"\"\"\n", + " A basic model for generating Cypher (Memgraph) queries, focusing on:\n", + " - MATCH (n:Label)\n", + " - Optional WHERE\n", + " - RETURN\n", + " - Optional ORDER BY\n", + " - Optional LIMIT\n", + " \"\"\"\n", + " label: str\n", + " fields: List[str]\n", + " filters: Optional[List[CypherFilter]] = None\n", + " limit: Optional[int] = None\n", + " sort_by: Optional[List[Tuple[str, str]]] = None # e.g. [(\"age\", \"DESC\")]\n", + "\n", + " def to_cypher(self) -> str:\n", + " query = f\"MATCH (n:{self.label})\"\n", + "\n", + " if self.filters:\n", + " conditions = []\n", + " for f in self.filters:\n", + " val = f\"\\\"{f.value}\\\"\" if isinstance(f.value, str) else str(f.value)\n", + " conditions.append(f\"n.{f.field} {f.operator} {val}\")\n", + " query += \" WHERE \" + \" AND \".join(conditions)\n", + "\n", + " if not self.fields:\n", + " query += \" RETURN n\"\n", + " else:\n", + " fields_str = \", \".join([f\"n.{field}\" for field in self.fields])\n", + " query += f\" RETURN {fields_str}\"\n", + "\n", + " if self.sort_by:\n", + " sort_clauses = [f\"n.{field} {direction}\" for field, direction in self.sort_by]\n", + " query += \" ORDER BY \" + \", \".join(sort_clauses)\n", + "\n", + " if self.limit is not None:\n", + " query += f\" LIMIT {self.limit}\"\n", + "\n", + " return query" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# DSPy `MemgraphQueryWriter`" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "class MemgraphQueryWriter(dspy.Signature):\n", + " \"\"\"\n", + " Translate a natural language information need into a Memgraph (Cypher) query.\n", + " \n", + " In your rationale for generating the final query string, you should be \n", + " very clear about why you chose specific query operators (e.g., WHERE \n", + " clauses, sort directions, etc.) and why you do not need the operators \n", + " that you chose not to include.\n", + " \"\"\"\n", + " \n", + " # The natural language command from the user\n", + " nl_command: str = dspy.InputField(\n", + " desc=\"A natural language command with an underlying information need your db_query should answer.\"\n", + " )\n", + " \n", + " # The database schema or relevant metadata\n", + " db_schema: str = dspy.InputField(\n", + " desc=\"The database schema (Memgraph schema) you can query.\"\n", + " )\n", + "\n", + " # The resulting Cypher query object (built to meet the info need)\n", + " db_query: CypherQuery = dspy.OutputField(\n", + " desc=\"A CypherQuery model that captures how to query Memgraph.\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Node labels: User, Company, Topic, Person\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/41/8dp_379x15d8zz4ppsjthdw40000gn/T/ipykernel_24994/2423293021.py:1: DeprecationWarning: Using a driver after it has been closed is deprecated. Future versions of the driver will raise an error.\n", + " records, summary, keys = client.execute_query(\n" + ] + } + ], + "source": [ + "records, summary, keys = client.execute_query(\n", + " \"MATCH (n) RETURN DISTINCT labels(n) AS labels\",\n", + " database_=\"memgraph\"\n", + ")\n", + "node_labels = []\n", + "for record in records:\n", + " node_labels.extend(record[\"labels\"])\n", + "node_labels_str = \", \".join(node_labels)\n", + "print(f\"Node labels: {node_labels_str}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prediction(\n", + " reasoning='The natural language command asks for the number of people working at Memgraph. In the given database schema, we have entities like User, Company, Topic, and Person. To find out how many people work at Memgraph, we need to focus on the `Person` label, as it is the most likely to represent individuals. We assume there is a relationship or property that connects `Person` to `Company`, specifically to Memgraph. However, since the schema does not provide explicit relationships or properties, we will assume that there is a property or relationship that can be used to filter `Person` nodes associated with Memgraph. The query will count the number of `Person` nodes that are associated with Memgraph.',\n", + " db_query=CypherQuery(label='Person', fields=['count(*)'], filters=[CypherFilter(field='company', operator='=', value='Memgraph')], limit=None, sort_by=None)\n", + ")\n" + ] + } + ], + "source": [ + "memgraph_writer = dspy.ChainOfThought(MemgraphQueryWriter)\n", + "\n", + "generated_query = memgraph_writer(\n", + " nl_command = \"How many people work at Memgraph?\",\n", + " db_schema = node_labels_str\n", + ")\n", + "\n", + "print(generated_query)\n", + "db_query = generated_query.db_query" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.undefined.undefined" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}