diff --git a/.github/workflows/github-build-actions-python314t.yaml b/.github/workflows/github-build-actions-python314t.yaml
index 010a72b..4f90686 100644
--- a/.github/workflows/github-build-actions-python314t.yaml
+++ b/.github/workflows/github-build-actions-python314t.yaml
@@ -19,7 +19,7 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v3
with:
- submodules: true # Ensure submodules are checked out
+ submodules: false # Ensure submodules are checked out
- name: Install Miniconda
shell: bash
@@ -84,45 +84,103 @@ jobs:
source $HOME/.elan/env
install-itp-interface
- - name: Check and Init opam version
- run: |
- opam --version
- opam init --disable-sandboxing --yes
-
- - name: Install Coq
+ - name: Build lean4_proj
+ shell: bash
run: |
- opam switch create simple_grp_theory 4.14.2
- opam switch simple_grp_theory
- eval $(opam env)
- opam repo add coq-released https://coq.inria.fr/opam/released
- opam pin add -y coq-lsp 0.1.8+8.18
+ source $HOME/.elan/env
+ cd src/data/test/lean4_proj && lake exe cache get && lake build
- name: List repository files (debug step)
run: find . -type f
- - name: Run Simple Env Test
+ - name: Clean up logs
+ run: |
+ rm -rf .log
+ echo "Cleaned .log directory for fresh parallel execution test"
+
+ - name: Ray Cleanup
+ shell: bash
+ run: |
+ rm -rf /tmp/* --verbose
+
+ - name: Run Simple Env Lean Test
shell: bash
run: |
export PATH="$HOME/miniconda/bin:$PATH"
source $HOME/miniconda/bin/activate py314-ft
- eval $(opam env)
source $HOME/.elan/env
- python src/test/simple_env_test.py
+ python src/test/simple_env_lean_test.py
+
+ - name: Clean up logs
+ run: |
+ rm -rf .log
+ echo "Cleaned .log directory for fresh parallel execution test"
+
+ - name: Ray Cleanup
+ shell: bash
+ run: |
+ rm -rf /tmp/* --verbose
- name: Run Data Gen Test
shell: bash
run: |
export PATH="$HOME/miniconda/bin:$PATH"
source $HOME/miniconda/bin/activate py314-ft
- eval $(opam env)
source $HOME/.elan/env
python src/test/simple_data_gen_test.py
+ - name: Clean up logs
+ run: |
+ rm -rf .log
+ echo "Cleaned .log directory for fresh parallel execution test"
+
+ - name: Ray Cleanup
+ shell: bash
+ run: |
+ rm -rf /tmp/* --verbose
+
- name: Run Data Extraction Test
shell: bash
run: |
export PATH="$HOME/miniconda/bin:$PATH"
source $HOME/miniconda/bin/activate py314-ft
- eval $(opam env)
source $HOME/.elan/env
python src/test/simple_data_extract_test.py
+
+ - name: Ray Cleanup
+ shell: bash
+ run: |
+ rm -rf /tmp/* --verbose
+
+ - name: Clean up logs
+ run: |
+ rm -rf .log
+ echo "Cleaned .log directory for fresh parallel execution test"
+
+ - name: Clean up .lake in lean4_proj
+ shell: bash
+ run: |
+ rm -rf src/data/test/lean4_proj/.lake
+ echo "Cleaned .lake directory in lean4_proj for fresh parallel execution test"
+
+ - name: Check and Init opam version
+ run: |
+ opam --version
+ opam init --disable-sandboxing --yes
+
+ - name: Install Coq
+ run: |
+ opam switch create simple_grp_theory 4.14.2
+ opam switch simple_grp_theory
+ eval $(opam env)
+ opam repo add coq-released https://coq.inria.fr/opam/released
+ opam pin add -y coq-lsp 0.1.8+8.18
+
+ - name: Run Simple Env Coq Test
+ shell: bash
+ run: |
+ export PATH="$HOME/miniconda/bin:$PATH"
+ source $HOME/miniconda/bin/activate py314-ft
+ eval $(opam env)
+ source $HOME/.elan/env
+ python src/test/simple_env_coq_test.py
diff --git a/.github/workflows/github-build-actions.yaml b/.github/workflows/github-build-actions.yaml
index 7814b40..e63b419 100644
--- a/.github/workflows/github-build-actions.yaml
+++ b/.github/workflows/github-build-actions.yaml
@@ -19,7 +19,7 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v3
with:
- submodules: true # Ensure submodules are checked out
+ submodules: false # Ensure submodules are checked out
- name: Install Python and pip
run: |
@@ -53,41 +53,52 @@ jobs:
source $HOME/.elan/env
install-itp-interface
- - name: Check and Init opam version
- run: |
- opam --version
- opam init --disable-sandboxing --yes
-
- - name: Install Coq
+ - name: Build lean4_proj
+ shell: bash
run: |
- opam switch create simple_grp_theory 4.14.2
- opam switch simple_grp_theory
- eval $(opam env)
- opam repo add coq-released https://coq.inria.fr/opam/released
- opam pin add -y coq-lsp 0.1.8+8.18
+ source $HOME/.elan/env
+ cd src/data/test/lean4_proj && lake exe cache get && lake build
- name: List repository files (debug step)
run: find . -type f
- - name: Run Simple Env Test
+ - name: Clean up logs
+ run: |
+ rm -rf .log
+ echo "Cleaned .log directory for fresh parallel execution test"
+
+ - name: Run Simple Env Lean Test
shell: bash
run: |
- eval $(opam env)
source $HOME/.elan/env
- python src/test/simple_env_test.py
-
+ python src/test/simple_env_lean_test.py
+
+ - name: Clean up logs
+ run: |
+ rm -rf .log
+ echo "Cleaned .log directory for fresh parallel execution test"
+
- name: Ray Cleanup
shell: bash
run: |
rm -rf /tmp/* --verbose
+ - name: Clean up logs
+ run: |
+ rm -rf .log
+ echo "Cleaned .log directory for fresh parallel execution test"
+
- name: Run Data Gen Test
shell: bash
run: |
- eval $(opam env)
source $HOME/.elan/env
python src/test/simple_data_gen_test.py
+ - name: Clean up logs
+ run: |
+ rm -rf .log
+ echo "Cleaned .log directory for fresh parallel execution test"
+
- name: Ray Cleanup
shell: bash
run: |
@@ -96,11 +107,41 @@ jobs:
- name: Run Data Extraction Test
shell: bash
run: |
- eval $(opam env)
source $HOME/.elan/env
python src/test/simple_data_extract_test.py
+ - name: Clean up logs
+ run: |
+ rm -rf .log
+ echo "Cleaned .log directory for fresh parallel execution test"
+
- name: Ray Cleanup
shell: bash
run: |
- rm -rf /tmp/* --verbose
\ No newline at end of file
+ rm -rf /tmp/* --verbose
+
+ - name: Clean up .lake in lean4_proj
+ shell: bash
+ run: |
+ rm -rf src/data/test/lean4_proj/.lake
+ echo "Cleaned .lake directory in lean4_proj for fresh parallel execution test"
+
+ - name: Check and Init opam version
+ run: |
+ opam --version
+ opam init --disable-sandboxing --yes
+
+ - name: Install Coq
+ run: |
+ opam switch create simple_grp_theory 4.14.2
+ opam switch simple_grp_theory
+ eval $(opam env)
+ opam repo add coq-released https://coq.inria.fr/opam/released
+ opam pin add -y coq-lsp 0.1.8+8.18
+
+ - name: Run Simple Env Coq Test
+ shell: bash
+ run: |
+ eval $(opam env)
+ source $HOME/.elan/env
+ python src/test/simple_env_coq_test.py
\ No newline at end of file
diff --git a/.gitmodules b/.gitmodules
index 65bf71a..1b68b33 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -13,3 +13,6 @@
[submodule "src/itp_interface/tools/repl"]
path = src/itp_interface/tools/repl
url = https://github.com/amit9oct/repl.git
+[submodule "src/data/test/batteries"]
+ path = src/data/test/batteries
+ url = https://github.com/leanprover-community/batteries.git
diff --git a/pyproject.toml b/pyproject.toml
index e9aac2c..6cc0954 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,7 +5,7 @@ requires = [
build-backend = "hatchling.build"
[project]
name = "itp_interface"
-version = "1.2.0"
+version = "1.3.0"
authors = [
{ name="Amitayush Thakur", email="amitayush@utexas.edu" },
]
@@ -47,8 +47,11 @@ dependencies = [
[project.optional-dependencies]
app = [
- "flask>=2.3.0",
- "flask-cors>=4.0.0"
+ "streamlit>=1.28.0",
+ "scipy>=1.16.0",
+ "plotly>=5.17.0",
+ "networkx>=3.1",
+ "pandas>=2.0.0"
]
[project.urls]
diff --git a/src/app/__init__.py b/src/app/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/app/data_explorer/README.md b/src/app/data_explorer/README.md
new file mode 100644
index 0000000..744ae1e
--- /dev/null
+++ b/src/app/data_explorer/README.md
@@ -0,0 +1,125 @@
+# Lean Declaration Database Explorer
+
+A Streamlit web application for exploring and analyzing Lean declaration databases created by the ITP Interface data extraction tool.
+
+## Features
+
+- **Custom SQL Queries**: Write and execute custom SQL queries with pre-built query templates
+- **Declaration Search**: Search declarations by name, type, file path, and namespace
+- **Dependency Explorer**: Visualize dependencies and dependents of declarations
+- **Forest Analysis**: Analyze connected components in the dependency graph
+
+## Installation
+
+Install the required dependencies using pip with the `app` extra:
+
+```bash
+pip install -e ".[app]"
+```
+
+This will install:
+- streamlit
+- plotly
+- networkx
+- pandas
+
+## Usage
+
+1. **Generate a database** using the data extraction transform:
+ ```python
+ from itp_interface.tools.lean4_local_data_extraction_transform import Local4DataExtractionTransform
+
+ transform = Local4DataExtractionTransform(
+ buffer_size=1000,
+ db_path="lean_declarations.db"
+ )
+ # ... run your extraction ...
+ ```
+
+2. **Launch the explorer**:
+ ```bash
+ cd src/app/data_explorer
+ streamlit run lean_db_explorer.py
+ ```
+
+3. **Load your database**:
+ - Enter the path to your `.db` file in the sidebar
+ - Click "Load Database"
+ - Explore using the tabs!
+
+## App Tabs
+
+### š Custom Query
+Write and execute custom SQL queries against the database. Includes pre-built queries for common tasks:
+- Show all files
+- Show all declarations
+- Count declarations by type
+- Find most depended-on declarations
+- And more...
+
+### š Search
+Search declarations with filters:
+- Name pattern (supports partial matching)
+- File path
+- Declaration type (theorem, def, axiom, etc.)
+- Namespace
+
+### š² Dependencies
+Explore dependency relationships:
+- **Show Dependencies**: What does this declaration depend on?
+- **Show Dependents**: What depends on this declaration?
+- **Configurable depth**: Control how deep to traverse
+- **Dual view**: Table and interactive graph visualization
+
+### š³ Forests
+Analyze connected components in the dependency graph:
+- **Find All Forests**: Discover all connected components
+- **Statistics**: See forest sizes, roots, and leaves
+- **Visualization**: View selected forests as graphs
+
+## Database Schema
+
+The app expects a SQLite database with the following tables:
+
+- `files`: File metadata
+- `imports`: Import relationships
+- `declarations`: All declarations with full info
+- `declaration_dependencies`: Dependency edges (from_decl_id ā to_decl_id)
+
+See `src/itp_interface/tools/simple_sqlite.py` for the complete schema.
+
+## Tips
+
+- **Large graphs**: Forest visualization is limited to 100 nodes for performance
+- **Query results**: All query results can be downloaded as CSV
+- **Unresolved declarations**: Declarations with `file_path IS NULL` are unresolved (from dependencies)
+
+## Troubleshooting
+
+**Database not loading?**
+- Check the file path is correct
+- Ensure the database was created with the correct schema
+
+**Graph visualization slow?**
+- Reduce the max depth for dependency exploration
+- Use filters in Search to narrow results
+
+**Import errors?**
+- Ensure you've installed the app dependencies: `pip install -e ".[app]"`
+- Run from the correct directory: `cd src/app/data_explorer`
+
+## Development
+
+File structure:
+```
+src/app/data_explorer/
+āāā lean_db_explorer.py # Main Streamlit app
+āāā db_utils.py # Database query utilities
+āāā graph_utils.py # Graph analysis and visualization
+āāā README.md # This file
+```
+
+To modify:
+1. Edit the Python files
+2. Streamlit will auto-reload on file changes
+3. Refresh browser to see updates
diff --git a/src/app/data_explorer/__init__.py b/src/app/data_explorer/__init__.py
new file mode 100644
index 0000000..c35c013
--- /dev/null
+++ b/src/app/data_explorer/__init__.py
@@ -0,0 +1 @@
+"""Lean Declaration Database Explorer package."""
diff --git a/src/app/data_explorer/db_utils.py b/src/app/data_explorer/db_utils.py
new file mode 100644
index 0000000..6ab2635
--- /dev/null
+++ b/src/app/data_explorer/db_utils.py
@@ -0,0 +1,197 @@
+"""
+Database utility functions for the Lean Declaration DB Explorer.
+"""
+
+import sys
+from pathlib import Path
+
+# Add parent directories to path to import from itp_interface
+root_dir = Path(__file__).parent.parent.parent.parent
+if str(root_dir) not in sys.path:
+ sys.path.insert(0, str(root_dir))
+
+import pandas as pd
+from typing import Dict, List, Any, Optional
+from itp_interface.tools.simple_sqlite import LeanDeclarationDB
+
+
+def execute_query(db: LeanDeclarationDB, query: str) -> pd.DataFrame:
+ """
+ Execute a SQL query and return results as a DataFrame.
+
+ Args:
+ db: LeanDeclarationDB instance
+ query: SQL query string
+
+ Returns:
+ DataFrame with query results
+ """
+ cursor = db.conn.cursor()
+ cursor.execute(query)
+ rows = cursor.fetchall()
+
+ if rows:
+ # Get column names
+ columns = [description[0] for description in cursor.description]
+ # Convert to DataFrame
+ return pd.DataFrame([dict(row) for row in rows], columns=columns)
+ else:
+ return pd.DataFrame()
+
+
+def get_common_queries() -> Dict[str, str]:
+ """
+ Return a dictionary of common pre-built queries.
+
+ Returns:
+ Dict mapping query name to SQL query string
+ """
+ return {
+ "Show all files": """
+ SELECT file_path, module_name
+ FROM files
+ ORDER BY file_path
+ """,
+ "Show all declarations": """
+ SELECT name, namespace, decl_type, file_path, module_name
+ FROM declarations
+ WHERE file_path IS NOT NULL
+ ORDER BY file_path, line
+ LIMIT 100
+ """,
+ "Count declarations by type": """
+ SELECT decl_type, COUNT(*) as count
+ FROM declarations
+ WHERE decl_type IS NOT NULL
+ GROUP BY decl_type
+ ORDER BY count DESC
+ """,
+ "Count declarations by file": """
+ SELECT file_path, COUNT(*) as count
+ FROM declarations
+ WHERE file_path IS NOT NULL
+ GROUP BY file_path
+ ORDER BY count DESC
+ LIMIT 20
+ """,
+ "Show unresolved declarations": """
+ SELECT name, namespace
+ FROM declarations
+ WHERE file_path IS NULL
+ LIMIT 100
+ """,
+ "Show imports for all files": """
+ SELECT f.file_path, i.module_name as import_module, i.text
+ FROM files f
+ JOIN imports i ON i.file_id = f.id
+ ORDER BY f.file_path
+ """,
+ "Show most depended-on declarations": """
+ SELECT d.name, d.namespace, d.file_path, COUNT(*) as dependents_count
+ FROM declarations d
+ JOIN declaration_dependencies dd ON dd.to_decl_id = d.decl_id
+ GROUP BY d.decl_id
+ ORDER BY dependents_count DESC
+ LIMIT 20
+ """,
+ "Show declarations with most dependencies": """
+ SELECT d.name, d.namespace, d.file_path, COUNT(*) as dependencies_count
+ FROM declarations d
+ JOIN declaration_dependencies dd ON dd.from_decl_id = d.decl_id
+ GROUP BY d.decl_id
+ ORDER BY dependencies_count DESC
+ LIMIT 20
+ """
+ }
+
+
+def search_declarations(
+ db: LeanDeclarationDB,
+ name_pattern: Optional[str] = None,
+ namespace: Optional[str] = None,
+ file_path: Optional[str] = None,
+ decl_type: Optional[str] = None
+) -> pd.DataFrame:
+ """
+ Search declarations with filters.
+
+ Args:
+ db: LeanDeclarationDB instance
+ name_pattern: Name pattern (SQL LIKE syntax)
+ namespace: Namespace filter
+ file_path: File path filter
+ decl_type: Declaration type filter
+
+ Returns:
+ DataFrame with matching declarations
+ """
+ query = "SELECT * FROM declarations WHERE 1=1"
+ params = []
+
+ if name_pattern:
+ query += " AND name LIKE ?"
+ params.append(f"%{name_pattern}%")
+
+ if namespace:
+ query += " AND namespace = ?"
+ params.append(namespace)
+
+ if file_path:
+ query += " AND file_path LIKE ?"
+ params.append(f"%{file_path}%")
+
+ if decl_type:
+ query += " AND decl_type = ?"
+ params.append(decl_type)
+
+ query += " LIMIT 1000"
+
+ cursor = db.conn.cursor()
+ cursor.execute(query, params)
+ rows = cursor.fetchall()
+
+ if rows:
+ columns = [description[0] for description in cursor.description]
+ return pd.DataFrame([dict(row) for row in rows], columns=columns)
+ else:
+ return pd.DataFrame()
+
+
+def get_all_declaration_names(db: LeanDeclarationDB) -> List[str]:
+ """
+ Get all unique declaration names for autocomplete.
+
+ Args:
+ db: LeanDeclarationDB instance
+
+ Returns:
+ List of declaration names
+ """
+ cursor = db.conn.cursor()
+ cursor.execute("""
+ SELECT DISTINCT name
+ FROM declarations
+ WHERE file_path IS NOT NULL
+ ORDER BY name
+ """)
+ return [row[0] for row in cursor.fetchall()]
+
+
+def get_declaration_types(db: LeanDeclarationDB) -> List[str]:
+ """
+ Get all unique declaration types.
+
+ Args:
+ db: LeanDeclarationDB instance
+
+ Returns:
+ List of declaration types
+ """
+ cursor = db.conn.cursor()
+ cursor.execute("""
+ SELECT DISTINCT decl_type
+ FROM declarations
+ WHERE decl_type IS NOT NULL
+ ORDER BY decl_type
+ """)
+ return [row[0] for row in cursor.fetchall()]
diff --git a/src/app/data_explorer/graph_utils.py b/src/app/data_explorer/graph_utils.py
new file mode 100644
index 0000000..235773f
--- /dev/null
+++ b/src/app/data_explorer/graph_utils.py
@@ -0,0 +1,336 @@
+"""
+Graph analysis utilities for dependency visualization and forest detection.
+"""
+
+import sys
+from pathlib import Path
+
+# Add parent directories to path
+root_dir = Path(__file__).parent.parent.parent.parent
+if str(root_dir) not in sys.path:
+ sys.path.insert(0, str(root_dir))
+
+import networkx as nx
+import plotly.graph_objects as go
+from typing import List, Set, Dict, Tuple, Any
+from itp_interface.tools.simple_sqlite import LeanDeclarationDB
+
+
+def build_dependency_graph(db: LeanDeclarationDB) -> nx.DiGraph:
+ """
+ Build a directed graph from the dependency relationships.
+
+ Only includes declarations with non-None decl_type (filters out unresolved declarations).
+ This is a read-only operation - the database is not modified.
+
+ Args:
+ db: LeanDeclarationDB instance
+
+ Returns:
+ NetworkX directed graph where edges point from dependency to dependent
+ (B -> A means "A depends on B")
+ - Root nodes (depend on nothing): in_degree == 0
+ - Leaf nodes (nothing depends on them): out_degree == 0
+ """
+ G = nx.DiGraph()
+
+ # Add all declarations as nodes, excluding those with None decl_type
+ cursor = db.conn.cursor()
+ cursor.execute("""
+ SELECT decl_id, name, namespace, decl_type, file_path
+ FROM declarations
+ WHERE decl_type IS NOT NULL
+ """)
+
+ for row in cursor.fetchall():
+ decl_id, name, namespace, decl_type, file_path = row
+ full_name = f"{namespace}.{name}" if namespace else name
+ G.add_node(decl_id, name=name, full_name=full_name,
+ namespace=namespace, decl_type=decl_type,
+ file_path=file_path)
+
+ # Add all dependency edges
+ # Edge direction: to_id -> from_id means "from_id depends on to_id"
+ # This way, root nodes (declarations that depend on nothing) have in_degree == 0
+ cursor.execute("""
+ SELECT from_decl_id, to_decl_id
+ FROM declaration_dependencies
+ """)
+
+ for from_id, to_id in cursor.fetchall():
+ # Only add edge if both nodes exist in the graph (not filtered out)
+ if to_id in G.nodes and from_id in G.nodes:
+ # Flip edge direction: dependency -> dependent
+ G.add_edge(to_id, from_id)
+
+ return G
+
+
+def find_all_forests(G: nx.DiGraph) -> List[Set[str]]:
+ """
+ Find all connected components (forests) in the graph.
+
+ Args:
+ G: NetworkX directed graph
+
+ Returns:
+ List of sets, where each set contains decl_ids in a connected component
+ """
+ # Convert to undirected for finding connected components
+ G_undirected = G.to_undirected()
+ components = list(nx.connected_components(G_undirected))
+
+ # Sort by size (largest first)
+ components.sort(key=len, reverse=True)
+
+ return components
+
+
+def get_dependencies_closure(
+ db: LeanDeclarationDB,
+ decl_id: str,
+ max_depth: int = 10
+) -> Tuple[List[Dict[str, Any]], nx.DiGraph]:
+ """
+ Get all dependencies (transitive) of a declaration.
+
+ Args:
+ db: LeanDeclarationDB instance
+ decl_id: Starting declaration ID
+ max_depth: Maximum depth to traverse
+
+ Returns:
+ Tuple of (list of declaration dicts, subgraph)
+ """
+ try:
+ G = build_dependency_graph(db)
+
+ if decl_id not in G:
+ return [], nx.DiGraph()
+
+ # Get all dependencies (follow incoming edges since edge direction is flipped)
+ # Edge B -> A means "A depends on B", so dependencies of A are predecessors
+ visited = set()
+ queue = [(decl_id, 0)]
+ visited.add(decl_id)
+
+ while queue:
+ current, depth = queue.pop(0)
+ if depth >= max_depth:
+ continue
+
+ try:
+ # Follow incoming edges to find dependencies
+ for neighbor in G.predecessors(current):
+ if neighbor not in visited:
+ visited.add(neighbor)
+ queue.append((neighbor, depth + 1))
+ except Exception as e:
+ print(f"Warning: Error processing predecessors of {current}: {e}")
+ continue
+
+ # Create subgraph
+ subgraph = G.subgraph(visited).copy()
+
+ # Get declaration info for all nodes
+ decls = []
+ for node_id in visited:
+ try:
+ decl = db.get_declaration_by_decl_id(node_id)
+ if decl:
+ decls.append(decl)
+ except Exception as e:
+ print(f"Warning: Error fetching declaration {node_id}: {e}")
+ continue
+
+ return decls, subgraph
+ except Exception as e:
+ print(f"Error in get_dependencies_closure: {e}")
+ raise
+
+
+def get_dependents_closure(
+ db: LeanDeclarationDB,
+ decl_id: str,
+ max_depth: int = 10
+) -> Tuple[List[Dict[str, Any]], nx.DiGraph]:
+ """
+ Get all dependents (transitive) of a declaration.
+
+ Args:
+ db: LeanDeclarationDB instance
+ decl_id: Starting declaration ID
+ max_depth: Maximum depth to traverse
+
+ Returns:
+ Tuple of (list of declaration dicts, subgraph)
+ """
+ try:
+ G = build_dependency_graph(db)
+
+ if decl_id not in G:
+ return [], nx.DiGraph()
+
+ # Get all dependents (follow outgoing edges since edge direction is flipped)
+ # Edge B -> A means "A depends on B", so dependents of B are successors
+ visited = set()
+ queue = [(decl_id, 0)]
+ visited.add(decl_id)
+
+ while queue:
+ current, depth = queue.pop(0)
+ if depth >= max_depth:
+ continue
+
+ try:
+ # Follow outgoing edges to find dependents
+ for neighbor in G.successors(current):
+ if neighbor not in visited:
+ visited.add(neighbor)
+ queue.append((neighbor, depth + 1))
+ except Exception as e:
+ print(f"Warning: Error processing successors of {current}: {e}")
+ continue
+
+ # Create subgraph
+ subgraph = G.subgraph(visited).copy()
+
+ # Get declaration info for all nodes
+ decls = []
+ for node_id in visited:
+ try:
+ decl = db.get_declaration_by_decl_id(node_id)
+ if decl:
+ decls.append(decl)
+ except Exception as e:
+ print(f"Warning: Error fetching declaration {node_id}: {e}")
+ continue
+
+ return decls, subgraph
+ except Exception as e:
+ print(f"Error in get_dependents_closure: {e}")
+ raise
+
+
+def visualize_graph(G: nx.DiGraph, title: str = "Dependency Graph") -> go.Figure:
+ """
+ Create an interactive Plotly visualization of the graph.
+
+ Args:
+ G: NetworkX directed graph
+ title: Title for the plot
+
+ Returns:
+ Plotly Figure object
+ """
+ if len(G.nodes()) == 0:
+ # Empty graph
+ fig = go.Figure()
+ fig.update_layout(title=title, annotations=[
+ dict(text="No dependencies found", showarrow=False,
+ xref="paper", yref="paper", x=0.5, y=0.5)
+ ])
+ return fig
+
+ # Use spring layout for positioning
+ pos = nx.spring_layout(G, k=1, iterations=50)
+
+ # Create edge traces
+ edge_x = []
+ edge_y = []
+ for edge in G.edges():
+ x0, y0 = pos[edge[0]]
+ x1, y1 = pos[edge[1]]
+ edge_x.extend([x0, x1, None])
+ edge_y.extend([y0, y1, None])
+
+ edge_trace = go.Scatter(
+ x=edge_x, y=edge_y,
+ line=dict(width=0.5, color='#888'),
+ hoverinfo='none',
+ mode='lines')
+
+ # Create node traces
+ node_x = []
+ node_y = []
+ node_text = []
+ node_color = []
+
+ # Color map for declaration types
+ type_colors = {
+ 'theorem': '#FF6B6B',
+ 'def': '#4ECDC4',
+ 'axiom': '#45B7D1',
+ 'instance': '#FFA07A',
+ 'class': '#98D8C8',
+ None: '#95A5A6'
+ }
+
+ for node in G.nodes():
+ x, y = pos[node]
+ node_x.append(x)
+ node_y.append(y)
+
+ # Get node info
+ node_data = G.nodes[node]
+ name = node_data.get('full_name', node_data.get('name', 'Unknown'))
+ decl_type = node_data.get('decl_type')
+ file_path = node_data.get('file_path', 'Unknown')
+
+ node_text.append(f"{name}
Type: {decl_type}
File: {file_path}")
+ node_color.append(type_colors.get(decl_type, type_colors[None]))
+
+ node_trace = go.Scatter(
+ x=node_x, y=node_y,
+ mode='markers',
+ hoverinfo='text',
+ text=node_text,
+ marker=dict(
+ showscale=False,
+ color=node_color,
+ size=10,
+ line_width=2))
+
+ # Create figure
+ fig = go.Figure(data=[edge_trace, node_trace],
+ layout=go.Layout(
+ title=dict(text=title, font=dict(size=16)),
+ showlegend=False,
+ hovermode='closest',
+ margin=dict(b=0, l=0, r=0, t=40),
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
+ )
+
+ return fig
+
+
+def get_forest_statistics(G: nx.DiGraph, forest: Set[str]) -> Dict[str, Any]:
+ """
+ Get statistics for a forest (connected component).
+
+ Args:
+ G: NetworkX directed graph (edge B->A means "A depends on B")
+ forest: Set of decl_ids in the forest
+
+ Returns:
+ Dictionary with forest statistics
+ """
+ subgraph = G.subgraph(forest)
+
+ # Find root nodes (declarations that depend on nothing)
+ # With edge direction B->A meaning "A depends on B", roots have in_degree == 0
+ roots = [node for node in forest if subgraph.in_degree(node) == 0]
+
+ # Find leaf nodes (declarations that nothing depends on)
+ # With edge direction B->A meaning "A depends on B", leaves have out_degree == 0
+ leaves = [node for node in forest if subgraph.out_degree(node) == 0]
+
+ return {
+ 'size': len(forest),
+ 'num_edges': subgraph.number_of_edges(),
+ 'num_roots': len(roots),
+ 'num_leaves': len(leaves),
+ 'roots': [G.nodes[r].get('full_name', r) for r in roots[:5]], # First 5 roots
+ 'leaves': [G.nodes[l].get('full_name', l) for l in leaves[:5]] # First 5 leaves
+ }
diff --git a/src/app/data_explorer/lean_db_explorer.py b/src/app/data_explorer/lean_db_explorer.py
new file mode 100644
index 0000000..675189b
--- /dev/null
+++ b/src/app/data_explorer/lean_db_explorer.py
@@ -0,0 +1,480 @@
+"""
+Lean Declaration Database Explorer
+
+A Streamlit web app for exploring and analyzing Lean declaration databases.
+"""
+
+import sys
+from pathlib import Path
+
+# Add parent directories to path
+root_dir = Path(__file__).parent.parent.parent.parent
+if str(root_dir) not in sys.path:
+ sys.path.insert(0, str(root_dir))
+
+import streamlit as st
+import pandas as pd
+from itp_interface.tools.simple_sqlite import LeanDeclarationDB
+import db_utils
+import graph_utils
+
+
+# Page configuration
+st.set_page_config(
+ page_title="Lean Declaration DB Explorer",
+ page_icon="š",
+ layout="wide",
+ initial_sidebar_state="expanded"
+)
+
+
+def load_database(db_path: str) -> LeanDeclarationDB:
+ """Load database and cache in session state."""
+ try:
+ db = LeanDeclarationDB(db_path)
+ return db
+ except Exception as e:
+ st.error(f"Failed to load database: {e}")
+ return None
+
+
+def display_statistics(db: LeanDeclarationDB):
+ """Display database statistics in the sidebar."""
+ stats = db.get_statistics()
+
+ st.sidebar.markdown("### š Database Statistics")
+ col1, col2 = st.sidebar.columns(2)
+
+ with col1:
+ st.metric("Files", stats.get('total_files', 0))
+ st.metric("Declarations", stats.get('total_declarations', 0))
+
+ with col2:
+ st.metric("Dependencies", stats.get('total_dependencies', 0))
+ st.metric("Imports", stats.get('total_imports', 0))
+
+ st.sidebar.metric("Unresolved", stats.get('unresolved_declarations', 0))
+
+
+def tab_custom_query(db: LeanDeclarationDB):
+ """Custom Query tab interface."""
+ st.header("š Custom SQL Query")
+
+ # Example queries dropdown
+ st.markdown("**Pre-built Queries:**")
+ common_queries = db_utils.get_common_queries()
+ query_name = st.selectbox(
+ "Select a query to load",
+ [""] + list(common_queries.keys()),
+ label_visibility="collapsed"
+ )
+
+ # Query text area - use query_name as part of key to force refresh
+ if query_name:
+ default_query = common_queries.get(query_name)
+ else:
+ default_query = "SELECT * FROM declarations LIMIT 10"
+
+ query = st.text_area(
+ "SQL Query",
+ value=default_query,
+ height=150,
+ key=f"custom_query_{query_name}"
+ )
+
+ col1, col2 = st.columns([1, 5])
+ with col1:
+ execute_button = st.button("Execute Query", type="primary")
+
+ if execute_button:
+ try:
+ with st.spinner("Executing query..."):
+ df = db_utils.execute_query(db, query)
+
+ if not df.empty:
+ st.success(f"Query returned {len(df)} rows")
+ st.dataframe(df, use_container_width=True, height=500)
+
+ # Download button
+ csv = df.to_csv(index=False)
+ st.download_button(
+ label="Download CSV",
+ data=csv,
+ file_name="query_results.csv",
+ mime="text/csv"
+ )
+ else:
+ st.info("Query returned no results")
+
+ except Exception as e:
+ st.error(f"Query error: {e}")
+
+
+def tab_search(db: LeanDeclarationDB):
+ """Search & Browse tab interface."""
+ st.header("š Search Declarations")
+
+ col1, col2 = st.columns(2)
+
+ with col1:
+ name_pattern = st.text_input("Name pattern", placeholder="e.g., add, mul, theorem")
+ file_path = st.text_input("File path contains", placeholder="e.g., Mathlib/Data")
+
+ with col2:
+ decl_types = [""] + db_utils.get_declaration_types(db)
+ decl_type = st.selectbox("Declaration type", decl_types)
+
+ namespace = st.text_input("Namespace", placeholder="e.g., Nat, List")
+
+ if st.button("Search", type="primary"):
+ with st.spinner("Searching..."):
+ df = db_utils.search_declarations(
+ db,
+ name_pattern=name_pattern if name_pattern else None,
+ namespace=namespace if namespace else None,
+ file_path=file_path if file_path else None,
+ decl_type=decl_type if decl_type else None
+ )
+
+ if not df.empty:
+ st.success(f"Found {len(df)} declarations")
+
+ # Show results
+ st.dataframe(df, use_container_width=True, height=500)
+
+ # Download button
+ csv = df.to_csv(index=False)
+ st.download_button(
+ label="Download Results",
+ data=csv,
+ file_name="search_results.csv",
+ mime="text/csv"
+ )
+ else:
+ st.info("No declarations found matching the criteria")
+
+
+def tab_dependencies(db: LeanDeclarationDB):
+ """Dependency Explorer tab interface."""
+ st.header("š² Dependency Explorer")
+
+ # Declaration selector with option to search by name or decl_id
+ st.markdown("**Select a declaration:**")
+
+ search_mode = st.radio(
+ "Search by",
+ ["Declaration name", "Declaration ID"],
+ horizontal=True,
+ help="Search by name or directly by decl_id"
+ )
+
+ if search_mode == "Declaration name":
+ decl_input = st.text_input("Declaration name", placeholder="e.g., dvd_trans, Nat.add")
+ else:
+ decl_input = st.text_input("Declaration ID", placeholder="e.g., Nat.add_comm_123abc")
+
+ col1, col2, col3 = st.columns(3)
+
+ with col1:
+ show_deps = st.button("Show Dependencies", type="primary")
+ with col2:
+ show_dependents = st.button("Show Dependents", type="primary")
+ with col3:
+ max_depth = st.number_input("Max depth", min_value=1, max_value=20, value=5)
+
+ if show_deps or show_dependents:
+ if not decl_input:
+ st.warning(f"Please enter a declaration {search_mode.lower()}")
+ return
+
+ # Find declaration by name or ID
+ decl_id = None
+ display_name = decl_input
+
+ if search_mode == "Declaration name":
+ search_df = db_utils.search_declarations(db, name_pattern=decl_input)
+
+ if search_df.empty:
+ st.error(f"Declaration '{decl_input}' not found")
+ return
+
+ if len(search_df) > 1:
+ st.warning(f"Found {len(search_df)} declarations with this name. Using the first one.")
+ st.dataframe(search_df[['decl_id', 'name', 'namespace', 'file_path', 'decl_type']])
+
+ decl_id = search_df.iloc[0]['decl_id']
+ display_name = search_df.iloc[0]['name']
+ else:
+ # Direct decl_id lookup
+ decl = db.get_declaration_by_decl_id(decl_input)
+ if decl:
+ decl_id = decl_input
+ display_name = decl.get('name', decl_input)
+ # Show the declaration info
+ st.info(f"Found: {display_name} (Type: {decl.get('decl_type', 'unknown')})")
+ else:
+ st.error(f"Declaration with ID '{decl_input}' not found")
+ return
+
+ try:
+ with st.spinner("Analyzing dependencies..."):
+ if show_deps:
+ decls, subgraph = graph_utils.get_dependencies_closure(db, decl_id, max_depth)
+ title = f"Dependencies of {display_name}"
+ else:
+ decls, subgraph = graph_utils.get_dependents_closure(db, decl_id, max_depth)
+ title = f"Dependents of {display_name}"
+
+ if decls:
+ st.success(f"Found {len(decls)} related declarations")
+
+ # Show tabs for table and graph views
+ tab1, tab2 = st.tabs(["š Table View", "š Graph View"])
+
+ with tab1:
+ df = pd.DataFrame(decls)
+ st.dataframe(df[['decl_id', 'name', 'namespace', 'decl_type', 'file_path', 'line']],
+ use_container_width=True, height=400)
+
+ with tab2:
+ fig = graph_utils.visualize_graph(subgraph, title)
+ st.plotly_chart(fig, use_container_width=True)
+ else:
+ st.info("No dependencies found")
+ except Exception as e:
+ st.error(f"Error analyzing dependencies: {str(e)}")
+ st.exception(e)
+
+
+def tab_forests(db: LeanDeclarationDB):
+ """Forest Analysis tab interface."""
+ st.header("š³ Forest Analysis")
+
+ st.markdown("""
+ A **forest** is a connected component in the dependency graph. Declarations in the same
+ forest are connected through dependency relationships. **Root nodes** are declarations
+ with no dependencies within the forest (no incoming edges).
+ """)
+
+ # Initialize session state for forests
+ if 'forests' not in st.session_state:
+ st.session_state.forests = None
+ st.session_state.forest_graph = None
+
+ if st.button("Find All Forests", type="primary"):
+ with st.spinner("Analyzing dependency graph..."):
+ G = graph_utils.build_dependency_graph(db)
+ forests = graph_utils.find_all_forests(G)
+ # Store in session state
+ st.session_state.forests = forests
+ st.session_state.forest_graph = G
+
+ st.success(f"Found {len(forests)} forests")
+
+ # Use forests from session state
+ forests = st.session_state.forests
+ G = st.session_state.forest_graph
+
+ if forests is not None:
+ # Display statistics
+ st.markdown("### Forest Statistics")
+
+ forest_data = []
+ for i, forest in enumerate(forests[:20]): # Limit to top 20
+ stats = graph_utils.get_forest_statistics(G, forest)
+ forest_data.append({
+ 'Forest ID': i + 1,
+ 'Size': stats['size'],
+ 'Edges': stats['num_edges'],
+ 'Roots': stats['num_roots'],
+ 'Leaves': stats['num_leaves'],
+ 'Sample Roots': ', '.join(stats['roots'][:3])
+ })
+
+ df = pd.DataFrame(forest_data)
+ st.dataframe(df, use_container_width=True)
+
+ # Show declarations in selected forest
+ st.markdown("### View Forest Details")
+ forest_id = st.number_input(
+ "Forest ID to view",
+ min_value=1,
+ max_value=len(forests),
+ value=1
+ )
+
+ col1, col2 = st.columns(2)
+ with col1:
+ show_all_decls = st.button("Show All Declarations")
+ with col2:
+ show_roots = st.button("Show Root Nodes Only")
+
+ if show_all_decls or show_roots:
+ if G is None:
+ st.error("Graph not available. Please click 'Find All Forests' first.")
+ return
+
+ forest = forests[forest_id - 1]
+ stats = graph_utils.get_forest_statistics(G, forest)
+
+ if show_roots:
+ st.markdown(f"#### š± Root Nodes of Forest #{forest_id}")
+ st.info(f"Found {stats['num_roots']} root nodes (declarations with no dependencies)")
+
+ # Get root node details
+ root_decls = []
+ subgraph = G.subgraph(forest)
+ root_ids = [node for node in forest if subgraph.in_degree(node) == 0]
+
+ for decl_id in root_ids:
+ decl = db.get_declaration_by_decl_id(decl_id)
+ if decl:
+ # Count direct dependents
+ num_dependents = subgraph.out_degree(decl_id)
+ root_decls.append({
+ 'decl_id': decl['decl_id'],
+ 'name': decl['name'],
+ 'namespace': decl.get('namespace', ''),
+ 'decl_type': decl.get('decl_type', ''),
+ 'dependents_count': num_dependents,
+ 'file_path': decl.get('file_path', ''),
+ 'line': decl.get('line', '')
+ })
+
+ if root_decls:
+ root_df = pd.DataFrame(root_decls)
+ # Sort by number of dependents
+ root_df = root_df.sort_values('dependents_count', ascending=False)
+ st.dataframe(root_df, use_container_width=True, height=500)
+
+ # Download button
+ csv = root_df.to_csv(index=False)
+ st.download_button(
+ label="Download Root Nodes",
+ data=csv,
+ file_name=f"forest_{forest_id}_roots.csv",
+ mime="text/csv"
+ )
+ else:
+ st.warning("No root nodes found in this forest.")
+
+ else: # show_all_decls
+ st.info(f"Forest #{forest_id} contains {len(forest)} declarations")
+
+ # Get all declarations in this forest
+ forest_decls = []
+ subgraph = G.subgraph(forest)
+
+ for decl_id in forest:
+ decl = db.get_declaration_by_decl_id(decl_id)
+ if decl:
+ # Check if it's a root node
+ is_root = subgraph.in_degree(decl_id) == 0
+ # Truncate text and proof for display
+ text = decl.get('text', '')
+ proof = decl.get('proof', '')
+ forest_decls.append({
+ 'decl_id': decl['decl_id'],
+ 'name': decl['name'],
+ 'namespace': decl.get('namespace', ''),
+ 'decl_type': decl.get('decl_type', ''),
+ 'is_root': 'š±' if is_root else '',
+ 'text': text[:100] + '...' if text and len(text) > 100 else text,
+ 'proof': proof[:100] + '...' if proof and len(proof) > 100 else proof,
+ 'file_path': decl.get('file_path', ''),
+ 'line': decl.get('line', '')
+ })
+
+ if forest_decls:
+ forest_df = pd.DataFrame(forest_decls)
+ st.dataframe(forest_df, use_container_width=True, height=500)
+
+ # Download button
+ csv = forest_df.to_csv(index=False)
+ st.download_button(
+ label="Download Forest Declarations",
+ data=csv,
+ file_name=f"forest_{forest_id}_declarations.csv",
+ mime="text/csv"
+ )
+ else:
+ st.warning("No declarations found in this forest.")
+ else:
+ st.info("Click 'Find All Forests' to analyze the dependency graph")
+
+
+def main():
+ """Main application."""
+ st.title("š Lean Declaration Database Explorer")
+
+ # Sidebar
+ st.sidebar.title("Database Connection")
+
+ # Database path input
+ db_path = st.sidebar.text_input(
+ "Database Path",
+ value="lean_declarations.db",
+ help="Path to the SQLite database file"
+ )
+
+ load_button = st.sidebar.button("Load Database", type="primary")
+
+ # Initialize session state
+ if 'db' not in st.session_state:
+ st.session_state.db = None
+
+ if load_button:
+ if Path(db_path).exists():
+ with st.spinner("Loading database..."):
+ st.session_state.db = load_database(db_path)
+
+ if st.session_state.db:
+ st.sidebar.success("Database loaded successfully!")
+ else:
+ st.sidebar.error(f"Database file not found: {db_path}")
+
+ # Show statistics if database is loaded
+ if st.session_state.db:
+ display_statistics(st.session_state.db)
+
+ # Main content tabs
+ tab1, tab2, tab3, tab4 = st.tabs([
+ "š Custom Query",
+ "š Search",
+ "š² Dependencies",
+ "š³ Forests"
+ ])
+
+ with tab1:
+ tab_custom_query(st.session_state.db)
+
+ with tab2:
+ tab_search(st.session_state.db)
+
+ with tab3:
+ tab_dependencies(st.session_state.db)
+
+ with tab4:
+ tab_forests(st.session_state.db)
+
+ else:
+ st.info("š Please load a database to get started")
+
+ st.markdown("""
+ ### Getting Started
+
+ 1. Enter the path to your SQLite database file in the sidebar
+ 2. Click "Load Database"
+ 3. Explore your data using the tabs above
+
+ ### Features
+
+ - **Custom Query**: Write and execute custom SQL queries
+ - **Search**: Search declarations by name, type, file, etc.
+ - **Dependencies**: Explore dependency relationships
+ - **Forests**: Analyze connected components in the dependency graph
+ """)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/app/itp-gui/app.py b/src/app/itp-gui/app.py
index a169744..1f42a92 100644
--- a/src/app/itp-gui/app.py
+++ b/src/app/itp-gui/app.py
@@ -19,7 +19,7 @@
from typing import Optional, Dict, Any, List
import json
-from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor
+from itp_interface.lean.simple_lean4_sync_executor import SimpleLean4SyncExecutor
from itp_interface.lean_server.lean_context import ProofContext
app = Flask(__name__, static_folder='static', template_folder='templates')
diff --git a/src/data/test/batteries b/src/data/test/batteries
new file mode 160000
index 0000000..8da40b7
--- /dev/null
+++ b/src/data/test/batteries
@@ -0,0 +1 @@
+Subproject commit 8da40b72fece29b7d3fe3d768bac4c8910ce9bee
diff --git a/src/itp_interface/lean/__init__.py b/src/itp_interface/lean/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/itp_interface/lean/lean4_local_data_extraction_transform.py b/src/itp_interface/lean/lean4_local_data_extraction_transform.py
new file mode 100644
index 0000000..3b6be6d
--- /dev/null
+++ b/src/itp_interface/lean/lean4_local_data_extraction_transform.py
@@ -0,0 +1,381 @@
+#!/usr/bin/env python3
+
+import os
+import sys
+dir_name = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
+root_dir = os.path.abspath(dir_name)
+if root_dir not in sys.path:
+ sys.path.append(root_dir)
+import typing
+import uuid
+import json
+from pathlib import Path
+from filelock import FileLock
+from itp_interface.lean.simple_lean4_sync_executor import SimpleLean4SyncExecutor
+from itp_interface.lean.parsing_helpers import LeanDeclParser, LeanParseResult, LeanDeclType
+from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType
+from itp_interface.tools.training_data_format import MergableCollection, TrainingDataMetadataFormat, ExtractionDataCollection
+from itp_interface.tools.training_data import TrainingData, DataLayoutFormat
+from itp_interface.lean.tactic_parser import FileDependencyAnalysis, DeclWithDependencies
+from itp_interface.tools.simple_sqlite import LeanDeclarationDB
+
+class Local4DataExtractionTransform(GenericTrainingDataGenerationTransform):
+ def __init__(self,
+ depth = None,
+ max_search_results = None,
+ buffer_size : int = 10000,
+ logger = None,
+ max_parallelism : int = 4,
+ db_path : typing.Optional[str] = None,
+ enable_file_export: bool = False,
+ enable_dependency_extraction: bool = False):
+ super().__init__(TrainingDataGenerationType.LOCAL, buffer_size, logger)
+ self.depth = depth
+ self.max_search_results = max_search_results
+ self.max_parallelism = max_parallelism
+ self.db_path = db_path # Store path, don't create connection yet (for Ray actors)
+ self.enable_file_export = enable_file_export
+ self.enable_dependency_extraction = enable_dependency_extraction
+ # Everything except UNKNOWN
+ self.supported_declaration_types = [ str(dt) for dt in LeanDeclType if dt != LeanDeclType.UNKNOWN]
+ if self.db_path is not None:
+ temp_db_path = Path(self.db_path)
+ cache_path = temp_db_path.parent / "cache"
+ os.makedirs(str(cache_path), exist_ok=True)
+ mapping_path = cache_path / f"{temp_db_path.stem}_declaration_mapping.json"
+ self.cache_path = str(cache_path)
+ self.mapping_path = str(mapping_path)
+ self._lock_path = os.path.join(self.cache_path, "lean_declaration_mapping.lock")
+ else:
+ self.cache_path = None
+ self.mapping_path = None
+ self._lock_path = None
+
+ def get_meta_object(self) -> TrainingDataMetadataFormat:
+ return TrainingDataMetadataFormat(
+ training_data_buffer_size=self.buffer_size,
+ data_filename_prefix="extraction_data_",
+ lemma_ref_filename_prefix="extraction_lemma_refs_")
+
+ def get_data_collection_object(self) -> MergableCollection:
+ return ExtractionDataCollection()
+
+ def load_meta_from_file(self, file_path) -> MergableCollection:
+ return TrainingDataMetadataFormat.load_from_file(file_path)
+
+ def load_data_from_file(self, file_path) -> MergableCollection:
+ return ExtractionDataCollection.load_from_file(file_path, self.logger)
+
+ def _remove_lake_package_prefix(self, module_name: str) -> str:
+ if module_name.startswith("lake.packages.") or module_name.startswith(".lake.packages."):
+ parts = module_name.split('.')
+ if len(parts) > 3:
+ module_name = '.'.join(parts[3:])
+ else:
+ module_name = ''
+ return module_name
+
+ def _remove_lake_file_prefix(self, path: Path) -> str:
+ parts = path.parts
+ if ".lake" in parts:
+ lake_index = parts.index(".lake")
+ new_parts = parts[lake_index + 3:]
+ new_path = Path(*new_parts)
+ return str(new_path)
+ return str(path)
+
+ def _check_if_lean_core_file(self, file_path: Path) -> bool:
+ parts = file_path.parts
+ # If `leanprover--lean4---v4.24.0/src/lean/` is in the path, it's a core file
+ for i in range(len(parts) - 4):
+ if (parts[i].startswith("leanprover--lean4---v") and
+ parts[i+1] == "src" and
+ parts[i+2] == "lean"):
+ return True
+ return False
+
+ def _check_if_lean_core_module(self, module_name: str) -> bool:
+ parts = module_name.split('.')
+ # leanprover--lean4---v4.24.0.src.lean
+ idx = None
+ for i in range(len(parts) - 2):
+ if parts[i].startswith("leanprover--lean4---v"):
+ idx = i
+ if idx is not None:
+ for i in range(idx, len(parts)):
+ if parts[i] == "src" and i + 1 < len(parts) and parts[i + 1] == "lean":
+ return True
+ return False
+
+ def _strip_core_file_path(self, file_path: Path) -> str:
+ # Take only the parts after `leanprover--lean4---v4.24.0/src/lean/`
+ parts = file_path.parts
+ for i in range(len(parts) - 4):
+ if (parts[i].startswith("leanprover--lean4---v") and
+ parts[i+1] == "src" and
+ parts[i+2] == "lean"):
+ new_parts = parts[i+3:]
+ new_path = Path(*new_parts)
+ return str(new_path)
+ return str(file_path)
+
+ def _strip_core_module_name(self, module_name: str) -> str:
+ parts = module_name.split('.')
+ # leanprover--lean4---v4.24.0.src.lean
+ idx: int | None = None
+ for i in range(len(parts) - 2):
+ if parts[i].startswith("leanprover--lean4---v"):
+ idx = i
+ if idx is not None:
+ for i in range(idx, len(parts)):
+ if parts[i] == "src" and i + 1 < len(parts) and parts[i + 1] == "lean":
+ new_parts = parts[i + 2:]
+ return '.'.join(new_parts)
+ return module_name
+
+ def _get_file_path(self, project_path: str, file_path: str) -> str:
+ fp = Path(file_path)
+ pp = Path(project_path)
+ fp_abs = fp.resolve()
+ pp_abs = pp.resolve()
+ if self._check_if_lean_core_file(fp_abs):
+ stripped_path = self._strip_core_file_path(fp_abs)
+ return stripped_path
+ relative_path = fp_abs.relative_to(pp_abs)
+ rel_file_path = self._remove_lake_file_prefix(relative_path)
+ return str(rel_file_path)
+
+ def _get_module_name(self, project_path: str, module_name: str) -> str:
+ if self._check_if_lean_core_module(module_name):
+ stripped_module = self._strip_core_module_name(module_name)
+ return stripped_module
+ pp = Path(project_path)
+ pp_abs = pp.resolve()
+ pp_module = str(pp_abs).replace('/', '.')
+ pp_module = pp_module.lstrip('.')
+ # self.logger.info(f"Project module prefix: {pp_module}")
+ if module_name.startswith(pp_module):
+ module_name = module_name[len(pp_module):]
+ module_name = module_name.lstrip('.')
+ # self.logger.info(f"Module name after removing project prefix: {module_name}")
+ module_name = self._remove_lake_package_prefix(module_name)
+ return module_name
+
+ def _remap_local_dependency(self,
+ project_path: str,
+ file_dep_analyses: typing.List[FileDependencyAnalysis]) -> None:
+ # Map local dependencies
+ all_decls : dict[str, DeclWithDependencies] = {}
+ for fda in file_dep_analyses:
+ for decl in fda.declarations:
+ name = decl.decl_info.name
+ if name is not None:
+ all_decls[name] = decl
+ for fda in file_dep_analyses:
+ fda_rel_path = self._get_file_path(project_path, fda.file_path)
+ # Remove the pp_module prefix from the fda module name
+ fda_module_name = self._get_module_name(project_path, fda.module_name)
+ for decl in fda.declarations:
+ for decl_dep in decl.dependencies:
+ if decl_dep.name in all_decls:
+ local_decl = all_decls[decl_dep.name]
+ if decl_dep.file_path is None and \
+ local_decl.decl_info.decl_type in self.supported_declaration_types:
+ # We need to ensure that it is locally defined.
+ decl_dep.file_path = fda_rel_path
+ if decl_dep.module_name is None:
+ decl_dep.module_name = fda_module_name
+ if decl_dep.namespace is None:
+ decl_dep.namespace = decl.decl_info.namespace
+
+ def _preprocess_declarations(self,
+ project_path: str,
+ file_dep_analyses: typing.List[FileDependencyAnalysis]) -> None:
+ # Preprocess declarations to set file paths and module names
+ for fda in file_dep_analyses:
+ for decl in fda.declarations:
+ self._preprocess_declaration(project_path, decl)
+
+ def _preprocess_declaration(
+ self,
+ project_path: str,
+ decl: DeclWithDependencies) -> None:
+ # Filter all unknown types
+ if decl.decl_info.decl_type == LeanDeclType.UNKNOWN.value:
+ # Check if we have docstring or not
+ if decl.decl_info.doc_string is None or decl.decl_info.doc_string.strip() != "":
+ parser = LeanDeclParser(decl.decl_info.text)
+ parse_result = parser.parse()
+ if parse_result.doc_string is not None:
+ decl.decl_info.doc_string = parse_result.doc_string.strip()
+ if parse_result.decl_type != LeanDeclType.UNKNOWN:
+ decl.decl_info.decl_type = str(parse_result.decl_type)
+ if parse_result.text is not None:
+ full_text = []
+ if parse_result.text_before is not None:
+ full_text.append(parse_result.text_before.strip())
+ full_text.append(parse_result.text.strip())
+ text = "\n".join(full_text)
+ decl.decl_info.text = text
+ # Update the proof if not already present
+ if decl.decl_info.proof is None and parse_result.proof is not None:
+ decl.decl_info.proof = parse_result.proof.strip()
+ if parse_result.name is not None:
+ decl.decl_info.name = parse_result.name.strip()
+ pass
+
+
+ def __call__(self,
+ training_data: TrainingData,
+ project_id : str,
+ lean_executor: SimpleLean4SyncExecutor,
+ print_coq_executor_callback: typing.Callable[[], SimpleLean4SyncExecutor],
+ theorems: typing.List[str] = None,
+ other_args: dict = {}) -> TrainingData:
+ file_path = lean_executor.main_file
+ project_path = project_id
+ rel_file_path = self._get_file_path(project_path, file_path)
+ file_namespace = rel_file_path.replace('/', '.')
+ self.logger.info(f"=========================Processing {file_namespace}=========================")
+
+ # Create database connection for this Ray actor (if db_path is provided)
+ db = None
+ if self.db_path:
+ db = LeanDeclarationDB(self.db_path)
+ self.logger.info(f"Connected to database: {self.db_path}")
+
+ try:
+ if isinstance(theorems, list) and len(theorems) == 1 and theorems[0] == "*":
+ theorems = None
+ else:
+ theorems = set(theorems) if theorems is not None else None
+ cnt = 0
+ temp_dir = os.path.join(training_data.folder, "temp") if self.cache_path is None else self.cache_path
+ temp_dir = temp_dir.rstrip('/')
+ if self.cache_path is None:
+ os.makedirs(temp_dir, exist_ok=True)
+ json_output_path = f"{temp_dir}/{file_namespace.replace('.', '_')}.lean.deps.json"
+ already_mapped = False
+ if self._lock_path is not None:
+ with FileLock(self._lock_path):
+ assert self.mapping_path is not None
+ if not os.path.exists(self.mapping_path):
+ with open(self.mapping_path, "w") as f:
+ json.dump({}, f)
+ with open(self.mapping_path, "r") as f:
+ mapping_data = json.load(f)
+ if json_output_path in mapping_data:
+ already_mapped = True
+ else:
+ mapping_data[json_output_path] = True
+ with open(self.mapping_path, "w") as f:
+ json.dump(mapping_data, f)
+ if already_mapped and os.path.exists(json_output_path):
+ # Read from existing file
+ self.logger.info(f"Using existing dependency extraction file: {json_output_path}")
+ with open(json_output_path, 'r') as f:
+ data = json.load(f)
+ file_dep_analyses = [FileDependencyAnalysis.model_validate(data)]
+ else:
+ file_dep_analyses = lean_executor.extract_all_theorems_and_definitions(json_output_path=json_output_path)
+ self.logger.info(f"Extracted {len(file_dep_analyses)} FileDependencyAnalysis objects from {file_namespace}")
+ self.logger.info(f"file_dep_analyses: {file_dep_analyses}")
+ assert len(file_dep_analyses) == 1, "Expected exactly one FileDependencyAnalysis object"
+
+ last_decl_id = None
+ self._preprocess_declarations(project_path, file_dep_analyses)
+ self._remap_local_dependency(project_path, file_dep_analyses)
+
+ for fda in file_dep_analyses:
+ fda_rel_path = self._get_file_path(project_path, fda.file_path)
+ # Remove the pp_module prefix from the fda module name
+ fda_module_name = self._get_module_name(project_path, fda.module_name)
+
+ # Insert file and imports into database (if db is enabled)
+ if db and fda.imports:
+ db.insert_file_imports(fda_rel_path, fda_module_name, fda.imports)
+ self.logger.info(f"Inserted file and {len(fda.imports)} imports for {fda_rel_path}")
+
+ for decl in fda.declarations:
+ # Get or create decl_id from database (or generate new one if no DB)
+ if db:
+ if decl.decl_info.decl_type not in self.supported_declaration_types:
+ self.logger.info(f"Skipping declaration '{decl.decl_info.name}' of unsupported type '{decl.decl_info.decl_type}'")
+ continue
+ decl_id = db.process_declaration(
+ fda_file_path=fda_rel_path,
+ fda_module_name=fda_module_name,
+ decl=decl,
+ enable_dependency_extraction=self.enable_dependency_extraction
+ )
+ self.logger.info(f"Processed declaration '{decl.decl_info.name}' with ID: {decl_id}")
+ else:
+ # Fallback: generate unique ID without database
+ timestamp = str(int(uuid.uuid1().time_low))
+ random_id = str(uuid.uuid4())
+ decl_id = f"{timestamp}_{random_id}"
+ if self.enable_file_export:
+ new_fda = FileDependencyAnalysis(
+ file_path=str(fda_rel_path),
+ module_name=fda_module_name,
+ imports=fda.imports,
+ declarations=[])
+ line_info = decl.decl_info
+ if theorems is not None and line_info.name not in theorems:
+ continue
+ decl.decl_id = decl_id
+ new_fda.declarations.append(decl)
+ training_data.merge(new_fda)
+ cnt += 1
+ last_decl_id = decl_id
+
+ if last_decl_id:
+ training_data.meta.last_proof_id = last_decl_id
+ self.logger.info(f"===============Finished processing {file_namespace}=====================")
+ self.logger.info(f"Total declarations processed in this transform: {cnt}")
+ return training_data
+ finally:
+ # Clean up database connection
+ if db:
+ db.close()
+ self.logger.info("Closed database connection")
+
+
+if __name__ == "__main__":
+ import os
+ import logging
+ import time
+ os.chdir(root_dir)
+ # project_dir = 'data/test/lean4_proj/'
+ project_dir = 'data/test/Mathlib'
+ # file_name = 'data/test/lean4_proj/Lean4Proj/Basic.lean'
+ # file_name = 'data/test/Mathlib/.lake/packages/mathlib/Mathlib/Algebra/Divisibility/Basic.lean'
+ home_path = str(Path.home())
+ file_name = f'{home_path}/.elan/toolchains/leanprover--lean4---v4.24.0/src/lean/Init/Prelude.lean'
+ project_id = project_dir #.replace('/', '.')
+ time_str = time.strftime("%Y%m%d-%H%M%S")
+ output_path = f".log/local_data_generation_transform/data/{time_str}"
+ log_path = f".log/local_data_generation_transform/log/{time_str}"
+ log_file = f"{log_path}/local_data_generation_transform-{time_str}.log"
+ db_path = f"{log_path}/lean_declarations.db"
+ os.makedirs(output_path, exist_ok=True)
+ os.makedirs(log_path, exist_ok=True)
+ logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
+ logger = logging.getLogger(__name__)
+ def _print_lean_executor_callback():
+ search_lean_exec = SimpleLean4SyncExecutor(main_file=file_name, project_root=project_dir)
+ search_lean_exec.__enter__()
+ return search_lean_exec
+ # Create transform with database enabled
+ transform = Local4DataExtractionTransform(0, buffer_size=1000, db_path=db_path)
+ logger.info(f"Using database: {db_path}")
+ training_data = TrainingData(
+ output_path,
+ "training_metadata.json",
+ training_meta=transform.get_meta_object(),
+ logger=logger,
+ layout=DataLayoutFormat.DECLARATION_EXTRACTION)
+ with SimpleLean4SyncExecutor(project_root=project_dir, main_file=file_name, use_human_readable_proof_context=True, suppress_error_log=True) as coq_exec:
+ transform(training_data, project_id, coq_exec, _print_lean_executor_callback, theorems=["*"])
+ save_info = training_data.save()
+ logger.info(f"Saved training data to {save_info}")
\ No newline at end of file
diff --git a/src/itp_interface/lean/parsing_helpers.py b/src/itp_interface/lean/parsing_helpers.py
new file mode 100644
index 0000000..de16b24
--- /dev/null
+++ b/src/itp_interface/lean/parsing_helpers.py
@@ -0,0 +1,376 @@
+import sys
+import re
+from enum import Enum
+from dataclasses import dataclass
+from typing import Optional, Set
+
+class LeanDeclType(Enum):
+ LEMMA = "lemma"
+ THEOREM = "theorem"
+ DEF = "def"
+ STRUCTURE = "structure"
+ CLASS = "class"
+ INDUCTIVE = "inductive"
+ INSTANCE = "instance"
+ ABBREV = "abbrev"
+ ABBREVIATION = "abbreviation"
+ AXIOM = "axiom"
+ EXAMPLE = "example"
+ OPAQUE = "opaque"
+ CONSTANT = "constant"
+ MUTUAL = "mutual"
+ UNKNOWN = "unknown"
+
+ def __str__(self) -> str:
+ return str(self.value).lower()
+
+@dataclass
+class LeanParseResult:
+ decl_type: LeanDeclType
+ name: Optional[str] = None
+ text_before: Optional[str] = None
+ doc_string: Optional[str] = None
+ text: Optional[str] = None
+ proof: Optional[str] = None
+
+class LeanDeclParser:
+ """
+ Parses Lean 4 declarations to separate context, docstrings,
+ declaration headers, and proofs/bodies.
+ """
+
+ # Keywords that mark the start of a declaration
+ DECL_KEYWORDS = {m.value for m in LeanDeclType if m != LeanDeclType.UNKNOWN}
+
+ # Types for which we should NOT attempt to extract a proof/body
+ NO_PROOF_TYPES = {
+ LeanDeclType.INDUCTIVE,
+ LeanDeclType.MUTUAL,
+ LeanDeclType.STRUCTURE,
+ LeanDeclType.CLASS
+ }
+
+ # Types that typically don't have a name
+ NO_NAME_TYPES = {
+ LeanDeclType.EXAMPLE,
+ LeanDeclType.MUTUAL,
+ LeanDeclType.UNKNOWN
+ }
+
+ def __init__(self, text: str):
+ self.text = text
+ self.n = len(text)
+ self.tokens = []
+ self.docstring_range = None # (start, end)
+
+ # Key Indices and Info
+ self.decl_start = -1
+ self.proof_start = -1
+ self.decl_type: LeanDeclType = LeanDeclType.UNKNOWN
+ self.decl_name: Optional[str] = None
+
+ def parse(self) -> LeanParseResult:
+ self._tokenize()
+ self._analyze_structure()
+ return self._construct_result()
+
+ def _tokenize(self):
+ """
+ Scans text to find tokens, respecting comments, strings, and nesting.
+ """
+ i = 0
+ # States
+ NORMAL = 0
+ IN_STRING = 1
+ IN_CHAR = 2
+
+ state = NORMAL
+ nesting = 0 # () [] {}
+
+ while i < self.n:
+ # 1. Handle Comments (Line and Block)
+ if state == NORMAL:
+ if self.text.startswith("--", i):
+ # Line comment
+ end_line = self.text.find('\n', i)
+ if end_line == -1: end_line = self.n
+ i = end_line
+ continue
+
+ if self.text.startswith("/-", i):
+ # Block comment
+ is_doc = self.text.startswith("/--", i)
+ start_idx = i
+
+ # Find end of block comment (handle nesting)
+ depth = 1
+ i += 2
+ while i < self.n and depth > 0:
+ if self.text.startswith("/-", i):
+ depth += 1
+ i += 2
+ elif self.text.startswith("-/", i):
+ depth -= 1
+ i += 2
+ else:
+ i += 1
+
+ # Capture the FIRST docstring found
+ if is_doc and self.docstring_range is None:
+ self.docstring_range = (start_idx, i)
+ continue
+
+ # 2. Handle Strings/Chars
+ if state == NORMAL:
+ if self.text[i] == '"':
+ state = IN_STRING
+ i += 1
+ continue
+ if self.text[i] == "'":
+ state = IN_CHAR
+ i += 1
+ continue
+
+ elif state == IN_STRING:
+ if self.text[i] == '\\': i += 2; continue
+ if self.text[i] == '"': state = NORMAL; i += 1; continue
+ i += 1
+ continue
+
+ elif state == IN_CHAR:
+ if self.text[i] == '\\': i += 2; continue
+ if self.text[i] == "'": state = NORMAL; i += 1; continue
+ i += 1
+ continue
+
+ # 3. Handle Structure Tokens in NORMAL state
+ char = self.text[i]
+
+ # Nesting tracking
+ if char in "([{":
+ nesting += 1
+ i += 1
+ continue
+ elif char in ")]}":
+ nesting = max(0, nesting - 1)
+ i += 1
+ continue
+
+ # Token detection (only at top level)
+ if nesting == 0:
+ # Check for 'in' keyword (standalone)
+ if self._is_keyword_at(i, "in"):
+ self.tokens.append(("IN", i, i+2))
+ i += 2
+ continue
+
+ # Check for Declaration Keywords
+ match_kw = self._match_any_keyword(i, self.DECL_KEYWORDS)
+ if match_kw:
+ kw, length = match_kw
+ self.tokens.append(("KW", i, i+length))
+ i += length
+ continue
+
+ # Check for Attribute Start
+ if self.text.startswith("@[", i):
+ self.tokens.append(("ATTR", i, i+2))
+ i += 2
+ nesting += 1 # The '[' counts as nesting
+ continue
+
+ # Check for Proof Starters
+ if self.text.startswith(":=", i):
+ self.tokens.append(("PROOF", i, i+2))
+ i += 2
+ continue
+ if self.text.startswith("where", i) and self._is_word_boundary(i+5):
+ self.tokens.append(("PROOF", i, i+5))
+ i += 5
+ continue
+ if char == '|':
+ self.tokens.append(("PROOF", i, i+1))
+ i += 1
+ continue
+
+ i += 1
+
+ def _analyze_structure(self):
+ """
+ Interpret the token stream to find split points.
+ """
+ candidate_decl = -1
+ decl_keyword_str = None
+ decl_keyword_end = -1
+
+ # Pass 1: Find Declaration Start
+ for t_type, t_start, t_end in self.tokens:
+ if t_type == "IN":
+ candidate_decl = -1 # Reset candidate
+ decl_keyword_str = None
+
+ elif t_type == "KW":
+ if candidate_decl == -1:
+ candidate_decl = t_start
+ if decl_keyword_str is None:
+ decl_keyword_str = self.text[t_start:t_end]
+ decl_keyword_end = t_end
+
+ elif t_type == "ATTR":
+ if candidate_decl == -1:
+ candidate_decl = t_start
+
+ if candidate_decl != -1:
+ self.decl_start = candidate_decl
+
+ # Resolve Enum Type
+ if decl_keyword_str:
+ try:
+ self.decl_type = LeanDeclType(decl_keyword_str)
+ except ValueError:
+ self.decl_type = LeanDeclType.UNKNOWN
+ else:
+ self.decl_type = LeanDeclType.UNKNOWN
+
+ # Extract Name
+ if self.decl_type not in self.NO_NAME_TYPES and decl_keyword_end != -1:
+ self.decl_name = self._extract_name_after(decl_keyword_end)
+
+ # Pass 2: Find Proof Start
+ skip_proof = self.decl_type in self.NO_PROOF_TYPES
+
+ if not skip_proof:
+ for t_type, t_start, t_end in self.tokens:
+ if t_start > self.decl_start and t_type == "PROOF":
+ self.proof_start = t_start
+ break
+ else:
+ pass
+
+ def _extract_name_after(self, idx: int) -> Optional[str]:
+ """
+ Finds the first identifier after the given index, skipping comments and whitespace.
+ Returns None if it hits a symbol (e.g. '(', '{', ':') before a name.
+ """
+ i = idx
+ while i < self.n:
+ c = self.text[i]
+
+ # Skip Whitespace
+ if c.isspace():
+ i += 1
+ continue
+
+ # Skip Line Comments
+ if self.text.startswith("--", i):
+ i = self.text.find('\n', i)
+ if i == -1: return None
+ continue
+
+ # Skip Block Comments
+ if self.text.startswith("/-", i):
+ # Quick skip for simple block comments, logic same as tokenizer
+ depth = 1
+ i += 2
+ while i < self.n and depth > 0:
+ if self.text.startswith("/-", i):
+ depth += 1
+ i += 2
+ elif self.text.startswith("-/", i):
+ depth -= 1
+ i += 2
+ else:
+ i += 1
+ continue
+
+ # Identifier Start Check
+ # Lean identifiers can be French-quoted «name» or standard
+ # If it starts with a symbol like (, {, [, :, it's anonymous
+ if not (c.isalnum() or c == '_' or c == 'Ā«'):
+ return None
+
+ # Extract
+ start = i
+ if c == 'Ā«':
+ end = self.text.find('Ā»', start)
+ if end != -1:
+ return self.text[start:end+1]
+ else:
+ # Malformed? Just return rest of line
+ return None
+ else:
+ # Standard identifier (alphanum + . + _)
+ while i < self.n:
+ curr = self.text[i]
+ if curr.isalnum() or curr == '_' or curr == '.':
+ i += 1
+ else:
+ break
+ return self.text[start:i]
+
+ return None
+
+ def _construct_result(self) -> LeanParseResult:
+ # Case 1: No declaration found
+ if self.decl_start == -1:
+ return LeanParseResult(
+ decl_type=LeanDeclType.UNKNOWN,
+ text_before=self.text
+ )
+
+ # Case 2: Declaration found
+ split_idx = self.decl_start
+
+ decl_end = self.n
+ proof_content = None
+ if self.proof_start != -1:
+ decl_end = self.proof_start
+ proof_content = self.text[self.proof_start:].strip()
+
+ raw_pre = self.text[:split_idx]
+ raw_decl = self.text[split_idx:decl_end]
+ doc_content = None
+
+ if self.docstring_range:
+ ds_start, ds_end = self.docstring_range
+ doc_content = self.text[ds_start:ds_end]
+
+ if ds_start < split_idx:
+ # Remove docstring from raw_pre
+ pre_part1 = self.text[:ds_start]
+ pre_part2 = self.text[ds_end:split_idx]
+ raw_pre = pre_part1 + pre_part2
+
+ return LeanParseResult(
+ decl_type=self.decl_type,
+ name=self.decl_name,
+ text_before=raw_pre.strip() or None,
+ doc_string=doc_content or None,
+ text=raw_decl.strip() or None,
+ proof=proof_content or None
+ )
+
+ # --- Helpers ---
+ def _is_keyword_at(self, idx, kw):
+ if not self.text.startswith(kw, idx): return False
+ return self._is_word_boundary(idx + len(kw))
+
+ def _match_any_keyword(self, idx, keywords):
+ if not self.text[idx].isalpha(): return None
+ j = idx
+ while j < self.n and (self.text[j].isalnum() or self.text[j] == '_'):
+ j += 1
+ word = self.text[idx:j]
+ if word in keywords:
+ return word, len(word)
+ return None
+
+ def _is_word_boundary(self, idx):
+ if idx >= self.n: return True
+ c = self.text[idx]
+ return not (c.isalnum() or c == '_')
+
+
+def parse_lean_text(text: str) -> LeanParseResult:
+ parser = LeanDeclParser(text)
+ return parser.parse()
\ No newline at end of file
diff --git a/src/itp_interface/tools/simple_lean4_sync_executor.py b/src/itp_interface/lean/simple_lean4_sync_executor.py
similarity index 99%
rename from src/itp_interface/tools/simple_lean4_sync_executor.py
rename to src/itp_interface/lean/simple_lean4_sync_executor.py
index 49250fc..b957aa0 100644
--- a/src/itp_interface/tools/simple_lean4_sync_executor.py
+++ b/src/itp_interface/lean/simple_lean4_sync_executor.py
@@ -10,7 +10,7 @@
import typing
import bisect
import subprocess
-from itp_interface.tools.tactic_parser import (
+from itp_interface.lean.tactic_parser import (
TacticParser,
ErrorInfo,
LeanLineInfo,
diff --git a/src/itp_interface/tools/tactic_parser.py b/src/itp_interface/lean/tactic_parser.py
similarity index 94%
rename from src/itp_interface/tools/tactic_parser.py
rename to src/itp_interface/lean/tactic_parser.py
index 966f781..36cb2c3 100644
--- a/src/itp_interface/tools/tactic_parser.py
+++ b/src/itp_interface/lean/tactic_parser.py
@@ -106,6 +106,7 @@ class LeanLineInfo(BaseModel):
name: Optional[str] = None
doc_string: Optional[str] = None
namespace: Optional[str] = None
+ proof: Optional[str] = None # Full proof text (if available)
def __repr__(self) -> str:
return f"LeanLineInfo(text={self.text!r}, line={self.line}, column={self.column})"
@@ -188,7 +189,8 @@ def from_dict(decl_data: Dict) -> 'DeclWithDependencies':
decl_type=decl_dict.get('declType'),
name=decl_dict.get('name'),
doc_string=decl_dict.get('docString'),
- namespace=decl_dict.get('namespace')
+ namespace=decl_dict.get('namespace'),
+ proof=decl_dict.get('proof')
)
# Parse dependencies
@@ -248,6 +250,23 @@ class FileDependencyAnalysis(BaseModel):
def __repr__(self) -> str:
return f"FileDependencyAnalysis({self.module_name}, {len(self.declarations)} decls)"
+
+ def to_json(self, indent=0) -> str:
+ if indent == 0:
+ return self.model_dump_json()
+ else:
+ return self.model_dump_json(indent=indent)
+
+ @staticmethod
+ def load_from_string(json_text: str) -> 'FileDependencyAnalysis':
+ return FileDependencyAnalysis.model_validate_json(json_text)
+
+ @staticmethod
+ def load_from_file(file_path: str) -> 'FileDependencyAnalysis':
+ with open(file_path, 'r', encoding='utf-8') as f:
+ data = f.read()
+ return FileDependencyAnalysis.load_from_string(data)
+
# Create an enum for parsing request type
class RequestType(Enum):
@@ -259,8 +278,8 @@ class RequestType(Enum):
def get_path_to_tactic_parser_project() -> str:
"""Get the path to the tactic parser project directory."""
- tools_dir = os.path.dirname(__file__)
- tactic_parser_path = os.path.join(tools_dir, "tactic_parser")
+ lean_dir = os.path.dirname(__file__)
+ tactic_parser_path = os.path.join(lean_dir, "tactic_parser")
abs_path = os.path.abspath(tactic_parser_path)
return abs_path
@@ -734,10 +753,31 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
else:
print(msg)
+def print_errors(errors: List[ErrorInfo], logger: Optional[logging.Logger] = None):
+ for error in errors:
+ msg = f"Error at Line {error.position.line}, Col {error.position.column}: {error.message}"
+ if logger:
+ logger.error(msg)
+ else:
+ print(msg)
+
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
project_path = str(Path(__file__).parent.parent.parent / "data" / "test" / "lean4_proj")
+ with TacticParser() as parser:
+ # Example 0: Empty proof
+ lean_code = """theorem test : ā {a : Nat}, a + 0 = a
+| 0 => by simp
+| n + 1 => by simp
+"""
+
+ print("Parsing example 0...")
+ tactics, errors = parser.parse(lean_code)
+ print_tactics(tactics)
+ if errors:
+ print(f"Error: {errors}")
+
with TacticParser() as parser:
# Example 1: Simple proof
lean_code = "example : True := by trivial"
@@ -746,7 +786,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
tactics, errors = parser.parse(lean_code)
print_tactics(tactics)
if errors:
- print(f"Error: {errors}")
+ print_errors(errors)
p_path = "/home/amthakur/Projects/copra/data/test/miniF2F-lean4"
with TacticParser(project_path=p_path) as parser:
@@ -765,7 +805,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
tactics, errors = parser.parse(lean_code, fail_on_error=False)
print_tactics(tactics)
if errors:
- print(f"Error: {errors}")
+ print_errors(errors)
with TacticParser(project_path=p_path) as parser:
# Example 1a: Simple proof with multiple tactics
@@ -786,7 +826,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
tactics, errors = parser.parse(lean_code, fail_on_error=False)
print_tactics(tactics)
if errors:
- print(f"Error: {errors}")
+ print_errors(errors)
with TacticParser(project_path=p_path) as parser:
# Example 1a: Simple proof with multiple tactics
@@ -807,7 +847,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
tactics, errors = parser.parse(lean_code, fail_on_error=False)
print_tactics(tactics)
if errors:
- print(f"Error: {errors}")
+ print_errors(errors)
with TacticParser() as parser:
# Example 1b: Simple have proofs
@@ -824,7 +864,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
tactics, errors = parser.parse(lean_code, fail_on_error=False)
print_tactics(tactics)
if errors:
- print(f"Error: {errors}")
+ print_errors(errors)
@@ -836,7 +876,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
tactics2, errors = parser.parse(lean_code2, fail_on_error=False)
print_tactics(tactics2)
if errors:
- print(f"Error: {errors}")
+ print_errors(errors)
# Check if linarith is parsed correctly
lean_code3 = """
@@ -853,7 +893,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
tactics3, errors = parser.parse(lean_code3)
print_tactics(tactics3)
if errors:
- print(f"Error: {errors}")
+ print_errors(errors)
file_path = str(Path(__file__).parent.parent.parent / "data" / "test" / "lean4_proj" / "Lean4Proj" / "Basic.lean")
@@ -863,7 +903,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
tactics4, errors = parser.parse_file(file_path)
print_tactics(tactics4)
if errors:
- print(f"Error: {errors}")
+ print_errors(errors)
with TacticParser(project_path=project_path) as parser:
# Example 2: Multiline with params
@@ -873,7 +913,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
tactics5, errors = parser.parse(lean_code4)
print_tactics(tactics5)
if errors:
- print(f"Error: {errors}")
+ print_errors(errors)
with TacticParser(project_path=project_path) as parser:
# Example 6: Parse tactics from file with multiple theorems
@@ -881,7 +921,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
tactics6, errors = parser.parse(lean_code3 + "\n" + lean_code4, parse_type=RequestType.PARSE_TACTICS)
print_tactics(tactics6)
if errors:
- print(f"Error: {errors}")
+ print_errors(errors)
with TacticParser(project_path=project_path) as parser:
# Example 7: Parse tactics which are wrong
@@ -890,7 +930,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
tactics7, errors = parser.parse(lean_code5, fail_on_error=False)
print_tactics(tactics7)
if errors:
- print(f"Error: {errors}")
+ print_errors(errors)
with TacticParser(project_path=project_path) as parser:
# Example 8: Parse tactics just before `by`
@@ -899,7 +939,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
tactics8, errors = parser.parse(lean_code8, fail_on_error=False)
print_tactics(tactics8)
if errors:
- print(f"Error: {errors}")
+ print_errors(errors)
with TacticParser(project_path=project_path) as parser:
# Example 9: Parse tactics just before `by`
@@ -908,7 +948,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
tactics9, errors = parser.parse(lean_code9, fail_on_error=False)
print_tactics(tactics9)
if errors:
- print(f"Error: {errors}")
+ print_errors(errors)
with TacticParser(project_path=project_path) as parser:
# Example 10: Test checkpointing
@@ -926,7 +966,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
tactics10, errors = parser.parse(lean_code10, fail_on_error=True, parse_type=RequestType.CHKPT_TACTICS)
print_tactics(tactics10)
if errors:
- print(f"Error: {errors}")
+ print_errors(errors)
# Now just execute from the checkpoint
lean_code10b = """
theorem temp2: 1 + 2 = 3 :=
@@ -938,7 +978,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
print_tactics(tactics10b)
if errors:
# The error should contain h_temp
- print(f"Error: {errors}")
+ print_errors(errors)
print("\nBreaking checkpoint...")
new_lean_code10c = lean_code10 + lean_code10b
@@ -946,4 +986,4 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger]
# ^This will reimport everything all run all theorems from scratch
print_tactics(tactics10c)
if errors:
- print(f"Error: {errors}")
\ No newline at end of file
+ print_errors(errors)
\ No newline at end of file
diff --git a/src/itp_interface/tools/tactic_parser/TacticParser.lean b/src/itp_interface/lean/tactic_parser/TacticParser.lean
similarity index 60%
rename from src/itp_interface/tools/tactic_parser/TacticParser.lean
rename to src/itp_interface/lean/tactic_parser/TacticParser.lean
index da652e7..762a727 100644
--- a/src/itp_interface/tools/tactic_parser/TacticParser.lean
+++ b/src/itp_interface/lean/tactic_parser/TacticParser.lean
@@ -2,3 +2,5 @@ import TacticParser.Base64
import TacticParser.Types
import TacticParser.SyntaxWalker
import TacticParser.Main
+import TacticParser.SyntaxWalkerMain
+import TacticParser.Example.simple
diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/Base64.lean b/src/itp_interface/lean/tactic_parser/TacticParser/Base64.lean
similarity index 100%
rename from src/itp_interface/tools/tactic_parser/TacticParser/Base64.lean
rename to src/itp_interface/lean/tactic_parser/TacticParser/Base64.lean
diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/DependencyParser.lean b/src/itp_interface/lean/tactic_parser/TacticParser/DependencyParser.lean
similarity index 98%
rename from src/itp_interface/tools/tactic_parser/TacticParser/DependencyParser.lean
rename to src/itp_interface/lean/tactic_parser/TacticParser/DependencyParser.lean
index 4e9cd0a..984aea6 100644
--- a/src/itp_interface/tools/tactic_parser/TacticParser/DependencyParser.lean
+++ b/src/itp_interface/lean/tactic_parser/TacticParser/DependencyParser.lean
@@ -3,6 +3,7 @@ import Lean.Data.Json
import Lean.Elab.Frontend
import TacticParser.Types
import TacticParser.LineParser
+import TacticParser.ProofExtractor
namespace TacticParser
@@ -261,6 +262,9 @@ unsafe def analyzeFileDependencies (filepath : System.FilePath) : IO FileDepende
-- Parse all declarations using LineParser
let decls ā parseFile filepath
+ -- Extract proofs from theorem/lemma/example declarations
+ let declsWithProofs ā extractProofsFromDecls decls content
+
-- Load environment with all imports and elaborate the file
Lean.initSearchPath (ā Lean.findSysroot)
Lean.enableInitializersExecution
@@ -282,7 +286,7 @@ unsafe def analyzeFileDependencies (filepath : System.FilePath) : IO FileDepende
-- Analyze each declaration
let mut declsWithDeps : Array DeclWithDependencies := #[]
- for declInfo in decls do
+ for declInfo in declsWithProofs do
-- Construct the fully qualified name for this declaration
let declName := match declInfo.namespc with
| some ns =>
diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/DependencyParserMain.lean b/src/itp_interface/lean/tactic_parser/TacticParser/DependencyParserMain.lean
similarity index 100%
rename from src/itp_interface/tools/tactic_parser/TacticParser/DependencyParserMain.lean
rename to src/itp_interface/lean/tactic_parser/TacticParser/DependencyParserMain.lean
diff --git a/src/itp_interface/lean/tactic_parser/TacticParser/Example/complex.lean b/src/itp_interface/lean/tactic_parser/TacticParser/Example/complex.lean
new file mode 100644
index 0000000..f05e8e1
--- /dev/null
+++ b/src/itp_interface/lean/tactic_parser/TacticParser/Example/complex.lean
@@ -0,0 +1,25 @@
+import TacticParser.Example.simple
+
+namespace TacticParser.Example
+
+theorem additive:
+ā {a b: Nat}, addNat a b = a + b := by
+ intro a b
+ induction a generalizing b
+ simp [addNat]
+ rename_i a ih
+ simp [addNat, *]
+ grind
+
+theorem additive_identity1 :
+ā {a : Nat}, addNat 0 a = a := by
+ simp [addNat]
+
+theorem additive_identity2 :
+ā {a : Nat}, addNat a 0 = a := by
+ simp [additive]
+
+theorem additive_comm:
+ā {a b : Nat}, addNat a b = addNat b a := by
+ simp [additive]
+ grind
diff --git a/src/itp_interface/lean/tactic_parser/TacticParser/Example/simple.lean b/src/itp_interface/lean/tactic_parser/TacticParser/Example/simple.lean
new file mode 100644
index 0000000..d7b1690
--- /dev/null
+++ b/src/itp_interface/lean/tactic_parser/TacticParser/Example/simple.lean
@@ -0,0 +1,13 @@
+namespace TacticParser.Example
+
+theorem test : ā {a : Nat}, a + 0 = a := by grind
+
+theorem test1 : ā {a : Nat}, a + 0 = a
+ | 0 => by simp
+ | n + 1 => by simp
+
+def addNat : Nat ā Nat ā Nat
+ | 0, m => m
+ | n + 1, m => addNat n (m + 1)
+
+end TacticParser.Example
diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/LineParser.lean b/src/itp_interface/lean/tactic_parser/TacticParser/LineParser.lean
similarity index 100%
rename from src/itp_interface/tools/tactic_parser/TacticParser/LineParser.lean
rename to src/itp_interface/lean/tactic_parser/TacticParser/LineParser.lean
diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/Main.lean b/src/itp_interface/lean/tactic_parser/TacticParser/Main.lean
similarity index 100%
rename from src/itp_interface/tools/tactic_parser/TacticParser/Main.lean
rename to src/itp_interface/lean/tactic_parser/TacticParser/Main.lean
diff --git a/src/itp_interface/lean/tactic_parser/TacticParser/ProofExtractor.lean b/src/itp_interface/lean/tactic_parser/TacticParser/ProofExtractor.lean
new file mode 100644
index 0000000..368e198
--- /dev/null
+++ b/src/itp_interface/lean/tactic_parser/TacticParser/ProofExtractor.lean
@@ -0,0 +1,179 @@
+/-
+Proof extractor for theorem, lemma, and example declarations.
+Uses a validation-based approach: tests candidate delimiters by replacing
+the proof with `sorry` and checking if it parses successfully.
+-/
+import Lean
+import Lean.Parser
+import TacticParser.Types
+import TacticParser.SyntaxWalker
+import Lean.Elab.Frontend
+
+namespace TacticParser
+
+open Lean Parser Elab
+
+/-- Represents a candidate delimiter position and type -/
+structure DelimiterCandidate where
+ position : Nat -- Byte position in the text
+ delimiterType : String -- ":=", "where", or "|"
+ deriving Repr, BEq
+
+instance : Ord DelimiterCandidate where
+ compare a b := compare a.position b.position
+
+/-- Try to parse a text snippet and return true if it parses without errors -/
+def tryParseSuccessfully (text : String) (cmdState : Option Command.State): IO Bool := do
+ let chkpt_result ā parseTactics text none cmdState
+ let parse_res := chkpt_result.1
+ let new_cmd_state := chkpt_result.2
+ pure (parse_res.errors.size == 0 ā§ new_cmd_state.isSome)
+
+/-- Check if a substring starts with a given substring at a specific position -/
+def substringStartsWith (text : String) (startPos : Nat) (substr : String) : Bool :=
+ if startPos + substr.length > text.length then
+ false
+ else
+ let extracted := text.drop startPos
+ extracted.startsWith substr
+
+/-- Find all occurrences of a substring in a text -/
+def findSubstrOccurences (text : String) (substr : String) : List Nat :=
+ if substr.length > text.length then
+ []
+ else
+ let all_pos := List.range (text.length - substr.length + 1)
+ let all_occurences := all_pos.filter (fun pos => substringStartsWith text pos substr)
+ all_occurences
+
+/-- Find all candidate delimiter positions -/
+def findCandidateDelimiters (text : String) : List DelimiterCandidate :=
+ let assignPositions := findSubstrOccurences text ":="
+ let wherePositions := findSubstrOccurences text "where"
+ let pipePositions := findSubstrOccurences text "|"
+
+ let assignCandidates := assignPositions.map fun pos =>
+ { position := pos, delimiterType := ":=" }
+ let whereCandidates := wherePositions.map fun pos =>
+ { position := pos, delimiterType := "where" }
+ let pipeCandidates := pipePositions.map fun pos =>
+ { position := pos, delimiterType := "|" }
+ let assignCandidatesSorted := assignCandidates.toArray.qsort (fun a b => a.position < b.position)
+ let whereCandidatesSorted := whereCandidates.toArray.qsort (fun a b => a.position < b.position)
+ let pipeCandidatesSorted := pipeCandidates.toArray.qsort (fun a b => a.position < b.position)
+
+ let allCandidates := assignCandidatesSorted.toList ++ whereCandidatesSorted.toList ++ pipeCandidatesSorted.toList
+ allCandidates
+
+def is_proof_extraction_needed (declInfo : DeclInfo) : Bool :=
+ declInfo.declType == DeclType.theorem ||
+ declInfo.declType == DeclType.lemma ||
+ declInfo.declType == DeclType.example
+
+def get_content_before_decl (fileContent : String) (declLineNum : Nat) : String :=
+ let lines := fileContent.splitOn "\n"
+ let beforeLines := lines.take (declLineNum - 1)
+ (String.intercalate "\n" beforeLines) ++ "\n"
+
+def get_in_between_content (fileContent : String) (startLine : Nat) (endLine : Nat) : String :=
+ let lines := fileContent.splitOn "\n"
+ let betweenLines := (lines.take endLine).drop (startLine - 1)
+ String.intercalate "\n" betweenLines
+
+/-- Extract proof from a declaration by testing candidate delimiters -/
+unsafe def extractProofFromDecl
+ (declInfo : DeclInfo)
+ (cmdState : Option Command.State)
+ (extra_content: Option String := none) : IO DeclInfo := do
+ -- Only process theorem, lemma, example
+ if !is_proof_extraction_needed declInfo then
+ panic! s!"extractProofFromDecl called on non-proof decl: {declInfo.declType}"
+
+ let text := declInfo.text
+
+ -- Convert Position to byte offset
+ -- Find all candidate delimiters
+ let candidates := findCandidateDelimiters text
+
+ -- Test each candidate
+ for candidate in candidates do
+ let beforeDelimiter := text.take candidate.position
+
+ -- IO.println s!"beforeDelimiter:\n{beforeDelimiter}\n---"
+
+ -- Build test text with full context
+ let mut statementOnly := match candidate.delimiterType with
+ | "|" => beforeDelimiter.trim ++ " := sorry"
+ | ":=" => beforeDelimiter ++ " := sorry"
+ | "where" => beforeDelimiter ++ " := sorry"
+ | _ => beforeDelimiter ++ " := sorry"
+
+
+ -- If extra content is provided, prepend it
+ statementOnly :=
+ match extra_content with
+ | some extra => extra ++ (if extra.endsWith "\n" then statementOnly else "\n" ++ statementOnly)
+ | none => statementOnly
+
+ -- IO.println s!"statementOnly:\n{statementOnly}\n---"
+ -- Try to parse
+ let success ā tryParseSuccessfully statementOnly cmdState
+ if success then
+ -- Found valid split!
+ -- IO.println s!"Found valid split at position {candidate.position} with delimiter {candidate.delimiterType}"
+ let proof := text.drop candidate.position
+ let thrm := text.take candidate.position
+ return { declInfo with proof := some proof.trim , text := thrm.trim }
+ -- else
+ -- IO.println s!"Failed split at position {candidate.position} with delimiter {candidate.delimiterType}"
+
+ -- No valid split found - no proof
+ return { declInfo with proof := none }
+
+unsafe def parse_between
+(fileContent : String)
+(prev: Nat)
+(next: Nat)
+(cmd_state : Option Command.State)
+: IO CheckpointedParseResult := do
+ -- IO.println s!"Extracting proof for segment between lines {prev} and {next}"
+ let contextBeforeDecl := get_in_between_content fileContent prev next
+ -- IO.println s!"Processing declaration at line {decl.startPos.line}-{decl.endPos.line}"
+ -- IO.println s!"Declaration text:\n{decl.text}\n---"
+ -- IO.println s!"--- Context Before Decl ----\n{contextBeforeDecl}\n--- Context Before Decl ----\n"
+ let chkpt_parse_res ā parseTactics contextBeforeDecl none cmd_state
+ return chkpt_parse_res
+
+/-- Extract proofs from multiple declarations -/
+unsafe def extractProofsFromDecls (decls : Array DeclInfo) (fileContent : String) : IO (Array DeclInfo) := do
+ let mut result := #[]
+ let mut prev := 0
+ let mut cmd_state : Option Command.State := none
+ let mut next := 0
+ let mut extra_content : Option String := none
+ for decl in decls do
+ if is_proof_extraction_needed decl then
+ next := decl.startPos.line - 1
+ let chkpt_parse_res ā parse_between fileContent prev next cmd_state
+ -- IO.println s!"--- Context Before Decl ----\n{contextBeforeDecl}\n--- Context Before Decl ----\n"
+ let parse_res := chkpt_parse_res.parseResult
+ if parse_res.errors.size > 0 ⨠chkpt_parse_res.chkptState.isNone then
+ -- supply the extra content to compile from
+ extra_content := get_in_between_content fileContent prev next
+ -- IO.println s!"Re-parsing declaration at lines: \n{extra_content.get!}"
+ -- IO.println s!"\nDeclaration text:\n{decl.text}\n---"
+ -- DO NOT update cmd_state yet
+ else
+ cmd_state := chkpt_parse_res.chkptState
+ extra_content := none
+ prev := next + 1
+ -- let cmd_st ā match cmd_state with
+ -- | some st => pure st
+ -- | none => panic! "Failed to get valid cmd_state before processing declaration"
+ let processed ā extractProofFromDecl decl cmd_state extra_content
+ result := result.push processed
+ else
+ result := result.push decl
+ return result
+
+end TacticParser
diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/SyntaxWalker.lean b/src/itp_interface/lean/tactic_parser/TacticParser/SyntaxWalker.lean
similarity index 98%
rename from src/itp_interface/tools/tactic_parser/TacticParser/SyntaxWalker.lean
rename to src/itp_interface/lean/tactic_parser/TacticParser/SyntaxWalker.lean
index 46689ac..5dd8d3a 100644
--- a/src/itp_interface/tools/tactic_parser/TacticParser/SyntaxWalker.lean
+++ b/src/itp_interface/lean/tactic_parser/TacticParser/SyntaxWalker.lean
@@ -308,6 +308,10 @@ unsafe def parseInCurrentContext (input : String) (filePath : Option String := n
let extentStruct := lineExtents.map getInfoNodeStruct
-- Go over all line extents and reassign the end_pos of the next node
let mut adjusted_trees : Array InfoNodeStruct := #[]
+ -- IO.println s!"Total line extents found: {lineExtents.size}"
+ if lineExtents.size == 0 then
+ let parseResult : ParseResult := { trees := #[], errors := errorInfos }
+ return { parseResult := parseResult, chkptState := cmdState , lineNum := none }
for i in [1:lineExtents.size] do
let prev_node := extentStruct[i - 1]!.getD default
let curr_node := extentStruct[i]!.getD default
@@ -350,6 +354,7 @@ unsafe def parseTacticsWithElaboration (input : String) (filePath : Option Strin
-- Initialize Lean from current directory (finds .lake/build if present)
Lean.initSearchPath (ā Lean.findSysroot)
Lean.enableInitializersExecution
+ -- IO.println "Initialized Lean environment for elaboration-based parsing."
return ā parseInCurrentContext input filePath chkptState
catch e =>
let errorInfo := ErrorInfo.mk (s!"Error in parseTacticsWithElaboration: {e}") { line := 0, column := 0 }
diff --git a/src/itp_interface/lean/tactic_parser/TacticParser/TacticExtractorMain.lean b/src/itp_interface/lean/tactic_parser/TacticParser/TacticExtractorMain.lean
new file mode 100644
index 0000000..5c66a6d
--- /dev/null
+++ b/src/itp_interface/lean/tactic_parser/TacticParser/TacticExtractorMain.lean
@@ -0,0 +1,52 @@
+import TacticParser.SyntaxWalker
+import Lean
+
+open Lean
+open TacticParser
+
+/-- Print usage information -/
+def printUsage : IO Unit := do
+ IO.println "Usage: dependency_parser "
+ IO.println ""
+ IO.println "Arguments:"
+ IO.println " Path to the Lean file to analyze"
+ IO.println " Path where JSON output will be written"
+ IO.println ""
+ IO.println "Example:"
+ IO.println " lake env .lake/build/bin/syntax-walker MyFile.lean output.json"
+
+unsafe def main (args : List String) : IO UInt32 := do
+ match args with
+ |
+ [
+ leanFilePath,
+ jsonOutputPath
+ ]
+ =>
+ try
+ let filepath : System.FilePath := leanFilePath
+ let jsonPath : System.FilePath := jsonOutputPath
+
+ -- Check if input file exists
+ if !(ā filepath.pathExists) then
+ IO.eprintln s!"Error: Input file not found: {filepath}"
+ return 1
+
+ let fileContent ā IO.FS.readFile filepath
+
+ IO.println s!"Tactic Parsing file: {filepath}"
+
+ -- Analyze the file and export to JSON
+ let tactics ā parseTactics fileContent none none
+ let tacticsJson := Lean.ToJson.toJson tactics.parseResult
+ IO.FS.writeFile jsonPath tacticsJson.compress
+ return 0
+ catch e =>
+ IO.eprintln s!"Error: {e}"
+ return 1
+
+ | _ =>
+ IO.eprintln "Error: Invalid number of arguments"
+ IO.eprintln ""
+ printUsage
+ return 1
diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/Types.lean b/src/itp_interface/lean/tactic_parser/TacticParser/Types.lean
similarity index 98%
rename from src/itp_interface/tools/tactic_parser/TacticParser/Types.lean
rename to src/itp_interface/lean/tactic_parser/TacticParser/Types.lean
index fee67f1..5dbac32 100644
--- a/src/itp_interface/tools/tactic_parser/TacticParser/Types.lean
+++ b/src/itp_interface/lean/tactic_parser/TacticParser/Types.lean
@@ -71,6 +71,7 @@ structure DeclInfo where
text : String
docString : Option String -- Extracted documentation comment
namespc : Option String -- Current namespace
+ proof : Option String := none -- Extracted proof (for theorem, lemma, example)
deriving Repr
instance : ToJson DeclInfo where
@@ -83,7 +84,8 @@ instance : ToJson DeclInfo where
("end_column", d.endPos.column),
("text", toJson d.text),
("doc_string", toJson d.docString),
- ("namespace", toJson d.namespc)
+ ("namespace", toJson d.namespc),
+ ("proof", toJson d.proof)
]
/-- Information about an import statement -/
diff --git a/src/itp_interface/tools/tactic_parser/lake-manifest.json b/src/itp_interface/lean/tactic_parser/lake-manifest.json
similarity index 100%
rename from src/itp_interface/tools/tactic_parser/lake-manifest.json
rename to src/itp_interface/lean/tactic_parser/lake-manifest.json
diff --git a/src/itp_interface/tools/tactic_parser/lakefile.toml b/src/itp_interface/lean/tactic_parser/lakefile.toml
similarity index 53%
rename from src/itp_interface/tools/tactic_parser/lakefile.toml
rename to src/itp_interface/lean/tactic_parser/lakefile.toml
index 16e7605..e9c3958 100644
--- a/src/itp_interface/tools/tactic_parser/lakefile.toml
+++ b/src/itp_interface/lean/tactic_parser/lakefile.toml
@@ -1,14 +1,22 @@
name = "TacticParser"
-defaultTargets = ["tactic-parser", "dependency-parser"]
+defaultTargets = ["tactic-parser", "tactic-extractor", "dependency-parser"]
[[lean_lib]]
name = "TacticParser"
+[[lean_lib]]
+name = "TacticParser.Example"
+
[[lean_exe]]
name = "tactic-parser"
root = "TacticParser.Main"
supportInterpreter = true
+[[lean_exe]]
+name = "tactic-extractor"
+root = "TacticParser.TacticExtractorMain"
+supportInterpreter = true
+
[[lean_exe]]
name = "dependency-parser"
root = "TacticParser.DependencyParserMain"
diff --git a/src/itp_interface/tools/tactic_parser/lean-toolchain b/src/itp_interface/lean/tactic_parser/lean-toolchain
similarity index 100%
rename from src/itp_interface/tools/tactic_parser/lean-toolchain
rename to src/itp_interface/lean/tactic_parser/lean-toolchain
diff --git a/src/itp_interface/lean/tactic_parser/test_user_example.lean b/src/itp_interface/lean/tactic_parser/test_user_example.lean
new file mode 100644
index 0000000..e8b4b94
--- /dev/null
+++ b/src/itp_interface/lean/tactic_parser/test_user_example.lean
@@ -0,0 +1,26 @@
+import Lean
+import TacticParser.DependencyParser
+import TacticParser.ProofExtractor
+import TacticParser.Types
+
+open TacticParser
+
+
+unsafe def analyze (filePath : String) : IO Unit := do
+ let fileDepAnalysis ā analyzeFileDependencies filePath
+ IO.println s!"Dependency Analysis for {filePath}:"
+ let json := Lean.toJson fileDepAnalysis
+ IO.println json.pretty
+
+def testCodes : List String := [
+ "TacticParser/Example/simple.lean",
+ "TacticParser/Example/complex.lean"
+]
+
+unsafe def main : IO Unit := do
+ IO.println "Testing User's Example"
+ IO.println (String.mk (List.replicate 70 '='))
+
+ for test in testCodes do
+ let filePath := test
+ analyze filePath
diff --git a/src/itp_interface/main/config.py b/src/itp_interface/main/config.py
index 5d0c042..86e49f2 100644
--- a/src/itp_interface/main/config.py
+++ b/src/itp_interface/main/config.py
@@ -79,6 +79,8 @@ class ExtractFile(object):
class EvalDataset(object):
project: str
files: typing.Union[typing.List[EvalFile], typing.List[ExtractFile]]
+ exclude_files: typing.List[str] = field(default_factory=list)
+ include_files: typing.List[str] = field(default_factory=list)
@dataclass_json
@dataclass
@@ -135,7 +137,6 @@ def add_theorem_to_maps(self, path: str, theorem: str, proof_result: ProofSearch
def parse_config(cfg):
- is_extraction_request = False
env_settings_cfg = cfg["env_settings"]
env_settings = EnvSettings(
name=env_settings_cfg["name"],
@@ -179,7 +180,6 @@ def parse_config(cfg):
eval_files.append(EvalFile(
path=file_cfg["path"],
theorems=theorems))
- is_extraction_request = False
elif "declarations" in file_cfg:
declarations = None
if type(file_cfg["declarations"]) == str:
@@ -189,12 +189,14 @@ def parse_config(cfg):
eval_files.append(ExtractFile(
path=file_cfg["path"],
declarations=declarations))
- is_extraction_request = True
else:
raise ValueError(f"File config must have either 'theorems' or 'declarations': {file_cfg}")
eval_datasets.append(EvalDataset(
project=dataset_cfg["project"],
- files=eval_files))
+ files=eval_files,
+ exclude_files=dataset_cfg.get("exclude_files", []),
+ include_files=dataset_cfg.get("include_files", [])
+ ))
language = ProofAction.Language(benchmark_cfg["language"])
benchmark = EvalBenchmark(
name=benchmark_cfg["name"],
@@ -205,6 +207,6 @@ def parse_config(cfg):
few_shot_metadata_filename_for_retrieval=benchmark_cfg["few_shot_metadata_filename_for_retrieval"],
dfs_data_path_for_retrieval=benchmark_cfg["dfs_data_path_for_retrieval"],
dfs_metadata_filename_for_retrieval=benchmark_cfg["dfs_metadata_filename_for_retrieval"],
- is_extraction_request=is_extraction_request and benchmark_cfg.get("is_extraction_request", False),
+ is_extraction_request=benchmark_cfg.get("is_extraction_request", False),
setup_cmds=benchmark_cfg["setup_cmds"] if "setup_cmds" in benchmark_cfg else [])
return Experiments(env_settings=env_settings, run_settings=eval_settings, benchmark=benchmark)
\ No newline at end of file
diff --git a/src/itp_interface/main/configs/benchmark/mathlib_benchmark_lean_ext.yaml b/src/itp_interface/main/configs/benchmark/mathlib_benchmark_lean_ext.yaml
new file mode 100644
index 0000000..b2bb43b
--- /dev/null
+++ b/src/itp_interface/main/configs/benchmark/mathlib_benchmark_lean_ext.yaml
@@ -0,0 +1,40 @@
+name: mathlib_benchmark_lean_ext
+num_files: 1
+language: LEAN4
+few_shot_data_path_for_retrieval:
+few_shot_metadata_filename_for_retrieval:
+dfs_data_path_for_retrieval:
+dfs_metadata_filename_for_retrieval:
+is_extraction_request: true
+datasets:
+ - project: src/data/test/Mathlib
+ files: []
+ exclude_files:
+ - src/data/test/Mathlib/.lake/packages/aesop
+ - src/data/test/Mathlib/.lake/packages/batteries
+ - src/data/test/Mathlib/.lake/packages/Cli
+ - src/data/test/Mathlib/.lake/packages/importGraph
+ - src/data/test/Mathlib/.lake/packages/LeanSearchClient
+ - src/data/test/Mathlib/.lake/packages/plausible
+ - src/data/test/Mathlib/.lake/packages/proofwidgets
+ - src/data/test/Mathlib/.lake/packages/Qq
+ - src/data/test/Mathlib/.lake/packages/mathlib/.devcontainer
+ - src/data/test/Mathlib/.lake/packages/mathlib/.github
+ - src/data/test/Mathlib/.lake/packages/mathlib/.docker
+ - src/data/test/Mathlib/.lake/packages/mathlib/.lake
+ - src/data/test/Mathlib/.lake/packages/mathlib/.vscode
+ - src/data/test/Mathlib/.lake/packages/mathlib/Archive
+ - src/data/test/Mathlib/.lake/packages/mathlib/Cache
+ - src/data/test/Mathlib/.lake/packages/mathlib/docs
+ - src/data/test/Mathlib/.lake/packages/mathlib/Counterexamples
+ - src/data/test/Mathlib/.lake/packages/mathlib/DownstreamTest
+ - src/data/test/Mathlib/.lake/packages/mathlib/LongestPole
+ - src/data/test/Mathlib/.lake/packages/mathlib/MathlibTest
+ - src/data/test/Mathlib/.lake/packages/mathlib/scripts
+ - src/data/test/Mathlib/.lake/packages/mathlib/widget
+ - src/data/test/Mathlib/.lake/packages/mathlib/Counterexamples.lean
+ - src/data/test/Mathlib/.lake/packages/mathlib/lakefile.lean
+ - src/data/test/Mathlib/.lake/packages/mathlib/Mathlib.lean
+ - src/data/test/Mathlib/.lake/packages/mathlib/Archive.lean
+ - src/data/test/Mathlib/lakefile.lean
+ - src/data/test/Mathlib/ReplMathlibTests.lean
diff --git a/src/itp_interface/main/configs/benchmark/simple_benchmark_lean_ext.yaml b/src/itp_interface/main/configs/benchmark/simple_benchmark_lean_ext.yaml
index 2ed8c1d..586c2f8 100644
--- a/src/itp_interface/main/configs/benchmark/simple_benchmark_lean_ext.yaml
+++ b/src/itp_interface/main/configs/benchmark/simple_benchmark_lean_ext.yaml
@@ -8,6 +8,6 @@ dfs_metadata_filename_for_retrieval:
is_extraction_request: true
datasets:
- project: src/data/test/lean4_proj
- files:
- - path: Lean4Proj/Basic.lean
- declarations: "*"
\ No newline at end of file
+ files: []
+ exclude_files:
+ - src/data/test/lean4_proj/.lake
\ No newline at end of file
diff --git a/src/itp_interface/main/configs/benchmark/stdlib_benchmark_lean_ext.yaml b/src/itp_interface/main/configs/benchmark/stdlib_benchmark_lean_ext.yaml
new file mode 100644
index 0000000..fde9d37
--- /dev/null
+++ b/src/itp_interface/main/configs/benchmark/stdlib_benchmark_lean_ext.yaml
@@ -0,0 +1,31 @@
+name: stdlib_benchmark_lean_ext
+num_files: 1
+language: LEAN4
+few_shot_data_path_for_retrieval:
+few_shot_metadata_filename_for_retrieval:
+dfs_data_path_for_retrieval:
+dfs_metadata_filename_for_retrieval:
+is_extraction_request: true
+datasets:
+ - project: src/data/test/batteries
+ files: []
+ exclude_files:
+ - src/data/test/batteries/.lake
+ - src/data/test/batteries/.vscode
+ - src/data/test/batteries/BatteriesTest
+ - src/data/test/batteries/Batteries
+ - src/data/test/batteries/scripts
+ - src/data/test/batteries/docs
+ - src/data/test/batteries/Shake
+ - src/data/test/batteries/Batteries.lean
+ - src/data/test/batteries/bors.toml
+ - src/data/test/batteries/lake-manifest.toml
+ - src/data/test/batteries/.github
+ - src/data/test/batteries/.docker
+ - src/data/test/batteries/lakefile.toml
+ - src/data/test/batteries/.gitpod.yml
+ - src/data/test/batteries/.github
+ - src/data/test/batteries/lake-manifest.json
+ include_files:
+ - ~/.elan/toolchains/leanprover--lean4---v4.24.0/src/lean/Init
+ - ~/.elan/toolchains/leanprover--lean4---v4.24.0/src/lean/Std
\ No newline at end of file
diff --git a/src/itp_interface/main/configs/mathlib_lean_data_extract.yaml b/src/itp_interface/main/configs/mathlib_lean_data_extract.yaml
new file mode 100644
index 0000000..14e39be
--- /dev/null
+++ b/src/itp_interface/main/configs/mathlib_lean_data_extract.yaml
@@ -0,0 +1,13 @@
+defaults:
+ # - benchmark: simple_benchmark_lean_training_data
+ # - run_settings: default_lean_data_generation_transforms
+ # - benchmark: simple_benchmark_1
+ # - run_settings: default_lean4_data_generation_transforms
+ - benchmark: mathlib_benchmark_lean_ext
+ - run_settings: default_lean4_data_generation_transforms
+ - env_settings: no_retrieval
+ - override hydra/job_logging: 'disabled'
+
+run_settings:
+ output_dir: .log/data_generation/benchmark/mathlib_benchmark_lean_ext
+ pool_size: 12
\ No newline at end of file
diff --git a/src/itp_interface/main/configs/simple_lean_data_extract.yaml b/src/itp_interface/main/configs/simple_lean_data_extract.yaml
index 4fed34c..590c3e9 100644
--- a/src/itp_interface/main/configs/simple_lean_data_extract.yaml
+++ b/src/itp_interface/main/configs/simple_lean_data_extract.yaml
@@ -10,4 +10,4 @@ defaults:
run_settings:
output_dir: .log/data_generation/benchmark/simple_benchmark_lean_ext
- pool_size: 2
\ No newline at end of file
+ pool_size: 1
\ No newline at end of file
diff --git a/src/itp_interface/main/configs/simple_lean_data_gen.yaml b/src/itp_interface/main/configs/simple_lean_data_gen.yaml
index d93dd0f..882a810 100644
--- a/src/itp_interface/main/configs/simple_lean_data_gen.yaml
+++ b/src/itp_interface/main/configs/simple_lean_data_gen.yaml
@@ -10,4 +10,4 @@ defaults:
run_settings:
output_dir: .log/data_generation/benchmark/simple_benchmark_lean
- pool_size: 2
\ No newline at end of file
+ pool_size: 1
\ No newline at end of file
diff --git a/src/itp_interface/main/configs/stdlib_lean_data_extract.yaml b/src/itp_interface/main/configs/stdlib_lean_data_extract.yaml
new file mode 100644
index 0000000..2269810
--- /dev/null
+++ b/src/itp_interface/main/configs/stdlib_lean_data_extract.yaml
@@ -0,0 +1,13 @@
+defaults:
+ # - benchmark: simple_benchmark_lean_training_data
+ # - run_settings: default_lean_data_generation_transforms
+ # - benchmark: simple_benchmark_1
+ # - run_settings: default_lean4_data_generation_transforms
+ - benchmark: stdlib_benchmark_lean_ext
+ - run_settings: default_lean4_data_generation_transforms
+ - env_settings: no_retrieval
+ - override hydra/job_logging: 'disabled'
+
+run_settings:
+ output_dir: .log/data_generation/benchmark/stdlib_benchmark_lean_ext
+ pool_size: 12
\ No newline at end of file
diff --git a/src/itp_interface/main/init_ray.py b/src/itp_interface/main/init_ray.py
new file mode 100644
index 0000000..41c72e8
--- /dev/null
+++ b/src/itp_interface/main/init_ray.py
@@ -0,0 +1,79 @@
+def main():
+ # Start the ray cluster
+ from filelock import FileLock
+ import json
+ import os
+ import ray
+ import logging
+ import time
+ import sys
+ import argparse
+ argument_parser = argparse.ArgumentParser()
+ argument_parser.add_argument("--num_cpus", type=int, default=10)
+ argument_parser.add_argument("--object_store_memory", type=int, default=150*2**30)
+ argument_parser.add_argument("--memory", type=int, default=300*2**30)
+ argument_parser.add_argument("--metrics_report_interval_ms", type=int, default=3*10**8)
+ args = argument_parser.parse_args()
+ root_dir = f"{os.path.abspath(__file__).split('itp_interface')[-2]}"
+ if root_dir not in sys.path:
+ sys.path.append(root_dir)
+ os.environ["RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE"] = "1"
+ os.environ["PYTHONPATH"] = f"{root_dir}:{os.environ.get('PYTHONPATH', '')}"
+ os.makedirs(".log/locks", exist_ok=True)
+ os.makedirs(".log/ray", exist_ok=True)
+ ray_was_started = False
+ pid = os.getpid()
+ print("Initializing Ray")
+ print("PID: ", pid)
+ ray_session_path = ".log/ray/session_latest" if os.environ.get("RAY_SESSION_PATH") is None else os.environ.get("RAY_SESSION_PATH")
+ # Try to first acquire the lock
+ file_path = ".log/locks/ray.lock"
+ # set RAY_SESSION_PATH environment variable
+ os.environ["RAY_SESSION_PATH"] = ray_session_path
+ # set lock file path
+ os.environ["RAY_LOCK_FILE_PATH"] = file_path
+ temp_lock = FileLock(file_path)
+ if os.path.exists(ray_session_path):
+ # try to acquire the lock for reading
+ try:
+ temp_lock.acquire(timeout=10)
+ temp_lock.release()
+ except:
+ with open(ray_session_path, "r") as f:
+ ray_session = f.read()
+ ray_session = json.loads(ray_session)
+ ray_address = ray_session["address"]
+ # ray.init(address=ray_address)
+ print("Ray was already started")
+ print("Ray session: ", ray_session)
+ sys.exit(0)
+ with FileLock(file_path):
+ if os.path.exists(ray_session_path):
+ # Remove the ray_session_path
+ os.remove(ray_session_path)
+ os.environ["RAY_INITIALIZED"] = "1"
+ ray_session = ray.init(
+ num_cpus=args.num_cpus,
+ object_store_memory=args.object_store_memory,
+ _memory=args.memory,
+ logging_level=logging.CRITICAL,
+ ignore_reinit_error=False,
+ log_to_driver=False,
+ configure_logging=False,
+ _system_config={"metrics_report_interval_ms": args.metrics_report_interval_ms})
+ ray_session = dict(ray_session)
+ ray_session["main_pid"] = pid
+ print("Ray session: ", ray_session)
+ with open(ray_session_path, "w") as f:
+ f.write(json.dumps(ray_session))
+ ray_was_started = True
+ print("Ray was started")
+ print("Ray session: ", ray_session)
+ # Flush the stdout buffer
+ sys.stdout.flush()
+ while ray_was_started:
+ # Keep the ray cluster alive till killed
+ time.sleep(10000)
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/src/itp_interface/main/install.py b/src/itp_interface/main/install.py
index e375f57..0c38ed0 100644
--- a/src/itp_interface/main/install.py
+++ b/src/itp_interface/main/install.py
@@ -3,7 +3,7 @@
import string
import logging
import traceback
-from itp_interface.tools.tactic_parser import build_tactic_parser_if_needed
+from itp_interface.lean.tactic_parser import build_tactic_parser_if_needed
file_path = os.path.abspath(__file__)
@@ -20,8 +20,8 @@ def generate_random_string(length, allowed_chars=None):
def install_itp_interface():
print("Installing itp_interface")
itp_dir = os.path.dirname(os.path.dirname(file_path))
- tools_dir = os.path.join(itp_dir, "tools")
- tactic_parser_dir = os.path.join(tools_dir, "tactic_parser")
+ lean_dir = os.path.join(itp_dir, "lean")
+ tactic_parser_dir = os.path.join(lean_dir, "tactic_parser")
assert os.path.exists(tactic_parser_dir), f"tactic_parser_dir: {tactic_parser_dir} does not exist"
assert os.path.exists(os.path.join(tactic_parser_dir, "lean-toolchain")), f"lean-toolchain does not exist in {tactic_parser_dir}, build has failed"
print("tactic_parser_dir: ", tactic_parser_dir)
diff --git a/src/itp_interface/main/run_tool.py b/src/itp_interface/main/run_tool.py
index 7fbd7da..24c0366 100644
--- a/src/itp_interface/main/run_tool.py
+++ b/src/itp_interface/main/run_tool.py
@@ -26,6 +26,7 @@
RayResourcePoolActor = None
TimedRayExec = None
RayUtils = None
+from pathlib import Path
from itp_interface.rl.proof_action import ProofAction
from itp_interface.tools.isabelle_server import IsabelleServer
from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
@@ -33,7 +34,7 @@
from itp_interface.tools.coq_local_data_generation_transform import LocalDataGenerationTransform as CoqLocalDataGenerationTransform
from itp_interface.tools.lean_local_data_generation_transform import LocalDataGenerationTransform as LeanLocalDataGenerationTransform
from itp_interface.tools.lean4_local_data_generation_transform import Local4DataGenerationTransform
-from itp_interface.tools.lean4_local_data_extraction_transform import Local4DataExtractionTransform
+from itp_interface.lean.lean4_local_data_extraction_transform import Local4DataExtractionTransform
from itp_interface.tools.isabelle_local_data_generation_transform import LocalDataGenerationTransform as IsabelleLocalDataGenerationTransform
from itp_interface.tools.run_data_generation_transforms import RunDataGenerationTransforms
from itp_interface.tools.log_utils import setup_logger
@@ -120,14 +121,35 @@ def _get_all_lemmas_impl(
logger.info(f"Discovered {len(lemmas_to_prove)} lemmas")
return lemmas_to_prove
-
def _get_all_lean_files_in_folder_recursively(
- project_folder: str) -> typing.List[str]:
+ project_folder: str, exclude_list: typing.List[str], include_list: typing.List[str]) -> typing.List[str]:
lean_files = []
+ for include in include_list:
+ # the include list should be a full path
+ include_path = Path(include)
+ include_path = include_path.expanduser().resolve()
+ # Get all files recursively under include_path
+ for root, dirs, files in os.walk(include_path):
+ for file in files:
+ if file.endswith(".lean"):
+ full_path = os.path.join(root, file)
+ full_path_obj = Path(full_path)
+ full_path = full_path_obj.expanduser().resolve()
+ if str(full_path) not in lean_files:
+ lean_files.append(str(full_path))
+
+ project_folder_path = Path(project_folder)
for root, dirs, files in os.walk(project_folder):
for file in files:
if file.endswith(".lean"):
- lean_files.append(os.path.join(root, file))
+ # Get full file path
+ full_path = os.path.join(root, file)
+ # Check if the full path starts with any exclude path
+ if any(full_path.startswith(exclude) for exclude in exclude_list):
+ continue
+ full_path_obj = Path(full_path)
+ rel_to_project = str(full_path_obj.relative_to(project_folder_path))
+ lean_files.append(rel_to_project)
return lean_files
# Create Ray remote version if Ray is available
@@ -248,7 +270,7 @@ def create_yaml(project_to_theorems, name, eval_benchmark: EvalBenchmark, output
with open(output_file, 'w') as yaml_file:
yaml.dump(data, yaml_file, sort_keys=False)
-def add_transform(experiment: Experiments, clone_dir: str, resources: list, transforms: list, logger: logging.Logger = None):
+def add_transform(experiment: Experiments, clone_dir: str, resources: list, transforms: list, logger: logging.Logger = None, str_time: str = None):
global ray_resource_pool
if experiment.run_settings.transform_type == TransformType.LOCAL:
if experiment.benchmark.language == ProofAction.Language.LEAN:
@@ -260,10 +282,31 @@ def add_transform(experiment: Experiments, clone_dir: str, resources: list, tran
os.makedirs(clone_dir, exist_ok=True)
elif experiment.benchmark.language == ProofAction.Language.LEAN4:
if experiment.benchmark.is_extraction_request:
- transform = Local4DataExtractionTransform(
+ output_dir_path = os.path.join(experiment.run_settings.output_dir, str_time)
+ db_path = None
+ if os.environ.get("EXTRACTION_DB_PATH") is not None:
+ db_path = os.environ.get("EXTRACTION_DB_PATH", "")
+ if not os.path.exists(db_path) and db_path.strip() != "":
+ db_path = None
+ if db_path is None:
+ os.makedirs(output_dir_path, exist_ok=True)
+ db_path = os.path.join(output_dir_path, "lean4_extraction_db.sqlite")
+ transform1 = Local4DataExtractionTransform(
experiment.run_settings.dep_depth,
buffer_size=experiment.run_settings.buffer_size,
- logger=logger)
+ logger=logger,
+ db_path=db_path,
+ enable_file_export=False,
+ enable_dependency_extraction=False)
+ transforms.append(transform1) # First just add the definitions
+ transform = Local4DataExtractionTransform(
+ experiment.run_settings.dep_depth,
+ buffer_size=experiment.run_settings.buffer_size,
+ logger=logger,
+ db_path=db_path,
+ enable_file_export=True,
+ enable_dependency_extraction=True
+ ) # This will be later appended to the transforms list
else:
transform = Local4DataGenerationTransform(
experiment.run_settings.dep_depth,
@@ -320,8 +363,10 @@ def get_decl_lemmas_to_parse(
if experiment.benchmark.language == ProofAction.Language.LEAN4 \
and experiment.benchmark.is_extraction_request:
if len(dataset.files) == 0:
+ logger.warning(f"No files specified for Lean4 extraction in dataset {dataset.project}, extracting from all Lean4 files in the project")
# List all the files recursively in the project folder
- files_in_dataset = _get_all_lean_files_in_folder_recursively(dataset.project)
+ files_in_dataset = _get_all_lean_files_in_folder_recursively(dataset.project, dataset.exclude_files, dataset.include_files)
+ logger.info(f"Found {len(files_in_dataset)} Lean4 files in the project {dataset.project}")
for file_path in files_in_dataset:
file_to_theorems[file_path] = ["*"]
file_args[file_path] = {}
@@ -459,7 +504,7 @@ def run_data_generation_pipeline(experiment: Experiments, log_dir: str, checkpoi
transforms = []
str_time = time.strftime("%Y%m%d-%H%M%S")
clone_dir = os.path.join(experiment.run_settings.output_dir, "clone{}".format(str_time))
- clone_dir = add_transform(experiment, clone_dir, resources, transforms, logger)
+ clone_dir = add_transform(experiment, clone_dir, resources, transforms, logger, str_time)
# Find all the lemmas to prove
project_to_theorems = {}
other_args = {}
@@ -592,4 +637,6 @@ def main(cfg):
if __name__ == "__main__":
# from itp_interface.tools.ray_utils import RayUtils
# RayUtils.init_ray(num_of_cpus=20, object_store_memory_in_gb=50, memory_in_gb=1, runtime_env={"working_dir": root_dir, "excludes": [".log", "data"]})
+ from itp_interface.tools.ray_utils import RayUtils
+ RayUtils.init_ray(num_of_cpus=20, object_store_memory_in_gb=100, memory_in_gb=50)
main()
\ No newline at end of file
diff --git a/src/itp_interface/tools/dynamic_lean4_proof_exec.py b/src/itp_interface/tools/dynamic_lean4_proof_exec.py
index 5417e83..86cd9c9 100644
--- a/src/itp_interface/tools/dynamic_lean4_proof_exec.py
+++ b/src/itp_interface/tools/dynamic_lean4_proof_exec.py
@@ -6,7 +6,7 @@
import copy
import enum
import logging
-from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor
+from itp_interface.lean.simple_lean4_sync_executor import SimpleLean4SyncExecutor
from itp_interface.tools.training_data_format import Goal, TheoremProvingTrainingDataFormat
from itp_interface.tools.lean_parse_utils import LeanLineByLineReader
from itp_interface.tools.lean_context_helper import Lean3ContextHelper
diff --git a/src/itp_interface/tools/lean4_local_data_extraction_transform.py b/src/itp_interface/tools/lean4_local_data_extraction_transform.py
deleted file mode 100644
index 9205638..0000000
--- a/src/itp_interface/tools/lean4_local_data_extraction_transform.py
+++ /dev/null
@@ -1,110 +0,0 @@
-#!/usr/bin/env python3
-
-import os
-import sys
-dir_name = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
-root_dir = os.path.abspath(dir_name)
-if root_dir not in sys.path:
- sys.path.append(root_dir)
-import typing
-import uuid
-from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor
-from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType
-from itp_interface.tools.training_data_format import MergableCollection, TrainingDataMetadataFormat, ExtractionDataCollection, TheoremProvingTrainingDataFormat
-from itp_interface.tools.training_data import TrainingData, DataLayoutFormat
-
-class Local4DataExtractionTransform(GenericTrainingDataGenerationTransform):
- def __init__(self,
- depth = None,
- max_search_results = None,
- buffer_size : int = 10000,
- logger = None,
- max_parallelism : int = 4):
- super().__init__(TrainingDataGenerationType.LOCAL, buffer_size, logger)
- self.depth = depth
- self.max_search_results = max_search_results
- self.max_parallelism = max_parallelism
-
- def get_meta_object(self) -> TrainingDataMetadataFormat:
- return TrainingDataMetadataFormat(
- training_data_buffer_size=self.buffer_size,
- data_filename_prefix="extraction_data_",
- lemma_ref_filename_prefix="extraction_lemma_refs_")
-
- def get_data_collection_object(self) -> MergableCollection:
- return ExtractionDataCollection()
-
- def load_meta_from_file(self, file_path) -> MergableCollection:
- return TrainingDataMetadataFormat.load_from_file(file_path)
-
- def load_data_from_file(self, file_path) -> MergableCollection:
- return ExtractionDataCollection.load_from_file(file_path, self.logger)
-
- def __call__(self,
- training_data: TrainingData,
- project_id : str,
- lean_executor: SimpleLean4SyncExecutor,
- print_coq_executor_callback: typing.Callable[[], SimpleLean4SyncExecutor],
- theorems: typing.List[str] = None,
- other_args: dict = {}) -> TrainingData:
- file_namespace = lean_executor.main_file.replace('/', '.')
- self.logger.info(f"=========================Processing {file_namespace}=========================")
- theorem_id = str(uuid.uuid4())
- if isinstance(theorems, list) and len(theorems) == 1 and theorems[0] == "*":
- theorems = None
- else:
- theorems = set(theorems) if theorems is not None else None
- cnt = 0
- temp_dir = os.path.join(training_data.folder, "temp")
- os.makedirs(temp_dir, exist_ok=True)
- json_output_path = f"{temp_dir}/{file_namespace.replace('.', '_')}.lean.deps.json"
- file_dep_analyses = lean_executor.extract_all_theorems_and_definitions(json_output_path=json_output_path)
- self.logger.info(f"Extracted {len(file_dep_analyses)} FileDependencyAnalysis objects from {file_namespace}")
- self.logger.info(f"file_dep_analyses: {file_dep_analyses}")
- assert len(file_dep_analyses) == 1, "Expected exactly one FileDependencyAnalysis object"
- file_dep_analysis = file_dep_analyses[0]
- for decls in file_dep_analysis.declarations:
- line_info = decls.decl_info
- if theorems is not None and line_info.name not in theorems:
- continue
- training_data.merge(decls)
- cnt += 1
- training_data.meta.last_proof_id = theorem_id
- self.logger.info(f"===============Finished processing {file_namespace}=====================")
- self.logger.info(f"Total declarations processed in this transform: {cnt}")
- return training_data
-
-
-if __name__ == "__main__":
- import os
- import logging
- import time
- os.chdir(root_dir)
- # project_dir = 'data/test/lean4_proj/'
- project_dir = 'data/test/Mathlib'
- # file_name = 'data/test/lean4_proj/Lean4Proj/Basic.lean'
- file_name = 'data/test/Mathlib/.lake/packages/mathlib/Mathlib/Algebra/Divisibility/Basic.lean'
- project_id = project_dir.replace('/', '.')
- time_str = time.strftime("%Y%m%d-%H%M%S")
- output_path = f".log/local_data_generation_transform/data/{time_str}"
- log_path = f".log/local_data_generation_transform/log/{time_str}"
- log_file = f"{log_path}/local_data_generation_transform-{time_str}.log"
- os.makedirs(output_path, exist_ok=True)
- os.makedirs(log_path, exist_ok=True)
- logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
- logger = logging.getLogger(__name__)
- def _print_lean_executor_callback():
- search_lean_exec = SimpleLean4SyncExecutor(main_file=file_name, project_root=project_dir)
- search_lean_exec.__enter__()
- return search_lean_exec
- transform = Local4DataExtractionTransform(0, buffer_size=1000)
- training_data = TrainingData(
- output_path,
- "training_metadata.json",
- training_meta=transform.get_meta_object(),
- logger=logger,
- layout=DataLayoutFormat.DECLARATION_EXTRACTION)
- with SimpleLean4SyncExecutor(project_root=project_dir, main_file=file_name, use_human_readable_proof_context=True, suppress_error_log=True) as coq_exec:
- transform(training_data, project_id, coq_exec, _print_lean_executor_callback, theorems=["*"])
- save_info = training_data.save()
- logger.info(f"Saved training data to {save_info}")
\ No newline at end of file
diff --git a/src/itp_interface/tools/lean4_local_data_generation_transform.py b/src/itp_interface/tools/lean4_local_data_generation_transform.py
index 4854b02..8550ff3 100644
--- a/src/itp_interface/tools/lean4_local_data_generation_transform.py
+++ b/src/itp_interface/tools/lean4_local_data_generation_transform.py
@@ -6,7 +6,7 @@
sys.path.append(root_dir)
import typing
import uuid
-from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor
+from itp_interface.lean.simple_lean4_sync_executor import SimpleLean4SyncExecutor
from itp_interface.tools.lean4_context_helper import Lean4ContextHelper
from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType
from itp_interface.tools.training_data_format import MergableCollection, TrainingDataMetadataFormat, TheoremProvingTrainingDataCollection, TheoremProvingTrainingDataFormat
diff --git a/src/itp_interface/tools/proof_exec_callback.py b/src/itp_interface/tools/proof_exec_callback.py
index bcb9744..9c3b031 100644
--- a/src/itp_interface/tools/proof_exec_callback.py
+++ b/src/itp_interface/tools/proof_exec_callback.py
@@ -15,7 +15,7 @@
from itp_interface.tools.coq_executor import CoqExecutor
from itp_interface.tools.lean_cmd_executor import Lean3Executor
from itp_interface.tools.lean4_sync_executor import Lean4SyncExecutor
-from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor
+from itp_interface.lean.simple_lean4_sync_executor import SimpleLean4SyncExecutor
from itp_interface.tools.isabelle_executor import IsabelleExecutor
from itp_interface.tools.dynamic_coq_proof_exec import DynamicProofExecutor as DynamicCoqProofExecutor
from itp_interface.tools.dynamic_lean_proof_exec import DynamicProofExecutor as DynamicLeanProofExecutor
diff --git a/src/itp_interface/tools/ray_utils.py b/src/itp_interface/tools/ray_utils.py
index 807ef7d..d462740 100644
--- a/src/itp_interface/tools/ray_utils.py
+++ b/src/itp_interface/tools/ray_utils.py
@@ -14,12 +14,41 @@
class RayUtils(object):
+ @staticmethod
+ def connect_to_ray():
+ if os.environ.get("RAY_INITIALIZED", "0") == "0":
+ return None
+ root_dir = f"{os.path.abspath(__file__).split('itp_interface')[-2]}"
+ os.environ["PYTHONPATH"] = f"{root_dir}:{os.environ.get('PYTHONPATH', '')}"
+ from filelock import FileLock
+ import json
+ os.environ["RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE"] = "1"
+ lock_path = os.environ.get("RAY_LOCK_FILE_PATH", ".log/locks/ray.lock")
+ ray_session_path = os.environ.get("RAY_SESSION_PATH", ".log/ray/session_latest")
+ temp_lock = FileLock(lock_path)
+ try:
+ temp_lock.acquire(timeout=10)
+ temp_lock.release()
+ except:
+ if os.path.exists(ray_session_path):
+ with open(ray_session_path, "r") as f:
+ ray_session = f.read()
+ ray_session = json.loads(ray_session)
+ ray_address = ray_session["address"]
+ obj = ray.init(address=ray_address)
+ return obj
+ return None
+
@staticmethod
def init_ray(num_of_cpus: int = 10, object_store_memory_in_gb: float = 25, memory_in_gb: float = 0.5, runtime_env: typing.Dict[str, str] = None):
gb = 2**30
object_store_memory = int(object_store_memory_in_gb * gb)
memory = int(memory_in_gb * gb)
+ obj = RayUtils.connect_to_ray()
+ if obj is not None:
+ return obj
os.environ["RAY_INITIALIZED"] = "1"
+ os.environ["RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE"] = "1"
obj = ray.init(num_cpus=num_of_cpus, object_store_memory=object_store_memory, _memory=memory, ignore_reinit_error=True, runtime_env=runtime_env)
return obj
diff --git a/src/itp_interface/tools/repl b/src/itp_interface/tools/repl
index 8fff855..ebeaf40 160000
--- a/src/itp_interface/tools/repl
+++ b/src/itp_interface/tools/repl
@@ -1 +1 @@
-Subproject commit 8fff8552292860d349b459d6a811e6915671dc0d
+Subproject commit ebeaf409fc93d64d5a52652c87e74bfb64805a66
diff --git a/src/itp_interface/tools/run_data_generation_transforms.py b/src/itp_interface/tools/run_data_generation_transforms.py
index 4701c36..bacbb25 100644
--- a/src/itp_interface/tools/run_data_generation_transforms.py
+++ b/src/itp_interface/tools/run_data_generation_transforms.py
@@ -10,6 +10,7 @@
import typing
import shutil
import gc
+from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
from itp_interface.tools.training_data import TrainingData, DataLayoutFormat
@@ -24,12 +25,12 @@
RayUtils = None
from itp_interface.tools.coq_executor import CoqExecutor
from itp_interface.tools.lean_cmd_executor import Lean3Executor
-from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor
+from itp_interface.lean.simple_lean4_sync_executor import SimpleLean4SyncExecutor
from itp_interface.tools.isabelle_executor import IsabelleExecutor
from itp_interface.tools.coq_local_data_generation_transform import LocalDataGenerationTransform as CoqLocalDataGenerationTransform
from itp_interface.tools.lean_local_data_generation_transform import LocalDataGenerationTransform as LeanLocalDataGenerationTransform
from itp_interface.tools.lean4_local_data_generation_transform import Local4DataGenerationTransform as Lean4LocalDataGenerationTransform
-from itp_interface.tools.lean4_local_data_extraction_transform import Local4DataExtractionTransform as Lean4LocalDataExtractionTransform
+from itp_interface.lean.lean4_local_data_extraction_transform import Local4DataExtractionTransform as Lean4LocalDataExtractionTransform
from itp_interface.tools.isabelle_local_data_generation_transform import LocalDataGenerationTransform as IsabelleLocalDataGenerationTransform
from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType
@@ -290,6 +291,7 @@ def run_local_transform(self, pool_size: int , transform: typing.Union[CoqLocalD
job_idx = 0
project_names = list(projects.keys())
project_names.sort()
+ file_index = 0
for project in project_names:
# Create temporary directory for each project
proj_name = os.path.basename(project)
@@ -305,9 +307,26 @@ def run_local_transform(self, pool_size: int , transform: typing.Union[CoqLocalD
some_files_processed = True
job_more_args = file_args.get(file_path, {})
# Create temporary directory for each file
- full_file_path = os.path.join(project_path, file_path)
- relative_file_path = file_path
- relative_file_path = relative_file_path.replace("/", ".").replace(".v", "").replace(".lean", "").replace(".thy", "")
+ fp = Path(file_path)
+ pp = Path(project_path)
+ if fp.is_absolute():
+ # then leave it as is
+ full_file_path = file_path
+ else:
+ # Check if the file path is relative to the project path
+ fp = fp.resolve()
+ pp = pp.resolve()
+ # Now check if fp is relative to pp
+ fp_is_rel_to_pp = fp.is_relative_to(pp)
+ if fp_is_rel_to_pp:
+ full_file_path = str(fp.relative_to(pp))
+ else:
+ # Just make it relative to project path
+ full_file_path = os.path.join(project_path, file_path)
+ assert os.path.exists(full_file_path), f"File path {full_file_path} does not exist"
+ file_index += 1
+
+ relative_file_path = f"transformer_{file_index}"
temp_file_dir = os.path.join(temp_project_dir, relative_file_path)
os.makedirs(temp_file_dir, exist_ok=True)
log_file = os.path.join(self.logging_dir, f"{relative_file_path}.log")
@@ -410,10 +429,13 @@ def _transform_output(results):
shutil.rmtree(temp_output_dir)
def run_all_local_transforms(self, pool_size: int, projects: typing.Dict[str, typing.Dict[str, str]], use_human_readable: bool, new_output_dir: str, log_error: bool, other_args: typing.Dict[str, typing.Dict[str, dict]] = {}):
+ os.makedirs(new_output_dir, exist_ok=True)
for idx, transform in enumerate(self.transforms):
last_transform = idx == len(self.transforms) - 1
save_transform = self.save_intermidiate_transforms or last_transform
- self.run_local_transform(pool_size, transform, projects, use_human_readable, new_output_dir, log_error, save_transform, preserve_temp=self.save_intermidiate_transforms, other_args=other_args)
+ temp_new_output_dir = str(Path(new_output_dir) / str(idx))
+ os.makedirs(temp_new_output_dir, exist_ok=True)
+ self.run_local_transform(pool_size, transform, projects, use_human_readable, temp_new_output_dir, log_error, save_transform, preserve_temp=self.save_intermidiate_transforms, other_args=other_args)
pass
# Create Ray remote version if Ray is available
diff --git a/src/itp_interface/tools/simple_sqlite.py b/src/itp_interface/tools/simple_sqlite.py
new file mode 100644
index 0000000..fd8aa47
--- /dev/null
+++ b/src/itp_interface/tools/simple_sqlite.py
@@ -0,0 +1,877 @@
+"""
+Simple SQLite database for storing and querying Lean declaration data.
+
+This module provides a thread-safe SQLite database for storing Lean declarations,
+their dependencies, and file metadata. Designed to work with Ray actors processing
+files in parallel.
+"""
+
+import sqlite3
+import uuid
+from typing import List, Dict, Any, Optional
+from contextlib import contextmanager
+
+
+class LeanDeclarationDB:
+ """
+ Thread-safe SQLite database for storing Lean file and declaration information.
+
+ Key features:
+ - Automatic ID assignment on first discovery (declaration or dependency)
+ - Simplified dependency storage (edges only: A depends on B)
+ - Thread-safe operations with WAL mode for concurrent Ray actors
+ - Idempotent operations for safe parallel execution
+ """
+
+ def __init__(self, db_path: str, timeout: float = 30.0):
+ """
+ Initialize the database connection.
+
+ Args:
+ db_path: Path to the SQLite database file
+ timeout: Timeout in seconds for database locks (default 30s)
+ """
+ self.db_path = db_path
+ self.timeout = timeout
+ self.conn = sqlite3.connect(db_path, timeout=timeout, check_same_thread=False)
+ # Explicitly set UTF-8 encoding
+ self.conn.execute("PRAGMA encoding = 'UTF-8'")
+ self.conn.row_factory = sqlite3.Row
+ # Enable WAL mode BEFORE creating tables for better concurrency
+ self.enable_wal_mode()
+ # Create tables (safe with IF NOT EXISTS even from multiple actors)
+ self._create_tables()
+
+ def _generate_unique_id(self) -> str:
+ """
+ Generate a unique ID for a declaration.
+
+ Format: {timestamp}_{uuid4}
+ Same format as used in lean4_local_data_extraction_transform.py
+
+ Returns:
+ Unique identifier string
+ """
+ timestamp = str(int(uuid.uuid1().time_low))
+ random_id = str(uuid.uuid4())
+ return f"{timestamp}_{random_id}"
+
+ def enable_wal_mode(self):
+ """
+ Enable Write-Ahead Logging mode for better concurrent write performance.
+
+ WAL mode allows multiple readers and one writer to access the database
+ simultaneously, which is essential for Ray actors processing files in parallel.
+ """
+ self.conn.execute("PRAGMA journal_mode=WAL")
+ self.conn.execute("PRAGMA synchronous=NORMAL")
+ self.conn.commit()
+
+ def _create_tables(self):
+ """
+ Create the database schema with proper indexes and constraints.
+
+ Safe for concurrent execution from multiple Ray actors:
+ - Uses CREATE TABLE IF NOT EXISTS (idempotent)
+ - Uses CREATE INDEX IF NOT EXISTS (idempotent)
+ - WAL mode is enabled before this is called
+ - 30s timeout handles lock contention
+ """
+ cursor = self.conn.cursor()
+
+ # Files table - stores file metadata
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS files (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ file_path TEXT UNIQUE NOT NULL,
+ module_name TEXT NOT NULL
+ )
+ """)
+
+ # Imports table - stores file import relationships
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS imports (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ file_id INTEGER NOT NULL,
+ end_pos INTEGER,
+ module_name TEXT,
+ start_pos INTEGER,
+ text TEXT,
+ FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE
+ )
+ """)
+
+ # Declarations table - stores all declarations (complete or partial)
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS declarations (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ decl_id TEXT UNIQUE NOT NULL,
+ name TEXT NOT NULL,
+ namespace TEXT,
+ file_path TEXT,
+ module_name TEXT,
+ decl_type TEXT,
+ text TEXT,
+ line INTEGER,
+ column INTEGER,
+ end_line INTEGER,
+ end_column INTEGER,
+ doc_string TEXT,
+ proof TEXT,
+ -- A declaration is uniquely identified by its name and location
+ UNIQUE(name, namespace, file_path, module_name)
+ )
+ """)
+
+ # Simplified dependencies table - stores only edges (A depends on B)
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS declaration_dependencies (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ from_decl_id TEXT NOT NULL,
+ to_decl_id TEXT NOT NULL,
+ UNIQUE(from_decl_id, to_decl_id),
+ FOREIGN KEY (from_decl_id) REFERENCES declarations(decl_id) ON DELETE CASCADE,
+ FOREIGN KEY (to_decl_id) REFERENCES declarations(decl_id) ON DELETE CASCADE
+ )
+ """)
+
+ # Create indexes for faster queries
+ cursor.execute("""
+ CREATE INDEX IF NOT EXISTS idx_files_path
+ ON files(file_path)
+ """)
+ cursor.execute("""
+ CREATE INDEX IF NOT EXISTS idx_files_module
+ ON files(module_name)
+ """)
+ cursor.execute("""
+ CREATE INDEX IF NOT EXISTS idx_declarations_name
+ ON declarations(name)
+ """)
+ cursor.execute("""
+ CREATE INDEX IF NOT EXISTS idx_declarations_namespace
+ ON declarations(namespace)
+ """)
+ cursor.execute("""
+ CREATE INDEX IF NOT EXISTS idx_declarations_decl_id
+ ON declarations(decl_id)
+ """)
+ cursor.execute("""
+ CREATE INDEX IF NOT EXISTS idx_declarations_lookup
+ ON declarations(name, namespace, file_path, module_name)
+ """)
+ cursor.execute("""
+ CREATE INDEX IF NOT EXISTS idx_dependencies_from
+ ON declaration_dependencies(from_decl_id)
+ """)
+ cursor.execute("""
+ CREATE INDEX IF NOT EXISTS idx_dependencies_to
+ ON declaration_dependencies(to_decl_id)
+ """)
+
+ self.conn.commit()
+
+ @contextmanager
+ def transaction(self):
+ """
+ Context manager for database transactions.
+
+ Usage:
+ with db.transaction():
+ db.insert_something(...)
+ db.insert_something_else(...)
+ """
+ try:
+ yield self.conn
+ self.conn.commit()
+ except Exception as e:
+ self.conn.rollback()
+ raise e
+
+ def get_or_create_decl_id(
+ self,
+ name: str,
+ namespace: Optional[str] = None,
+ file_path: Optional[str] = None,
+ module_name: Optional[str] = None,
+ assert_exists: bool = False
+ ) -> str:
+ """
+ Get existing decl_id or create a new one for a declaration.
+
+ This is the core method for ID assignment. IDs are assigned as soon as
+ a declaration is discovered (either as a dependency or as a declaration itself).
+
+ Args:
+ name: Declaration name (required)
+ namespace: Namespace (can be None)
+ file_path: File path (can be None for unresolved dependencies)
+ module_name: Module name (can be None for unresolved dependencies)
+
+ Returns:
+ The decl_id (existing or newly created)
+ """
+ cursor = self.conn.cursor()
+
+ # Try to find existing declaration
+ # Handle NULL values properly in SQL
+ cursor.execute("""
+ SELECT decl_id FROM declarations
+ WHERE name = ?
+ AND (namespace IS ? OR (namespace IS NULL AND ? IS NULL) OR (namespace IS NOT NULL AND ? IS NULL))
+ AND (file_path IS ? OR (file_path IS NULL AND ? IS NULL) OR (file_path IS NOT NULL AND ? IS NULL))
+ AND (module_name IS ? OR (module_name IS NULL AND ? IS NULL) OR (module_name IS NOT NULL AND ? IS NULL))
+ """, (name,
+ namespace, namespace, namespace,
+ file_path, file_path, file_path,
+ module_name, module_name, module_name))
+
+ row = cursor.fetchone()
+ if row:
+ return row[0] # Return existing ID
+
+ if assert_exists:
+ raise ValueError(f"Declaration not found: name={name}, namespace={namespace}, file_path={file_path}, module_name={module_name}")
+
+ # Generate new ID and insert minimal record
+ new_decl_id = self._generate_unique_id()
+
+ try:
+ cursor.execute("""
+ INSERT INTO declarations (decl_id, name, namespace, file_path, module_name)
+ VALUES (?, ?, ?, ?, ?)
+ """, (new_decl_id, name, namespace, file_path, module_name))
+ self.conn.commit()
+ except sqlite3.IntegrityError:
+ # Race condition: another process inserted it between our SELECT and INSERT
+ # Query again to get the existing ID
+ cursor.execute("""
+ SELECT decl_id FROM declarations
+ WHERE name = ?
+ AND (namespace IS ? OR (namespace IS NULL AND ? IS NULL) OR (namespace IS NOT NULL AND ? IS NULL))
+ AND (file_path IS ? OR (file_path IS NULL AND ? IS NULL) OR (file_path IS NOT NULL AND ? IS NULL))
+ AND (module_name IS ? OR (module_name IS NULL AND ? IS NULL) OR (module_name IS NOT NULL AND ? IS NULL))
+ """, (name,
+ namespace, namespace, namespace,
+ file_path, file_path, file_path,
+ module_name, module_name, module_name))
+ row = cursor.fetchone()
+ if row:
+ return row[0]
+ else:
+ # This shouldn't happen, but raise if it does
+ raise
+
+ return new_decl_id
+
+ def upsert_declaration_full_info(
+ self,
+ decl_id: str,
+ name: str,
+ namespace: Optional[str],
+ file_path: str,
+ module_name: str,
+ decl_type: Optional[str] = None,
+ text: Optional[str] = None,
+ line: Optional[int] = None,
+ column: Optional[int] = None,
+ end_line: Optional[int] = None,
+ end_column: Optional[int] = None,
+ doc_string: Optional[str] = None,
+ proof: Optional[str] = None
+ ):
+ """
+ Update a declaration with complete information.
+
+ This is called when we process the actual declaration (not just a reference).
+ Updates the record with full metadata.
+
+ Args:
+ decl_id: The declaration ID
+ name: Declaration name
+ namespace: Namespace
+ file_path: File path
+ module_name: Module name
+ decl_type: Declaration type (theorem, def, etc.)
+ text: Full declaration text
+ line: Starting line number
+ column: Starting column number
+ end_line: Ending line number
+ end_column: Ending column number
+ doc_string: Documentation string
+ proof: Proof text
+ """
+ cursor = self.conn.cursor()
+
+ cursor.execute("""
+ UPDATE declarations
+ SET decl_type = ?,
+ text = ?,
+ line = ?,
+ column = ?,
+ end_line = ?,
+ end_column = ?,
+ doc_string = ?,
+ proof = ?
+ WHERE decl_id = ?
+ """, (decl_type, text, line, column, end_line, end_column, doc_string, proof, decl_id))
+
+ self.conn.commit()
+
+ def insert_dependency_edge(self, from_decl_id: str, to_decl_id: str):
+ """
+ Insert a dependency edge: from_decl_id depends on to_decl_id.
+
+ Uses INSERT OR IGNORE for idempotency (safe to call multiple times).
+
+ Args:
+ from_decl_id: The declaration that has the dependency
+ to_decl_id: The declaration being depended on
+ """
+ cursor = self.conn.cursor()
+
+ cursor.execute("""
+ INSERT OR IGNORE INTO declaration_dependencies (from_decl_id, to_decl_id)
+ VALUES (?, ?)
+ """, (from_decl_id, to_decl_id))
+
+ self.conn.commit()
+
+ def process_fda_list(self, fda_list: List) -> List[str]:
+ """
+ Process a list of FileDependencyAnalysis objects.
+
+ Args:
+ fda_list: List of FileDependencyAnalysis objects
+
+ Returns:
+ List of all decl_ids that were processed
+ """
+ all_decl_ids = []
+
+ for fda in fda_list:
+ # Insert file and imports first
+ if fda.imports:
+ self.insert_file_imports(fda.file_path, fda.module_name, fda.imports)
+
+ # Process each declaration in this file
+ for decl in fda.declarations:
+ decl_id = self.process_declaration(fda.file_path, fda.module_name, decl)
+ all_decl_ids.append(decl_id)
+
+ return all_decl_ids
+
+ def process_declaration(
+ self,
+ fda_file_path: str,
+ fda_module_name: str,
+ decl,
+ enable_dependency_extraction: bool = True
+ ) -> str:
+ """
+ Process a declaration from a FileDependencyAnalysis object.
+
+ This is the main high-level method that:
+ 1. Gets or creates decl_id for this declaration
+ 2. Updates full declaration info
+ 3. Processes all dependencies and creates edges
+
+ Args:
+ fda_file_path: File path from FileDependencyAnalysis
+ fda_module_name: Module name from FileDependencyAnalysis
+ decl: DeclWithDependencies object
+
+ Returns:
+ The decl_id for this declaration
+ """
+ # Get or create ID for this declaration
+ decl_id = self.get_or_create_decl_id(
+ name=decl.decl_info.name,
+ namespace=decl.decl_info.namespace,
+ file_path=fda_file_path,
+ module_name=fda_module_name,
+ # If we are extracting dependencies, we expect the declaration to exist
+ assert_exists=enable_dependency_extraction
+ )
+
+ if not enable_dependency_extraction:
+ # Update with full declaration info
+ # This is a new declaration, so we can safely update all info
+ self.upsert_declaration_full_info(
+ decl_id=decl_id,
+ name=decl.decl_info.name,
+ namespace=decl.decl_info.namespace,
+ file_path=fda_file_path,
+ module_name=fda_module_name,
+ decl_type=decl.decl_info.decl_type,
+ text=decl.decl_info.text,
+ line=decl.decl_info.line,
+ column=decl.decl_info.column,
+ end_line=decl.decl_info.end_line,
+ end_column=decl.decl_info.end_column,
+ doc_string=decl.decl_info.doc_string,
+ proof=decl.decl_info.proof
+ )
+ else:
+ # Process dependencies
+ for dep in decl.dependencies:
+ # Get or create ID for the dependency
+ try:
+ dep_decl_id = self.get_or_create_decl_id(
+ name=dep.name,
+ namespace=dep.namespace,
+ file_path=dep.file_path,
+ module_name=dep.module_name,
+ # At this point all dependencies should exist
+ assert_exists=True
+ )
+ except ValueError as e:
+ # This dependency is probably from outside the project
+ # Let's just skip it
+ continue
+
+ # Propagate the decl_id back to the dependency object
+ dep.decl_id = dep_decl_id
+
+ # Insert dependency edge
+ self.insert_dependency_edge(decl_id, dep_decl_id)
+
+ return decl_id
+
+ def get_or_create_file(self, file_path: str, module_name: str) -> int:
+ """
+ Get or create a file record.
+
+ Args:
+ file_path: The file path
+ module_name: The module name
+
+ Returns:
+ The file_id (existing or newly created)
+ """
+ cursor = self.conn.cursor()
+
+ cursor.execute("SELECT id FROM files WHERE file_path = ?", (file_path,))
+ row = cursor.fetchone()
+ if row:
+ return row[0]
+
+ cursor.execute("""
+ INSERT INTO files (file_path, module_name)
+ VALUES (?, ?)
+ """, (file_path, module_name))
+ self.conn.commit()
+
+ file_id = cursor.lastrowid
+ if file_id is None:
+ raise RuntimeError("Failed to get file_id after insert")
+ return file_id
+
+ def insert_file_imports(self, file_path: str, module_name: str, imports: List[Dict]):
+ """
+ Insert imports for a file.
+
+ Args:
+ file_path: The file path
+ module_name: The module name
+ imports: List of import dictionaries
+ """
+ file_id = self.get_or_create_file(file_path, module_name)
+ cursor = self.conn.cursor()
+
+ for import_data in imports:
+ cursor.execute("""
+ INSERT INTO imports (file_id, end_pos, module_name, start_pos, text)
+ VALUES (?, ?, ?, ?, ?)
+ """, (file_id, import_data.get('end_pos'), import_data.get('module_name'),
+ import_data.get('start_pos'), import_data.get('text')))
+
+ self.conn.commit()
+
+ # Query methods
+
+ def get_declaration_by_decl_id(self, decl_id: str) -> Optional[Dict[str, Any]]:
+ """
+ Get a declaration by its decl_id.
+
+ Args:
+ decl_id: The unique declaration ID
+
+ Returns:
+ Dictionary containing declaration information or None
+ """
+ cursor = self.conn.cursor()
+ cursor.execute("SELECT * FROM declarations WHERE decl_id = ?", (decl_id,))
+ row = cursor.fetchone()
+ return dict(row) if row else None
+
+ def get_declarations_by_name(
+ self,
+ name: str,
+ namespace: Optional[str] = None
+ ) -> List[Dict[str, Any]]:
+ """
+ Get declarations by name and optionally namespace.
+
+ Args:
+ name: The declaration name
+ namespace: Optional namespace to filter by
+
+ Returns:
+ List of dictionaries containing declaration information
+ """
+ cursor = self.conn.cursor()
+
+ if namespace is not None:
+ cursor.execute("""
+ SELECT * FROM declarations
+ WHERE name = ? AND namespace = ?
+ """, (name, namespace))
+ else:
+ cursor.execute("""
+ SELECT * FROM declarations
+ WHERE name = ?
+ """, (name,))
+
+ return [dict(row) for row in cursor.fetchall()]
+
+ def get_declarations_by_file(self, file_path: str) -> List[Dict[str, Any]]:
+ """
+ Get all declarations in a specific file.
+
+ Args:
+ file_path: The file path
+
+ Returns:
+ List of dictionaries containing declaration information
+ """
+ cursor = self.conn.cursor()
+ cursor.execute("""
+ SELECT * FROM declarations
+ WHERE file_path = ?
+ ORDER BY line
+ """, (file_path,))
+
+ return [dict(row) for row in cursor.fetchall()]
+
+ def get_declarations_by_module(self, module_name: str) -> List[Dict[str, Any]]:
+ """
+ Get all declarations in a specific module.
+
+ Args:
+ module_name: The module name
+
+ Returns:
+ List of dictionaries containing declaration information
+ """
+ cursor = self.conn.cursor()
+ cursor.execute("""
+ SELECT * FROM declarations
+ WHERE module_name = ?
+ """, (module_name,))
+
+ return [dict(row) for row in cursor.fetchall()]
+
+ def get_dependencies(self, decl_id: str) -> List[Dict[str, Any]]:
+ """
+ Get all declarations that this declaration depends on.
+
+ Args:
+ decl_id: The declaration ID
+
+ Returns:
+ List of declarations that this declaration depends on
+ """
+ cursor = self.conn.cursor()
+ cursor.execute("""
+ SELECT d.* FROM declarations d
+ JOIN declaration_dependencies dd ON dd.to_decl_id = d.decl_id
+ WHERE dd.from_decl_id = ?
+ """, (decl_id,))
+
+ return [dict(row) for row in cursor.fetchall()]
+
+ def get_dependents(self, decl_id: str) -> List[Dict[str, Any]]:
+ """
+ Get all declarations that depend on this declaration.
+
+ Args:
+ decl_id: The declaration ID
+
+ Returns:
+ List of declarations that depend on this one
+ """
+ cursor = self.conn.cursor()
+ cursor.execute("""
+ SELECT d.* FROM declarations d
+ JOIN declaration_dependencies dd ON dd.from_decl_id = d.decl_id
+ WHERE dd.to_decl_id = ?
+ """, (decl_id,))
+
+ return [dict(row) for row in cursor.fetchall()]
+
+ def get_dependency_graph(
+ self,
+ decl_id: str,
+ max_depth: int = 10,
+ direction: str = 'dependencies'
+ ) -> Dict[str, Any]:
+ """
+ Get the dependency graph for a declaration (recursive).
+
+ Args:
+ decl_id: The declaration ID to start from
+ max_depth: Maximum depth to traverse (default 10)
+ direction: 'dependencies' (what this depends on) or 'dependents' (what depends on this)
+
+ Returns:
+ Dictionary containing the dependency graph
+ """
+ visited = set()
+
+ def _get_graph_recursive(current_decl_id: str, depth: int) -> Optional[Dict[str, Any]]:
+ if depth > max_depth or current_decl_id in visited:
+ return None
+
+ visited.add(current_decl_id)
+
+ decl = self.get_declaration_by_decl_id(current_decl_id)
+ if not decl:
+ return None
+
+ if direction == 'dependencies':
+ related = self.get_dependencies(current_decl_id)
+ else:
+ related = self.get_dependents(current_decl_id)
+
+ result = {
+ 'declaration': decl,
+ direction: []
+ }
+
+ for rel_decl in related:
+ sub_graph = _get_graph_recursive(rel_decl['decl_id'], depth + 1)
+ if sub_graph:
+ result[direction].append(sub_graph)
+
+ return result
+
+ graph = _get_graph_recursive(decl_id, 0)
+ return graph if graph else {}
+
+ def search_declarations(
+ self,
+ name: Optional[str] = None,
+ namespace: Optional[str] = None,
+ file_path: Optional[str] = None,
+ module_name: Optional[str] = None,
+ decl_type: Optional[str] = None
+ ) -> List[Dict[str, Any]]:
+ """
+ Search declarations with multiple optional filters.
+
+ Args:
+ name: Declaration name (optional)
+ namespace: Namespace (optional)
+ file_path: File path (optional)
+ module_name: Module name (optional)
+ decl_type: Declaration type (optional)
+
+ Returns:
+ List of matching declarations
+ """
+ cursor = self.conn.cursor()
+
+ query = "SELECT * FROM declarations WHERE 1=1"
+ params = []
+
+ if name is not None:
+ query += " AND name = ?"
+ params.append(name)
+
+ if namespace is not None:
+ query += " AND namespace = ?"
+ params.append(namespace)
+
+ if file_path is not None:
+ query += " AND file_path = ?"
+ params.append(file_path)
+
+ if module_name is not None:
+ query += " AND module_name = ?"
+ params.append(module_name)
+
+ if decl_type is not None:
+ query += " AND decl_type = ?"
+ params.append(decl_type)
+
+ cursor.execute(query, params)
+ return [dict(row) for row in cursor.fetchall()]
+
+ def get_statistics(self) -> Dict[str, int]:
+ """
+ Get database statistics.
+
+ Returns:
+ Dictionary with counts of files, declarations, and dependencies
+ """
+ cursor = self.conn.cursor()
+
+ stats = {}
+
+ try:
+ cursor.execute("SELECT COUNT(*) FROM files")
+ result = cursor.fetchone()
+ stats['total_files'] = result[0] if result else 0
+ except Exception as e:
+ print(f"Warning: Error fetching files count: {e}")
+ stats['total_files'] = 0
+
+ try:
+ cursor.execute("SELECT COUNT(*) FROM declarations")
+ result = cursor.fetchone()
+ stats['total_declarations'] = result[0] if result else 0
+ except Exception as e:
+ print(f"Warning: Error fetching declarations count: {e}")
+ stats['total_declarations'] = 0
+
+ try:
+ cursor.execute("SELECT COUNT(*) FROM declaration_dependencies")
+ result = cursor.fetchone()
+ stats['total_dependencies'] = result[0] if result else 0
+ except Exception as e:
+ print(f"Warning: Error fetching dependencies count: {e}")
+ stats['total_dependencies'] = 0
+
+ try:
+ cursor.execute("SELECT COUNT(*) FROM declarations WHERE file_path IS NULL")
+ result = cursor.fetchone()
+ stats['unresolved_declarations'] = result[0] if result else 0
+ except Exception as e:
+ print(f"Warning: Error fetching unresolved declarations count: {e}")
+ stats['unresolved_declarations'] = 0
+
+ try:
+ cursor.execute("SELECT COUNT(*) FROM imports")
+ result = cursor.fetchone()
+ stats['total_imports'] = result[0] if result else 0
+ except Exception as e:
+ print(f"Warning: Error fetching imports count: {e}")
+ stats['total_imports'] = 0
+
+ return stats
+
+ def close(self):
+ """Close the database connection."""
+ if self.conn:
+ self.conn.close()
+
+ def __enter__(self):
+ """Context manager entry."""
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """Context manager exit."""
+ self.close()
+ return False
+
+
+if __name__ == "__main__":
+ # Example usage
+ print("Creating test database...")
+
+ with LeanDeclarationDB("test_lean_declarations.db") as db:
+ # Test with file and imports
+ print("\nTesting file and imports insertion:")
+ test_imports = [
+ {
+ "end_pos": 292,
+ "module_name": "Mathlib.Algebra.Group.Basic",
+ "start_pos": 258,
+ "text": "import Mathlib.Algebra.Group.Basic"
+ },
+ {
+ "end_pos": 321,
+ "module_name": "Mathlib.Tactic.Common",
+ "start_pos": 293,
+ "text": "import Mathlib.Tactic.Common"
+ }
+ ]
+
+ db.insert_file_imports(
+ file_path="Mathlib/Algebra/Divisibility/Basic.lean",
+ module_name="Mathlib.Algebra.Divisibility.Basic",
+ imports=test_imports
+ )
+ print("Inserted file and imports")
+
+ # Test ID generation
+ print("\nTesting ID generation:")
+ decl_id1 = db.get_or_create_decl_id(
+ name="dvd_trans",
+ namespace=None,
+ file_path="Mathlib/Algebra/Divisibility/Basic.lean",
+ module_name="Mathlib.Algebra.Divisibility.Basic"
+ )
+ print(f"Generated ID for dvd_trans: {decl_id1}")
+
+ # Get same ID again (should return existing)
+ decl_id2 = db.get_or_create_decl_id(
+ name="dvd_trans",
+ namespace=None,
+ file_path="Mathlib/Algebra/Divisibility/Basic.lean",
+ module_name="Mathlib.Algebra.Divisibility.Basic"
+ )
+ print(f"Retrieved ID for dvd_trans: {decl_id2}")
+ print(f"IDs match: {decl_id1 == decl_id2}")
+
+ # Update with full info
+ db.upsert_declaration_full_info(
+ decl_id=decl_id1,
+ name="dvd_trans",
+ namespace=None,
+ file_path="Mathlib/Algebra/Divisibility/Basic.lean",
+ module_name="Mathlib.Algebra.Divisibility.Basic",
+ decl_type="theorem",
+ text="@[trans]\ntheorem dvd_trans : a # b ļæ½ b # c ļæ½ a # c",
+ line=63,
+ column=0,
+ end_line=68,
+ end_column=0,
+ doc_string=None,
+ proof="| �d, h��, �e, h�� => �d * e, h� � h�.trans <| mul_assoc a d e�"
+ )
+ print("Updated declaration with full info")
+
+ # Create a dependency
+ dep_id = db.get_or_create_decl_id(
+ name="mul_assoc",
+ namespace=None,
+ file_path=None,
+ module_name=None
+ )
+ print(f"\nGenerated ID for dependency mul_assoc: {dep_id}")
+
+ # Insert dependency edge
+ db.insert_dependency_edge(decl_id1, dep_id)
+ print("Inserted dependency edge")
+
+ # Query
+ print("\nQuerying declaration:")
+ decl = db.get_declaration_by_decl_id(decl_id1)
+ if decl:
+ print(f"Declaration: {decl['name']} ({decl['decl_type']})")
+ else:
+ print("Declaration not found!")
+
+ print("\nQuerying dependencies:")
+ deps = db.get_dependencies(decl_id1)
+ print(f"Dependencies: {[d['name'] for d in deps]}")
+
+ # Statistics
+ print("\nDatabase statistics:")
+ stats = db.get_statistics()
+ for key, value in stats.items():
+ print(f" {key}: {value}")
+
+ print("\nTest complete!")
diff --git a/src/itp_interface/tools/training_data_format.py b/src/itp_interface/tools/training_data_format.py
index e7dabfa..326aee7 100644
--- a/src/itp_interface/tools/training_data_format.py
+++ b/src/itp_interface/tools/training_data_format.py
@@ -12,7 +12,7 @@
from collections import OrderedDict
from typing import List, Optional, Union, runtime_checkable, Protocol
from pydantic import BaseModel
-from itp_interface.tools.tactic_parser import DeclWithDependencies
+from itp_interface.lean.tactic_parser import FileDependencyAnalysis
@runtime_checkable
class TrainingDataFormat(Protocol):
@@ -503,7 +503,7 @@ def load_from_string(json_text: str, logger: logging.Logger = None):
return deserialized
class ExtractionDataCollection(BaseModel):
- training_data: list[DeclWithDependencies] = []
+ training_data: list[FileDependencyAnalysis] = []
def __len__(self) -> int:
return len(self.training_data)
diff --git a/src/test/parsing_helpers_test.py b/src/test/parsing_helpers_test.py
new file mode 100644
index 0000000..1726c05
--- /dev/null
+++ b/src/test/parsing_helpers_test.py
@@ -0,0 +1,69 @@
+import unittest
+from itp_interface.lean.parsing_helpers import parse_lean_text, LeanDeclType
+
+
+class ParsingHelpersTest(unittest.TestCase):
+ def test_lean_declaration_parsing(self):
+ """Test parsing of various Lean 4 declaration types"""
+ test_cases = [
+ (
+ "Simple Theorem",
+ "@[simp] lemma foo (x : Nat) : x = x := rfl"
+ ),
+ (
+ "Context and Doc",
+ "open Algebra.TensorProduct in\n/-- Doc -/\ntheorem left_of_tensor [Module R] : True where\n out := sorry"
+ ),
+ (
+ "Inductive (No Proof)",
+ "/-- The base e -/\ninductive ExBase : Type\n| A\n| B"
+ ),
+ (
+ "Structure (No Proof)",
+ "structure MyStruct where\n field1 : Nat\n field2 : Int"
+ ),
+ (
+ "Mutual (No Proof)",
+ "mutual\n inductive A | a\n inductive B | b\nend"
+ ),
+ (
+ "Run Cmd (Fallback)",
+ "open Lean in\nrun_cmd Command.liftTermElabM do\n logInfo \"hi\""
+ ),
+ (
+ "Inductive Nat (hard case)",
+ "set_option genCtorIdx false in\n/--\nThe natural numbers, starting at zero.\n\nThis type is special-cased by both the kernel and the compiler, and overridden with an efficient\nimplementation. Both use a fast arbitrary-precision arithmetic library (usually\n[GMP](https://gmplib.org/)); at runtime, `Nat` values that are sufficiently small are unboxed.\n-/\ninductive Nat where\n /--\n Zero, the smallest natural number.\n\n Using `Nat.zero` explicitly should usually be avoided in favor of the literal `0`, which is the\n [simp normal form](lean-manual://section/simp-normal-forms).\n -/\n | zero : Nat\n /--\n The successor of a natural number `n`.\n\n Using `Nat.succ n` should usually be avoided in favor of `n + 1`, which is the [simp normal\n form](lean-manual://section/simp-normal-forms).\n -/\n | succ (n : Nat) : Nat\n"
+ )
+ ]
+
+ print(f"{'TYPE':<12} | {'NAME':<10} | {'TEXT BEFORE':<20} | {'DOC':<10} | {'TEXT':<20} | {'PROOF':<15}")
+ print("-" * 115)
+
+ for test_name, inp in test_cases:
+ res = parse_lean_text(inp)
+
+ tp = res.decl_type.value
+ nm = res.name or ""
+ tb = (res.text_before or "").replace('\n', '\\n')
+ ds = (res.doc_string or "")
+ tx = (res.text or "").replace('\n', '\\n')
+ pf = (res.proof or "").replace('\n', '\\n')
+
+ if len(tb) > 18: tb = tb[:18] + "..."
+ if len(ds) > 8: ds = "/--...-/"
+ if len(tx) > 18: tx = tx[:18] + "..."
+ if len(pf) > 12: pf = pf[:12] + "..."
+
+ print(f"{tp:<12} | {nm:<10} | {tb:<20} | {ds:<10} | {tx:<20} | {pf:<15}")
+
+ # Basic assertions to verify parsing works
+ self.assertIsNotNone(res.decl_type)
+ self.assertIsInstance(res.decl_type, LeanDeclType)
+
+
+def main():
+ unittest.main()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/src/test/simple_data_extract_test.py b/src/test/simple_data_extract_test.py
index 36f1dca..07aba42 100644
--- a/src/test/simple_data_extract_test.py
+++ b/src/test/simple_data_extract_test.py
@@ -17,7 +17,7 @@ def pretty_print_file_contents(dir_path):
print(f"Printing all files in the directory: {dir_path}")
for f in os.listdir(dir_path):
file_path = os.path.join(dir_path, f)
- if os.path.isfile(file_path):
+ if os.path.isfile(file_path) and any(file_path.endswith(ext) for ext in [".json", ".yaml", ".yml", ".txt", ".log"]):
print('-'*50)
print(f"Contents of {file_path}:")
with open(file_path, "r") as file:
@@ -63,7 +63,7 @@ def test_lean_data_extract(self):
# Print the directory contents
last_dir_path = os.path.join(".log/data_generation/benchmark/simple_benchmark_lean_ext", last_dir)
print("Last Directory Contents:", os.listdir(last_dir_path))
- train_data = os.path.join(last_dir_path, "train")
+ train_data = os.path.join(last_dir_path, "train", "1")
list_files = os.listdir(train_data)
print("Train Directory Contents:", list_files)
data_files = [f for f in list_files if f.endswith(".json") and f.startswith("local_data_")]
diff --git a/src/test/simple_data_gen_test.py b/src/test/simple_data_gen_test.py
index e39ce93..48fd226 100644
--- a/src/test/simple_data_gen_test.py
+++ b/src/test/simple_data_gen_test.py
@@ -62,7 +62,7 @@ def test_proof_step_data_gen(self):
# Print the directory contents
last_dir_path = os.path.join(".log/data_generation/benchmark/simple_benchmark_lean", last_dir)
print("Last Directory Contents:", os.listdir(last_dir_path))
- train_data = os.path.join(last_dir_path, "train")
+ train_data = os.path.join(last_dir_path, "train", "0")
list_files = os.listdir(train_data)
print("Train Directory Contents:", list_files)
data_files = [f for f in list_files if f.endswith(".json") and f.startswith("local_data_")]
diff --git a/src/test/simple_env_coq_test.py b/src/test/simple_env_coq_test.py
new file mode 100644
index 0000000..77c5108
--- /dev/null
+++ b/src/test/simple_env_coq_test.py
@@ -0,0 +1,131 @@
+import unittest
+import os
+
+def pretty_print(s1, s2, proof_step, done):
+ print(f"Current Goal:")
+ print('-'*30)
+ for goal in s1.training_data_format.start_goals:
+ hyps = '\n'.join([hyp for hyp in goal.hypotheses])
+ print(hyps)
+ print('|- ', end='')
+ print(goal.goal)
+ print(f'*'*30)
+ print(f"="*30)
+ print(f"Action: {proof_step}")
+ print(f"="*30)
+ print(f"Next Goal:")
+ print('-'*30)
+ if s2 is not None:
+ for goal in s2.training_data_format.start_goals:
+ hyps = '\n'.join([hyp for hyp in goal.hypotheses])
+ print(hyps)
+ print('|- ', end='')
+ print(goal.goal)
+ print(f'*'*30)
+ print(f"="*30)
+ print(f"DONE: {done}")
+ print('-'*30)
+ if s2 is None and done:
+ print("No more goals. Proof Finished!")
+
+class CoqHelper():
+ def __init__(self):
+ self.current_switch = None
+
+ def build_coq_project(self, project_folder):
+ try:
+ with os.popen("opam switch show") as proc:
+ self.current_switch = proc.read().strip()
+ except:
+ self.current_switch = None
+ # Check if the switch exists
+ # opam switch create simple_grp_theory 4.14.2
+ if os.system("opam switch simple_grp_theory") != 0:
+ cmds = [
+ 'opam switch create simple_grp_theory 4.14.2',
+ 'opam switch simple_grp_theory',
+ 'eval $(opam env)',
+ 'opam repo add coq-released https://coq.inria.fr/opam/released',
+ 'opam pin add -y coq-lsp 0.1.8+8.18'
+ ]
+ final_cmd = ' && '.join(cmds)
+ os.system(final_cmd)
+ # IMPORTANT NOTE: Make sure to switch to the correct switch before running the code.
+ os.system("opam switch simple_grp_theory && eval $(opam env)")
+ # Clean the project
+ os.system(f"eval $(opam env) && cd {project_folder} && make clean")
+ # Build the project
+ with os.popen(f"eval $(opam env) && cd {project_folder} && make") as proc:
+ print("Building Coq project...")
+ print('-'*15 + 'Build Logs' + '-'*15)
+ print(proc.read())
+ print('-'*15 + 'End Build Logs' + '-'*15)
+
+ def switch_to_current_switch(self):
+ if self.current_switch is not None:
+ try:
+ proc = os.popen(f"opam switch {self.current_switch} && eval $(opam env)")
+ print(proc.read())
+ finally:
+ proc.close()
+
+
+class CoqTest(unittest.TestCase):
+ def test_simple_coq(self):
+ from itp_interface.rl.proof_state import ProofState
+ from itp_interface.rl.proof_action import ProofAction
+ from itp_interface.rl.simple_proof_env import ProofEnv
+ from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
+ from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
+ project_folder = "src/data/test/coq/custom_group_theory/theories"
+ file_path = "src/data/test/coq/custom_group_theory/theories/grpthm.v"
+ # Build the project
+ # cd src/data/test/coq/custom_group_theory/theories && make
+ helper = CoqHelper()
+ helper.build_coq_project(project_folder)
+ language = ProofAction.Language.COQ
+ theorem_name = "algb_identity_sum"
+ # Theorem algb_identity_sum :
+ # forall a, algb_add a e = a.
+ proof_exec_callback = ProofExecutorCallback(
+ project_folder=project_folder,
+ file_path=file_path,
+ language=language,
+ always_use_retrieval=False,
+ keep_local_context=True
+ )
+ always_retrieve_thms = False
+ retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
+ env = ProofEnv("test_coq", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
+ proof_steps = [
+ 'intros.',
+ 'destruct a.',
+ '- reflexivity.',
+ '- reflexivity.'
+ ]
+ with env:
+ for proof_step in proof_steps:
+ state, _, next_state, _, done, info = env.step(ProofAction(
+ ProofAction.ActionType.RUN_TACTIC,
+ language,
+ tactics=[proof_step]))
+ if info.error_message is not None:
+ print(f"Error: {info.error_message}")
+ # This prints StateChanged, StateUnchanged, Failed, or Done
+ print(info.progress)
+ print('-'*30)
+ if done:
+ print("Proof Finished!!")
+ else:
+ s1 : ProofState = state
+ s2 : ProofState = next_state
+ pretty_print(s1, s2, proof_step, done)
+ helper.switch_to_current_switch()
+
+
+def main():
+ unittest.main()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/src/test/simple_env_test.py b/src/test/simple_env_lean_test.py
similarity index 81%
rename from src/test/simple_env_test.py
rename to src/test/simple_env_lean_test.py
index 8306fb5..abc6d40 100644
--- a/src/test/simple_env_test.py
+++ b/src/test/simple_env_lean_test.py
@@ -1,6 +1,6 @@
import unittest
import os
-from itp_interface.tools.tactic_parser import build_lean4_project, build_tactic_parser_if_needed
+from itp_interface.lean.tactic_parser import build_lean4_project, build_tactic_parser_if_needed
def pretty_print(s1, s2, proof_step, done):
print(f"Current Goal:")
@@ -29,10 +29,7 @@ def pretty_print(s1, s2, proof_step, done):
if s2 is None and done:
print("No more goals. Proof Finished!")
-class Helper():
- def __init__(self):
- self.current_switch = None
-
+class LeanHelper():
def build_lean4_project(self, project_folder):
build_tactic_parser_if_needed()
# Build the project
@@ -41,43 +38,6 @@ def build_lean4_project(self, project_folder):
build_lean4_project(project_folder)
- def build_coq_project(self, project_folder):
- try:
- with os.popen("opam switch show") as proc:
- self.current_switch = proc.read().strip()
- except:
- self.current_switch = None
- # Check if the switch exists
- # opam switch create simple_grp_theory 4.14.2
- if os.system("opam switch simple_grp_theory") != 0:
- cmds = [
- 'opam switch create simple_grp_theory 4.14.2',
- 'opam switch simple_grp_theory',
- 'eval $(opam env)',
- 'opam repo add coq-released https://coq.inria.fr/opam/released',
- 'opam pin add -y coq-lsp 0.1.8+8.18'
- ]
- final_cmd = ' && '.join(cmds)
- os.system(final_cmd)
- # IMPORTANT NOTE: Make sure to switch to the correct switch before running the code.
- os.system("opam switch simple_grp_theory && eval $(opam env)")
- # Clean the project
- os.system(f"eval $(opam env) && cd {project_folder} && make clean")
- # Build the project
- with os.popen(f"eval $(opam env) && cd {project_folder} && make") as proc:
- print("Building Coq project...")
- print('-'*15 + 'Build Logs' + '-'*15)
- print(proc.read())
- print('-'*15 + 'End Build Logs' + '-'*15)
-
- def switch_to_current_switch(self):
- if self.current_switch is not None:
- try:
- proc = os.popen(f"opam switch {self.current_switch} && eval $(opam env)")
- print(proc.read())
- finally:
- proc.close()
-
class Lean4Test(unittest.TestCase):
def test_simple_lean4(self):
from itp_interface.rl.proof_state import ProofState
@@ -89,7 +49,7 @@ def test_simple_lean4(self):
file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
# Build the project
# cd src/data/test/lean4_proj && lake build
- helper = Helper()
+ helper = LeanHelper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"test3\"}'
@@ -119,8 +79,8 @@ def test_simple_lean4(self):
proof_was_finished = False
for proof_step in proof_steps:
state, _, next_state, _, done, info = env.step(ProofAction(
- ProofAction.ActionType.RUN_TACTIC,
- language,
+ ProofAction.ActionType.RUN_TACTIC,
+ language,
tactics=[proof_step]))
if info.error_message is not None:
print(f"Error: {info.error_message}")
@@ -145,7 +105,7 @@ def test_lean4_backtracking(self):
project_folder = "src/data/test/lean4_proj"
file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
# Build the project
- helper = Helper()
+ helper = LeanHelper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"test3\"}'
@@ -176,20 +136,20 @@ def test_lean4_backtracking(self):
print(f"Backtracking at step {idx + 1} i.e. {proof_step}")
state, _, next_state, _, done, info = env.step(
ProofAction(
- ProofAction.ActionType.BACKTRACK,
+ ProofAction.ActionType.BACKTRACK,
language))
assert next_state == prev_state, "Backtracking failed"
# Replay the last action
last_proof_step = proof_steps[idx-1]
state, _, next_state, _, done, info = env.step(
ProofAction(
- ProofAction.ActionType.RUN_TACTIC,
- language,
+ ProofAction.ActionType.RUN_TACTIC,
+ language,
tactics=[last_proof_step]))
state, _, next_state, _, done, info = env.step(
ProofAction(
- ProofAction.ActionType.RUN_TACTIC,
- language,
+ ProofAction.ActionType.RUN_TACTIC,
+ language,
tactics=[proof_step]))
prev_state = state
if done:
@@ -197,57 +157,6 @@ def test_lean4_backtracking(self):
proof_was_finished = True
assert proof_was_finished, "Proof was not finished"
- def test_simple_coq(self):
- from itp_interface.rl.proof_state import ProofState
- from itp_interface.rl.proof_action import ProofAction
- from itp_interface.rl.simple_proof_env import ProofEnv
- from itp_interface.tools.proof_exec_callback import ProofExecutorCallback
- from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy
- project_folder = "src/data/test/coq/custom_group_theory/theories"
- file_path = "src/data/test/coq/custom_group_theory/theories/grpthm.v"
- # Build the project
- # cd src/data/test/coq/custom_group_theory/theories && make
- helper = Helper()
- helper.build_coq_project(project_folder)
- language = ProofAction.Language.COQ
- theorem_name = "algb_identity_sum"
- # Theorem algb_identity_sum :
- # forall a, algb_add a e = a.
- proof_exec_callback = ProofExecutorCallback(
- project_folder=project_folder,
- file_path=file_path,
- language=language,
- always_use_retrieval=False,
- keep_local_context=True
- )
- always_retrieve_thms = False
- retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK
- env = ProofEnv("test_coq", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms)
- proof_steps = [
- 'intros.',
- 'destruct a.',
- '- reflexivity.',
- '- reflexivity.'
- ]
- with env:
- for proof_step in proof_steps:
- state, _, next_state, _, done, info = env.step(ProofAction(
- ProofAction.ActionType.RUN_TACTIC,
- language,
- tactics=[proof_step]))
- if info.error_message is not None:
- print(f"Error: {info.error_message}")
- # This prints StateChanged, StateUnchanged, Failed, or Done
- print(info.progress)
- print('-'*30)
- if done:
- print("Proof Finished!!")
- else:
- s1 : ProofState = state
- s2 : ProofState = next_state
- pretty_print(s1, s2, proof_step, done)
- helper.switch_to_current_switch()
-
def test_simple_lean_calc(self):
from itp_interface.rl.proof_state import ProofState
from itp_interface.rl.proof_action import ProofAction
@@ -258,7 +167,7 @@ def test_simple_lean_calc(self):
file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
# Build the project
# cd src/data/test/lean4_proj && lake build
- helper = Helper()
+ helper = LeanHelper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
theorem_name = "{\"namespace\":\"Lean4Proj1\",\"name\":\"test_calc\"}"
@@ -290,8 +199,8 @@ def test_simple_lean_calc(self):
proof_was_finished = False
for proof_step in proof_steps:
state, _, next_state, _, done, info = env.step(ProofAction(
- ProofAction.ActionType.RUN_TACTIC,
- language,
+ ProofAction.ActionType.RUN_TACTIC,
+ language,
tactics=[proof_step]))
if info.error_message is not None:
print(f"Error: {info.error_message}")
@@ -319,7 +228,7 @@ def test_simple_lean_calc_with_validation(self):
file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
# Build the project
# cd src/data/test/lean4_proj && lake build
- helper = Helper()
+ helper = LeanHelper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
theorem_name = "{\"namespace\":\"Lean4Proj1\",\"name\":\"test_calc\"}"
@@ -351,8 +260,8 @@ def test_simple_lean_calc_with_validation(self):
proof_was_finished = False
for proof_step in proof_steps:
state, _, next_state, _, done, info = env.step(ProofAction(
- ProofAction.ActionType.RUN_TACTIC,
- language,
+ ProofAction.ActionType.RUN_TACTIC,
+ language,
tactics=[proof_step]))
if info.error_message is not None:
print(f"Error: {info.error_message}")
@@ -386,7 +295,7 @@ def test_simple_lean_enforce_done_test(self):
file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
# Build the project
# cd src/data/test/lean4_proj && lake build
- helper = Helper()
+ helper = LeanHelper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
theorem_name = "{\"namespace\":\"Lean4Proj1\",\"name\":\"test_calc\"}"
@@ -420,8 +329,8 @@ def test_simple_lean_enforce_done_test(self):
proof_finished = False
for proof_step in proof_steps:
state, _, next_state, _, done, info = env.step(ProofAction(
- ProofAction.ActionType.RUN_TACTIC,
- language,
+ ProofAction.ActionType.RUN_TACTIC,
+ language,
tactics=[proof_step]))
if info.error_message is not None:
print(f"Error: {info.error_message}")
@@ -450,7 +359,7 @@ def test_simple_lean4_done_test(self):
file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
# Build the project
# cd src/data/test/lean4_proj && lake build
- helper = Helper()
+ helper = LeanHelper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"test3\"}'
@@ -477,8 +386,8 @@ def test_simple_lean4_done_test(self):
with env:
for proof_step in proof_steps:
state, _, next_state, _, done, info = env.step(ProofAction(
- ProofAction.ActionType.RUN_TACTIC,
- language,
+ ProofAction.ActionType.RUN_TACTIC,
+ language,
tactics=[proof_step]))
if info.error_message is not None:
print(f"Error: {info.error_message}")
@@ -502,7 +411,7 @@ def test_simple_lean4_have_test(self):
file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
# Build the project
# cd src/data/test/lean4_proj && lake build
- helper = Helper()
+ helper = LeanHelper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"imo_1959_p1\"}'
@@ -540,8 +449,8 @@ def test_simple_lean4_have_test(self):
env.set_max_proof_step_length(10000)
for proof_step in proof_steps:
state, m_action, next_state, _, done, info = env.step(ProofAction(
- ProofAction.ActionType.RUN_TACTIC,
- language,
+ ProofAction.ActionType.RUN_TACTIC,
+ language,
tactics=[proof_step]))
if info.error_message is not None:
print(f"Error: {info.error_message}")
@@ -560,7 +469,7 @@ def test_simple_lean4_have_test(self):
s1 : ProofState = state
s2 : ProofState = next_state
pretty_print(s1, s2, proof_step, done)
-
+
def test_simple_lean4_with_error(self):
from itp_interface.rl.proof_state import ProofState
from itp_interface.rl.proof_action import ProofAction
@@ -571,7 +480,7 @@ def test_simple_lean4_with_error(self):
file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
# Build the project
# cd src/data/test/lean4_proj && lake build
- helper = Helper()
+ helper = LeanHelper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"test3\"}'
@@ -599,8 +508,8 @@ def test_simple_lean4_with_error(self):
with env:
for i, proof_step in enumerate(proof_steps):
state, _, next_state, _, done, info = env.step(ProofAction(
- ProofAction.ActionType.RUN_TACTIC,
- language,
+ ProofAction.ActionType.RUN_TACTIC,
+ language,
tactics=[proof_step]))
if info.error_message is not None:
print(f"Error: {info.error_message}")
@@ -631,7 +540,7 @@ def test_simple_lean4_multiline_multigoal(self):
file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean"
# Build the project
# cd src/data/test/lean4_proj && lake build
- helper = Helper()
+ helper = LeanHelper()
helper.build_lean4_project(project_folder)
language = ProofAction.Language.LEAN4
theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"complicated_have\"}'
@@ -659,8 +568,8 @@ def test_simple_lean4_multiline_multigoal(self):
proof_was_finished = False
for proof_step in proof_steps:
state, action, next_state, _, done, info = env.step(ProofAction(
- ProofAction.ActionType.RUN_TACTIC,
- language,
+ ProofAction.ActionType.RUN_TACTIC,
+ language,
tactics=[proof_step]))
proof_step = action.kwargs.get('tactics', ['INVALID'])[0]
if info.error_message is not None:
@@ -677,20 +586,10 @@ def test_simple_lean4_multiline_multigoal(self):
pretty_print(s1, s2, proof_step, done)
assert proof_was_finished, "Proof was not finished"
+
def main():
- # unittest.main()
- # Run only the Lean 4 tests
- t = Lean4Test()
- # t.test_simple_lean4_multiline_multigoal()
- # t.test_simple_lean4()
- # t.test_lean4_backtracking()
- # t.test_simple_lean4_done_test()
- # t.test_simple_lean_calc()
- # t.test_simple_lean_calc_with_validation()
- # t.test_simple_lean4_with_error()
- t.test_simple_lean4_have_test()
- # t.test_simple_lean_enforce_done_test()
+ unittest.main()
if __name__ == '__main__':
- main()
\ No newline at end of file
+ main()
diff --git a/src/test/test_tactic_parser.py b/src/test/test_tactic_parser.py
index cb79825..d988769 100644
--- a/src/test/test_tactic_parser.py
+++ b/src/test/test_tactic_parser.py
@@ -11,7 +11,7 @@
# Add parent directory to path to import tactic_parser
sys.path.insert(0, str(Path(__file__).parent.parent / "itp_interface" / "tools"))
-from itp_interface.tools.tactic_parser import TacticParser, print_tactics
+from itp_interface.lean.tactic_parser import TacticParser, print_tactics
project_path = str(Path(__file__).parent.parent / "data" / "test" / "lean4_proj")