diff --git a/api/llm_client.py b/api/llm_client.py index 31b9b23..3f185d5 100644 --- a/api/llm_client.py +++ b/api/llm_client.py @@ -1,41 +1,44 @@ -"""LLM client wrapper for API chat functionality using Google Gemini.""" - import os import time -from typing import Any, Dict, List +import requests +import json +from typing import Dict, List +from dotenv import load_dotenv -try: - import google.generativeai as genai -except ModuleNotFoundError: - genai = None +SYSTEM_PROMPTS = { + "default": """You are BAIO, an expert bioinformatics assistant specialized in DNA sequence classification and pathogen detection. + +You help researchers: +- Understand classification results (Virus vs Host predictions) +- Interpret confidence scores and risk levels +- Explain k-mer analysis and model predictions +- Provide guidance on next steps for validation + +Be concise, helpful, and scientific in your responses. Use emojis sparingly.""", + "analysis_helper": "You are analyzing metagenomic sequencing data with BAIO. Help interpret the classification results and suggest next steps.", + "technical_expert": "You are a technical expert on BAIO's architecture, focusing on RandomForest models, k-mer tokenization, and TF-IDF features.", +} class LLMClient: - """Wrapper class for Google Gemini LLM API calls.""" - def __init__(self, provider: str = "google", model: str = "gemini-1.5-flash"): - self.provider = provider + def __init__(self, model: str = "liquid/lfm-2.5-1.2b-instruct:free"): + load_dotenv() self.model = model - self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") - self.client: Any = None - - if self.api_key and genai is not None: - try: - genai.configure(api_key=self.api_key) - self.client = genai.GenerativeModel(model) - except Exception as e: - print(f"Failed to initialize Gemini: {e}", flush=True) - elif self.api_key and genai is None: + self.api_key = os.getenv("OPENROUTER_API_KEY") + if self.api_key is None: print( - "google-generativeai is not installed; falling back to mock responses.", + "OpenRouter api key not found; falling back to mock responses.", flush=True, ) def generate_response( - self, messages: List[Dict[str, str]], system_prompt: str + self, + messages: List[Dict[str, str]], + system_prompt: str = SYSTEM_PROMPTS["default"], ) -> str: """ - Generate response from Gemini LLM. + Generate response from the LLM or fallback to mock if API key missing/error. Args: messages: List of conversation messages @@ -44,25 +47,39 @@ def generate_response( Returns: Generated response text """ - try: - if self.client is None: - return self._mock_response(messages) - history = [] - for msg in messages[:-1]: - role = "user" if msg["role"] == "user" else "model" - history.append({"role": role, "parts": [msg["content"]]}) + if self.api_key is None: + return "OpenRouter api key not found; falling back to mock responses." - chat = self.client.start_chat(history=history) + # Build the payload for OpenRouter API + payload = { + "model": self.model, + "messages": [ + {"role": "system", "content": system_prompt}, + *messages, + ], + } - last_message = messages[-1]["content"] if messages else "" - full_prompt = f"{system_prompt}\n\nUser: {last_message}" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } - response = chat.send_message(full_prompt) - return response.text + try: + response = requests.post( + url="https://openrouter.ai/api/v1/chat/completions", + headers=headers, + data=json.dumps(payload), + timeout=10, # seconds + ) + response.raise_for_status() + data = response.json() + + # Extract the assistant's reply + return data["choices"][0]["message"]["content"] except Exception as e: - print(f"Gemini API error: {str(e)}") + print(f"API error: {e}. Falling back to mock response.", flush=True) return self._mock_response(messages) def _mock_response(self, messages: List[Dict[str, str]]) -> str: @@ -90,18 +107,3 @@ def _mock_response(self, messages: List[Dict[str, str]]) -> str: "virus/host detection, confidence scores, and the analysis pipeline. " "What would you like to know?" ) - - -SYSTEM_PROMPTS = { - "default": """You are BAIO, an expert bioinformatics assistant specialized in DNA sequence classification and pathogen detection. - -You help researchers: -- Understand classification results (Virus vs Host predictions) -- Interpret confidence scores and risk levels -- Explain k-mer analysis and model predictions -- Provide guidance on next steps for validation - -Be concise, helpful, and scientific in your responses. Use emojis sparingly.""", - "analysis_helper": "You are analyzing metagenomic sequencing data with BAIO. Help interpret the classification results and suggest next steps.", - "technical_expert": "You are a technical expert on BAIO's architecture, focusing on RandomForest models, k-mer tokenization, and TF-IDF features.", -} diff --git a/environment.yml b/environment.yml index ec2c592..fd20722 100644 --- a/environment.yml +++ b/environment.yml @@ -2,67 +2,77 @@ name: baio channels: - conda-forge - bioconda + - defaults dependencies: - # Core + # === Core === - python>=3.12,<3.13 - uvicorn + - requests + - python-dotenv>=1.0 + - click + - rich + - jsonschema + - python-json-logger>=3.0 + - pyyaml>=6.0 - # Numeric stack compatible with Py3.12 + # === Numeric / Scientific stack === - numpy=2.2 - scipy=1.14 - numba=0.61 - llvmlite=0.44 - - # ML / DS - - scikit-learn=1.5 - pandas=2.2 - matplotlib=3.9 - seaborn=0.13 - plotly - joblib>=1.3 - # Torch (CPU/MPS on Apple Silicon) + # === ML / AI === + - scikit-learn=1.5 - pytorch=2.5.1 - # I/O & storage - - pyarrow=17 + # === Bioinformatics / Clustering === + - biopython>=1.85 + - hdbscan=0.8.39 + - umap-learn=0.5.7 + + # === I/O / Storage === + - pyarrow>=17.0 - h5py - zarr - # App / tooling - - tqdm - - click - - python-dotenv - - pyyaml - - jsonschema - - rich - - requests - - # Dev / QA - - pytest - - pytest-cov - - black - - flake8 - - isort - - mypy - - # Jupyter + # === Jupyter / Notebook === - jupyter - jupyterlab - ipywidgets + - ipykernel>=6.0 - # Install without conflicts - - hdbscan=0.8.39 - - umap-learn==0.5.7 + # === Dev / QA === + - pytest>=8.0 + - pytest-cov>=4.1 + - pytest-mock>=3.12 + - pytest-asyncio>=0.23 + - black>=24.0 + - flake8>=7.0 + - isort>=5.13 + - mypy>=1.10 + - watchdog>=3.0 - # Install via pip + # === Pip-only packages === - pip - pip: + # runtime / HuggingFace - transformers==4.56.1 - tokenizers==0.22.0 - accelerate>=0.30 - datasets>=2.19 - - types-requests>=2.31.0 - fastapi>=0.115.0 - - pandas-stubs==2.3.2.250926 \ No newline at end of file + + # dev / profiling / type stubs + - line-profiler>=4.1 + - memory-profiler>=0.61 + - pandas-stubs==2.3.2.250926 + - types-requests>=2.31.0 + - types-PyYAML>=6.0 + - types-jsonschema + - pre-commit>=3.3 \ No newline at end of file diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 96ab165..26b6a3f 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -1566,9 +1566,9 @@ } }, "node_modules/ajv": { - "version": "6.12.6", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", - "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "version": "6.14.0", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.14.0.tgz", + "integrity": "sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw==", "dev": true, "license": "MIT", "dependencies": { diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 4078e89..e1bbb7e 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -5,6 +5,7 @@ import Header from './components/Header' import SequenceInput from './components/SequenceInput' import ConfigPanel from './components/ConfigPanel' import ResultsDashboard from './components/ResultsDashboard' +import ChatWidget from './components/ChatWidget' import type { ChatMessage, ClassificationResponse, @@ -73,10 +74,11 @@ function App() { const [chatMessages, setChatMessages] = useState([ { - role: 'assistant', - content: + "role": 'assistant', + "content": 'Hi! Paste FASTA sequences, run classification, and ask questions here.', }, + ]) const [chatInput, setChatInput] = useState('') const [chatLoading, setChatLoading] = useState(false) @@ -167,11 +169,6 @@ function App() { healthOk={healthOk} darkMode={darkMode} toggleDarkMode={() => setDarkMode(!darkMode)} - chatMessages={chatMessages} - chatInput={chatInput} - onChatInputChange={setChatInput} - onChatSend={handleChatSend} - chatLoading={chatLoading} /> {error && ( @@ -261,6 +258,14 @@ function App() { + + ) } diff --git a/frontend/src/types.ts b/frontend/src/types.ts index e4252b5..2f6f9b7 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -40,7 +40,7 @@ export type ClassificationResponse = { } export type ChatMessage = { - role: 'user' | 'assistant' + role: 'user' | 'system' | 'assistant' content: string } diff --git a/metaseq/train.py b/metaseq/train.py index 086b07f..b26d35f 100644 --- a/metaseq/train.py +++ b/metaseq/train.py @@ -2,7 +2,7 @@ from typing import List, Tuple, Optional, Dict, Any import os import json -import yaml # type: ignore[import-untyped] +import yaml import pandas as pd from sklearn.model_selection import train_test_split # type: ignore[import-untyped] from sklearn.metrics import classification_report # type: ignore[import-untyped] diff --git a/pyproject.toml b/pyproject.toml index 8e2cdb5..bb8bbfd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,22 +11,34 @@ license = {text = "MIT"} requires-python = ">=3.10" dependencies = [ - # ---- Core / Python scientific stack (match environment.yml) ---- + # ---- Core / Python scientific stack ---- "numpy==2.2.*", "scipy==1.14.*", "numba==0.61.*", "llvmlite==0.44.*", - # ML / DS + # ML / Data Science "scikit-learn==1.5.*", "pandas==2.2.*", "matplotlib==3.9.*", "seaborn==0.13.*", "plotly", + "joblib>=1.3", # Torch (CPU/MPS) "torch==2.8.0", + # HuggingFace / ML + "transformers==4.56.1", + "tokenizers==0.22.0", + "accelerate>=0.30", + "datasets>=2.19", + + # Bioinformatics + "biopython>=1.85", + "hdbscan==0.8.39", + "umap-learn==0.5.7", + # I/O & storage "pyarrow==17.*", "h5py", @@ -35,44 +47,51 @@ dependencies = [ # App / tooling "tqdm", "click", - "python-dotenv", - "pyyaml", + "python-dotenv>=1.0", + "pyyaml>=6.0", "jsonschema", "rich", - - # Clustering / DR - "hdbscan==0.8.39", - "umap-learn==0.5.7", - - # Jupyter - "jupyter", - "jupyterlab", - "ipywidgets", - - # Pip-only libs from env.yml - "transformers==4.56.1", - "tokenizers==0.22.0", - "accelerate>=0.30", - "datasets>=2.19", - - # --- add stubs for mypy --- - "pandas-stubs==2.3.2.250926", - "types-requests>=2.31.0", - + "requests>=2.32", + "python-json-logger>=3.0", # FastAPI "fastapi>=0.115.0", ] [project.optional-dependencies] +# Development dependencies dev = [ - "pytest", - "pytest-cov", - "black", - "flake8", - "isort", - "mypy", + # Testing + "pytest>=8.0", + "pytest-cov>=4.1", + "pytest-mock>=3.12", + "pytest-asyncio>=0.23", + + # Linting / formatting + "black>=24.0", + "flake8>=7.0", + "isort>=5.13", + "mypy>=1.10", "flake8-pyproject", + + # Type stubs + "pandas-stubs==2.3.2.250926", + "types-requests>=2.31.0", + "types-PyYAML>=6.0", + "types-jsonschema", + + # Dev utilities + "watchdog>=3.0", + "line-profiler>=4.1", + "memory-profiler>=0.61", +] + +# Notebook / interactive environment +notebook = [ + "jupyter>=1.0", + "jupyterlab>=4.0", + "ipywidgets", + "ipykernel>=6.0", ] [project.urls] @@ -84,7 +103,6 @@ Documentation = "https://github.com/oss-slu/baio/blob/main/docs/" [tool.setuptools] packages = ["metaseq", "prompting", "data_processing"] - [tool.setuptools.package-data] metaseq = ["*.yml", "*.yaml", "*.json"] @@ -95,7 +113,6 @@ target-version = ['py312'] include = '\\.pyi?$' extend-exclude = ''' /( - # directories \.eggs | \.git | \.hg @@ -117,7 +134,7 @@ known_first_party = ["metaseq", "app"] [tool.flake8] max-line-length = 88 -extend-ignore = ["E203", "W503", "E501"] # Add E501 here +extend-ignore = ["E203", "W503", "E501"] exclude = [ ".git", "__pycache__", @@ -147,7 +164,7 @@ explicit_package_bases = true [[tool.mypy.overrides]] module = [ "biopython.*", - "pysam.*", + "pysam.*", "hdbscan.*", "umap.*", "plotly.*", @@ -209,4 +226,4 @@ exclude_lines = [ 'if __name__ == .__main__.:', 'class .*\bProtocol\b.*:', '@(abc\\.)?abstractmethod', -] +] \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index ad18e3e..f90b577 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,32 +1,35 @@ -# Include runtime deps --r requirements.txt +# NOTE +# Should not be necessary if environment activated correctly -# === Testing === -pytest>=8.0 -pytest-cov>=4.1 -pytest-mock>=3.12 -pytest-asyncio>=0.23 +# -r requirements.txt -# === Linting / Code Quality === -black>=24.0 -flake8>=7.0 -isort>=5.13 -mypy>=1.10 -pandas-stubs>=2.3 -types-requests>=2.31 -types-PyYAML>=6.0 -types-jsonschema +# # === Testing === +# pytest>=8.0 +# pytest-cov>=4.1 +# pytest-mock>=3.12 +# pytest-asyncio>=0.23 -# === Dev Tools === -rich>=13.0 -watchdog>=3.0 -python-dotenv>=1.0 +# # === Linting / Code Quality === +# black>=24.0 +# flake8>=7.0 +# isort>=5.13 +# mypy>=1.10 +# pandas-stubs>=2.3 +# types-requests>=2.31 +# types-PyYAML>=6.0 +# types-jsonschema -# === Jupyter === -jupyter>=1.0 -jupyterlab>=4.0 -ipykernel>=6.0 +# # === Jupyter / Notebook === +# jupyter>=1.0 +# jupyterlab>=4.0 +# ipykernel>=6.0 +# ipywidgets -# === Profiling / Utilities === -line-profiler>=4.1 -memory-profiler>=0.61 +# # === Dev Utilities === +# rich>=13.0 +# watchdog>=3.0 +# python-dotenv>=1.0 + +# # === Profiling / Performance === +# line-profiler>=4.1 +# memory-profiler>=0.61 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 1b2fbef..50e4d1e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,31 +1,40 @@ -# === Core Runtime Dependencies === -python-dotenv>=1.0 -numpy>=2.2 -pandas>=2.2 -scikit-learn>=1.5 -matplotlib>=3.9 -seaborn>=0.13 -plotly>=5.20 -tqdm>=4.67 -pyyaml>=6.0 -requests>=2.32 +# NOTE +# Should not be necessary if environment activated correctly -# === Bioinformatics === -biopython>=1.85 -hdbscan>=0.8.39 -umap-learn==0.5.7 +# # === Core Runtime Dependencies === +# python-dotenv>=1.0 +# numpy>=2.2 +# pandas>=2.2 +# scikit-learn>=1.5 +# matplotlib>=3.9 +# seaborn>=0.13 +# plotly>=5.20 +# tqdm>=4.67 +# pyyaml>=6.0 +# requests>=2.32 +# python-json-logger>=3.0 +# joblib>=1.3 -# === HuggingFace / ML === -torch>=2.8.0 -transformers==4.56.1 -tokenizers==0.22.0 -accelerate>=0.30 -datasets>=2.19 -joblib>=1.3 +# # === Bioinformatics === +# biopython>=1.85 +# hdbscan>=0.8.39 +# umap-learn>=0.5.7 -# === Data I/O === -pyarrow>=17.0 -zarr>=2.17 +# # === ML / HuggingFace === +# torch>=2.8 +# transformers==4.56.1 +# tokenizers==0.22.0 +# accelerate>=0.30 +# datasets>=2.19 -# === Other Utilities === -python-json-logger>=3.0 +# # === I/O / Storage === +# pyarrow>=17.0 +# zarr>=2.17 +# h5py + +# # === Web / API === +# fastapi>=0.115.0 +# uvicorn>=0.22 +# click +# rich +# jsonschema \ No newline at end of file