diff --git a/.github/workflows/github-build-actions-python314t.yaml b/.github/workflows/github-build-actions-python314t.yaml index 010a72b..4f90686 100644 --- a/.github/workflows/github-build-actions-python314t.yaml +++ b/.github/workflows/github-build-actions-python314t.yaml @@ -19,7 +19,7 @@ jobs: - name: Checkout repository uses: actions/checkout@v3 with: - submodules: true # Ensure submodules are checked out + submodules: false # Ensure submodules are checked out - name: Install Miniconda shell: bash @@ -84,45 +84,103 @@ jobs: source $HOME/.elan/env install-itp-interface - - name: Check and Init opam version - run: | - opam --version - opam init --disable-sandboxing --yes - - - name: Install Coq + - name: Build lean4_proj + shell: bash run: | - opam switch create simple_grp_theory 4.14.2 - opam switch simple_grp_theory - eval $(opam env) - opam repo add coq-released https://coq.inria.fr/opam/released - opam pin add -y coq-lsp 0.1.8+8.18 + source $HOME/.elan/env + cd src/data/test/lean4_proj && lake exe cache get && lake build - name: List repository files (debug step) run: find . -type f - - name: Run Simple Env Test + - name: Clean up logs + run: | + rm -rf .log + echo "Cleaned .log directory for fresh parallel execution test" + + - name: Ray Cleanup + shell: bash + run: | + rm -rf /tmp/* --verbose + + - name: Run Simple Env Lean Test shell: bash run: | export PATH="$HOME/miniconda/bin:$PATH" source $HOME/miniconda/bin/activate py314-ft - eval $(opam env) source $HOME/.elan/env - python src/test/simple_env_test.py + python src/test/simple_env_lean_test.py + + - name: Clean up logs + run: | + rm -rf .log + echo "Cleaned .log directory for fresh parallel execution test" + + - name: Ray Cleanup + shell: bash + run: | + rm -rf /tmp/* --verbose - name: Run Data Gen Test shell: bash run: | export PATH="$HOME/miniconda/bin:$PATH" source $HOME/miniconda/bin/activate py314-ft - eval $(opam env) source $HOME/.elan/env python src/test/simple_data_gen_test.py + - name: Clean up logs + run: | + rm -rf .log + echo "Cleaned .log directory for fresh parallel execution test" + + - name: Ray Cleanup + shell: bash + run: | + rm -rf /tmp/* --verbose + - name: Run Data Extraction Test shell: bash run: | export PATH="$HOME/miniconda/bin:$PATH" source $HOME/miniconda/bin/activate py314-ft - eval $(opam env) source $HOME/.elan/env python src/test/simple_data_extract_test.py + + - name: Ray Cleanup + shell: bash + run: | + rm -rf /tmp/* --verbose + + - name: Clean up logs + run: | + rm -rf .log + echo "Cleaned .log directory for fresh parallel execution test" + + - name: Clean up .lake in lean4_proj + shell: bash + run: | + rm -rf src/data/test/lean4_proj/.lake + echo "Cleaned .lake directory in lean4_proj for fresh parallel execution test" + + - name: Check and Init opam version + run: | + opam --version + opam init --disable-sandboxing --yes + + - name: Install Coq + run: | + opam switch create simple_grp_theory 4.14.2 + opam switch simple_grp_theory + eval $(opam env) + opam repo add coq-released https://coq.inria.fr/opam/released + opam pin add -y coq-lsp 0.1.8+8.18 + + - name: Run Simple Env Coq Test + shell: bash + run: | + export PATH="$HOME/miniconda/bin:$PATH" + source $HOME/miniconda/bin/activate py314-ft + eval $(opam env) + source $HOME/.elan/env + python src/test/simple_env_coq_test.py diff --git a/.github/workflows/github-build-actions.yaml b/.github/workflows/github-build-actions.yaml index 7814b40..e63b419 100644 --- a/.github/workflows/github-build-actions.yaml +++ b/.github/workflows/github-build-actions.yaml @@ -19,7 +19,7 @@ jobs: - name: Checkout repository uses: actions/checkout@v3 with: - submodules: true # Ensure submodules are checked out + submodules: false # Ensure submodules are checked out - name: Install Python and pip run: | @@ -53,41 +53,52 @@ jobs: source $HOME/.elan/env install-itp-interface - - name: Check and Init opam version - run: | - opam --version - opam init --disable-sandboxing --yes - - - name: Install Coq + - name: Build lean4_proj + shell: bash run: | - opam switch create simple_grp_theory 4.14.2 - opam switch simple_grp_theory - eval $(opam env) - opam repo add coq-released https://coq.inria.fr/opam/released - opam pin add -y coq-lsp 0.1.8+8.18 + source $HOME/.elan/env + cd src/data/test/lean4_proj && lake exe cache get && lake build - name: List repository files (debug step) run: find . -type f - - name: Run Simple Env Test + - name: Clean up logs + run: | + rm -rf .log + echo "Cleaned .log directory for fresh parallel execution test" + + - name: Run Simple Env Lean Test shell: bash run: | - eval $(opam env) source $HOME/.elan/env - python src/test/simple_env_test.py - + python src/test/simple_env_lean_test.py + + - name: Clean up logs + run: | + rm -rf .log + echo "Cleaned .log directory for fresh parallel execution test" + - name: Ray Cleanup shell: bash run: | rm -rf /tmp/* --verbose + - name: Clean up logs + run: | + rm -rf .log + echo "Cleaned .log directory for fresh parallel execution test" + - name: Run Data Gen Test shell: bash run: | - eval $(opam env) source $HOME/.elan/env python src/test/simple_data_gen_test.py + - name: Clean up logs + run: | + rm -rf .log + echo "Cleaned .log directory for fresh parallel execution test" + - name: Ray Cleanup shell: bash run: | @@ -96,11 +107,41 @@ jobs: - name: Run Data Extraction Test shell: bash run: | - eval $(opam env) source $HOME/.elan/env python src/test/simple_data_extract_test.py + - name: Clean up logs + run: | + rm -rf .log + echo "Cleaned .log directory for fresh parallel execution test" + - name: Ray Cleanup shell: bash run: | - rm -rf /tmp/* --verbose \ No newline at end of file + rm -rf /tmp/* --verbose + + - name: Clean up .lake in lean4_proj + shell: bash + run: | + rm -rf src/data/test/lean4_proj/.lake + echo "Cleaned .lake directory in lean4_proj for fresh parallel execution test" + + - name: Check and Init opam version + run: | + opam --version + opam init --disable-sandboxing --yes + + - name: Install Coq + run: | + opam switch create simple_grp_theory 4.14.2 + opam switch simple_grp_theory + eval $(opam env) + opam repo add coq-released https://coq.inria.fr/opam/released + opam pin add -y coq-lsp 0.1.8+8.18 + + - name: Run Simple Env Coq Test + shell: bash + run: | + eval $(opam env) + source $HOME/.elan/env + python src/test/simple_env_coq_test.py \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 65bf71a..1b68b33 100644 --- a/.gitmodules +++ b/.gitmodules @@ -13,3 +13,6 @@ [submodule "src/itp_interface/tools/repl"] path = src/itp_interface/tools/repl url = https://github.com/amit9oct/repl.git +[submodule "src/data/test/batteries"] + path = src/data/test/batteries + url = https://github.com/leanprover-community/batteries.git diff --git a/pyproject.toml b/pyproject.toml index e9aac2c..6cc0954 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ build-backend = "hatchling.build" [project] name = "itp_interface" -version = "1.2.0" +version = "1.3.0" authors = [ { name="Amitayush Thakur", email="amitayush@utexas.edu" }, ] @@ -47,8 +47,11 @@ dependencies = [ [project.optional-dependencies] app = [ - "flask>=2.3.0", - "flask-cors>=4.0.0" + "streamlit>=1.28.0", + "scipy>=1.16.0", + "plotly>=5.17.0", + "networkx>=3.1", + "pandas>=2.0.0" ] [project.urls] diff --git a/src/app/__init__.py b/src/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app/data_explorer/README.md b/src/app/data_explorer/README.md new file mode 100644 index 0000000..744ae1e --- /dev/null +++ b/src/app/data_explorer/README.md @@ -0,0 +1,125 @@ +# Lean Declaration Database Explorer + +A Streamlit web application for exploring and analyzing Lean declaration databases created by the ITP Interface data extraction tool. + +## Features + +- **Custom SQL Queries**: Write and execute custom SQL queries with pre-built query templates +- **Declaration Search**: Search declarations by name, type, file path, and namespace +- **Dependency Explorer**: Visualize dependencies and dependents of declarations +- **Forest Analysis**: Analyze connected components in the dependency graph + +## Installation + +Install the required dependencies using pip with the `app` extra: + +```bash +pip install -e ".[app]" +``` + +This will install: +- streamlit +- plotly +- networkx +- pandas + +## Usage + +1. **Generate a database** using the data extraction transform: + ```python + from itp_interface.tools.lean4_local_data_extraction_transform import Local4DataExtractionTransform + + transform = Local4DataExtractionTransform( + buffer_size=1000, + db_path="lean_declarations.db" + ) + # ... run your extraction ... + ``` + +2. **Launch the explorer**: + ```bash + cd src/app/data_explorer + streamlit run lean_db_explorer.py + ``` + +3. **Load your database**: + - Enter the path to your `.db` file in the sidebar + - Click "Load Database" + - Explore using the tabs! + +## App Tabs + +### šŸ” Custom Query +Write and execute custom SQL queries against the database. Includes pre-built queries for common tasks: +- Show all files +- Show all declarations +- Count declarations by type +- Find most depended-on declarations +- And more... + +### šŸ”Ž Search +Search declarations with filters: +- Name pattern (supports partial matching) +- File path +- Declaration type (theorem, def, axiom, etc.) +- Namespace + +### 🌲 Dependencies +Explore dependency relationships: +- **Show Dependencies**: What does this declaration depend on? +- **Show Dependents**: What depends on this declaration? +- **Configurable depth**: Control how deep to traverse +- **Dual view**: Table and interactive graph visualization + +### 🌳 Forests +Analyze connected components in the dependency graph: +- **Find All Forests**: Discover all connected components +- **Statistics**: See forest sizes, roots, and leaves +- **Visualization**: View selected forests as graphs + +## Database Schema + +The app expects a SQLite database with the following tables: + +- `files`: File metadata +- `imports`: Import relationships +- `declarations`: All declarations with full info +- `declaration_dependencies`: Dependency edges (from_decl_id → to_decl_id) + +See `src/itp_interface/tools/simple_sqlite.py` for the complete schema. + +## Tips + +- **Large graphs**: Forest visualization is limited to 100 nodes for performance +- **Query results**: All query results can be downloaded as CSV +- **Unresolved declarations**: Declarations with `file_path IS NULL` are unresolved (from dependencies) + +## Troubleshooting + +**Database not loading?** +- Check the file path is correct +- Ensure the database was created with the correct schema + +**Graph visualization slow?** +- Reduce the max depth for dependency exploration +- Use filters in Search to narrow results + +**Import errors?** +- Ensure you've installed the app dependencies: `pip install -e ".[app]"` +- Run from the correct directory: `cd src/app/data_explorer` + +## Development + +File structure: +``` +src/app/data_explorer/ +ā”œā”€ā”€ lean_db_explorer.py # Main Streamlit app +ā”œā”€ā”€ db_utils.py # Database query utilities +ā”œā”€ā”€ graph_utils.py # Graph analysis and visualization +└── README.md # This file +``` + +To modify: +1. Edit the Python files +2. Streamlit will auto-reload on file changes +3. Refresh browser to see updates diff --git a/src/app/data_explorer/__init__.py b/src/app/data_explorer/__init__.py new file mode 100644 index 0000000..c35c013 --- /dev/null +++ b/src/app/data_explorer/__init__.py @@ -0,0 +1 @@ +"""Lean Declaration Database Explorer package.""" diff --git a/src/app/data_explorer/db_utils.py b/src/app/data_explorer/db_utils.py new file mode 100644 index 0000000..6ab2635 --- /dev/null +++ b/src/app/data_explorer/db_utils.py @@ -0,0 +1,197 @@ +""" +Database utility functions for the Lean Declaration DB Explorer. +""" + +import sys +from pathlib import Path + +# Add parent directories to path to import from itp_interface +root_dir = Path(__file__).parent.parent.parent.parent +if str(root_dir) not in sys.path: + sys.path.insert(0, str(root_dir)) + +import pandas as pd +from typing import Dict, List, Any, Optional +from itp_interface.tools.simple_sqlite import LeanDeclarationDB + + +def execute_query(db: LeanDeclarationDB, query: str) -> pd.DataFrame: + """ + Execute a SQL query and return results as a DataFrame. + + Args: + db: LeanDeclarationDB instance + query: SQL query string + + Returns: + DataFrame with query results + """ + cursor = db.conn.cursor() + cursor.execute(query) + rows = cursor.fetchall() + + if rows: + # Get column names + columns = [description[0] for description in cursor.description] + # Convert to DataFrame + return pd.DataFrame([dict(row) for row in rows], columns=columns) + else: + return pd.DataFrame() + + +def get_common_queries() -> Dict[str, str]: + """ + Return a dictionary of common pre-built queries. + + Returns: + Dict mapping query name to SQL query string + """ + return { + "Show all files": """ + SELECT file_path, module_name + FROM files + ORDER BY file_path + """, + "Show all declarations": """ + SELECT name, namespace, decl_type, file_path, module_name + FROM declarations + WHERE file_path IS NOT NULL + ORDER BY file_path, line + LIMIT 100 + """, + "Count declarations by type": """ + SELECT decl_type, COUNT(*) as count + FROM declarations + WHERE decl_type IS NOT NULL + GROUP BY decl_type + ORDER BY count DESC + """, + "Count declarations by file": """ + SELECT file_path, COUNT(*) as count + FROM declarations + WHERE file_path IS NOT NULL + GROUP BY file_path + ORDER BY count DESC + LIMIT 20 + """, + "Show unresolved declarations": """ + SELECT name, namespace + FROM declarations + WHERE file_path IS NULL + LIMIT 100 + """, + "Show imports for all files": """ + SELECT f.file_path, i.module_name as import_module, i.text + FROM files f + JOIN imports i ON i.file_id = f.id + ORDER BY f.file_path + """, + "Show most depended-on declarations": """ + SELECT d.name, d.namespace, d.file_path, COUNT(*) as dependents_count + FROM declarations d + JOIN declaration_dependencies dd ON dd.to_decl_id = d.decl_id + GROUP BY d.decl_id + ORDER BY dependents_count DESC + LIMIT 20 + """, + "Show declarations with most dependencies": """ + SELECT d.name, d.namespace, d.file_path, COUNT(*) as dependencies_count + FROM declarations d + JOIN declaration_dependencies dd ON dd.from_decl_id = d.decl_id + GROUP BY d.decl_id + ORDER BY dependencies_count DESC + LIMIT 20 + """ + } + + +def search_declarations( + db: LeanDeclarationDB, + name_pattern: Optional[str] = None, + namespace: Optional[str] = None, + file_path: Optional[str] = None, + decl_type: Optional[str] = None +) -> pd.DataFrame: + """ + Search declarations with filters. + + Args: + db: LeanDeclarationDB instance + name_pattern: Name pattern (SQL LIKE syntax) + namespace: Namespace filter + file_path: File path filter + decl_type: Declaration type filter + + Returns: + DataFrame with matching declarations + """ + query = "SELECT * FROM declarations WHERE 1=1" + params = [] + + if name_pattern: + query += " AND name LIKE ?" + params.append(f"%{name_pattern}%") + + if namespace: + query += " AND namespace = ?" + params.append(namespace) + + if file_path: + query += " AND file_path LIKE ?" + params.append(f"%{file_path}%") + + if decl_type: + query += " AND decl_type = ?" + params.append(decl_type) + + query += " LIMIT 1000" + + cursor = db.conn.cursor() + cursor.execute(query, params) + rows = cursor.fetchall() + + if rows: + columns = [description[0] for description in cursor.description] + return pd.DataFrame([dict(row) for row in rows], columns=columns) + else: + return pd.DataFrame() + + +def get_all_declaration_names(db: LeanDeclarationDB) -> List[str]: + """ + Get all unique declaration names for autocomplete. + + Args: + db: LeanDeclarationDB instance + + Returns: + List of declaration names + """ + cursor = db.conn.cursor() + cursor.execute(""" + SELECT DISTINCT name + FROM declarations + WHERE file_path IS NOT NULL + ORDER BY name + """) + return [row[0] for row in cursor.fetchall()] + + +def get_declaration_types(db: LeanDeclarationDB) -> List[str]: + """ + Get all unique declaration types. + + Args: + db: LeanDeclarationDB instance + + Returns: + List of declaration types + """ + cursor = db.conn.cursor() + cursor.execute(""" + SELECT DISTINCT decl_type + FROM declarations + WHERE decl_type IS NOT NULL + ORDER BY decl_type + """) + return [row[0] for row in cursor.fetchall()] diff --git a/src/app/data_explorer/graph_utils.py b/src/app/data_explorer/graph_utils.py new file mode 100644 index 0000000..235773f --- /dev/null +++ b/src/app/data_explorer/graph_utils.py @@ -0,0 +1,336 @@ +""" +Graph analysis utilities for dependency visualization and forest detection. +""" + +import sys +from pathlib import Path + +# Add parent directories to path +root_dir = Path(__file__).parent.parent.parent.parent +if str(root_dir) not in sys.path: + sys.path.insert(0, str(root_dir)) + +import networkx as nx +import plotly.graph_objects as go +from typing import List, Set, Dict, Tuple, Any +from itp_interface.tools.simple_sqlite import LeanDeclarationDB + + +def build_dependency_graph(db: LeanDeclarationDB) -> nx.DiGraph: + """ + Build a directed graph from the dependency relationships. + + Only includes declarations with non-None decl_type (filters out unresolved declarations). + This is a read-only operation - the database is not modified. + + Args: + db: LeanDeclarationDB instance + + Returns: + NetworkX directed graph where edges point from dependency to dependent + (B -> A means "A depends on B") + - Root nodes (depend on nothing): in_degree == 0 + - Leaf nodes (nothing depends on them): out_degree == 0 + """ + G = nx.DiGraph() + + # Add all declarations as nodes, excluding those with None decl_type + cursor = db.conn.cursor() + cursor.execute(""" + SELECT decl_id, name, namespace, decl_type, file_path + FROM declarations + WHERE decl_type IS NOT NULL + """) + + for row in cursor.fetchall(): + decl_id, name, namespace, decl_type, file_path = row + full_name = f"{namespace}.{name}" if namespace else name + G.add_node(decl_id, name=name, full_name=full_name, + namespace=namespace, decl_type=decl_type, + file_path=file_path) + + # Add all dependency edges + # Edge direction: to_id -> from_id means "from_id depends on to_id" + # This way, root nodes (declarations that depend on nothing) have in_degree == 0 + cursor.execute(""" + SELECT from_decl_id, to_decl_id + FROM declaration_dependencies + """) + + for from_id, to_id in cursor.fetchall(): + # Only add edge if both nodes exist in the graph (not filtered out) + if to_id in G.nodes and from_id in G.nodes: + # Flip edge direction: dependency -> dependent + G.add_edge(to_id, from_id) + + return G + + +def find_all_forests(G: nx.DiGraph) -> List[Set[str]]: + """ + Find all connected components (forests) in the graph. + + Args: + G: NetworkX directed graph + + Returns: + List of sets, where each set contains decl_ids in a connected component + """ + # Convert to undirected for finding connected components + G_undirected = G.to_undirected() + components = list(nx.connected_components(G_undirected)) + + # Sort by size (largest first) + components.sort(key=len, reverse=True) + + return components + + +def get_dependencies_closure( + db: LeanDeclarationDB, + decl_id: str, + max_depth: int = 10 +) -> Tuple[List[Dict[str, Any]], nx.DiGraph]: + """ + Get all dependencies (transitive) of a declaration. + + Args: + db: LeanDeclarationDB instance + decl_id: Starting declaration ID + max_depth: Maximum depth to traverse + + Returns: + Tuple of (list of declaration dicts, subgraph) + """ + try: + G = build_dependency_graph(db) + + if decl_id not in G: + return [], nx.DiGraph() + + # Get all dependencies (follow incoming edges since edge direction is flipped) + # Edge B -> A means "A depends on B", so dependencies of A are predecessors + visited = set() + queue = [(decl_id, 0)] + visited.add(decl_id) + + while queue: + current, depth = queue.pop(0) + if depth >= max_depth: + continue + + try: + # Follow incoming edges to find dependencies + for neighbor in G.predecessors(current): + if neighbor not in visited: + visited.add(neighbor) + queue.append((neighbor, depth + 1)) + except Exception as e: + print(f"Warning: Error processing predecessors of {current}: {e}") + continue + + # Create subgraph + subgraph = G.subgraph(visited).copy() + + # Get declaration info for all nodes + decls = [] + for node_id in visited: + try: + decl = db.get_declaration_by_decl_id(node_id) + if decl: + decls.append(decl) + except Exception as e: + print(f"Warning: Error fetching declaration {node_id}: {e}") + continue + + return decls, subgraph + except Exception as e: + print(f"Error in get_dependencies_closure: {e}") + raise + + +def get_dependents_closure( + db: LeanDeclarationDB, + decl_id: str, + max_depth: int = 10 +) -> Tuple[List[Dict[str, Any]], nx.DiGraph]: + """ + Get all dependents (transitive) of a declaration. + + Args: + db: LeanDeclarationDB instance + decl_id: Starting declaration ID + max_depth: Maximum depth to traverse + + Returns: + Tuple of (list of declaration dicts, subgraph) + """ + try: + G = build_dependency_graph(db) + + if decl_id not in G: + return [], nx.DiGraph() + + # Get all dependents (follow outgoing edges since edge direction is flipped) + # Edge B -> A means "A depends on B", so dependents of B are successors + visited = set() + queue = [(decl_id, 0)] + visited.add(decl_id) + + while queue: + current, depth = queue.pop(0) + if depth >= max_depth: + continue + + try: + # Follow outgoing edges to find dependents + for neighbor in G.successors(current): + if neighbor not in visited: + visited.add(neighbor) + queue.append((neighbor, depth + 1)) + except Exception as e: + print(f"Warning: Error processing successors of {current}: {e}") + continue + + # Create subgraph + subgraph = G.subgraph(visited).copy() + + # Get declaration info for all nodes + decls = [] + for node_id in visited: + try: + decl = db.get_declaration_by_decl_id(node_id) + if decl: + decls.append(decl) + except Exception as e: + print(f"Warning: Error fetching declaration {node_id}: {e}") + continue + + return decls, subgraph + except Exception as e: + print(f"Error in get_dependents_closure: {e}") + raise + + +def visualize_graph(G: nx.DiGraph, title: str = "Dependency Graph") -> go.Figure: + """ + Create an interactive Plotly visualization of the graph. + + Args: + G: NetworkX directed graph + title: Title for the plot + + Returns: + Plotly Figure object + """ + if len(G.nodes()) == 0: + # Empty graph + fig = go.Figure() + fig.update_layout(title=title, annotations=[ + dict(text="No dependencies found", showarrow=False, + xref="paper", yref="paper", x=0.5, y=0.5) + ]) + return fig + + # Use spring layout for positioning + pos = nx.spring_layout(G, k=1, iterations=50) + + # Create edge traces + edge_x = [] + edge_y = [] + for edge in G.edges(): + x0, y0 = pos[edge[0]] + x1, y1 = pos[edge[1]] + edge_x.extend([x0, x1, None]) + edge_y.extend([y0, y1, None]) + + edge_trace = go.Scatter( + x=edge_x, y=edge_y, + line=dict(width=0.5, color='#888'), + hoverinfo='none', + mode='lines') + + # Create node traces + node_x = [] + node_y = [] + node_text = [] + node_color = [] + + # Color map for declaration types + type_colors = { + 'theorem': '#FF6B6B', + 'def': '#4ECDC4', + 'axiom': '#45B7D1', + 'instance': '#FFA07A', + 'class': '#98D8C8', + None: '#95A5A6' + } + + for node in G.nodes(): + x, y = pos[node] + node_x.append(x) + node_y.append(y) + + # Get node info + node_data = G.nodes[node] + name = node_data.get('full_name', node_data.get('name', 'Unknown')) + decl_type = node_data.get('decl_type') + file_path = node_data.get('file_path', 'Unknown') + + node_text.append(f"{name}
Type: {decl_type}
File: {file_path}") + node_color.append(type_colors.get(decl_type, type_colors[None])) + + node_trace = go.Scatter( + x=node_x, y=node_y, + mode='markers', + hoverinfo='text', + text=node_text, + marker=dict( + showscale=False, + color=node_color, + size=10, + line_width=2)) + + # Create figure + fig = go.Figure(data=[edge_trace, node_trace], + layout=go.Layout( + title=dict(text=title, font=dict(size=16)), + showlegend=False, + hovermode='closest', + margin=dict(b=0, l=0, r=0, t=40), + xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), + yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)) + ) + + return fig + + +def get_forest_statistics(G: nx.DiGraph, forest: Set[str]) -> Dict[str, Any]: + """ + Get statistics for a forest (connected component). + + Args: + G: NetworkX directed graph (edge B->A means "A depends on B") + forest: Set of decl_ids in the forest + + Returns: + Dictionary with forest statistics + """ + subgraph = G.subgraph(forest) + + # Find root nodes (declarations that depend on nothing) + # With edge direction B->A meaning "A depends on B", roots have in_degree == 0 + roots = [node for node in forest if subgraph.in_degree(node) == 0] + + # Find leaf nodes (declarations that nothing depends on) + # With edge direction B->A meaning "A depends on B", leaves have out_degree == 0 + leaves = [node for node in forest if subgraph.out_degree(node) == 0] + + return { + 'size': len(forest), + 'num_edges': subgraph.number_of_edges(), + 'num_roots': len(roots), + 'num_leaves': len(leaves), + 'roots': [G.nodes[r].get('full_name', r) for r in roots[:5]], # First 5 roots + 'leaves': [G.nodes[l].get('full_name', l) for l in leaves[:5]] # First 5 leaves + } diff --git a/src/app/data_explorer/lean_db_explorer.py b/src/app/data_explorer/lean_db_explorer.py new file mode 100644 index 0000000..675189b --- /dev/null +++ b/src/app/data_explorer/lean_db_explorer.py @@ -0,0 +1,480 @@ +""" +Lean Declaration Database Explorer + +A Streamlit web app for exploring and analyzing Lean declaration databases. +""" + +import sys +from pathlib import Path + +# Add parent directories to path +root_dir = Path(__file__).parent.parent.parent.parent +if str(root_dir) not in sys.path: + sys.path.insert(0, str(root_dir)) + +import streamlit as st +import pandas as pd +from itp_interface.tools.simple_sqlite import LeanDeclarationDB +import db_utils +import graph_utils + + +# Page configuration +st.set_page_config( + page_title="Lean Declaration DB Explorer", + page_icon="šŸ“Š", + layout="wide", + initial_sidebar_state="expanded" +) + + +def load_database(db_path: str) -> LeanDeclarationDB: + """Load database and cache in session state.""" + try: + db = LeanDeclarationDB(db_path) + return db + except Exception as e: + st.error(f"Failed to load database: {e}") + return None + + +def display_statistics(db: LeanDeclarationDB): + """Display database statistics in the sidebar.""" + stats = db.get_statistics() + + st.sidebar.markdown("### šŸ“Š Database Statistics") + col1, col2 = st.sidebar.columns(2) + + with col1: + st.metric("Files", stats.get('total_files', 0)) + st.metric("Declarations", stats.get('total_declarations', 0)) + + with col2: + st.metric("Dependencies", stats.get('total_dependencies', 0)) + st.metric("Imports", stats.get('total_imports', 0)) + + st.sidebar.metric("Unresolved", stats.get('unresolved_declarations', 0)) + + +def tab_custom_query(db: LeanDeclarationDB): + """Custom Query tab interface.""" + st.header("šŸ” Custom SQL Query") + + # Example queries dropdown + st.markdown("**Pre-built Queries:**") + common_queries = db_utils.get_common_queries() + query_name = st.selectbox( + "Select a query to load", + [""] + list(common_queries.keys()), + label_visibility="collapsed" + ) + + # Query text area - use query_name as part of key to force refresh + if query_name: + default_query = common_queries.get(query_name) + else: + default_query = "SELECT * FROM declarations LIMIT 10" + + query = st.text_area( + "SQL Query", + value=default_query, + height=150, + key=f"custom_query_{query_name}" + ) + + col1, col2 = st.columns([1, 5]) + with col1: + execute_button = st.button("Execute Query", type="primary") + + if execute_button: + try: + with st.spinner("Executing query..."): + df = db_utils.execute_query(db, query) + + if not df.empty: + st.success(f"Query returned {len(df)} rows") + st.dataframe(df, use_container_width=True, height=500) + + # Download button + csv = df.to_csv(index=False) + st.download_button( + label="Download CSV", + data=csv, + file_name="query_results.csv", + mime="text/csv" + ) + else: + st.info("Query returned no results") + + except Exception as e: + st.error(f"Query error: {e}") + + +def tab_search(db: LeanDeclarationDB): + """Search & Browse tab interface.""" + st.header("šŸ”Ž Search Declarations") + + col1, col2 = st.columns(2) + + with col1: + name_pattern = st.text_input("Name pattern", placeholder="e.g., add, mul, theorem") + file_path = st.text_input("File path contains", placeholder="e.g., Mathlib/Data") + + with col2: + decl_types = [""] + db_utils.get_declaration_types(db) + decl_type = st.selectbox("Declaration type", decl_types) + + namespace = st.text_input("Namespace", placeholder="e.g., Nat, List") + + if st.button("Search", type="primary"): + with st.spinner("Searching..."): + df = db_utils.search_declarations( + db, + name_pattern=name_pattern if name_pattern else None, + namespace=namespace if namespace else None, + file_path=file_path if file_path else None, + decl_type=decl_type if decl_type else None + ) + + if not df.empty: + st.success(f"Found {len(df)} declarations") + + # Show results + st.dataframe(df, use_container_width=True, height=500) + + # Download button + csv = df.to_csv(index=False) + st.download_button( + label="Download Results", + data=csv, + file_name="search_results.csv", + mime="text/csv" + ) + else: + st.info("No declarations found matching the criteria") + + +def tab_dependencies(db: LeanDeclarationDB): + """Dependency Explorer tab interface.""" + st.header("🌲 Dependency Explorer") + + # Declaration selector with option to search by name or decl_id + st.markdown("**Select a declaration:**") + + search_mode = st.radio( + "Search by", + ["Declaration name", "Declaration ID"], + horizontal=True, + help="Search by name or directly by decl_id" + ) + + if search_mode == "Declaration name": + decl_input = st.text_input("Declaration name", placeholder="e.g., dvd_trans, Nat.add") + else: + decl_input = st.text_input("Declaration ID", placeholder="e.g., Nat.add_comm_123abc") + + col1, col2, col3 = st.columns(3) + + with col1: + show_deps = st.button("Show Dependencies", type="primary") + with col2: + show_dependents = st.button("Show Dependents", type="primary") + with col3: + max_depth = st.number_input("Max depth", min_value=1, max_value=20, value=5) + + if show_deps or show_dependents: + if not decl_input: + st.warning(f"Please enter a declaration {search_mode.lower()}") + return + + # Find declaration by name or ID + decl_id = None + display_name = decl_input + + if search_mode == "Declaration name": + search_df = db_utils.search_declarations(db, name_pattern=decl_input) + + if search_df.empty: + st.error(f"Declaration '{decl_input}' not found") + return + + if len(search_df) > 1: + st.warning(f"Found {len(search_df)} declarations with this name. Using the first one.") + st.dataframe(search_df[['decl_id', 'name', 'namespace', 'file_path', 'decl_type']]) + + decl_id = search_df.iloc[0]['decl_id'] + display_name = search_df.iloc[0]['name'] + else: + # Direct decl_id lookup + decl = db.get_declaration_by_decl_id(decl_input) + if decl: + decl_id = decl_input + display_name = decl.get('name', decl_input) + # Show the declaration info + st.info(f"Found: {display_name} (Type: {decl.get('decl_type', 'unknown')})") + else: + st.error(f"Declaration with ID '{decl_input}' not found") + return + + try: + with st.spinner("Analyzing dependencies..."): + if show_deps: + decls, subgraph = graph_utils.get_dependencies_closure(db, decl_id, max_depth) + title = f"Dependencies of {display_name}" + else: + decls, subgraph = graph_utils.get_dependents_closure(db, decl_id, max_depth) + title = f"Dependents of {display_name}" + + if decls: + st.success(f"Found {len(decls)} related declarations") + + # Show tabs for table and graph views + tab1, tab2 = st.tabs(["šŸ“Š Table View", "šŸ“ˆ Graph View"]) + + with tab1: + df = pd.DataFrame(decls) + st.dataframe(df[['decl_id', 'name', 'namespace', 'decl_type', 'file_path', 'line']], + use_container_width=True, height=400) + + with tab2: + fig = graph_utils.visualize_graph(subgraph, title) + st.plotly_chart(fig, use_container_width=True) + else: + st.info("No dependencies found") + except Exception as e: + st.error(f"Error analyzing dependencies: {str(e)}") + st.exception(e) + + +def tab_forests(db: LeanDeclarationDB): + """Forest Analysis tab interface.""" + st.header("🌳 Forest Analysis") + + st.markdown(""" + A **forest** is a connected component in the dependency graph. Declarations in the same + forest are connected through dependency relationships. **Root nodes** are declarations + with no dependencies within the forest (no incoming edges). + """) + + # Initialize session state for forests + if 'forests' not in st.session_state: + st.session_state.forests = None + st.session_state.forest_graph = None + + if st.button("Find All Forests", type="primary"): + with st.spinner("Analyzing dependency graph..."): + G = graph_utils.build_dependency_graph(db) + forests = graph_utils.find_all_forests(G) + # Store in session state + st.session_state.forests = forests + st.session_state.forest_graph = G + + st.success(f"Found {len(forests)} forests") + + # Use forests from session state + forests = st.session_state.forests + G = st.session_state.forest_graph + + if forests is not None: + # Display statistics + st.markdown("### Forest Statistics") + + forest_data = [] + for i, forest in enumerate(forests[:20]): # Limit to top 20 + stats = graph_utils.get_forest_statistics(G, forest) + forest_data.append({ + 'Forest ID': i + 1, + 'Size': stats['size'], + 'Edges': stats['num_edges'], + 'Roots': stats['num_roots'], + 'Leaves': stats['num_leaves'], + 'Sample Roots': ', '.join(stats['roots'][:3]) + }) + + df = pd.DataFrame(forest_data) + st.dataframe(df, use_container_width=True) + + # Show declarations in selected forest + st.markdown("### View Forest Details") + forest_id = st.number_input( + "Forest ID to view", + min_value=1, + max_value=len(forests), + value=1 + ) + + col1, col2 = st.columns(2) + with col1: + show_all_decls = st.button("Show All Declarations") + with col2: + show_roots = st.button("Show Root Nodes Only") + + if show_all_decls or show_roots: + if G is None: + st.error("Graph not available. Please click 'Find All Forests' first.") + return + + forest = forests[forest_id - 1] + stats = graph_utils.get_forest_statistics(G, forest) + + if show_roots: + st.markdown(f"#### 🌱 Root Nodes of Forest #{forest_id}") + st.info(f"Found {stats['num_roots']} root nodes (declarations with no dependencies)") + + # Get root node details + root_decls = [] + subgraph = G.subgraph(forest) + root_ids = [node for node in forest if subgraph.in_degree(node) == 0] + + for decl_id in root_ids: + decl = db.get_declaration_by_decl_id(decl_id) + if decl: + # Count direct dependents + num_dependents = subgraph.out_degree(decl_id) + root_decls.append({ + 'decl_id': decl['decl_id'], + 'name': decl['name'], + 'namespace': decl.get('namespace', ''), + 'decl_type': decl.get('decl_type', ''), + 'dependents_count': num_dependents, + 'file_path': decl.get('file_path', ''), + 'line': decl.get('line', '') + }) + + if root_decls: + root_df = pd.DataFrame(root_decls) + # Sort by number of dependents + root_df = root_df.sort_values('dependents_count', ascending=False) + st.dataframe(root_df, use_container_width=True, height=500) + + # Download button + csv = root_df.to_csv(index=False) + st.download_button( + label="Download Root Nodes", + data=csv, + file_name=f"forest_{forest_id}_roots.csv", + mime="text/csv" + ) + else: + st.warning("No root nodes found in this forest.") + + else: # show_all_decls + st.info(f"Forest #{forest_id} contains {len(forest)} declarations") + + # Get all declarations in this forest + forest_decls = [] + subgraph = G.subgraph(forest) + + for decl_id in forest: + decl = db.get_declaration_by_decl_id(decl_id) + if decl: + # Check if it's a root node + is_root = subgraph.in_degree(decl_id) == 0 + # Truncate text and proof for display + text = decl.get('text', '') + proof = decl.get('proof', '') + forest_decls.append({ + 'decl_id': decl['decl_id'], + 'name': decl['name'], + 'namespace': decl.get('namespace', ''), + 'decl_type': decl.get('decl_type', ''), + 'is_root': '🌱' if is_root else '', + 'text': text[:100] + '...' if text and len(text) > 100 else text, + 'proof': proof[:100] + '...' if proof and len(proof) > 100 else proof, + 'file_path': decl.get('file_path', ''), + 'line': decl.get('line', '') + }) + + if forest_decls: + forest_df = pd.DataFrame(forest_decls) + st.dataframe(forest_df, use_container_width=True, height=500) + + # Download button + csv = forest_df.to_csv(index=False) + st.download_button( + label="Download Forest Declarations", + data=csv, + file_name=f"forest_{forest_id}_declarations.csv", + mime="text/csv" + ) + else: + st.warning("No declarations found in this forest.") + else: + st.info("Click 'Find All Forests' to analyze the dependency graph") + + +def main(): + """Main application.""" + st.title("šŸ“Š Lean Declaration Database Explorer") + + # Sidebar + st.sidebar.title("Database Connection") + + # Database path input + db_path = st.sidebar.text_input( + "Database Path", + value="lean_declarations.db", + help="Path to the SQLite database file" + ) + + load_button = st.sidebar.button("Load Database", type="primary") + + # Initialize session state + if 'db' not in st.session_state: + st.session_state.db = None + + if load_button: + if Path(db_path).exists(): + with st.spinner("Loading database..."): + st.session_state.db = load_database(db_path) + + if st.session_state.db: + st.sidebar.success("Database loaded successfully!") + else: + st.sidebar.error(f"Database file not found: {db_path}") + + # Show statistics if database is loaded + if st.session_state.db: + display_statistics(st.session_state.db) + + # Main content tabs + tab1, tab2, tab3, tab4 = st.tabs([ + "šŸ” Custom Query", + "šŸ”Ž Search", + "🌲 Dependencies", + "🌳 Forests" + ]) + + with tab1: + tab_custom_query(st.session_state.db) + + with tab2: + tab_search(st.session_state.db) + + with tab3: + tab_dependencies(st.session_state.db) + + with tab4: + tab_forests(st.session_state.db) + + else: + st.info("šŸ‘ˆ Please load a database to get started") + + st.markdown(""" + ### Getting Started + + 1. Enter the path to your SQLite database file in the sidebar + 2. Click "Load Database" + 3. Explore your data using the tabs above + + ### Features + + - **Custom Query**: Write and execute custom SQL queries + - **Search**: Search declarations by name, type, file, etc. + - **Dependencies**: Explore dependency relationships + - **Forests**: Analyze connected components in the dependency graph + """) + + +if __name__ == "__main__": + main() diff --git a/src/app/itp-gui/app.py b/src/app/itp-gui/app.py index a169744..1f42a92 100644 --- a/src/app/itp-gui/app.py +++ b/src/app/itp-gui/app.py @@ -19,7 +19,7 @@ from typing import Optional, Dict, Any, List import json -from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor +from itp_interface.lean.simple_lean4_sync_executor import SimpleLean4SyncExecutor from itp_interface.lean_server.lean_context import ProofContext app = Flask(__name__, static_folder='static', template_folder='templates') diff --git a/src/data/test/batteries b/src/data/test/batteries new file mode 160000 index 0000000..8da40b7 --- /dev/null +++ b/src/data/test/batteries @@ -0,0 +1 @@ +Subproject commit 8da40b72fece29b7d3fe3d768bac4c8910ce9bee diff --git a/src/itp_interface/lean/__init__.py b/src/itp_interface/lean/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/itp_interface/lean/lean4_local_data_extraction_transform.py b/src/itp_interface/lean/lean4_local_data_extraction_transform.py new file mode 100644 index 0000000..3b6be6d --- /dev/null +++ b/src/itp_interface/lean/lean4_local_data_extraction_transform.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python3 + +import os +import sys +dir_name = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) +root_dir = os.path.abspath(dir_name) +if root_dir not in sys.path: + sys.path.append(root_dir) +import typing +import uuid +import json +from pathlib import Path +from filelock import FileLock +from itp_interface.lean.simple_lean4_sync_executor import SimpleLean4SyncExecutor +from itp_interface.lean.parsing_helpers import LeanDeclParser, LeanParseResult, LeanDeclType +from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType +from itp_interface.tools.training_data_format import MergableCollection, TrainingDataMetadataFormat, ExtractionDataCollection +from itp_interface.tools.training_data import TrainingData, DataLayoutFormat +from itp_interface.lean.tactic_parser import FileDependencyAnalysis, DeclWithDependencies +from itp_interface.tools.simple_sqlite import LeanDeclarationDB + +class Local4DataExtractionTransform(GenericTrainingDataGenerationTransform): + def __init__(self, + depth = None, + max_search_results = None, + buffer_size : int = 10000, + logger = None, + max_parallelism : int = 4, + db_path : typing.Optional[str] = None, + enable_file_export: bool = False, + enable_dependency_extraction: bool = False): + super().__init__(TrainingDataGenerationType.LOCAL, buffer_size, logger) + self.depth = depth + self.max_search_results = max_search_results + self.max_parallelism = max_parallelism + self.db_path = db_path # Store path, don't create connection yet (for Ray actors) + self.enable_file_export = enable_file_export + self.enable_dependency_extraction = enable_dependency_extraction + # Everything except UNKNOWN + self.supported_declaration_types = [ str(dt) for dt in LeanDeclType if dt != LeanDeclType.UNKNOWN] + if self.db_path is not None: + temp_db_path = Path(self.db_path) + cache_path = temp_db_path.parent / "cache" + os.makedirs(str(cache_path), exist_ok=True) + mapping_path = cache_path / f"{temp_db_path.stem}_declaration_mapping.json" + self.cache_path = str(cache_path) + self.mapping_path = str(mapping_path) + self._lock_path = os.path.join(self.cache_path, "lean_declaration_mapping.lock") + else: + self.cache_path = None + self.mapping_path = None + self._lock_path = None + + def get_meta_object(self) -> TrainingDataMetadataFormat: + return TrainingDataMetadataFormat( + training_data_buffer_size=self.buffer_size, + data_filename_prefix="extraction_data_", + lemma_ref_filename_prefix="extraction_lemma_refs_") + + def get_data_collection_object(self) -> MergableCollection: + return ExtractionDataCollection() + + def load_meta_from_file(self, file_path) -> MergableCollection: + return TrainingDataMetadataFormat.load_from_file(file_path) + + def load_data_from_file(self, file_path) -> MergableCollection: + return ExtractionDataCollection.load_from_file(file_path, self.logger) + + def _remove_lake_package_prefix(self, module_name: str) -> str: + if module_name.startswith("lake.packages.") or module_name.startswith(".lake.packages."): + parts = module_name.split('.') + if len(parts) > 3: + module_name = '.'.join(parts[3:]) + else: + module_name = '' + return module_name + + def _remove_lake_file_prefix(self, path: Path) -> str: + parts = path.parts + if ".lake" in parts: + lake_index = parts.index(".lake") + new_parts = parts[lake_index + 3:] + new_path = Path(*new_parts) + return str(new_path) + return str(path) + + def _check_if_lean_core_file(self, file_path: Path) -> bool: + parts = file_path.parts + # If `leanprover--lean4---v4.24.0/src/lean/` is in the path, it's a core file + for i in range(len(parts) - 4): + if (parts[i].startswith("leanprover--lean4---v") and + parts[i+1] == "src" and + parts[i+2] == "lean"): + return True + return False + + def _check_if_lean_core_module(self, module_name: str) -> bool: + parts = module_name.split('.') + # leanprover--lean4---v4.24.0.src.lean + idx = None + for i in range(len(parts) - 2): + if parts[i].startswith("leanprover--lean4---v"): + idx = i + if idx is not None: + for i in range(idx, len(parts)): + if parts[i] == "src" and i + 1 < len(parts) and parts[i + 1] == "lean": + return True + return False + + def _strip_core_file_path(self, file_path: Path) -> str: + # Take only the parts after `leanprover--lean4---v4.24.0/src/lean/` + parts = file_path.parts + for i in range(len(parts) - 4): + if (parts[i].startswith("leanprover--lean4---v") and + parts[i+1] == "src" and + parts[i+2] == "lean"): + new_parts = parts[i+3:] + new_path = Path(*new_parts) + return str(new_path) + return str(file_path) + + def _strip_core_module_name(self, module_name: str) -> str: + parts = module_name.split('.') + # leanprover--lean4---v4.24.0.src.lean + idx: int | None = None + for i in range(len(parts) - 2): + if parts[i].startswith("leanprover--lean4---v"): + idx = i + if idx is not None: + for i in range(idx, len(parts)): + if parts[i] == "src" and i + 1 < len(parts) and parts[i + 1] == "lean": + new_parts = parts[i + 2:] + return '.'.join(new_parts) + return module_name + + def _get_file_path(self, project_path: str, file_path: str) -> str: + fp = Path(file_path) + pp = Path(project_path) + fp_abs = fp.resolve() + pp_abs = pp.resolve() + if self._check_if_lean_core_file(fp_abs): + stripped_path = self._strip_core_file_path(fp_abs) + return stripped_path + relative_path = fp_abs.relative_to(pp_abs) + rel_file_path = self._remove_lake_file_prefix(relative_path) + return str(rel_file_path) + + def _get_module_name(self, project_path: str, module_name: str) -> str: + if self._check_if_lean_core_module(module_name): + stripped_module = self._strip_core_module_name(module_name) + return stripped_module + pp = Path(project_path) + pp_abs = pp.resolve() + pp_module = str(pp_abs).replace('/', '.') + pp_module = pp_module.lstrip('.') + # self.logger.info(f"Project module prefix: {pp_module}") + if module_name.startswith(pp_module): + module_name = module_name[len(pp_module):] + module_name = module_name.lstrip('.') + # self.logger.info(f"Module name after removing project prefix: {module_name}") + module_name = self._remove_lake_package_prefix(module_name) + return module_name + + def _remap_local_dependency(self, + project_path: str, + file_dep_analyses: typing.List[FileDependencyAnalysis]) -> None: + # Map local dependencies + all_decls : dict[str, DeclWithDependencies] = {} + for fda in file_dep_analyses: + for decl in fda.declarations: + name = decl.decl_info.name + if name is not None: + all_decls[name] = decl + for fda in file_dep_analyses: + fda_rel_path = self._get_file_path(project_path, fda.file_path) + # Remove the pp_module prefix from the fda module name + fda_module_name = self._get_module_name(project_path, fda.module_name) + for decl in fda.declarations: + for decl_dep in decl.dependencies: + if decl_dep.name in all_decls: + local_decl = all_decls[decl_dep.name] + if decl_dep.file_path is None and \ + local_decl.decl_info.decl_type in self.supported_declaration_types: + # We need to ensure that it is locally defined. + decl_dep.file_path = fda_rel_path + if decl_dep.module_name is None: + decl_dep.module_name = fda_module_name + if decl_dep.namespace is None: + decl_dep.namespace = decl.decl_info.namespace + + def _preprocess_declarations(self, + project_path: str, + file_dep_analyses: typing.List[FileDependencyAnalysis]) -> None: + # Preprocess declarations to set file paths and module names + for fda in file_dep_analyses: + for decl in fda.declarations: + self._preprocess_declaration(project_path, decl) + + def _preprocess_declaration( + self, + project_path: str, + decl: DeclWithDependencies) -> None: + # Filter all unknown types + if decl.decl_info.decl_type == LeanDeclType.UNKNOWN.value: + # Check if we have docstring or not + if decl.decl_info.doc_string is None or decl.decl_info.doc_string.strip() != "": + parser = LeanDeclParser(decl.decl_info.text) + parse_result = parser.parse() + if parse_result.doc_string is not None: + decl.decl_info.doc_string = parse_result.doc_string.strip() + if parse_result.decl_type != LeanDeclType.UNKNOWN: + decl.decl_info.decl_type = str(parse_result.decl_type) + if parse_result.text is not None: + full_text = [] + if parse_result.text_before is not None: + full_text.append(parse_result.text_before.strip()) + full_text.append(parse_result.text.strip()) + text = "\n".join(full_text) + decl.decl_info.text = text + # Update the proof if not already present + if decl.decl_info.proof is None and parse_result.proof is not None: + decl.decl_info.proof = parse_result.proof.strip() + if parse_result.name is not None: + decl.decl_info.name = parse_result.name.strip() + pass + + + def __call__(self, + training_data: TrainingData, + project_id : str, + lean_executor: SimpleLean4SyncExecutor, + print_coq_executor_callback: typing.Callable[[], SimpleLean4SyncExecutor], + theorems: typing.List[str] = None, + other_args: dict = {}) -> TrainingData: + file_path = lean_executor.main_file + project_path = project_id + rel_file_path = self._get_file_path(project_path, file_path) + file_namespace = rel_file_path.replace('/', '.') + self.logger.info(f"=========================Processing {file_namespace}=========================") + + # Create database connection for this Ray actor (if db_path is provided) + db = None + if self.db_path: + db = LeanDeclarationDB(self.db_path) + self.logger.info(f"Connected to database: {self.db_path}") + + try: + if isinstance(theorems, list) and len(theorems) == 1 and theorems[0] == "*": + theorems = None + else: + theorems = set(theorems) if theorems is not None else None + cnt = 0 + temp_dir = os.path.join(training_data.folder, "temp") if self.cache_path is None else self.cache_path + temp_dir = temp_dir.rstrip('/') + if self.cache_path is None: + os.makedirs(temp_dir, exist_ok=True) + json_output_path = f"{temp_dir}/{file_namespace.replace('.', '_')}.lean.deps.json" + already_mapped = False + if self._lock_path is not None: + with FileLock(self._lock_path): + assert self.mapping_path is not None + if not os.path.exists(self.mapping_path): + with open(self.mapping_path, "w") as f: + json.dump({}, f) + with open(self.mapping_path, "r") as f: + mapping_data = json.load(f) + if json_output_path in mapping_data: + already_mapped = True + else: + mapping_data[json_output_path] = True + with open(self.mapping_path, "w") as f: + json.dump(mapping_data, f) + if already_mapped and os.path.exists(json_output_path): + # Read from existing file + self.logger.info(f"Using existing dependency extraction file: {json_output_path}") + with open(json_output_path, 'r') as f: + data = json.load(f) + file_dep_analyses = [FileDependencyAnalysis.model_validate(data)] + else: + file_dep_analyses = lean_executor.extract_all_theorems_and_definitions(json_output_path=json_output_path) + self.logger.info(f"Extracted {len(file_dep_analyses)} FileDependencyAnalysis objects from {file_namespace}") + self.logger.info(f"file_dep_analyses: {file_dep_analyses}") + assert len(file_dep_analyses) == 1, "Expected exactly one FileDependencyAnalysis object" + + last_decl_id = None + self._preprocess_declarations(project_path, file_dep_analyses) + self._remap_local_dependency(project_path, file_dep_analyses) + + for fda in file_dep_analyses: + fda_rel_path = self._get_file_path(project_path, fda.file_path) + # Remove the pp_module prefix from the fda module name + fda_module_name = self._get_module_name(project_path, fda.module_name) + + # Insert file and imports into database (if db is enabled) + if db and fda.imports: + db.insert_file_imports(fda_rel_path, fda_module_name, fda.imports) + self.logger.info(f"Inserted file and {len(fda.imports)} imports for {fda_rel_path}") + + for decl in fda.declarations: + # Get or create decl_id from database (or generate new one if no DB) + if db: + if decl.decl_info.decl_type not in self.supported_declaration_types: + self.logger.info(f"Skipping declaration '{decl.decl_info.name}' of unsupported type '{decl.decl_info.decl_type}'") + continue + decl_id = db.process_declaration( + fda_file_path=fda_rel_path, + fda_module_name=fda_module_name, + decl=decl, + enable_dependency_extraction=self.enable_dependency_extraction + ) + self.logger.info(f"Processed declaration '{decl.decl_info.name}' with ID: {decl_id}") + else: + # Fallback: generate unique ID without database + timestamp = str(int(uuid.uuid1().time_low)) + random_id = str(uuid.uuid4()) + decl_id = f"{timestamp}_{random_id}" + if self.enable_file_export: + new_fda = FileDependencyAnalysis( + file_path=str(fda_rel_path), + module_name=fda_module_name, + imports=fda.imports, + declarations=[]) + line_info = decl.decl_info + if theorems is not None and line_info.name not in theorems: + continue + decl.decl_id = decl_id + new_fda.declarations.append(decl) + training_data.merge(new_fda) + cnt += 1 + last_decl_id = decl_id + + if last_decl_id: + training_data.meta.last_proof_id = last_decl_id + self.logger.info(f"===============Finished processing {file_namespace}=====================") + self.logger.info(f"Total declarations processed in this transform: {cnt}") + return training_data + finally: + # Clean up database connection + if db: + db.close() + self.logger.info("Closed database connection") + + +if __name__ == "__main__": + import os + import logging + import time + os.chdir(root_dir) + # project_dir = 'data/test/lean4_proj/' + project_dir = 'data/test/Mathlib' + # file_name = 'data/test/lean4_proj/Lean4Proj/Basic.lean' + # file_name = 'data/test/Mathlib/.lake/packages/mathlib/Mathlib/Algebra/Divisibility/Basic.lean' + home_path = str(Path.home()) + file_name = f'{home_path}/.elan/toolchains/leanprover--lean4---v4.24.0/src/lean/Init/Prelude.lean' + project_id = project_dir #.replace('/', '.') + time_str = time.strftime("%Y%m%d-%H%M%S") + output_path = f".log/local_data_generation_transform/data/{time_str}" + log_path = f".log/local_data_generation_transform/log/{time_str}" + log_file = f"{log_path}/local_data_generation_transform-{time_str}.log" + db_path = f"{log_path}/lean_declarations.db" + os.makedirs(output_path, exist_ok=True) + os.makedirs(log_path, exist_ok=True) + logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') + logger = logging.getLogger(__name__) + def _print_lean_executor_callback(): + search_lean_exec = SimpleLean4SyncExecutor(main_file=file_name, project_root=project_dir) + search_lean_exec.__enter__() + return search_lean_exec + # Create transform with database enabled + transform = Local4DataExtractionTransform(0, buffer_size=1000, db_path=db_path) + logger.info(f"Using database: {db_path}") + training_data = TrainingData( + output_path, + "training_metadata.json", + training_meta=transform.get_meta_object(), + logger=logger, + layout=DataLayoutFormat.DECLARATION_EXTRACTION) + with SimpleLean4SyncExecutor(project_root=project_dir, main_file=file_name, use_human_readable_proof_context=True, suppress_error_log=True) as coq_exec: + transform(training_data, project_id, coq_exec, _print_lean_executor_callback, theorems=["*"]) + save_info = training_data.save() + logger.info(f"Saved training data to {save_info}") \ No newline at end of file diff --git a/src/itp_interface/lean/parsing_helpers.py b/src/itp_interface/lean/parsing_helpers.py new file mode 100644 index 0000000..de16b24 --- /dev/null +++ b/src/itp_interface/lean/parsing_helpers.py @@ -0,0 +1,376 @@ +import sys +import re +from enum import Enum +from dataclasses import dataclass +from typing import Optional, Set + +class LeanDeclType(Enum): + LEMMA = "lemma" + THEOREM = "theorem" + DEF = "def" + STRUCTURE = "structure" + CLASS = "class" + INDUCTIVE = "inductive" + INSTANCE = "instance" + ABBREV = "abbrev" + ABBREVIATION = "abbreviation" + AXIOM = "axiom" + EXAMPLE = "example" + OPAQUE = "opaque" + CONSTANT = "constant" + MUTUAL = "mutual" + UNKNOWN = "unknown" + + def __str__(self) -> str: + return str(self.value).lower() + +@dataclass +class LeanParseResult: + decl_type: LeanDeclType + name: Optional[str] = None + text_before: Optional[str] = None + doc_string: Optional[str] = None + text: Optional[str] = None + proof: Optional[str] = None + +class LeanDeclParser: + """ + Parses Lean 4 declarations to separate context, docstrings, + declaration headers, and proofs/bodies. + """ + + # Keywords that mark the start of a declaration + DECL_KEYWORDS = {m.value for m in LeanDeclType if m != LeanDeclType.UNKNOWN} + + # Types for which we should NOT attempt to extract a proof/body + NO_PROOF_TYPES = { + LeanDeclType.INDUCTIVE, + LeanDeclType.MUTUAL, + LeanDeclType.STRUCTURE, + LeanDeclType.CLASS + } + + # Types that typically don't have a name + NO_NAME_TYPES = { + LeanDeclType.EXAMPLE, + LeanDeclType.MUTUAL, + LeanDeclType.UNKNOWN + } + + def __init__(self, text: str): + self.text = text + self.n = len(text) + self.tokens = [] + self.docstring_range = None # (start, end) + + # Key Indices and Info + self.decl_start = -1 + self.proof_start = -1 + self.decl_type: LeanDeclType = LeanDeclType.UNKNOWN + self.decl_name: Optional[str] = None + + def parse(self) -> LeanParseResult: + self._tokenize() + self._analyze_structure() + return self._construct_result() + + def _tokenize(self): + """ + Scans text to find tokens, respecting comments, strings, and nesting. + """ + i = 0 + # States + NORMAL = 0 + IN_STRING = 1 + IN_CHAR = 2 + + state = NORMAL + nesting = 0 # () [] {} + + while i < self.n: + # 1. Handle Comments (Line and Block) + if state == NORMAL: + if self.text.startswith("--", i): + # Line comment + end_line = self.text.find('\n', i) + if end_line == -1: end_line = self.n + i = end_line + continue + + if self.text.startswith("/-", i): + # Block comment + is_doc = self.text.startswith("/--", i) + start_idx = i + + # Find end of block comment (handle nesting) + depth = 1 + i += 2 + while i < self.n and depth > 0: + if self.text.startswith("/-", i): + depth += 1 + i += 2 + elif self.text.startswith("-/", i): + depth -= 1 + i += 2 + else: + i += 1 + + # Capture the FIRST docstring found + if is_doc and self.docstring_range is None: + self.docstring_range = (start_idx, i) + continue + + # 2. Handle Strings/Chars + if state == NORMAL: + if self.text[i] == '"': + state = IN_STRING + i += 1 + continue + if self.text[i] == "'": + state = IN_CHAR + i += 1 + continue + + elif state == IN_STRING: + if self.text[i] == '\\': i += 2; continue + if self.text[i] == '"': state = NORMAL; i += 1; continue + i += 1 + continue + + elif state == IN_CHAR: + if self.text[i] == '\\': i += 2; continue + if self.text[i] == "'": state = NORMAL; i += 1; continue + i += 1 + continue + + # 3. Handle Structure Tokens in NORMAL state + char = self.text[i] + + # Nesting tracking + if char in "([{": + nesting += 1 + i += 1 + continue + elif char in ")]}": + nesting = max(0, nesting - 1) + i += 1 + continue + + # Token detection (only at top level) + if nesting == 0: + # Check for 'in' keyword (standalone) + if self._is_keyword_at(i, "in"): + self.tokens.append(("IN", i, i+2)) + i += 2 + continue + + # Check for Declaration Keywords + match_kw = self._match_any_keyword(i, self.DECL_KEYWORDS) + if match_kw: + kw, length = match_kw + self.tokens.append(("KW", i, i+length)) + i += length + continue + + # Check for Attribute Start + if self.text.startswith("@[", i): + self.tokens.append(("ATTR", i, i+2)) + i += 2 + nesting += 1 # The '[' counts as nesting + continue + + # Check for Proof Starters + if self.text.startswith(":=", i): + self.tokens.append(("PROOF", i, i+2)) + i += 2 + continue + if self.text.startswith("where", i) and self._is_word_boundary(i+5): + self.tokens.append(("PROOF", i, i+5)) + i += 5 + continue + if char == '|': + self.tokens.append(("PROOF", i, i+1)) + i += 1 + continue + + i += 1 + + def _analyze_structure(self): + """ + Interpret the token stream to find split points. + """ + candidate_decl = -1 + decl_keyword_str = None + decl_keyword_end = -1 + + # Pass 1: Find Declaration Start + for t_type, t_start, t_end in self.tokens: + if t_type == "IN": + candidate_decl = -1 # Reset candidate + decl_keyword_str = None + + elif t_type == "KW": + if candidate_decl == -1: + candidate_decl = t_start + if decl_keyword_str is None: + decl_keyword_str = self.text[t_start:t_end] + decl_keyword_end = t_end + + elif t_type == "ATTR": + if candidate_decl == -1: + candidate_decl = t_start + + if candidate_decl != -1: + self.decl_start = candidate_decl + + # Resolve Enum Type + if decl_keyword_str: + try: + self.decl_type = LeanDeclType(decl_keyword_str) + except ValueError: + self.decl_type = LeanDeclType.UNKNOWN + else: + self.decl_type = LeanDeclType.UNKNOWN + + # Extract Name + if self.decl_type not in self.NO_NAME_TYPES and decl_keyword_end != -1: + self.decl_name = self._extract_name_after(decl_keyword_end) + + # Pass 2: Find Proof Start + skip_proof = self.decl_type in self.NO_PROOF_TYPES + + if not skip_proof: + for t_type, t_start, t_end in self.tokens: + if t_start > self.decl_start and t_type == "PROOF": + self.proof_start = t_start + break + else: + pass + + def _extract_name_after(self, idx: int) -> Optional[str]: + """ + Finds the first identifier after the given index, skipping comments and whitespace. + Returns None if it hits a symbol (e.g. '(', '{', ':') before a name. + """ + i = idx + while i < self.n: + c = self.text[i] + + # Skip Whitespace + if c.isspace(): + i += 1 + continue + + # Skip Line Comments + if self.text.startswith("--", i): + i = self.text.find('\n', i) + if i == -1: return None + continue + + # Skip Block Comments + if self.text.startswith("/-", i): + # Quick skip for simple block comments, logic same as tokenizer + depth = 1 + i += 2 + while i < self.n and depth > 0: + if self.text.startswith("/-", i): + depth += 1 + i += 2 + elif self.text.startswith("-/", i): + depth -= 1 + i += 2 + else: + i += 1 + continue + + # Identifier Start Check + # Lean identifiers can be French-quoted Ā«nameĀ» or standard + # If it starts with a symbol like (, {, [, :, it's anonymous + if not (c.isalnum() or c == '_' or c == 'Ā«'): + return None + + # Extract + start = i + if c == 'Ā«': + end = self.text.find('Ā»', start) + if end != -1: + return self.text[start:end+1] + else: + # Malformed? Just return rest of line + return None + else: + # Standard identifier (alphanum + . + _) + while i < self.n: + curr = self.text[i] + if curr.isalnum() or curr == '_' or curr == '.': + i += 1 + else: + break + return self.text[start:i] + + return None + + def _construct_result(self) -> LeanParseResult: + # Case 1: No declaration found + if self.decl_start == -1: + return LeanParseResult( + decl_type=LeanDeclType.UNKNOWN, + text_before=self.text + ) + + # Case 2: Declaration found + split_idx = self.decl_start + + decl_end = self.n + proof_content = None + if self.proof_start != -1: + decl_end = self.proof_start + proof_content = self.text[self.proof_start:].strip() + + raw_pre = self.text[:split_idx] + raw_decl = self.text[split_idx:decl_end] + doc_content = None + + if self.docstring_range: + ds_start, ds_end = self.docstring_range + doc_content = self.text[ds_start:ds_end] + + if ds_start < split_idx: + # Remove docstring from raw_pre + pre_part1 = self.text[:ds_start] + pre_part2 = self.text[ds_end:split_idx] + raw_pre = pre_part1 + pre_part2 + + return LeanParseResult( + decl_type=self.decl_type, + name=self.decl_name, + text_before=raw_pre.strip() or None, + doc_string=doc_content or None, + text=raw_decl.strip() or None, + proof=proof_content or None + ) + + # --- Helpers --- + def _is_keyword_at(self, idx, kw): + if not self.text.startswith(kw, idx): return False + return self._is_word_boundary(idx + len(kw)) + + def _match_any_keyword(self, idx, keywords): + if not self.text[idx].isalpha(): return None + j = idx + while j < self.n and (self.text[j].isalnum() or self.text[j] == '_'): + j += 1 + word = self.text[idx:j] + if word in keywords: + return word, len(word) + return None + + def _is_word_boundary(self, idx): + if idx >= self.n: return True + c = self.text[idx] + return not (c.isalnum() or c == '_') + + +def parse_lean_text(text: str) -> LeanParseResult: + parser = LeanDeclParser(text) + return parser.parse() \ No newline at end of file diff --git a/src/itp_interface/tools/simple_lean4_sync_executor.py b/src/itp_interface/lean/simple_lean4_sync_executor.py similarity index 99% rename from src/itp_interface/tools/simple_lean4_sync_executor.py rename to src/itp_interface/lean/simple_lean4_sync_executor.py index 49250fc..b957aa0 100644 --- a/src/itp_interface/tools/simple_lean4_sync_executor.py +++ b/src/itp_interface/lean/simple_lean4_sync_executor.py @@ -10,7 +10,7 @@ import typing import bisect import subprocess -from itp_interface.tools.tactic_parser import ( +from itp_interface.lean.tactic_parser import ( TacticParser, ErrorInfo, LeanLineInfo, diff --git a/src/itp_interface/tools/tactic_parser.py b/src/itp_interface/lean/tactic_parser.py similarity index 94% rename from src/itp_interface/tools/tactic_parser.py rename to src/itp_interface/lean/tactic_parser.py index 966f781..36cb2c3 100644 --- a/src/itp_interface/tools/tactic_parser.py +++ b/src/itp_interface/lean/tactic_parser.py @@ -106,6 +106,7 @@ class LeanLineInfo(BaseModel): name: Optional[str] = None doc_string: Optional[str] = None namespace: Optional[str] = None + proof: Optional[str] = None # Full proof text (if available) def __repr__(self) -> str: return f"LeanLineInfo(text={self.text!r}, line={self.line}, column={self.column})" @@ -188,7 +189,8 @@ def from_dict(decl_data: Dict) -> 'DeclWithDependencies': decl_type=decl_dict.get('declType'), name=decl_dict.get('name'), doc_string=decl_dict.get('docString'), - namespace=decl_dict.get('namespace') + namespace=decl_dict.get('namespace'), + proof=decl_dict.get('proof') ) # Parse dependencies @@ -248,6 +250,23 @@ class FileDependencyAnalysis(BaseModel): def __repr__(self) -> str: return f"FileDependencyAnalysis({self.module_name}, {len(self.declarations)} decls)" + + def to_json(self, indent=0) -> str: + if indent == 0: + return self.model_dump_json() + else: + return self.model_dump_json(indent=indent) + + @staticmethod + def load_from_string(json_text: str) -> 'FileDependencyAnalysis': + return FileDependencyAnalysis.model_validate_json(json_text) + + @staticmethod + def load_from_file(file_path: str) -> 'FileDependencyAnalysis': + with open(file_path, 'r', encoding='utf-8') as f: + data = f.read() + return FileDependencyAnalysis.load_from_string(data) + # Create an enum for parsing request type class RequestType(Enum): @@ -259,8 +278,8 @@ class RequestType(Enum): def get_path_to_tactic_parser_project() -> str: """Get the path to the tactic parser project directory.""" - tools_dir = os.path.dirname(__file__) - tactic_parser_path = os.path.join(tools_dir, "tactic_parser") + lean_dir = os.path.dirname(__file__) + tactic_parser_path = os.path.join(lean_dir, "tactic_parser") abs_path = os.path.abspath(tactic_parser_path) return abs_path @@ -734,10 +753,31 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] else: print(msg) +def print_errors(errors: List[ErrorInfo], logger: Optional[logging.Logger] = None): + for error in errors: + msg = f"Error at Line {error.position.line}, Col {error.position.column}: {error.message}" + if logger: + logger.error(msg) + else: + print(msg) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) project_path = str(Path(__file__).parent.parent.parent / "data" / "test" / "lean4_proj") + with TacticParser() as parser: + # Example 0: Empty proof + lean_code = """theorem test : āˆ€ {a : Nat}, a + 0 = a +| 0 => by simp +| n + 1 => by simp +""" + + print("Parsing example 0...") + tactics, errors = parser.parse(lean_code) + print_tactics(tactics) + if errors: + print(f"Error: {errors}") + with TacticParser() as parser: # Example 1: Simple proof lean_code = "example : True := by trivial" @@ -746,7 +786,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] tactics, errors = parser.parse(lean_code) print_tactics(tactics) if errors: - print(f"Error: {errors}") + print_errors(errors) p_path = "/home/amthakur/Projects/copra/data/test/miniF2F-lean4" with TacticParser(project_path=p_path) as parser: @@ -765,7 +805,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] tactics, errors = parser.parse(lean_code, fail_on_error=False) print_tactics(tactics) if errors: - print(f"Error: {errors}") + print_errors(errors) with TacticParser(project_path=p_path) as parser: # Example 1a: Simple proof with multiple tactics @@ -786,7 +826,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] tactics, errors = parser.parse(lean_code, fail_on_error=False) print_tactics(tactics) if errors: - print(f"Error: {errors}") + print_errors(errors) with TacticParser(project_path=p_path) as parser: # Example 1a: Simple proof with multiple tactics @@ -807,7 +847,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] tactics, errors = parser.parse(lean_code, fail_on_error=False) print_tactics(tactics) if errors: - print(f"Error: {errors}") + print_errors(errors) with TacticParser() as parser: # Example 1b: Simple have proofs @@ -824,7 +864,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] tactics, errors = parser.parse(lean_code, fail_on_error=False) print_tactics(tactics) if errors: - print(f"Error: {errors}") + print_errors(errors) @@ -836,7 +876,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] tactics2, errors = parser.parse(lean_code2, fail_on_error=False) print_tactics(tactics2) if errors: - print(f"Error: {errors}") + print_errors(errors) # Check if linarith is parsed correctly lean_code3 = """ @@ -853,7 +893,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] tactics3, errors = parser.parse(lean_code3) print_tactics(tactics3) if errors: - print(f"Error: {errors}") + print_errors(errors) file_path = str(Path(__file__).parent.parent.parent / "data" / "test" / "lean4_proj" / "Lean4Proj" / "Basic.lean") @@ -863,7 +903,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] tactics4, errors = parser.parse_file(file_path) print_tactics(tactics4) if errors: - print(f"Error: {errors}") + print_errors(errors) with TacticParser(project_path=project_path) as parser: # Example 2: Multiline with params @@ -873,7 +913,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] tactics5, errors = parser.parse(lean_code4) print_tactics(tactics5) if errors: - print(f"Error: {errors}") + print_errors(errors) with TacticParser(project_path=project_path) as parser: # Example 6: Parse tactics from file with multiple theorems @@ -881,7 +921,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] tactics6, errors = parser.parse(lean_code3 + "\n" + lean_code4, parse_type=RequestType.PARSE_TACTICS) print_tactics(tactics6) if errors: - print(f"Error: {errors}") + print_errors(errors) with TacticParser(project_path=project_path) as parser: # Example 7: Parse tactics which are wrong @@ -890,7 +930,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] tactics7, errors = parser.parse(lean_code5, fail_on_error=False) print_tactics(tactics7) if errors: - print(f"Error: {errors}") + print_errors(errors) with TacticParser(project_path=project_path) as parser: # Example 8: Parse tactics just before `by` @@ -899,7 +939,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] tactics8, errors = parser.parse(lean_code8, fail_on_error=False) print_tactics(tactics8) if errors: - print(f"Error: {errors}") + print_errors(errors) with TacticParser(project_path=project_path) as parser: # Example 9: Parse tactics just before `by` @@ -908,7 +948,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] tactics9, errors = parser.parse(lean_code9, fail_on_error=False) print_tactics(tactics9) if errors: - print(f"Error: {errors}") + print_errors(errors) with TacticParser(project_path=project_path) as parser: # Example 10: Test checkpointing @@ -926,7 +966,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] tactics10, errors = parser.parse(lean_code10, fail_on_error=True, parse_type=RequestType.CHKPT_TACTICS) print_tactics(tactics10) if errors: - print(f"Error: {errors}") + print_errors(errors) # Now just execute from the checkpoint lean_code10b = """ theorem temp2: 1 + 2 = 3 := @@ -938,7 +978,7 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] print_tactics(tactics10b) if errors: # The error should contain h_temp - print(f"Error: {errors}") + print_errors(errors) print("\nBreaking checkpoint...") new_lean_code10c = lean_code10 + lean_code10b @@ -946,4 +986,4 @@ def print_tactics(tactics: List[LeanLineInfo], logger: Optional[logging.Logger] # ^This will reimport everything all run all theorems from scratch print_tactics(tactics10c) if errors: - print(f"Error: {errors}") \ No newline at end of file + print_errors(errors) \ No newline at end of file diff --git a/src/itp_interface/tools/tactic_parser/TacticParser.lean b/src/itp_interface/lean/tactic_parser/TacticParser.lean similarity index 60% rename from src/itp_interface/tools/tactic_parser/TacticParser.lean rename to src/itp_interface/lean/tactic_parser/TacticParser.lean index da652e7..762a727 100644 --- a/src/itp_interface/tools/tactic_parser/TacticParser.lean +++ b/src/itp_interface/lean/tactic_parser/TacticParser.lean @@ -2,3 +2,5 @@ import TacticParser.Base64 import TacticParser.Types import TacticParser.SyntaxWalker import TacticParser.Main +import TacticParser.SyntaxWalkerMain +import TacticParser.Example.simple diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/Base64.lean b/src/itp_interface/lean/tactic_parser/TacticParser/Base64.lean similarity index 100% rename from src/itp_interface/tools/tactic_parser/TacticParser/Base64.lean rename to src/itp_interface/lean/tactic_parser/TacticParser/Base64.lean diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/DependencyParser.lean b/src/itp_interface/lean/tactic_parser/TacticParser/DependencyParser.lean similarity index 98% rename from src/itp_interface/tools/tactic_parser/TacticParser/DependencyParser.lean rename to src/itp_interface/lean/tactic_parser/TacticParser/DependencyParser.lean index 4e9cd0a..984aea6 100644 --- a/src/itp_interface/tools/tactic_parser/TacticParser/DependencyParser.lean +++ b/src/itp_interface/lean/tactic_parser/TacticParser/DependencyParser.lean @@ -3,6 +3,7 @@ import Lean.Data.Json import Lean.Elab.Frontend import TacticParser.Types import TacticParser.LineParser +import TacticParser.ProofExtractor namespace TacticParser @@ -261,6 +262,9 @@ unsafe def analyzeFileDependencies (filepath : System.FilePath) : IO FileDepende -- Parse all declarations using LineParser let decls ← parseFile filepath + -- Extract proofs from theorem/lemma/example declarations + let declsWithProofs ← extractProofsFromDecls decls content + -- Load environment with all imports and elaborate the file Lean.initSearchPath (← Lean.findSysroot) Lean.enableInitializersExecution @@ -282,7 +286,7 @@ unsafe def analyzeFileDependencies (filepath : System.FilePath) : IO FileDepende -- Analyze each declaration let mut declsWithDeps : Array DeclWithDependencies := #[] - for declInfo in decls do + for declInfo in declsWithProofs do -- Construct the fully qualified name for this declaration let declName := match declInfo.namespc with | some ns => diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/DependencyParserMain.lean b/src/itp_interface/lean/tactic_parser/TacticParser/DependencyParserMain.lean similarity index 100% rename from src/itp_interface/tools/tactic_parser/TacticParser/DependencyParserMain.lean rename to src/itp_interface/lean/tactic_parser/TacticParser/DependencyParserMain.lean diff --git a/src/itp_interface/lean/tactic_parser/TacticParser/Example/complex.lean b/src/itp_interface/lean/tactic_parser/TacticParser/Example/complex.lean new file mode 100644 index 0000000..f05e8e1 --- /dev/null +++ b/src/itp_interface/lean/tactic_parser/TacticParser/Example/complex.lean @@ -0,0 +1,25 @@ +import TacticParser.Example.simple + +namespace TacticParser.Example + +theorem additive: +āˆ€ {a b: Nat}, addNat a b = a + b := by + intro a b + induction a generalizing b + simp [addNat] + rename_i a ih + simp [addNat, *] + grind + +theorem additive_identity1 : +āˆ€ {a : Nat}, addNat 0 a = a := by + simp [addNat] + +theorem additive_identity2 : +āˆ€ {a : Nat}, addNat a 0 = a := by + simp [additive] + +theorem additive_comm: +āˆ€ {a b : Nat}, addNat a b = addNat b a := by + simp [additive] + grind diff --git a/src/itp_interface/lean/tactic_parser/TacticParser/Example/simple.lean b/src/itp_interface/lean/tactic_parser/TacticParser/Example/simple.lean new file mode 100644 index 0000000..d7b1690 --- /dev/null +++ b/src/itp_interface/lean/tactic_parser/TacticParser/Example/simple.lean @@ -0,0 +1,13 @@ +namespace TacticParser.Example + +theorem test : āˆ€ {a : Nat}, a + 0 = a := by grind + +theorem test1 : āˆ€ {a : Nat}, a + 0 = a + | 0 => by simp + | n + 1 => by simp + +def addNat : Nat → Nat → Nat + | 0, m => m + | n + 1, m => addNat n (m + 1) + +end TacticParser.Example diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/LineParser.lean b/src/itp_interface/lean/tactic_parser/TacticParser/LineParser.lean similarity index 100% rename from src/itp_interface/tools/tactic_parser/TacticParser/LineParser.lean rename to src/itp_interface/lean/tactic_parser/TacticParser/LineParser.lean diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/Main.lean b/src/itp_interface/lean/tactic_parser/TacticParser/Main.lean similarity index 100% rename from src/itp_interface/tools/tactic_parser/TacticParser/Main.lean rename to src/itp_interface/lean/tactic_parser/TacticParser/Main.lean diff --git a/src/itp_interface/lean/tactic_parser/TacticParser/ProofExtractor.lean b/src/itp_interface/lean/tactic_parser/TacticParser/ProofExtractor.lean new file mode 100644 index 0000000..368e198 --- /dev/null +++ b/src/itp_interface/lean/tactic_parser/TacticParser/ProofExtractor.lean @@ -0,0 +1,179 @@ +/- +Proof extractor for theorem, lemma, and example declarations. +Uses a validation-based approach: tests candidate delimiters by replacing +the proof with `sorry` and checking if it parses successfully. +-/ +import Lean +import Lean.Parser +import TacticParser.Types +import TacticParser.SyntaxWalker +import Lean.Elab.Frontend + +namespace TacticParser + +open Lean Parser Elab + +/-- Represents a candidate delimiter position and type -/ +structure DelimiterCandidate where + position : Nat -- Byte position in the text + delimiterType : String -- ":=", "where", or "|" + deriving Repr, BEq + +instance : Ord DelimiterCandidate where + compare a b := compare a.position b.position + +/-- Try to parse a text snippet and return true if it parses without errors -/ +def tryParseSuccessfully (text : String) (cmdState : Option Command.State): IO Bool := do + let chkpt_result ← parseTactics text none cmdState + let parse_res := chkpt_result.1 + let new_cmd_state := chkpt_result.2 + pure (parse_res.errors.size == 0 ∧ new_cmd_state.isSome) + +/-- Check if a substring starts with a given substring at a specific position -/ +def substringStartsWith (text : String) (startPos : Nat) (substr : String) : Bool := + if startPos + substr.length > text.length then + false + else + let extracted := text.drop startPos + extracted.startsWith substr + +/-- Find all occurrences of a substring in a text -/ +def findSubstrOccurences (text : String) (substr : String) : List Nat := + if substr.length > text.length then + [] + else + let all_pos := List.range (text.length - substr.length + 1) + let all_occurences := all_pos.filter (fun pos => substringStartsWith text pos substr) + all_occurences + +/-- Find all candidate delimiter positions -/ +def findCandidateDelimiters (text : String) : List DelimiterCandidate := + let assignPositions := findSubstrOccurences text ":=" + let wherePositions := findSubstrOccurences text "where" + let pipePositions := findSubstrOccurences text "|" + + let assignCandidates := assignPositions.map fun pos => + { position := pos, delimiterType := ":=" } + let whereCandidates := wherePositions.map fun pos => + { position := pos, delimiterType := "where" } + let pipeCandidates := pipePositions.map fun pos => + { position := pos, delimiterType := "|" } + let assignCandidatesSorted := assignCandidates.toArray.qsort (fun a b => a.position < b.position) + let whereCandidatesSorted := whereCandidates.toArray.qsort (fun a b => a.position < b.position) + let pipeCandidatesSorted := pipeCandidates.toArray.qsort (fun a b => a.position < b.position) + + let allCandidates := assignCandidatesSorted.toList ++ whereCandidatesSorted.toList ++ pipeCandidatesSorted.toList + allCandidates + +def is_proof_extraction_needed (declInfo : DeclInfo) : Bool := + declInfo.declType == DeclType.theorem || + declInfo.declType == DeclType.lemma || + declInfo.declType == DeclType.example + +def get_content_before_decl (fileContent : String) (declLineNum : Nat) : String := + let lines := fileContent.splitOn "\n" + let beforeLines := lines.take (declLineNum - 1) + (String.intercalate "\n" beforeLines) ++ "\n" + +def get_in_between_content (fileContent : String) (startLine : Nat) (endLine : Nat) : String := + let lines := fileContent.splitOn "\n" + let betweenLines := (lines.take endLine).drop (startLine - 1) + String.intercalate "\n" betweenLines + +/-- Extract proof from a declaration by testing candidate delimiters -/ +unsafe def extractProofFromDecl + (declInfo : DeclInfo) + (cmdState : Option Command.State) + (extra_content: Option String := none) : IO DeclInfo := do + -- Only process theorem, lemma, example + if !is_proof_extraction_needed declInfo then + panic! s!"extractProofFromDecl called on non-proof decl: {declInfo.declType}" + + let text := declInfo.text + + -- Convert Position to byte offset + -- Find all candidate delimiters + let candidates := findCandidateDelimiters text + + -- Test each candidate + for candidate in candidates do + let beforeDelimiter := text.take candidate.position + + -- IO.println s!"beforeDelimiter:\n{beforeDelimiter}\n---" + + -- Build test text with full context + let mut statementOnly := match candidate.delimiterType with + | "|" => beforeDelimiter.trim ++ " := sorry" + | ":=" => beforeDelimiter ++ " := sorry" + | "where" => beforeDelimiter ++ " := sorry" + | _ => beforeDelimiter ++ " := sorry" + + + -- If extra content is provided, prepend it + statementOnly := + match extra_content with + | some extra => extra ++ (if extra.endsWith "\n" then statementOnly else "\n" ++ statementOnly) + | none => statementOnly + + -- IO.println s!"statementOnly:\n{statementOnly}\n---" + -- Try to parse + let success ← tryParseSuccessfully statementOnly cmdState + if success then + -- Found valid split! + -- IO.println s!"Found valid split at position {candidate.position} with delimiter {candidate.delimiterType}" + let proof := text.drop candidate.position + let thrm := text.take candidate.position + return { declInfo with proof := some proof.trim , text := thrm.trim } + -- else + -- IO.println s!"Failed split at position {candidate.position} with delimiter {candidate.delimiterType}" + + -- No valid split found - no proof + return { declInfo with proof := none } + +unsafe def parse_between +(fileContent : String) +(prev: Nat) +(next: Nat) +(cmd_state : Option Command.State) +: IO CheckpointedParseResult := do + -- IO.println s!"Extracting proof for segment between lines {prev} and {next}" + let contextBeforeDecl := get_in_between_content fileContent prev next + -- IO.println s!"Processing declaration at line {decl.startPos.line}-{decl.endPos.line}" + -- IO.println s!"Declaration text:\n{decl.text}\n---" + -- IO.println s!"--- Context Before Decl ----\n{contextBeforeDecl}\n--- Context Before Decl ----\n" + let chkpt_parse_res ← parseTactics contextBeforeDecl none cmd_state + return chkpt_parse_res + +/-- Extract proofs from multiple declarations -/ +unsafe def extractProofsFromDecls (decls : Array DeclInfo) (fileContent : String) : IO (Array DeclInfo) := do + let mut result := #[] + let mut prev := 0 + let mut cmd_state : Option Command.State := none + let mut next := 0 + let mut extra_content : Option String := none + for decl in decls do + if is_proof_extraction_needed decl then + next := decl.startPos.line - 1 + let chkpt_parse_res ← parse_between fileContent prev next cmd_state + -- IO.println s!"--- Context Before Decl ----\n{contextBeforeDecl}\n--- Context Before Decl ----\n" + let parse_res := chkpt_parse_res.parseResult + if parse_res.errors.size > 0 ∨ chkpt_parse_res.chkptState.isNone then + -- supply the extra content to compile from + extra_content := get_in_between_content fileContent prev next + -- IO.println s!"Re-parsing declaration at lines: \n{extra_content.get!}" + -- IO.println s!"\nDeclaration text:\n{decl.text}\n---" + -- DO NOT update cmd_state yet + else + cmd_state := chkpt_parse_res.chkptState + extra_content := none + prev := next + 1 + -- let cmd_st ← match cmd_state with + -- | some st => pure st + -- | none => panic! "Failed to get valid cmd_state before processing declaration" + let processed ← extractProofFromDecl decl cmd_state extra_content + result := result.push processed + else + result := result.push decl + return result + +end TacticParser diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/SyntaxWalker.lean b/src/itp_interface/lean/tactic_parser/TacticParser/SyntaxWalker.lean similarity index 98% rename from src/itp_interface/tools/tactic_parser/TacticParser/SyntaxWalker.lean rename to src/itp_interface/lean/tactic_parser/TacticParser/SyntaxWalker.lean index 46689ac..5dd8d3a 100644 --- a/src/itp_interface/tools/tactic_parser/TacticParser/SyntaxWalker.lean +++ b/src/itp_interface/lean/tactic_parser/TacticParser/SyntaxWalker.lean @@ -308,6 +308,10 @@ unsafe def parseInCurrentContext (input : String) (filePath : Option String := n let extentStruct := lineExtents.map getInfoNodeStruct -- Go over all line extents and reassign the end_pos of the next node let mut adjusted_trees : Array InfoNodeStruct := #[] + -- IO.println s!"Total line extents found: {lineExtents.size}" + if lineExtents.size == 0 then + let parseResult : ParseResult := { trees := #[], errors := errorInfos } + return { parseResult := parseResult, chkptState := cmdState , lineNum := none } for i in [1:lineExtents.size] do let prev_node := extentStruct[i - 1]!.getD default let curr_node := extentStruct[i]!.getD default @@ -350,6 +354,7 @@ unsafe def parseTacticsWithElaboration (input : String) (filePath : Option Strin -- Initialize Lean from current directory (finds .lake/build if present) Lean.initSearchPath (← Lean.findSysroot) Lean.enableInitializersExecution + -- IO.println "Initialized Lean environment for elaboration-based parsing." return ← parseInCurrentContext input filePath chkptState catch e => let errorInfo := ErrorInfo.mk (s!"Error in parseTacticsWithElaboration: {e}") { line := 0, column := 0 } diff --git a/src/itp_interface/lean/tactic_parser/TacticParser/TacticExtractorMain.lean b/src/itp_interface/lean/tactic_parser/TacticParser/TacticExtractorMain.lean new file mode 100644 index 0000000..5c66a6d --- /dev/null +++ b/src/itp_interface/lean/tactic_parser/TacticParser/TacticExtractorMain.lean @@ -0,0 +1,52 @@ +import TacticParser.SyntaxWalker +import Lean + +open Lean +open TacticParser + +/-- Print usage information -/ +def printUsage : IO Unit := do + IO.println "Usage: dependency_parser " + IO.println "" + IO.println "Arguments:" + IO.println " Path to the Lean file to analyze" + IO.println " Path where JSON output will be written" + IO.println "" + IO.println "Example:" + IO.println " lake env .lake/build/bin/syntax-walker MyFile.lean output.json" + +unsafe def main (args : List String) : IO UInt32 := do + match args with + | + [ + leanFilePath, + jsonOutputPath + ] + => + try + let filepath : System.FilePath := leanFilePath + let jsonPath : System.FilePath := jsonOutputPath + + -- Check if input file exists + if !(← filepath.pathExists) then + IO.eprintln s!"Error: Input file not found: {filepath}" + return 1 + + let fileContent ← IO.FS.readFile filepath + + IO.println s!"Tactic Parsing file: {filepath}" + + -- Analyze the file and export to JSON + let tactics ← parseTactics fileContent none none + let tacticsJson := Lean.ToJson.toJson tactics.parseResult + IO.FS.writeFile jsonPath tacticsJson.compress + return 0 + catch e => + IO.eprintln s!"Error: {e}" + return 1 + + | _ => + IO.eprintln "Error: Invalid number of arguments" + IO.eprintln "" + printUsage + return 1 diff --git a/src/itp_interface/tools/tactic_parser/TacticParser/Types.lean b/src/itp_interface/lean/tactic_parser/TacticParser/Types.lean similarity index 98% rename from src/itp_interface/tools/tactic_parser/TacticParser/Types.lean rename to src/itp_interface/lean/tactic_parser/TacticParser/Types.lean index fee67f1..5dbac32 100644 --- a/src/itp_interface/tools/tactic_parser/TacticParser/Types.lean +++ b/src/itp_interface/lean/tactic_parser/TacticParser/Types.lean @@ -71,6 +71,7 @@ structure DeclInfo where text : String docString : Option String -- Extracted documentation comment namespc : Option String -- Current namespace + proof : Option String := none -- Extracted proof (for theorem, lemma, example) deriving Repr instance : ToJson DeclInfo where @@ -83,7 +84,8 @@ instance : ToJson DeclInfo where ("end_column", d.endPos.column), ("text", toJson d.text), ("doc_string", toJson d.docString), - ("namespace", toJson d.namespc) + ("namespace", toJson d.namespc), + ("proof", toJson d.proof) ] /-- Information about an import statement -/ diff --git a/src/itp_interface/tools/tactic_parser/lake-manifest.json b/src/itp_interface/lean/tactic_parser/lake-manifest.json similarity index 100% rename from src/itp_interface/tools/tactic_parser/lake-manifest.json rename to src/itp_interface/lean/tactic_parser/lake-manifest.json diff --git a/src/itp_interface/tools/tactic_parser/lakefile.toml b/src/itp_interface/lean/tactic_parser/lakefile.toml similarity index 53% rename from src/itp_interface/tools/tactic_parser/lakefile.toml rename to src/itp_interface/lean/tactic_parser/lakefile.toml index 16e7605..e9c3958 100644 --- a/src/itp_interface/tools/tactic_parser/lakefile.toml +++ b/src/itp_interface/lean/tactic_parser/lakefile.toml @@ -1,14 +1,22 @@ name = "TacticParser" -defaultTargets = ["tactic-parser", "dependency-parser"] +defaultTargets = ["tactic-parser", "tactic-extractor", "dependency-parser"] [[lean_lib]] name = "TacticParser" +[[lean_lib]] +name = "TacticParser.Example" + [[lean_exe]] name = "tactic-parser" root = "TacticParser.Main" supportInterpreter = true +[[lean_exe]] +name = "tactic-extractor" +root = "TacticParser.TacticExtractorMain" +supportInterpreter = true + [[lean_exe]] name = "dependency-parser" root = "TacticParser.DependencyParserMain" diff --git a/src/itp_interface/tools/tactic_parser/lean-toolchain b/src/itp_interface/lean/tactic_parser/lean-toolchain similarity index 100% rename from src/itp_interface/tools/tactic_parser/lean-toolchain rename to src/itp_interface/lean/tactic_parser/lean-toolchain diff --git a/src/itp_interface/lean/tactic_parser/test_user_example.lean b/src/itp_interface/lean/tactic_parser/test_user_example.lean new file mode 100644 index 0000000..e8b4b94 --- /dev/null +++ b/src/itp_interface/lean/tactic_parser/test_user_example.lean @@ -0,0 +1,26 @@ +import Lean +import TacticParser.DependencyParser +import TacticParser.ProofExtractor +import TacticParser.Types + +open TacticParser + + +unsafe def analyze (filePath : String) : IO Unit := do + let fileDepAnalysis ← analyzeFileDependencies filePath + IO.println s!"Dependency Analysis for {filePath}:" + let json := Lean.toJson fileDepAnalysis + IO.println json.pretty + +def testCodes : List String := [ + "TacticParser/Example/simple.lean", + "TacticParser/Example/complex.lean" +] + +unsafe def main : IO Unit := do + IO.println "Testing User's Example" + IO.println (String.mk (List.replicate 70 '=')) + + for test in testCodes do + let filePath := test + analyze filePath diff --git a/src/itp_interface/main/config.py b/src/itp_interface/main/config.py index 5d0c042..86e49f2 100644 --- a/src/itp_interface/main/config.py +++ b/src/itp_interface/main/config.py @@ -79,6 +79,8 @@ class ExtractFile(object): class EvalDataset(object): project: str files: typing.Union[typing.List[EvalFile], typing.List[ExtractFile]] + exclude_files: typing.List[str] = field(default_factory=list) + include_files: typing.List[str] = field(default_factory=list) @dataclass_json @dataclass @@ -135,7 +137,6 @@ def add_theorem_to_maps(self, path: str, theorem: str, proof_result: ProofSearch def parse_config(cfg): - is_extraction_request = False env_settings_cfg = cfg["env_settings"] env_settings = EnvSettings( name=env_settings_cfg["name"], @@ -179,7 +180,6 @@ def parse_config(cfg): eval_files.append(EvalFile( path=file_cfg["path"], theorems=theorems)) - is_extraction_request = False elif "declarations" in file_cfg: declarations = None if type(file_cfg["declarations"]) == str: @@ -189,12 +189,14 @@ def parse_config(cfg): eval_files.append(ExtractFile( path=file_cfg["path"], declarations=declarations)) - is_extraction_request = True else: raise ValueError(f"File config must have either 'theorems' or 'declarations': {file_cfg}") eval_datasets.append(EvalDataset( project=dataset_cfg["project"], - files=eval_files)) + files=eval_files, + exclude_files=dataset_cfg.get("exclude_files", []), + include_files=dataset_cfg.get("include_files", []) + )) language = ProofAction.Language(benchmark_cfg["language"]) benchmark = EvalBenchmark( name=benchmark_cfg["name"], @@ -205,6 +207,6 @@ def parse_config(cfg): few_shot_metadata_filename_for_retrieval=benchmark_cfg["few_shot_metadata_filename_for_retrieval"], dfs_data_path_for_retrieval=benchmark_cfg["dfs_data_path_for_retrieval"], dfs_metadata_filename_for_retrieval=benchmark_cfg["dfs_metadata_filename_for_retrieval"], - is_extraction_request=is_extraction_request and benchmark_cfg.get("is_extraction_request", False), + is_extraction_request=benchmark_cfg.get("is_extraction_request", False), setup_cmds=benchmark_cfg["setup_cmds"] if "setup_cmds" in benchmark_cfg else []) return Experiments(env_settings=env_settings, run_settings=eval_settings, benchmark=benchmark) \ No newline at end of file diff --git a/src/itp_interface/main/configs/benchmark/mathlib_benchmark_lean_ext.yaml b/src/itp_interface/main/configs/benchmark/mathlib_benchmark_lean_ext.yaml new file mode 100644 index 0000000..b2bb43b --- /dev/null +++ b/src/itp_interface/main/configs/benchmark/mathlib_benchmark_lean_ext.yaml @@ -0,0 +1,40 @@ +name: mathlib_benchmark_lean_ext +num_files: 1 +language: LEAN4 +few_shot_data_path_for_retrieval: +few_shot_metadata_filename_for_retrieval: +dfs_data_path_for_retrieval: +dfs_metadata_filename_for_retrieval: +is_extraction_request: true +datasets: + - project: src/data/test/Mathlib + files: [] + exclude_files: + - src/data/test/Mathlib/.lake/packages/aesop + - src/data/test/Mathlib/.lake/packages/batteries + - src/data/test/Mathlib/.lake/packages/Cli + - src/data/test/Mathlib/.lake/packages/importGraph + - src/data/test/Mathlib/.lake/packages/LeanSearchClient + - src/data/test/Mathlib/.lake/packages/plausible + - src/data/test/Mathlib/.lake/packages/proofwidgets + - src/data/test/Mathlib/.lake/packages/Qq + - src/data/test/Mathlib/.lake/packages/mathlib/.devcontainer + - src/data/test/Mathlib/.lake/packages/mathlib/.github + - src/data/test/Mathlib/.lake/packages/mathlib/.docker + - src/data/test/Mathlib/.lake/packages/mathlib/.lake + - src/data/test/Mathlib/.lake/packages/mathlib/.vscode + - src/data/test/Mathlib/.lake/packages/mathlib/Archive + - src/data/test/Mathlib/.lake/packages/mathlib/Cache + - src/data/test/Mathlib/.lake/packages/mathlib/docs + - src/data/test/Mathlib/.lake/packages/mathlib/Counterexamples + - src/data/test/Mathlib/.lake/packages/mathlib/DownstreamTest + - src/data/test/Mathlib/.lake/packages/mathlib/LongestPole + - src/data/test/Mathlib/.lake/packages/mathlib/MathlibTest + - src/data/test/Mathlib/.lake/packages/mathlib/scripts + - src/data/test/Mathlib/.lake/packages/mathlib/widget + - src/data/test/Mathlib/.lake/packages/mathlib/Counterexamples.lean + - src/data/test/Mathlib/.lake/packages/mathlib/lakefile.lean + - src/data/test/Mathlib/.lake/packages/mathlib/Mathlib.lean + - src/data/test/Mathlib/.lake/packages/mathlib/Archive.lean + - src/data/test/Mathlib/lakefile.lean + - src/data/test/Mathlib/ReplMathlibTests.lean diff --git a/src/itp_interface/main/configs/benchmark/simple_benchmark_lean_ext.yaml b/src/itp_interface/main/configs/benchmark/simple_benchmark_lean_ext.yaml index 2ed8c1d..586c2f8 100644 --- a/src/itp_interface/main/configs/benchmark/simple_benchmark_lean_ext.yaml +++ b/src/itp_interface/main/configs/benchmark/simple_benchmark_lean_ext.yaml @@ -8,6 +8,6 @@ dfs_metadata_filename_for_retrieval: is_extraction_request: true datasets: - project: src/data/test/lean4_proj - files: - - path: Lean4Proj/Basic.lean - declarations: "*" \ No newline at end of file + files: [] + exclude_files: + - src/data/test/lean4_proj/.lake \ No newline at end of file diff --git a/src/itp_interface/main/configs/benchmark/stdlib_benchmark_lean_ext.yaml b/src/itp_interface/main/configs/benchmark/stdlib_benchmark_lean_ext.yaml new file mode 100644 index 0000000..fde9d37 --- /dev/null +++ b/src/itp_interface/main/configs/benchmark/stdlib_benchmark_lean_ext.yaml @@ -0,0 +1,31 @@ +name: stdlib_benchmark_lean_ext +num_files: 1 +language: LEAN4 +few_shot_data_path_for_retrieval: +few_shot_metadata_filename_for_retrieval: +dfs_data_path_for_retrieval: +dfs_metadata_filename_for_retrieval: +is_extraction_request: true +datasets: + - project: src/data/test/batteries + files: [] + exclude_files: + - src/data/test/batteries/.lake + - src/data/test/batteries/.vscode + - src/data/test/batteries/BatteriesTest + - src/data/test/batteries/Batteries + - src/data/test/batteries/scripts + - src/data/test/batteries/docs + - src/data/test/batteries/Shake + - src/data/test/batteries/Batteries.lean + - src/data/test/batteries/bors.toml + - src/data/test/batteries/lake-manifest.toml + - src/data/test/batteries/.github + - src/data/test/batteries/.docker + - src/data/test/batteries/lakefile.toml + - src/data/test/batteries/.gitpod.yml + - src/data/test/batteries/.github + - src/data/test/batteries/lake-manifest.json + include_files: + - ~/.elan/toolchains/leanprover--lean4---v4.24.0/src/lean/Init + - ~/.elan/toolchains/leanprover--lean4---v4.24.0/src/lean/Std \ No newline at end of file diff --git a/src/itp_interface/main/configs/mathlib_lean_data_extract.yaml b/src/itp_interface/main/configs/mathlib_lean_data_extract.yaml new file mode 100644 index 0000000..14e39be --- /dev/null +++ b/src/itp_interface/main/configs/mathlib_lean_data_extract.yaml @@ -0,0 +1,13 @@ +defaults: + # - benchmark: simple_benchmark_lean_training_data + # - run_settings: default_lean_data_generation_transforms + # - benchmark: simple_benchmark_1 + # - run_settings: default_lean4_data_generation_transforms + - benchmark: mathlib_benchmark_lean_ext + - run_settings: default_lean4_data_generation_transforms + - env_settings: no_retrieval + - override hydra/job_logging: 'disabled' + +run_settings: + output_dir: .log/data_generation/benchmark/mathlib_benchmark_lean_ext + pool_size: 12 \ No newline at end of file diff --git a/src/itp_interface/main/configs/simple_lean_data_extract.yaml b/src/itp_interface/main/configs/simple_lean_data_extract.yaml index 4fed34c..590c3e9 100644 --- a/src/itp_interface/main/configs/simple_lean_data_extract.yaml +++ b/src/itp_interface/main/configs/simple_lean_data_extract.yaml @@ -10,4 +10,4 @@ defaults: run_settings: output_dir: .log/data_generation/benchmark/simple_benchmark_lean_ext - pool_size: 2 \ No newline at end of file + pool_size: 1 \ No newline at end of file diff --git a/src/itp_interface/main/configs/simple_lean_data_gen.yaml b/src/itp_interface/main/configs/simple_lean_data_gen.yaml index d93dd0f..882a810 100644 --- a/src/itp_interface/main/configs/simple_lean_data_gen.yaml +++ b/src/itp_interface/main/configs/simple_lean_data_gen.yaml @@ -10,4 +10,4 @@ defaults: run_settings: output_dir: .log/data_generation/benchmark/simple_benchmark_lean - pool_size: 2 \ No newline at end of file + pool_size: 1 \ No newline at end of file diff --git a/src/itp_interface/main/configs/stdlib_lean_data_extract.yaml b/src/itp_interface/main/configs/stdlib_lean_data_extract.yaml new file mode 100644 index 0000000..2269810 --- /dev/null +++ b/src/itp_interface/main/configs/stdlib_lean_data_extract.yaml @@ -0,0 +1,13 @@ +defaults: + # - benchmark: simple_benchmark_lean_training_data + # - run_settings: default_lean_data_generation_transforms + # - benchmark: simple_benchmark_1 + # - run_settings: default_lean4_data_generation_transforms + - benchmark: stdlib_benchmark_lean_ext + - run_settings: default_lean4_data_generation_transforms + - env_settings: no_retrieval + - override hydra/job_logging: 'disabled' + +run_settings: + output_dir: .log/data_generation/benchmark/stdlib_benchmark_lean_ext + pool_size: 12 \ No newline at end of file diff --git a/src/itp_interface/main/init_ray.py b/src/itp_interface/main/init_ray.py new file mode 100644 index 0000000..41c72e8 --- /dev/null +++ b/src/itp_interface/main/init_ray.py @@ -0,0 +1,79 @@ +def main(): + # Start the ray cluster + from filelock import FileLock + import json + import os + import ray + import logging + import time + import sys + import argparse + argument_parser = argparse.ArgumentParser() + argument_parser.add_argument("--num_cpus", type=int, default=10) + argument_parser.add_argument("--object_store_memory", type=int, default=150*2**30) + argument_parser.add_argument("--memory", type=int, default=300*2**30) + argument_parser.add_argument("--metrics_report_interval_ms", type=int, default=3*10**8) + args = argument_parser.parse_args() + root_dir = f"{os.path.abspath(__file__).split('itp_interface')[-2]}" + if root_dir not in sys.path: + sys.path.append(root_dir) + os.environ["RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE"] = "1" + os.environ["PYTHONPATH"] = f"{root_dir}:{os.environ.get('PYTHONPATH', '')}" + os.makedirs(".log/locks", exist_ok=True) + os.makedirs(".log/ray", exist_ok=True) + ray_was_started = False + pid = os.getpid() + print("Initializing Ray") + print("PID: ", pid) + ray_session_path = ".log/ray/session_latest" if os.environ.get("RAY_SESSION_PATH") is None else os.environ.get("RAY_SESSION_PATH") + # Try to first acquire the lock + file_path = ".log/locks/ray.lock" + # set RAY_SESSION_PATH environment variable + os.environ["RAY_SESSION_PATH"] = ray_session_path + # set lock file path + os.environ["RAY_LOCK_FILE_PATH"] = file_path + temp_lock = FileLock(file_path) + if os.path.exists(ray_session_path): + # try to acquire the lock for reading + try: + temp_lock.acquire(timeout=10) + temp_lock.release() + except: + with open(ray_session_path, "r") as f: + ray_session = f.read() + ray_session = json.loads(ray_session) + ray_address = ray_session["address"] + # ray.init(address=ray_address) + print("Ray was already started") + print("Ray session: ", ray_session) + sys.exit(0) + with FileLock(file_path): + if os.path.exists(ray_session_path): + # Remove the ray_session_path + os.remove(ray_session_path) + os.environ["RAY_INITIALIZED"] = "1" + ray_session = ray.init( + num_cpus=args.num_cpus, + object_store_memory=args.object_store_memory, + _memory=args.memory, + logging_level=logging.CRITICAL, + ignore_reinit_error=False, + log_to_driver=False, + configure_logging=False, + _system_config={"metrics_report_interval_ms": args.metrics_report_interval_ms}) + ray_session = dict(ray_session) + ray_session["main_pid"] = pid + print("Ray session: ", ray_session) + with open(ray_session_path, "w") as f: + f.write(json.dumps(ray_session)) + ray_was_started = True + print("Ray was started") + print("Ray session: ", ray_session) + # Flush the stdout buffer + sys.stdout.flush() + while ray_was_started: + # Keep the ray cluster alive till killed + time.sleep(10000) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/itp_interface/main/install.py b/src/itp_interface/main/install.py index e375f57..0c38ed0 100644 --- a/src/itp_interface/main/install.py +++ b/src/itp_interface/main/install.py @@ -3,7 +3,7 @@ import string import logging import traceback -from itp_interface.tools.tactic_parser import build_tactic_parser_if_needed +from itp_interface.lean.tactic_parser import build_tactic_parser_if_needed file_path = os.path.abspath(__file__) @@ -20,8 +20,8 @@ def generate_random_string(length, allowed_chars=None): def install_itp_interface(): print("Installing itp_interface") itp_dir = os.path.dirname(os.path.dirname(file_path)) - tools_dir = os.path.join(itp_dir, "tools") - tactic_parser_dir = os.path.join(tools_dir, "tactic_parser") + lean_dir = os.path.join(itp_dir, "lean") + tactic_parser_dir = os.path.join(lean_dir, "tactic_parser") assert os.path.exists(tactic_parser_dir), f"tactic_parser_dir: {tactic_parser_dir} does not exist" assert os.path.exists(os.path.join(tactic_parser_dir, "lean-toolchain")), f"lean-toolchain does not exist in {tactic_parser_dir}, build has failed" print("tactic_parser_dir: ", tactic_parser_dir) diff --git a/src/itp_interface/main/run_tool.py b/src/itp_interface/main/run_tool.py index 7fbd7da..24c0366 100644 --- a/src/itp_interface/main/run_tool.py +++ b/src/itp_interface/main/run_tool.py @@ -26,6 +26,7 @@ RayResourcePoolActor = None TimedRayExec = None RayUtils = None +from pathlib import Path from itp_interface.rl.proof_action import ProofAction from itp_interface.tools.isabelle_server import IsabelleServer from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy @@ -33,7 +34,7 @@ from itp_interface.tools.coq_local_data_generation_transform import LocalDataGenerationTransform as CoqLocalDataGenerationTransform from itp_interface.tools.lean_local_data_generation_transform import LocalDataGenerationTransform as LeanLocalDataGenerationTransform from itp_interface.tools.lean4_local_data_generation_transform import Local4DataGenerationTransform -from itp_interface.tools.lean4_local_data_extraction_transform import Local4DataExtractionTransform +from itp_interface.lean.lean4_local_data_extraction_transform import Local4DataExtractionTransform from itp_interface.tools.isabelle_local_data_generation_transform import LocalDataGenerationTransform as IsabelleLocalDataGenerationTransform from itp_interface.tools.run_data_generation_transforms import RunDataGenerationTransforms from itp_interface.tools.log_utils import setup_logger @@ -120,14 +121,35 @@ def _get_all_lemmas_impl( logger.info(f"Discovered {len(lemmas_to_prove)} lemmas") return lemmas_to_prove - def _get_all_lean_files_in_folder_recursively( - project_folder: str) -> typing.List[str]: + project_folder: str, exclude_list: typing.List[str], include_list: typing.List[str]) -> typing.List[str]: lean_files = [] + for include in include_list: + # the include list should be a full path + include_path = Path(include) + include_path = include_path.expanduser().resolve() + # Get all files recursively under include_path + for root, dirs, files in os.walk(include_path): + for file in files: + if file.endswith(".lean"): + full_path = os.path.join(root, file) + full_path_obj = Path(full_path) + full_path = full_path_obj.expanduser().resolve() + if str(full_path) not in lean_files: + lean_files.append(str(full_path)) + + project_folder_path = Path(project_folder) for root, dirs, files in os.walk(project_folder): for file in files: if file.endswith(".lean"): - lean_files.append(os.path.join(root, file)) + # Get full file path + full_path = os.path.join(root, file) + # Check if the full path starts with any exclude path + if any(full_path.startswith(exclude) for exclude in exclude_list): + continue + full_path_obj = Path(full_path) + rel_to_project = str(full_path_obj.relative_to(project_folder_path)) + lean_files.append(rel_to_project) return lean_files # Create Ray remote version if Ray is available @@ -248,7 +270,7 @@ def create_yaml(project_to_theorems, name, eval_benchmark: EvalBenchmark, output with open(output_file, 'w') as yaml_file: yaml.dump(data, yaml_file, sort_keys=False) -def add_transform(experiment: Experiments, clone_dir: str, resources: list, transforms: list, logger: logging.Logger = None): +def add_transform(experiment: Experiments, clone_dir: str, resources: list, transforms: list, logger: logging.Logger = None, str_time: str = None): global ray_resource_pool if experiment.run_settings.transform_type == TransformType.LOCAL: if experiment.benchmark.language == ProofAction.Language.LEAN: @@ -260,10 +282,31 @@ def add_transform(experiment: Experiments, clone_dir: str, resources: list, tran os.makedirs(clone_dir, exist_ok=True) elif experiment.benchmark.language == ProofAction.Language.LEAN4: if experiment.benchmark.is_extraction_request: - transform = Local4DataExtractionTransform( + output_dir_path = os.path.join(experiment.run_settings.output_dir, str_time) + db_path = None + if os.environ.get("EXTRACTION_DB_PATH") is not None: + db_path = os.environ.get("EXTRACTION_DB_PATH", "") + if not os.path.exists(db_path) and db_path.strip() != "": + db_path = None + if db_path is None: + os.makedirs(output_dir_path, exist_ok=True) + db_path = os.path.join(output_dir_path, "lean4_extraction_db.sqlite") + transform1 = Local4DataExtractionTransform( experiment.run_settings.dep_depth, buffer_size=experiment.run_settings.buffer_size, - logger=logger) + logger=logger, + db_path=db_path, + enable_file_export=False, + enable_dependency_extraction=False) + transforms.append(transform1) # First just add the definitions + transform = Local4DataExtractionTransform( + experiment.run_settings.dep_depth, + buffer_size=experiment.run_settings.buffer_size, + logger=logger, + db_path=db_path, + enable_file_export=True, + enable_dependency_extraction=True + ) # This will be later appended to the transforms list else: transform = Local4DataGenerationTransform( experiment.run_settings.dep_depth, @@ -320,8 +363,10 @@ def get_decl_lemmas_to_parse( if experiment.benchmark.language == ProofAction.Language.LEAN4 \ and experiment.benchmark.is_extraction_request: if len(dataset.files) == 0: + logger.warning(f"No files specified for Lean4 extraction in dataset {dataset.project}, extracting from all Lean4 files in the project") # List all the files recursively in the project folder - files_in_dataset = _get_all_lean_files_in_folder_recursively(dataset.project) + files_in_dataset = _get_all_lean_files_in_folder_recursively(dataset.project, dataset.exclude_files, dataset.include_files) + logger.info(f"Found {len(files_in_dataset)} Lean4 files in the project {dataset.project}") for file_path in files_in_dataset: file_to_theorems[file_path] = ["*"] file_args[file_path] = {} @@ -459,7 +504,7 @@ def run_data_generation_pipeline(experiment: Experiments, log_dir: str, checkpoi transforms = [] str_time = time.strftime("%Y%m%d-%H%M%S") clone_dir = os.path.join(experiment.run_settings.output_dir, "clone{}".format(str_time)) - clone_dir = add_transform(experiment, clone_dir, resources, transforms, logger) + clone_dir = add_transform(experiment, clone_dir, resources, transforms, logger, str_time) # Find all the lemmas to prove project_to_theorems = {} other_args = {} @@ -592,4 +637,6 @@ def main(cfg): if __name__ == "__main__": # from itp_interface.tools.ray_utils import RayUtils # RayUtils.init_ray(num_of_cpus=20, object_store_memory_in_gb=50, memory_in_gb=1, runtime_env={"working_dir": root_dir, "excludes": [".log", "data"]}) + from itp_interface.tools.ray_utils import RayUtils + RayUtils.init_ray(num_of_cpus=20, object_store_memory_in_gb=100, memory_in_gb=50) main() \ No newline at end of file diff --git a/src/itp_interface/tools/dynamic_lean4_proof_exec.py b/src/itp_interface/tools/dynamic_lean4_proof_exec.py index 5417e83..86cd9c9 100644 --- a/src/itp_interface/tools/dynamic_lean4_proof_exec.py +++ b/src/itp_interface/tools/dynamic_lean4_proof_exec.py @@ -6,7 +6,7 @@ import copy import enum import logging -from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor +from itp_interface.lean.simple_lean4_sync_executor import SimpleLean4SyncExecutor from itp_interface.tools.training_data_format import Goal, TheoremProvingTrainingDataFormat from itp_interface.tools.lean_parse_utils import LeanLineByLineReader from itp_interface.tools.lean_context_helper import Lean3ContextHelper diff --git a/src/itp_interface/tools/lean4_local_data_extraction_transform.py b/src/itp_interface/tools/lean4_local_data_extraction_transform.py deleted file mode 100644 index 9205638..0000000 --- a/src/itp_interface/tools/lean4_local_data_extraction_transform.py +++ /dev/null @@ -1,110 +0,0 @@ -#!/usr/bin/env python3 - -import os -import sys -dir_name = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) -root_dir = os.path.abspath(dir_name) -if root_dir not in sys.path: - sys.path.append(root_dir) -import typing -import uuid -from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor -from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType -from itp_interface.tools.training_data_format import MergableCollection, TrainingDataMetadataFormat, ExtractionDataCollection, TheoremProvingTrainingDataFormat -from itp_interface.tools.training_data import TrainingData, DataLayoutFormat - -class Local4DataExtractionTransform(GenericTrainingDataGenerationTransform): - def __init__(self, - depth = None, - max_search_results = None, - buffer_size : int = 10000, - logger = None, - max_parallelism : int = 4): - super().__init__(TrainingDataGenerationType.LOCAL, buffer_size, logger) - self.depth = depth - self.max_search_results = max_search_results - self.max_parallelism = max_parallelism - - def get_meta_object(self) -> TrainingDataMetadataFormat: - return TrainingDataMetadataFormat( - training_data_buffer_size=self.buffer_size, - data_filename_prefix="extraction_data_", - lemma_ref_filename_prefix="extraction_lemma_refs_") - - def get_data_collection_object(self) -> MergableCollection: - return ExtractionDataCollection() - - def load_meta_from_file(self, file_path) -> MergableCollection: - return TrainingDataMetadataFormat.load_from_file(file_path) - - def load_data_from_file(self, file_path) -> MergableCollection: - return ExtractionDataCollection.load_from_file(file_path, self.logger) - - def __call__(self, - training_data: TrainingData, - project_id : str, - lean_executor: SimpleLean4SyncExecutor, - print_coq_executor_callback: typing.Callable[[], SimpleLean4SyncExecutor], - theorems: typing.List[str] = None, - other_args: dict = {}) -> TrainingData: - file_namespace = lean_executor.main_file.replace('/', '.') - self.logger.info(f"=========================Processing {file_namespace}=========================") - theorem_id = str(uuid.uuid4()) - if isinstance(theorems, list) and len(theorems) == 1 and theorems[0] == "*": - theorems = None - else: - theorems = set(theorems) if theorems is not None else None - cnt = 0 - temp_dir = os.path.join(training_data.folder, "temp") - os.makedirs(temp_dir, exist_ok=True) - json_output_path = f"{temp_dir}/{file_namespace.replace('.', '_')}.lean.deps.json" - file_dep_analyses = lean_executor.extract_all_theorems_and_definitions(json_output_path=json_output_path) - self.logger.info(f"Extracted {len(file_dep_analyses)} FileDependencyAnalysis objects from {file_namespace}") - self.logger.info(f"file_dep_analyses: {file_dep_analyses}") - assert len(file_dep_analyses) == 1, "Expected exactly one FileDependencyAnalysis object" - file_dep_analysis = file_dep_analyses[0] - for decls in file_dep_analysis.declarations: - line_info = decls.decl_info - if theorems is not None and line_info.name not in theorems: - continue - training_data.merge(decls) - cnt += 1 - training_data.meta.last_proof_id = theorem_id - self.logger.info(f"===============Finished processing {file_namespace}=====================") - self.logger.info(f"Total declarations processed in this transform: {cnt}") - return training_data - - -if __name__ == "__main__": - import os - import logging - import time - os.chdir(root_dir) - # project_dir = 'data/test/lean4_proj/' - project_dir = 'data/test/Mathlib' - # file_name = 'data/test/lean4_proj/Lean4Proj/Basic.lean' - file_name = 'data/test/Mathlib/.lake/packages/mathlib/Mathlib/Algebra/Divisibility/Basic.lean' - project_id = project_dir.replace('/', '.') - time_str = time.strftime("%Y%m%d-%H%M%S") - output_path = f".log/local_data_generation_transform/data/{time_str}" - log_path = f".log/local_data_generation_transform/log/{time_str}" - log_file = f"{log_path}/local_data_generation_transform-{time_str}.log" - os.makedirs(output_path, exist_ok=True) - os.makedirs(log_path, exist_ok=True) - logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') - logger = logging.getLogger(__name__) - def _print_lean_executor_callback(): - search_lean_exec = SimpleLean4SyncExecutor(main_file=file_name, project_root=project_dir) - search_lean_exec.__enter__() - return search_lean_exec - transform = Local4DataExtractionTransform(0, buffer_size=1000) - training_data = TrainingData( - output_path, - "training_metadata.json", - training_meta=transform.get_meta_object(), - logger=logger, - layout=DataLayoutFormat.DECLARATION_EXTRACTION) - with SimpleLean4SyncExecutor(project_root=project_dir, main_file=file_name, use_human_readable_proof_context=True, suppress_error_log=True) as coq_exec: - transform(training_data, project_id, coq_exec, _print_lean_executor_callback, theorems=["*"]) - save_info = training_data.save() - logger.info(f"Saved training data to {save_info}") \ No newline at end of file diff --git a/src/itp_interface/tools/lean4_local_data_generation_transform.py b/src/itp_interface/tools/lean4_local_data_generation_transform.py index 4854b02..8550ff3 100644 --- a/src/itp_interface/tools/lean4_local_data_generation_transform.py +++ b/src/itp_interface/tools/lean4_local_data_generation_transform.py @@ -6,7 +6,7 @@ sys.path.append(root_dir) import typing import uuid -from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor +from itp_interface.lean.simple_lean4_sync_executor import SimpleLean4SyncExecutor from itp_interface.tools.lean4_context_helper import Lean4ContextHelper from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType from itp_interface.tools.training_data_format import MergableCollection, TrainingDataMetadataFormat, TheoremProvingTrainingDataCollection, TheoremProvingTrainingDataFormat diff --git a/src/itp_interface/tools/proof_exec_callback.py b/src/itp_interface/tools/proof_exec_callback.py index bcb9744..9c3b031 100644 --- a/src/itp_interface/tools/proof_exec_callback.py +++ b/src/itp_interface/tools/proof_exec_callback.py @@ -15,7 +15,7 @@ from itp_interface.tools.coq_executor import CoqExecutor from itp_interface.tools.lean_cmd_executor import Lean3Executor from itp_interface.tools.lean4_sync_executor import Lean4SyncExecutor -from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor +from itp_interface.lean.simple_lean4_sync_executor import SimpleLean4SyncExecutor from itp_interface.tools.isabelle_executor import IsabelleExecutor from itp_interface.tools.dynamic_coq_proof_exec import DynamicProofExecutor as DynamicCoqProofExecutor from itp_interface.tools.dynamic_lean_proof_exec import DynamicProofExecutor as DynamicLeanProofExecutor diff --git a/src/itp_interface/tools/ray_utils.py b/src/itp_interface/tools/ray_utils.py index 807ef7d..d462740 100644 --- a/src/itp_interface/tools/ray_utils.py +++ b/src/itp_interface/tools/ray_utils.py @@ -14,12 +14,41 @@ class RayUtils(object): + @staticmethod + def connect_to_ray(): + if os.environ.get("RAY_INITIALIZED", "0") == "0": + return None + root_dir = f"{os.path.abspath(__file__).split('itp_interface')[-2]}" + os.environ["PYTHONPATH"] = f"{root_dir}:{os.environ.get('PYTHONPATH', '')}" + from filelock import FileLock + import json + os.environ["RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE"] = "1" + lock_path = os.environ.get("RAY_LOCK_FILE_PATH", ".log/locks/ray.lock") + ray_session_path = os.environ.get("RAY_SESSION_PATH", ".log/ray/session_latest") + temp_lock = FileLock(lock_path) + try: + temp_lock.acquire(timeout=10) + temp_lock.release() + except: + if os.path.exists(ray_session_path): + with open(ray_session_path, "r") as f: + ray_session = f.read() + ray_session = json.loads(ray_session) + ray_address = ray_session["address"] + obj = ray.init(address=ray_address) + return obj + return None + @staticmethod def init_ray(num_of_cpus: int = 10, object_store_memory_in_gb: float = 25, memory_in_gb: float = 0.5, runtime_env: typing.Dict[str, str] = None): gb = 2**30 object_store_memory = int(object_store_memory_in_gb * gb) memory = int(memory_in_gb * gb) + obj = RayUtils.connect_to_ray() + if obj is not None: + return obj os.environ["RAY_INITIALIZED"] = "1" + os.environ["RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE"] = "1" obj = ray.init(num_cpus=num_of_cpus, object_store_memory=object_store_memory, _memory=memory, ignore_reinit_error=True, runtime_env=runtime_env) return obj diff --git a/src/itp_interface/tools/repl b/src/itp_interface/tools/repl index 8fff855..ebeaf40 160000 --- a/src/itp_interface/tools/repl +++ b/src/itp_interface/tools/repl @@ -1 +1 @@ -Subproject commit 8fff8552292860d349b459d6a811e6915671dc0d +Subproject commit ebeaf409fc93d64d5a52652c87e74bfb64805a66 diff --git a/src/itp_interface/tools/run_data_generation_transforms.py b/src/itp_interface/tools/run_data_generation_transforms.py index 4701c36..bacbb25 100644 --- a/src/itp_interface/tools/run_data_generation_transforms.py +++ b/src/itp_interface/tools/run_data_generation_transforms.py @@ -10,6 +10,7 @@ import typing import shutil import gc +from pathlib import Path from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError from itp_interface.tools.training_data import TrainingData, DataLayoutFormat @@ -24,12 +25,12 @@ RayUtils = None from itp_interface.tools.coq_executor import CoqExecutor from itp_interface.tools.lean_cmd_executor import Lean3Executor -from itp_interface.tools.simple_lean4_sync_executor import SimpleLean4SyncExecutor +from itp_interface.lean.simple_lean4_sync_executor import SimpleLean4SyncExecutor from itp_interface.tools.isabelle_executor import IsabelleExecutor from itp_interface.tools.coq_local_data_generation_transform import LocalDataGenerationTransform as CoqLocalDataGenerationTransform from itp_interface.tools.lean_local_data_generation_transform import LocalDataGenerationTransform as LeanLocalDataGenerationTransform from itp_interface.tools.lean4_local_data_generation_transform import Local4DataGenerationTransform as Lean4LocalDataGenerationTransform -from itp_interface.tools.lean4_local_data_extraction_transform import Local4DataExtractionTransform as Lean4LocalDataExtractionTransform +from itp_interface.lean.lean4_local_data_extraction_transform import Local4DataExtractionTransform as Lean4LocalDataExtractionTransform from itp_interface.tools.isabelle_local_data_generation_transform import LocalDataGenerationTransform as IsabelleLocalDataGenerationTransform from itp_interface.tools.coq_training_data_generator import GenericTrainingDataGenerationTransform, TrainingDataGenerationType @@ -290,6 +291,7 @@ def run_local_transform(self, pool_size: int , transform: typing.Union[CoqLocalD job_idx = 0 project_names = list(projects.keys()) project_names.sort() + file_index = 0 for project in project_names: # Create temporary directory for each project proj_name = os.path.basename(project) @@ -305,9 +307,26 @@ def run_local_transform(self, pool_size: int , transform: typing.Union[CoqLocalD some_files_processed = True job_more_args = file_args.get(file_path, {}) # Create temporary directory for each file - full_file_path = os.path.join(project_path, file_path) - relative_file_path = file_path - relative_file_path = relative_file_path.replace("/", ".").replace(".v", "").replace(".lean", "").replace(".thy", "") + fp = Path(file_path) + pp = Path(project_path) + if fp.is_absolute(): + # then leave it as is + full_file_path = file_path + else: + # Check if the file path is relative to the project path + fp = fp.resolve() + pp = pp.resolve() + # Now check if fp is relative to pp + fp_is_rel_to_pp = fp.is_relative_to(pp) + if fp_is_rel_to_pp: + full_file_path = str(fp.relative_to(pp)) + else: + # Just make it relative to project path + full_file_path = os.path.join(project_path, file_path) + assert os.path.exists(full_file_path), f"File path {full_file_path} does not exist" + file_index += 1 + + relative_file_path = f"transformer_{file_index}" temp_file_dir = os.path.join(temp_project_dir, relative_file_path) os.makedirs(temp_file_dir, exist_ok=True) log_file = os.path.join(self.logging_dir, f"{relative_file_path}.log") @@ -410,10 +429,13 @@ def _transform_output(results): shutil.rmtree(temp_output_dir) def run_all_local_transforms(self, pool_size: int, projects: typing.Dict[str, typing.Dict[str, str]], use_human_readable: bool, new_output_dir: str, log_error: bool, other_args: typing.Dict[str, typing.Dict[str, dict]] = {}): + os.makedirs(new_output_dir, exist_ok=True) for idx, transform in enumerate(self.transforms): last_transform = idx == len(self.transforms) - 1 save_transform = self.save_intermidiate_transforms or last_transform - self.run_local_transform(pool_size, transform, projects, use_human_readable, new_output_dir, log_error, save_transform, preserve_temp=self.save_intermidiate_transforms, other_args=other_args) + temp_new_output_dir = str(Path(new_output_dir) / str(idx)) + os.makedirs(temp_new_output_dir, exist_ok=True) + self.run_local_transform(pool_size, transform, projects, use_human_readable, temp_new_output_dir, log_error, save_transform, preserve_temp=self.save_intermidiate_transforms, other_args=other_args) pass # Create Ray remote version if Ray is available diff --git a/src/itp_interface/tools/simple_sqlite.py b/src/itp_interface/tools/simple_sqlite.py new file mode 100644 index 0000000..fd8aa47 --- /dev/null +++ b/src/itp_interface/tools/simple_sqlite.py @@ -0,0 +1,877 @@ +""" +Simple SQLite database for storing and querying Lean declaration data. + +This module provides a thread-safe SQLite database for storing Lean declarations, +their dependencies, and file metadata. Designed to work with Ray actors processing +files in parallel. +""" + +import sqlite3 +import uuid +from typing import List, Dict, Any, Optional +from contextlib import contextmanager + + +class LeanDeclarationDB: + """ + Thread-safe SQLite database for storing Lean file and declaration information. + + Key features: + - Automatic ID assignment on first discovery (declaration or dependency) + - Simplified dependency storage (edges only: A depends on B) + - Thread-safe operations with WAL mode for concurrent Ray actors + - Idempotent operations for safe parallel execution + """ + + def __init__(self, db_path: str, timeout: float = 30.0): + """ + Initialize the database connection. + + Args: + db_path: Path to the SQLite database file + timeout: Timeout in seconds for database locks (default 30s) + """ + self.db_path = db_path + self.timeout = timeout + self.conn = sqlite3.connect(db_path, timeout=timeout, check_same_thread=False) + # Explicitly set UTF-8 encoding + self.conn.execute("PRAGMA encoding = 'UTF-8'") + self.conn.row_factory = sqlite3.Row + # Enable WAL mode BEFORE creating tables for better concurrency + self.enable_wal_mode() + # Create tables (safe with IF NOT EXISTS even from multiple actors) + self._create_tables() + + def _generate_unique_id(self) -> str: + """ + Generate a unique ID for a declaration. + + Format: {timestamp}_{uuid4} + Same format as used in lean4_local_data_extraction_transform.py + + Returns: + Unique identifier string + """ + timestamp = str(int(uuid.uuid1().time_low)) + random_id = str(uuid.uuid4()) + return f"{timestamp}_{random_id}" + + def enable_wal_mode(self): + """ + Enable Write-Ahead Logging mode for better concurrent write performance. + + WAL mode allows multiple readers and one writer to access the database + simultaneously, which is essential for Ray actors processing files in parallel. + """ + self.conn.execute("PRAGMA journal_mode=WAL") + self.conn.execute("PRAGMA synchronous=NORMAL") + self.conn.commit() + + def _create_tables(self): + """ + Create the database schema with proper indexes and constraints. + + Safe for concurrent execution from multiple Ray actors: + - Uses CREATE TABLE IF NOT EXISTS (idempotent) + - Uses CREATE INDEX IF NOT EXISTS (idempotent) + - WAL mode is enabled before this is called + - 30s timeout handles lock contention + """ + cursor = self.conn.cursor() + + # Files table - stores file metadata + cursor.execute(""" + CREATE TABLE IF NOT EXISTS files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_path TEXT UNIQUE NOT NULL, + module_name TEXT NOT NULL + ) + """) + + # Imports table - stores file import relationships + cursor.execute(""" + CREATE TABLE IF NOT EXISTS imports ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + file_id INTEGER NOT NULL, + end_pos INTEGER, + module_name TEXT, + start_pos INTEGER, + text TEXT, + FOREIGN KEY (file_id) REFERENCES files(id) ON DELETE CASCADE + ) + """) + + # Declarations table - stores all declarations (complete or partial) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS declarations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + decl_id TEXT UNIQUE NOT NULL, + name TEXT NOT NULL, + namespace TEXT, + file_path TEXT, + module_name TEXT, + decl_type TEXT, + text TEXT, + line INTEGER, + column INTEGER, + end_line INTEGER, + end_column INTEGER, + doc_string TEXT, + proof TEXT, + -- A declaration is uniquely identified by its name and location + UNIQUE(name, namespace, file_path, module_name) + ) + """) + + # Simplified dependencies table - stores only edges (A depends on B) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS declaration_dependencies ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + from_decl_id TEXT NOT NULL, + to_decl_id TEXT NOT NULL, + UNIQUE(from_decl_id, to_decl_id), + FOREIGN KEY (from_decl_id) REFERENCES declarations(decl_id) ON DELETE CASCADE, + FOREIGN KEY (to_decl_id) REFERENCES declarations(decl_id) ON DELETE CASCADE + ) + """) + + # Create indexes for faster queries + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_files_path + ON files(file_path) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_files_module + ON files(module_name) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_declarations_name + ON declarations(name) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_declarations_namespace + ON declarations(namespace) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_declarations_decl_id + ON declarations(decl_id) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_declarations_lookup + ON declarations(name, namespace, file_path, module_name) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_dependencies_from + ON declaration_dependencies(from_decl_id) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_dependencies_to + ON declaration_dependencies(to_decl_id) + """) + + self.conn.commit() + + @contextmanager + def transaction(self): + """ + Context manager for database transactions. + + Usage: + with db.transaction(): + db.insert_something(...) + db.insert_something_else(...) + """ + try: + yield self.conn + self.conn.commit() + except Exception as e: + self.conn.rollback() + raise e + + def get_or_create_decl_id( + self, + name: str, + namespace: Optional[str] = None, + file_path: Optional[str] = None, + module_name: Optional[str] = None, + assert_exists: bool = False + ) -> str: + """ + Get existing decl_id or create a new one for a declaration. + + This is the core method for ID assignment. IDs are assigned as soon as + a declaration is discovered (either as a dependency or as a declaration itself). + + Args: + name: Declaration name (required) + namespace: Namespace (can be None) + file_path: File path (can be None for unresolved dependencies) + module_name: Module name (can be None for unresolved dependencies) + + Returns: + The decl_id (existing or newly created) + """ + cursor = self.conn.cursor() + + # Try to find existing declaration + # Handle NULL values properly in SQL + cursor.execute(""" + SELECT decl_id FROM declarations + WHERE name = ? + AND (namespace IS ? OR (namespace IS NULL AND ? IS NULL) OR (namespace IS NOT NULL AND ? IS NULL)) + AND (file_path IS ? OR (file_path IS NULL AND ? IS NULL) OR (file_path IS NOT NULL AND ? IS NULL)) + AND (module_name IS ? OR (module_name IS NULL AND ? IS NULL) OR (module_name IS NOT NULL AND ? IS NULL)) + """, (name, + namespace, namespace, namespace, + file_path, file_path, file_path, + module_name, module_name, module_name)) + + row = cursor.fetchone() + if row: + return row[0] # Return existing ID + + if assert_exists: + raise ValueError(f"Declaration not found: name={name}, namespace={namespace}, file_path={file_path}, module_name={module_name}") + + # Generate new ID and insert minimal record + new_decl_id = self._generate_unique_id() + + try: + cursor.execute(""" + INSERT INTO declarations (decl_id, name, namespace, file_path, module_name) + VALUES (?, ?, ?, ?, ?) + """, (new_decl_id, name, namespace, file_path, module_name)) + self.conn.commit() + except sqlite3.IntegrityError: + # Race condition: another process inserted it between our SELECT and INSERT + # Query again to get the existing ID + cursor.execute(""" + SELECT decl_id FROM declarations + WHERE name = ? + AND (namespace IS ? OR (namespace IS NULL AND ? IS NULL) OR (namespace IS NOT NULL AND ? IS NULL)) + AND (file_path IS ? OR (file_path IS NULL AND ? IS NULL) OR (file_path IS NOT NULL AND ? IS NULL)) + AND (module_name IS ? OR (module_name IS NULL AND ? IS NULL) OR (module_name IS NOT NULL AND ? IS NULL)) + """, (name, + namespace, namespace, namespace, + file_path, file_path, file_path, + module_name, module_name, module_name)) + row = cursor.fetchone() + if row: + return row[0] + else: + # This shouldn't happen, but raise if it does + raise + + return new_decl_id + + def upsert_declaration_full_info( + self, + decl_id: str, + name: str, + namespace: Optional[str], + file_path: str, + module_name: str, + decl_type: Optional[str] = None, + text: Optional[str] = None, + line: Optional[int] = None, + column: Optional[int] = None, + end_line: Optional[int] = None, + end_column: Optional[int] = None, + doc_string: Optional[str] = None, + proof: Optional[str] = None + ): + """ + Update a declaration with complete information. + + This is called when we process the actual declaration (not just a reference). + Updates the record with full metadata. + + Args: + decl_id: The declaration ID + name: Declaration name + namespace: Namespace + file_path: File path + module_name: Module name + decl_type: Declaration type (theorem, def, etc.) + text: Full declaration text + line: Starting line number + column: Starting column number + end_line: Ending line number + end_column: Ending column number + doc_string: Documentation string + proof: Proof text + """ + cursor = self.conn.cursor() + + cursor.execute(""" + UPDATE declarations + SET decl_type = ?, + text = ?, + line = ?, + column = ?, + end_line = ?, + end_column = ?, + doc_string = ?, + proof = ? + WHERE decl_id = ? + """, (decl_type, text, line, column, end_line, end_column, doc_string, proof, decl_id)) + + self.conn.commit() + + def insert_dependency_edge(self, from_decl_id: str, to_decl_id: str): + """ + Insert a dependency edge: from_decl_id depends on to_decl_id. + + Uses INSERT OR IGNORE for idempotency (safe to call multiple times). + + Args: + from_decl_id: The declaration that has the dependency + to_decl_id: The declaration being depended on + """ + cursor = self.conn.cursor() + + cursor.execute(""" + INSERT OR IGNORE INTO declaration_dependencies (from_decl_id, to_decl_id) + VALUES (?, ?) + """, (from_decl_id, to_decl_id)) + + self.conn.commit() + + def process_fda_list(self, fda_list: List) -> List[str]: + """ + Process a list of FileDependencyAnalysis objects. + + Args: + fda_list: List of FileDependencyAnalysis objects + + Returns: + List of all decl_ids that were processed + """ + all_decl_ids = [] + + for fda in fda_list: + # Insert file and imports first + if fda.imports: + self.insert_file_imports(fda.file_path, fda.module_name, fda.imports) + + # Process each declaration in this file + for decl in fda.declarations: + decl_id = self.process_declaration(fda.file_path, fda.module_name, decl) + all_decl_ids.append(decl_id) + + return all_decl_ids + + def process_declaration( + self, + fda_file_path: str, + fda_module_name: str, + decl, + enable_dependency_extraction: bool = True + ) -> str: + """ + Process a declaration from a FileDependencyAnalysis object. + + This is the main high-level method that: + 1. Gets or creates decl_id for this declaration + 2. Updates full declaration info + 3. Processes all dependencies and creates edges + + Args: + fda_file_path: File path from FileDependencyAnalysis + fda_module_name: Module name from FileDependencyAnalysis + decl: DeclWithDependencies object + + Returns: + The decl_id for this declaration + """ + # Get or create ID for this declaration + decl_id = self.get_or_create_decl_id( + name=decl.decl_info.name, + namespace=decl.decl_info.namespace, + file_path=fda_file_path, + module_name=fda_module_name, + # If we are extracting dependencies, we expect the declaration to exist + assert_exists=enable_dependency_extraction + ) + + if not enable_dependency_extraction: + # Update with full declaration info + # This is a new declaration, so we can safely update all info + self.upsert_declaration_full_info( + decl_id=decl_id, + name=decl.decl_info.name, + namespace=decl.decl_info.namespace, + file_path=fda_file_path, + module_name=fda_module_name, + decl_type=decl.decl_info.decl_type, + text=decl.decl_info.text, + line=decl.decl_info.line, + column=decl.decl_info.column, + end_line=decl.decl_info.end_line, + end_column=decl.decl_info.end_column, + doc_string=decl.decl_info.doc_string, + proof=decl.decl_info.proof + ) + else: + # Process dependencies + for dep in decl.dependencies: + # Get or create ID for the dependency + try: + dep_decl_id = self.get_or_create_decl_id( + name=dep.name, + namespace=dep.namespace, + file_path=dep.file_path, + module_name=dep.module_name, + # At this point all dependencies should exist + assert_exists=True + ) + except ValueError as e: + # This dependency is probably from outside the project + # Let's just skip it + continue + + # Propagate the decl_id back to the dependency object + dep.decl_id = dep_decl_id + + # Insert dependency edge + self.insert_dependency_edge(decl_id, dep_decl_id) + + return decl_id + + def get_or_create_file(self, file_path: str, module_name: str) -> int: + """ + Get or create a file record. + + Args: + file_path: The file path + module_name: The module name + + Returns: + The file_id (existing or newly created) + """ + cursor = self.conn.cursor() + + cursor.execute("SELECT id FROM files WHERE file_path = ?", (file_path,)) + row = cursor.fetchone() + if row: + return row[0] + + cursor.execute(""" + INSERT INTO files (file_path, module_name) + VALUES (?, ?) + """, (file_path, module_name)) + self.conn.commit() + + file_id = cursor.lastrowid + if file_id is None: + raise RuntimeError("Failed to get file_id after insert") + return file_id + + def insert_file_imports(self, file_path: str, module_name: str, imports: List[Dict]): + """ + Insert imports for a file. + + Args: + file_path: The file path + module_name: The module name + imports: List of import dictionaries + """ + file_id = self.get_or_create_file(file_path, module_name) + cursor = self.conn.cursor() + + for import_data in imports: + cursor.execute(""" + INSERT INTO imports (file_id, end_pos, module_name, start_pos, text) + VALUES (?, ?, ?, ?, ?) + """, (file_id, import_data.get('end_pos'), import_data.get('module_name'), + import_data.get('start_pos'), import_data.get('text'))) + + self.conn.commit() + + # Query methods + + def get_declaration_by_decl_id(self, decl_id: str) -> Optional[Dict[str, Any]]: + """ + Get a declaration by its decl_id. + + Args: + decl_id: The unique declaration ID + + Returns: + Dictionary containing declaration information or None + """ + cursor = self.conn.cursor() + cursor.execute("SELECT * FROM declarations WHERE decl_id = ?", (decl_id,)) + row = cursor.fetchone() + return dict(row) if row else None + + def get_declarations_by_name( + self, + name: str, + namespace: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + Get declarations by name and optionally namespace. + + Args: + name: The declaration name + namespace: Optional namespace to filter by + + Returns: + List of dictionaries containing declaration information + """ + cursor = self.conn.cursor() + + if namespace is not None: + cursor.execute(""" + SELECT * FROM declarations + WHERE name = ? AND namespace = ? + """, (name, namespace)) + else: + cursor.execute(""" + SELECT * FROM declarations + WHERE name = ? + """, (name,)) + + return [dict(row) for row in cursor.fetchall()] + + def get_declarations_by_file(self, file_path: str) -> List[Dict[str, Any]]: + """ + Get all declarations in a specific file. + + Args: + file_path: The file path + + Returns: + List of dictionaries containing declaration information + """ + cursor = self.conn.cursor() + cursor.execute(""" + SELECT * FROM declarations + WHERE file_path = ? + ORDER BY line + """, (file_path,)) + + return [dict(row) for row in cursor.fetchall()] + + def get_declarations_by_module(self, module_name: str) -> List[Dict[str, Any]]: + """ + Get all declarations in a specific module. + + Args: + module_name: The module name + + Returns: + List of dictionaries containing declaration information + """ + cursor = self.conn.cursor() + cursor.execute(""" + SELECT * FROM declarations + WHERE module_name = ? + """, (module_name,)) + + return [dict(row) for row in cursor.fetchall()] + + def get_dependencies(self, decl_id: str) -> List[Dict[str, Any]]: + """ + Get all declarations that this declaration depends on. + + Args: + decl_id: The declaration ID + + Returns: + List of declarations that this declaration depends on + """ + cursor = self.conn.cursor() + cursor.execute(""" + SELECT d.* FROM declarations d + JOIN declaration_dependencies dd ON dd.to_decl_id = d.decl_id + WHERE dd.from_decl_id = ? + """, (decl_id,)) + + return [dict(row) for row in cursor.fetchall()] + + def get_dependents(self, decl_id: str) -> List[Dict[str, Any]]: + """ + Get all declarations that depend on this declaration. + + Args: + decl_id: The declaration ID + + Returns: + List of declarations that depend on this one + """ + cursor = self.conn.cursor() + cursor.execute(""" + SELECT d.* FROM declarations d + JOIN declaration_dependencies dd ON dd.from_decl_id = d.decl_id + WHERE dd.to_decl_id = ? + """, (decl_id,)) + + return [dict(row) for row in cursor.fetchall()] + + def get_dependency_graph( + self, + decl_id: str, + max_depth: int = 10, + direction: str = 'dependencies' + ) -> Dict[str, Any]: + """ + Get the dependency graph for a declaration (recursive). + + Args: + decl_id: The declaration ID to start from + max_depth: Maximum depth to traverse (default 10) + direction: 'dependencies' (what this depends on) or 'dependents' (what depends on this) + + Returns: + Dictionary containing the dependency graph + """ + visited = set() + + def _get_graph_recursive(current_decl_id: str, depth: int) -> Optional[Dict[str, Any]]: + if depth > max_depth or current_decl_id in visited: + return None + + visited.add(current_decl_id) + + decl = self.get_declaration_by_decl_id(current_decl_id) + if not decl: + return None + + if direction == 'dependencies': + related = self.get_dependencies(current_decl_id) + else: + related = self.get_dependents(current_decl_id) + + result = { + 'declaration': decl, + direction: [] + } + + for rel_decl in related: + sub_graph = _get_graph_recursive(rel_decl['decl_id'], depth + 1) + if sub_graph: + result[direction].append(sub_graph) + + return result + + graph = _get_graph_recursive(decl_id, 0) + return graph if graph else {} + + def search_declarations( + self, + name: Optional[str] = None, + namespace: Optional[str] = None, + file_path: Optional[str] = None, + module_name: Optional[str] = None, + decl_type: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + Search declarations with multiple optional filters. + + Args: + name: Declaration name (optional) + namespace: Namespace (optional) + file_path: File path (optional) + module_name: Module name (optional) + decl_type: Declaration type (optional) + + Returns: + List of matching declarations + """ + cursor = self.conn.cursor() + + query = "SELECT * FROM declarations WHERE 1=1" + params = [] + + if name is not None: + query += " AND name = ?" + params.append(name) + + if namespace is not None: + query += " AND namespace = ?" + params.append(namespace) + + if file_path is not None: + query += " AND file_path = ?" + params.append(file_path) + + if module_name is not None: + query += " AND module_name = ?" + params.append(module_name) + + if decl_type is not None: + query += " AND decl_type = ?" + params.append(decl_type) + + cursor.execute(query, params) + return [dict(row) for row in cursor.fetchall()] + + def get_statistics(self) -> Dict[str, int]: + """ + Get database statistics. + + Returns: + Dictionary with counts of files, declarations, and dependencies + """ + cursor = self.conn.cursor() + + stats = {} + + try: + cursor.execute("SELECT COUNT(*) FROM files") + result = cursor.fetchone() + stats['total_files'] = result[0] if result else 0 + except Exception as e: + print(f"Warning: Error fetching files count: {e}") + stats['total_files'] = 0 + + try: + cursor.execute("SELECT COUNT(*) FROM declarations") + result = cursor.fetchone() + stats['total_declarations'] = result[0] if result else 0 + except Exception as e: + print(f"Warning: Error fetching declarations count: {e}") + stats['total_declarations'] = 0 + + try: + cursor.execute("SELECT COUNT(*) FROM declaration_dependencies") + result = cursor.fetchone() + stats['total_dependencies'] = result[0] if result else 0 + except Exception as e: + print(f"Warning: Error fetching dependencies count: {e}") + stats['total_dependencies'] = 0 + + try: + cursor.execute("SELECT COUNT(*) FROM declarations WHERE file_path IS NULL") + result = cursor.fetchone() + stats['unresolved_declarations'] = result[0] if result else 0 + except Exception as e: + print(f"Warning: Error fetching unresolved declarations count: {e}") + stats['unresolved_declarations'] = 0 + + try: + cursor.execute("SELECT COUNT(*) FROM imports") + result = cursor.fetchone() + stats['total_imports'] = result[0] if result else 0 + except Exception as e: + print(f"Warning: Error fetching imports count: {e}") + stats['total_imports'] = 0 + + return stats + + def close(self): + """Close the database connection.""" + if self.conn: + self.conn.close() + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.close() + return False + + +if __name__ == "__main__": + # Example usage + print("Creating test database...") + + with LeanDeclarationDB("test_lean_declarations.db") as db: + # Test with file and imports + print("\nTesting file and imports insertion:") + test_imports = [ + { + "end_pos": 292, + "module_name": "Mathlib.Algebra.Group.Basic", + "start_pos": 258, + "text": "import Mathlib.Algebra.Group.Basic" + }, + { + "end_pos": 321, + "module_name": "Mathlib.Tactic.Common", + "start_pos": 293, + "text": "import Mathlib.Tactic.Common" + } + ] + + db.insert_file_imports( + file_path="Mathlib/Algebra/Divisibility/Basic.lean", + module_name="Mathlib.Algebra.Divisibility.Basic", + imports=test_imports + ) + print("Inserted file and imports") + + # Test ID generation + print("\nTesting ID generation:") + decl_id1 = db.get_or_create_decl_id( + name="dvd_trans", + namespace=None, + file_path="Mathlib/Algebra/Divisibility/Basic.lean", + module_name="Mathlib.Algebra.Divisibility.Basic" + ) + print(f"Generated ID for dvd_trans: {decl_id1}") + + # Get same ID again (should return existing) + decl_id2 = db.get_or_create_decl_id( + name="dvd_trans", + namespace=None, + file_path="Mathlib/Algebra/Divisibility/Basic.lean", + module_name="Mathlib.Algebra.Divisibility.Basic" + ) + print(f"Retrieved ID for dvd_trans: {decl_id2}") + print(f"IDs match: {decl_id1 == decl_id2}") + + # Update with full info + db.upsert_declaration_full_info( + decl_id=decl_id1, + name="dvd_trans", + namespace=None, + file_path="Mathlib/Algebra/Divisibility/Basic.lean", + module_name="Mathlib.Algebra.Divisibility.Basic", + decl_type="theorem", + text="@[trans]\ntheorem dvd_trans : a # b ļæ½ b # c ļæ½ a # c", + line=63, + column=0, + end_line=68, + end_column=0, + doc_string=None, + proof="| ļæ½d, h��, ļæ½e, h�� => ļæ½d * e, hļæ½ ļæ½ hļæ½.trans <| mul_assoc a d eļæ½" + ) + print("Updated declaration with full info") + + # Create a dependency + dep_id = db.get_or_create_decl_id( + name="mul_assoc", + namespace=None, + file_path=None, + module_name=None + ) + print(f"\nGenerated ID for dependency mul_assoc: {dep_id}") + + # Insert dependency edge + db.insert_dependency_edge(decl_id1, dep_id) + print("Inserted dependency edge") + + # Query + print("\nQuerying declaration:") + decl = db.get_declaration_by_decl_id(decl_id1) + if decl: + print(f"Declaration: {decl['name']} ({decl['decl_type']})") + else: + print("Declaration not found!") + + print("\nQuerying dependencies:") + deps = db.get_dependencies(decl_id1) + print(f"Dependencies: {[d['name'] for d in deps]}") + + # Statistics + print("\nDatabase statistics:") + stats = db.get_statistics() + for key, value in stats.items(): + print(f" {key}: {value}") + + print("\nTest complete!") diff --git a/src/itp_interface/tools/training_data_format.py b/src/itp_interface/tools/training_data_format.py index e7dabfa..326aee7 100644 --- a/src/itp_interface/tools/training_data_format.py +++ b/src/itp_interface/tools/training_data_format.py @@ -12,7 +12,7 @@ from collections import OrderedDict from typing import List, Optional, Union, runtime_checkable, Protocol from pydantic import BaseModel -from itp_interface.tools.tactic_parser import DeclWithDependencies +from itp_interface.lean.tactic_parser import FileDependencyAnalysis @runtime_checkable class TrainingDataFormat(Protocol): @@ -503,7 +503,7 @@ def load_from_string(json_text: str, logger: logging.Logger = None): return deserialized class ExtractionDataCollection(BaseModel): - training_data: list[DeclWithDependencies] = [] + training_data: list[FileDependencyAnalysis] = [] def __len__(self) -> int: return len(self.training_data) diff --git a/src/test/parsing_helpers_test.py b/src/test/parsing_helpers_test.py new file mode 100644 index 0000000..1726c05 --- /dev/null +++ b/src/test/parsing_helpers_test.py @@ -0,0 +1,69 @@ +import unittest +from itp_interface.lean.parsing_helpers import parse_lean_text, LeanDeclType + + +class ParsingHelpersTest(unittest.TestCase): + def test_lean_declaration_parsing(self): + """Test parsing of various Lean 4 declaration types""" + test_cases = [ + ( + "Simple Theorem", + "@[simp] lemma foo (x : Nat) : x = x := rfl" + ), + ( + "Context and Doc", + "open Algebra.TensorProduct in\n/-- Doc -/\ntheorem left_of_tensor [Module R] : True where\n out := sorry" + ), + ( + "Inductive (No Proof)", + "/-- The base e -/\ninductive ExBase : Type\n| A\n| B" + ), + ( + "Structure (No Proof)", + "structure MyStruct where\n field1 : Nat\n field2 : Int" + ), + ( + "Mutual (No Proof)", + "mutual\n inductive A | a\n inductive B | b\nend" + ), + ( + "Run Cmd (Fallback)", + "open Lean in\nrun_cmd Command.liftTermElabM do\n logInfo \"hi\"" + ), + ( + "Inductive Nat (hard case)", + "set_option genCtorIdx false in\n/--\nThe natural numbers, starting at zero.\n\nThis type is special-cased by both the kernel and the compiler, and overridden with an efficient\nimplementation. Both use a fast arbitrary-precision arithmetic library (usually\n[GMP](https://gmplib.org/)); at runtime, `Nat` values that are sufficiently small are unboxed.\n-/\ninductive Nat where\n /--\n Zero, the smallest natural number.\n\n Using `Nat.zero` explicitly should usually be avoided in favor of the literal `0`, which is the\n [simp normal form](lean-manual://section/simp-normal-forms).\n -/\n | zero : Nat\n /--\n The successor of a natural number `n`.\n\n Using `Nat.succ n` should usually be avoided in favor of `n + 1`, which is the [simp normal\n form](lean-manual://section/simp-normal-forms).\n -/\n | succ (n : Nat) : Nat\n" + ) + ] + + print(f"{'TYPE':<12} | {'NAME':<10} | {'TEXT BEFORE':<20} | {'DOC':<10} | {'TEXT':<20} | {'PROOF':<15}") + print("-" * 115) + + for test_name, inp in test_cases: + res = parse_lean_text(inp) + + tp = res.decl_type.value + nm = res.name or "" + tb = (res.text_before or "").replace('\n', '\\n') + ds = (res.doc_string or "") + tx = (res.text or "").replace('\n', '\\n') + pf = (res.proof or "").replace('\n', '\\n') + + if len(tb) > 18: tb = tb[:18] + "..." + if len(ds) > 8: ds = "/--...-/" + if len(tx) > 18: tx = tx[:18] + "..." + if len(pf) > 12: pf = pf[:12] + "..." + + print(f"{tp:<12} | {nm:<10} | {tb:<20} | {ds:<10} | {tx:<20} | {pf:<15}") + + # Basic assertions to verify parsing works + self.assertIsNotNone(res.decl_type) + self.assertIsInstance(res.decl_type, LeanDeclType) + + +def main(): + unittest.main() + + +if __name__ == '__main__': + main() diff --git a/src/test/simple_data_extract_test.py b/src/test/simple_data_extract_test.py index 36f1dca..07aba42 100644 --- a/src/test/simple_data_extract_test.py +++ b/src/test/simple_data_extract_test.py @@ -17,7 +17,7 @@ def pretty_print_file_contents(dir_path): print(f"Printing all files in the directory: {dir_path}") for f in os.listdir(dir_path): file_path = os.path.join(dir_path, f) - if os.path.isfile(file_path): + if os.path.isfile(file_path) and any(file_path.endswith(ext) for ext in [".json", ".yaml", ".yml", ".txt", ".log"]): print('-'*50) print(f"Contents of {file_path}:") with open(file_path, "r") as file: @@ -63,7 +63,7 @@ def test_lean_data_extract(self): # Print the directory contents last_dir_path = os.path.join(".log/data_generation/benchmark/simple_benchmark_lean_ext", last_dir) print("Last Directory Contents:", os.listdir(last_dir_path)) - train_data = os.path.join(last_dir_path, "train") + train_data = os.path.join(last_dir_path, "train", "1") list_files = os.listdir(train_data) print("Train Directory Contents:", list_files) data_files = [f for f in list_files if f.endswith(".json") and f.startswith("local_data_")] diff --git a/src/test/simple_data_gen_test.py b/src/test/simple_data_gen_test.py index e39ce93..48fd226 100644 --- a/src/test/simple_data_gen_test.py +++ b/src/test/simple_data_gen_test.py @@ -62,7 +62,7 @@ def test_proof_step_data_gen(self): # Print the directory contents last_dir_path = os.path.join(".log/data_generation/benchmark/simple_benchmark_lean", last_dir) print("Last Directory Contents:", os.listdir(last_dir_path)) - train_data = os.path.join(last_dir_path, "train") + train_data = os.path.join(last_dir_path, "train", "0") list_files = os.listdir(train_data) print("Train Directory Contents:", list_files) data_files = [f for f in list_files if f.endswith(".json") and f.startswith("local_data_")] diff --git a/src/test/simple_env_coq_test.py b/src/test/simple_env_coq_test.py new file mode 100644 index 0000000..77c5108 --- /dev/null +++ b/src/test/simple_env_coq_test.py @@ -0,0 +1,131 @@ +import unittest +import os + +def pretty_print(s1, s2, proof_step, done): + print(f"Current Goal:") + print('-'*30) + for goal in s1.training_data_format.start_goals: + hyps = '\n'.join([hyp for hyp in goal.hypotheses]) + print(hyps) + print('|- ', end='') + print(goal.goal) + print(f'*'*30) + print(f"="*30) + print(f"Action: {proof_step}") + print(f"="*30) + print(f"Next Goal:") + print('-'*30) + if s2 is not None: + for goal in s2.training_data_format.start_goals: + hyps = '\n'.join([hyp for hyp in goal.hypotheses]) + print(hyps) + print('|- ', end='') + print(goal.goal) + print(f'*'*30) + print(f"="*30) + print(f"DONE: {done}") + print('-'*30) + if s2 is None and done: + print("No more goals. Proof Finished!") + +class CoqHelper(): + def __init__(self): + self.current_switch = None + + def build_coq_project(self, project_folder): + try: + with os.popen("opam switch show") as proc: + self.current_switch = proc.read().strip() + except: + self.current_switch = None + # Check if the switch exists + # opam switch create simple_grp_theory 4.14.2 + if os.system("opam switch simple_grp_theory") != 0: + cmds = [ + 'opam switch create simple_grp_theory 4.14.2', + 'opam switch simple_grp_theory', + 'eval $(opam env)', + 'opam repo add coq-released https://coq.inria.fr/opam/released', + 'opam pin add -y coq-lsp 0.1.8+8.18' + ] + final_cmd = ' && '.join(cmds) + os.system(final_cmd) + # IMPORTANT NOTE: Make sure to switch to the correct switch before running the code. + os.system("opam switch simple_grp_theory && eval $(opam env)") + # Clean the project + os.system(f"eval $(opam env) && cd {project_folder} && make clean") + # Build the project + with os.popen(f"eval $(opam env) && cd {project_folder} && make") as proc: + print("Building Coq project...") + print('-'*15 + 'Build Logs' + '-'*15) + print(proc.read()) + print('-'*15 + 'End Build Logs' + '-'*15) + + def switch_to_current_switch(self): + if self.current_switch is not None: + try: + proc = os.popen(f"opam switch {self.current_switch} && eval $(opam env)") + print(proc.read()) + finally: + proc.close() + + +class CoqTest(unittest.TestCase): + def test_simple_coq(self): + from itp_interface.rl.proof_state import ProofState + from itp_interface.rl.proof_action import ProofAction + from itp_interface.rl.simple_proof_env import ProofEnv + from itp_interface.tools.proof_exec_callback import ProofExecutorCallback + from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy + project_folder = "src/data/test/coq/custom_group_theory/theories" + file_path = "src/data/test/coq/custom_group_theory/theories/grpthm.v" + # Build the project + # cd src/data/test/coq/custom_group_theory/theories && make + helper = CoqHelper() + helper.build_coq_project(project_folder) + language = ProofAction.Language.COQ + theorem_name = "algb_identity_sum" + # Theorem algb_identity_sum : + # forall a, algb_add a e = a. + proof_exec_callback = ProofExecutorCallback( + project_folder=project_folder, + file_path=file_path, + language=language, + always_use_retrieval=False, + keep_local_context=True + ) + always_retrieve_thms = False + retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK + env = ProofEnv("test_coq", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms) + proof_steps = [ + 'intros.', + 'destruct a.', + '- reflexivity.', + '- reflexivity.' + ] + with env: + for proof_step in proof_steps: + state, _, next_state, _, done, info = env.step(ProofAction( + ProofAction.ActionType.RUN_TACTIC, + language, + tactics=[proof_step])) + if info.error_message is not None: + print(f"Error: {info.error_message}") + # This prints StateChanged, StateUnchanged, Failed, or Done + print(info.progress) + print('-'*30) + if done: + print("Proof Finished!!") + else: + s1 : ProofState = state + s2 : ProofState = next_state + pretty_print(s1, s2, proof_step, done) + helper.switch_to_current_switch() + + +def main(): + unittest.main() + + +if __name__ == '__main__': + main() diff --git a/src/test/simple_env_test.py b/src/test/simple_env_lean_test.py similarity index 81% rename from src/test/simple_env_test.py rename to src/test/simple_env_lean_test.py index 8306fb5..abc6d40 100644 --- a/src/test/simple_env_test.py +++ b/src/test/simple_env_lean_test.py @@ -1,6 +1,6 @@ import unittest import os -from itp_interface.tools.tactic_parser import build_lean4_project, build_tactic_parser_if_needed +from itp_interface.lean.tactic_parser import build_lean4_project, build_tactic_parser_if_needed def pretty_print(s1, s2, proof_step, done): print(f"Current Goal:") @@ -29,10 +29,7 @@ def pretty_print(s1, s2, proof_step, done): if s2 is None and done: print("No more goals. Proof Finished!") -class Helper(): - def __init__(self): - self.current_switch = None - +class LeanHelper(): def build_lean4_project(self, project_folder): build_tactic_parser_if_needed() # Build the project @@ -41,43 +38,6 @@ def build_lean4_project(self, project_folder): build_lean4_project(project_folder) - def build_coq_project(self, project_folder): - try: - with os.popen("opam switch show") as proc: - self.current_switch = proc.read().strip() - except: - self.current_switch = None - # Check if the switch exists - # opam switch create simple_grp_theory 4.14.2 - if os.system("opam switch simple_grp_theory") != 0: - cmds = [ - 'opam switch create simple_grp_theory 4.14.2', - 'opam switch simple_grp_theory', - 'eval $(opam env)', - 'opam repo add coq-released https://coq.inria.fr/opam/released', - 'opam pin add -y coq-lsp 0.1.8+8.18' - ] - final_cmd = ' && '.join(cmds) - os.system(final_cmd) - # IMPORTANT NOTE: Make sure to switch to the correct switch before running the code. - os.system("opam switch simple_grp_theory && eval $(opam env)") - # Clean the project - os.system(f"eval $(opam env) && cd {project_folder} && make clean") - # Build the project - with os.popen(f"eval $(opam env) && cd {project_folder} && make") as proc: - print("Building Coq project...") - print('-'*15 + 'Build Logs' + '-'*15) - print(proc.read()) - print('-'*15 + 'End Build Logs' + '-'*15) - - def switch_to_current_switch(self): - if self.current_switch is not None: - try: - proc = os.popen(f"opam switch {self.current_switch} && eval $(opam env)") - print(proc.read()) - finally: - proc.close() - class Lean4Test(unittest.TestCase): def test_simple_lean4(self): from itp_interface.rl.proof_state import ProofState @@ -89,7 +49,7 @@ def test_simple_lean4(self): file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean" # Build the project # cd src/data/test/lean4_proj && lake build - helper = Helper() + helper = LeanHelper() helper.build_lean4_project(project_folder) language = ProofAction.Language.LEAN4 theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"test3\"}' @@ -119,8 +79,8 @@ def test_simple_lean4(self): proof_was_finished = False for proof_step in proof_steps: state, _, next_state, _, done, info = env.step(ProofAction( - ProofAction.ActionType.RUN_TACTIC, - language, + ProofAction.ActionType.RUN_TACTIC, + language, tactics=[proof_step])) if info.error_message is not None: print(f"Error: {info.error_message}") @@ -145,7 +105,7 @@ def test_lean4_backtracking(self): project_folder = "src/data/test/lean4_proj" file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean" # Build the project - helper = Helper() + helper = LeanHelper() helper.build_lean4_project(project_folder) language = ProofAction.Language.LEAN4 theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"test3\"}' @@ -176,20 +136,20 @@ def test_lean4_backtracking(self): print(f"Backtracking at step {idx + 1} i.e. {proof_step}") state, _, next_state, _, done, info = env.step( ProofAction( - ProofAction.ActionType.BACKTRACK, + ProofAction.ActionType.BACKTRACK, language)) assert next_state == prev_state, "Backtracking failed" # Replay the last action last_proof_step = proof_steps[idx-1] state, _, next_state, _, done, info = env.step( ProofAction( - ProofAction.ActionType.RUN_TACTIC, - language, + ProofAction.ActionType.RUN_TACTIC, + language, tactics=[last_proof_step])) state, _, next_state, _, done, info = env.step( ProofAction( - ProofAction.ActionType.RUN_TACTIC, - language, + ProofAction.ActionType.RUN_TACTIC, + language, tactics=[proof_step])) prev_state = state if done: @@ -197,57 +157,6 @@ def test_lean4_backtracking(self): proof_was_finished = True assert proof_was_finished, "Proof was not finished" - def test_simple_coq(self): - from itp_interface.rl.proof_state import ProofState - from itp_interface.rl.proof_action import ProofAction - from itp_interface.rl.simple_proof_env import ProofEnv - from itp_interface.tools.proof_exec_callback import ProofExecutorCallback - from itp_interface.rl.simple_proof_env import ProofEnvReRankStrategy - project_folder = "src/data/test/coq/custom_group_theory/theories" - file_path = "src/data/test/coq/custom_group_theory/theories/grpthm.v" - # Build the project - # cd src/data/test/coq/custom_group_theory/theories && make - helper = Helper() - helper.build_coq_project(project_folder) - language = ProofAction.Language.COQ - theorem_name = "algb_identity_sum" - # Theorem algb_identity_sum : - # forall a, algb_add a e = a. - proof_exec_callback = ProofExecutorCallback( - project_folder=project_folder, - file_path=file_path, - language=language, - always_use_retrieval=False, - keep_local_context=True - ) - always_retrieve_thms = False - retrieval_strategy = ProofEnvReRankStrategy.NO_RE_RANK - env = ProofEnv("test_coq", proof_exec_callback, theorem_name, retrieval_strategy=retrieval_strategy, max_proof_depth=10, always_retrieve_thms=always_retrieve_thms) - proof_steps = [ - 'intros.', - 'destruct a.', - '- reflexivity.', - '- reflexivity.' - ] - with env: - for proof_step in proof_steps: - state, _, next_state, _, done, info = env.step(ProofAction( - ProofAction.ActionType.RUN_TACTIC, - language, - tactics=[proof_step])) - if info.error_message is not None: - print(f"Error: {info.error_message}") - # This prints StateChanged, StateUnchanged, Failed, or Done - print(info.progress) - print('-'*30) - if done: - print("Proof Finished!!") - else: - s1 : ProofState = state - s2 : ProofState = next_state - pretty_print(s1, s2, proof_step, done) - helper.switch_to_current_switch() - def test_simple_lean_calc(self): from itp_interface.rl.proof_state import ProofState from itp_interface.rl.proof_action import ProofAction @@ -258,7 +167,7 @@ def test_simple_lean_calc(self): file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean" # Build the project # cd src/data/test/lean4_proj && lake build - helper = Helper() + helper = LeanHelper() helper.build_lean4_project(project_folder) language = ProofAction.Language.LEAN4 theorem_name = "{\"namespace\":\"Lean4Proj1\",\"name\":\"test_calc\"}" @@ -290,8 +199,8 @@ def test_simple_lean_calc(self): proof_was_finished = False for proof_step in proof_steps: state, _, next_state, _, done, info = env.step(ProofAction( - ProofAction.ActionType.RUN_TACTIC, - language, + ProofAction.ActionType.RUN_TACTIC, + language, tactics=[proof_step])) if info.error_message is not None: print(f"Error: {info.error_message}") @@ -319,7 +228,7 @@ def test_simple_lean_calc_with_validation(self): file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean" # Build the project # cd src/data/test/lean4_proj && lake build - helper = Helper() + helper = LeanHelper() helper.build_lean4_project(project_folder) language = ProofAction.Language.LEAN4 theorem_name = "{\"namespace\":\"Lean4Proj1\",\"name\":\"test_calc\"}" @@ -351,8 +260,8 @@ def test_simple_lean_calc_with_validation(self): proof_was_finished = False for proof_step in proof_steps: state, _, next_state, _, done, info = env.step(ProofAction( - ProofAction.ActionType.RUN_TACTIC, - language, + ProofAction.ActionType.RUN_TACTIC, + language, tactics=[proof_step])) if info.error_message is not None: print(f"Error: {info.error_message}") @@ -386,7 +295,7 @@ def test_simple_lean_enforce_done_test(self): file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean" # Build the project # cd src/data/test/lean4_proj && lake build - helper = Helper() + helper = LeanHelper() helper.build_lean4_project(project_folder) language = ProofAction.Language.LEAN4 theorem_name = "{\"namespace\":\"Lean4Proj1\",\"name\":\"test_calc\"}" @@ -420,8 +329,8 @@ def test_simple_lean_enforce_done_test(self): proof_finished = False for proof_step in proof_steps: state, _, next_state, _, done, info = env.step(ProofAction( - ProofAction.ActionType.RUN_TACTIC, - language, + ProofAction.ActionType.RUN_TACTIC, + language, tactics=[proof_step])) if info.error_message is not None: print(f"Error: {info.error_message}") @@ -450,7 +359,7 @@ def test_simple_lean4_done_test(self): file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean" # Build the project # cd src/data/test/lean4_proj && lake build - helper = Helper() + helper = LeanHelper() helper.build_lean4_project(project_folder) language = ProofAction.Language.LEAN4 theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"test3\"}' @@ -477,8 +386,8 @@ def test_simple_lean4_done_test(self): with env: for proof_step in proof_steps: state, _, next_state, _, done, info = env.step(ProofAction( - ProofAction.ActionType.RUN_TACTIC, - language, + ProofAction.ActionType.RUN_TACTIC, + language, tactics=[proof_step])) if info.error_message is not None: print(f"Error: {info.error_message}") @@ -502,7 +411,7 @@ def test_simple_lean4_have_test(self): file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean" # Build the project # cd src/data/test/lean4_proj && lake build - helper = Helper() + helper = LeanHelper() helper.build_lean4_project(project_folder) language = ProofAction.Language.LEAN4 theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"imo_1959_p1\"}' @@ -540,8 +449,8 @@ def test_simple_lean4_have_test(self): env.set_max_proof_step_length(10000) for proof_step in proof_steps: state, m_action, next_state, _, done, info = env.step(ProofAction( - ProofAction.ActionType.RUN_TACTIC, - language, + ProofAction.ActionType.RUN_TACTIC, + language, tactics=[proof_step])) if info.error_message is not None: print(f"Error: {info.error_message}") @@ -560,7 +469,7 @@ def test_simple_lean4_have_test(self): s1 : ProofState = state s2 : ProofState = next_state pretty_print(s1, s2, proof_step, done) - + def test_simple_lean4_with_error(self): from itp_interface.rl.proof_state import ProofState from itp_interface.rl.proof_action import ProofAction @@ -571,7 +480,7 @@ def test_simple_lean4_with_error(self): file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean" # Build the project # cd src/data/test/lean4_proj && lake build - helper = Helper() + helper = LeanHelper() helper.build_lean4_project(project_folder) language = ProofAction.Language.LEAN4 theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"test3\"}' @@ -599,8 +508,8 @@ def test_simple_lean4_with_error(self): with env: for i, proof_step in enumerate(proof_steps): state, _, next_state, _, done, info = env.step(ProofAction( - ProofAction.ActionType.RUN_TACTIC, - language, + ProofAction.ActionType.RUN_TACTIC, + language, tactics=[proof_step])) if info.error_message is not None: print(f"Error: {info.error_message}") @@ -631,7 +540,7 @@ def test_simple_lean4_multiline_multigoal(self): file_path = "src/data/test/lean4_proj/Lean4Proj/Basic.lean" # Build the project # cd src/data/test/lean4_proj && lake build - helper = Helper() + helper = LeanHelper() helper.build_lean4_project(project_folder) language = ProofAction.Language.LEAN4 theorem_name = '{\"namespace\":\"Lean4Proj2\",\"name\":\"complicated_have\"}' @@ -659,8 +568,8 @@ def test_simple_lean4_multiline_multigoal(self): proof_was_finished = False for proof_step in proof_steps: state, action, next_state, _, done, info = env.step(ProofAction( - ProofAction.ActionType.RUN_TACTIC, - language, + ProofAction.ActionType.RUN_TACTIC, + language, tactics=[proof_step])) proof_step = action.kwargs.get('tactics', ['INVALID'])[0] if info.error_message is not None: @@ -677,20 +586,10 @@ def test_simple_lean4_multiline_multigoal(self): pretty_print(s1, s2, proof_step, done) assert proof_was_finished, "Proof was not finished" + def main(): - # unittest.main() - # Run only the Lean 4 tests - t = Lean4Test() - # t.test_simple_lean4_multiline_multigoal() - # t.test_simple_lean4() - # t.test_lean4_backtracking() - # t.test_simple_lean4_done_test() - # t.test_simple_lean_calc() - # t.test_simple_lean_calc_with_validation() - # t.test_simple_lean4_with_error() - t.test_simple_lean4_have_test() - # t.test_simple_lean_enforce_done_test() + unittest.main() if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/src/test/test_tactic_parser.py b/src/test/test_tactic_parser.py index cb79825..d988769 100644 --- a/src/test/test_tactic_parser.py +++ b/src/test/test_tactic_parser.py @@ -11,7 +11,7 @@ # Add parent directory to path to import tactic_parser sys.path.insert(0, str(Path(__file__).parent.parent / "itp_interface" / "tools")) -from itp_interface.tools.tactic_parser import TacticParser, print_tactics +from itp_interface.lean.tactic_parser import TacticParser, print_tactics project_path = str(Path(__file__).parent.parent / "data" / "test" / "lean4_proj")