Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ dev = [
"snakefmt==0.11.0",
"pytest==8.4.1",
]
ai = [
"openai==1.88.0",
]

[project.urls]
"Homepage" = "https://github.com/sunbeam-labs/sunbeam"
Expand Down
2 changes: 1 addition & 1 deletion sunbeam/bfx/decontam.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_mapped_reads(fp: str, min_pct_id: float, min_len_frac: float) -> Iterato


def _get_pct_identity(
read: Dict[str, Union[int, float, str, Tuple[int, str]]]
read: Dict[str, Union[int, float, str, Tuple[int, str]]],
) -> float:
edit_dist = read.get("NM", 0)
pct_mm = float(edit_dist) / len(read["SEQ"])
Expand Down
20 changes: 20 additions & 0 deletions sunbeam/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,26 @@
from pathlib import Path


class StreamToLogger:
def __init__(self, logger, level=logging.INFO):
self.logger = logger
self.level = level
self.buffer = ""

def write(self, message):
if message != "\n":
self.buffer += message
if "\n" in self.buffer:
for line in self.buffer.splitlines():
self.logger.log(self.level, line.strip())
self.buffer = ""

def flush(self):
if self.buffer:
self.logger.log(self.level, self.buffer.strip())
self.buffer = ""


class ConditionalLevelFormatter(logging.Formatter):
def format(self, record):
# For WARNING and above, include "LEVELNAME: message"
Expand Down
70 changes: 65 additions & 5 deletions sunbeam/scripts/run.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,57 @@
import argparse
import contextlib
import datetime
import logging
import os
import sys
from pathlib import Path
from snakemake.cli import main as snakemake_main
from sunbeam import __version__
from sunbeam.logging import get_pipeline_logger
from sunbeam.logging import get_pipeline_logger, StreamToLogger


def analyze_run(log: str, logger: logging.Logger, ai: bool) -> None:
"""Analyze the run log and provide insights or suggestions."""
# We could do some rule-based analysis here but I'd rather lean into the AI features and see how far they can take us
if ai:
try:
import openai
except ImportError: # pragma: no cover - this is a soft dependency
logger.error(
"AI analysis requested, but the 'openai' package is not installed. Try `pip install -e sunbeamlib[ai]`.\n"
)
return

api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
logger.error("OPENAI_API_KEY not set; skipping AI analysis.\n")
return

try:
client = openai.OpenAI(api_key=api_key)
resp = client.chat.completions.create(
model="gpt-4.1-nano",
messages=[
{
"role": "system",
"content": "You diagnose errors from Sunbeam pipeline runs. If there are problems, suggest possible causes and solutions. Keep the answer short and sweet. If there are relevant file paths for debugging (like log files), mention them.",
},
{
"role": "user",
"content": f"Sunbeam ran with the following output:\n{log}\n",
},
],
max_tokens=1500,
)
logger.info(
"\n\nAI diagnosis:\n"
+ resp.choices[0].message.content
+ "\nCheck out the Sunbeam documentation (https://sunbeam.readthedocs.io/en/stable/) and the GitHub issues page (https://github.com/sunbeam-labs/sunbeam/issues) for more information or to open a new issue.\n"
)
except (
Exception
) as exc: # pragma: no cover - network errors are non-deterministic
logger.error(f"AI analysis failed: {exc}\n")


def main(argv: list[str] = sys.argv):
Expand All @@ -27,6 +73,8 @@ def main(argv: list[str] = sys.argv):
# You could argue it would make more sense to start this at the actual snakemake call
# but this way we can log some relevant setup information that might be useful on post-mortem analysis
logger = get_pipeline_logger(log_file)
logger.debug("Sunbeam pipeline logger initialized.")
print(log_file.exists())

snakefile = Path(__file__).parent.parent / "workflow" / "Snakefile"
if not snakefile.exists():
Expand Down Expand Up @@ -73,10 +121,17 @@ def main(argv: list[str] = sys.argv):
logger.info("Running: " + " ".join(snakemake_args))

try:
snakemake_main(snakemake_args)
except Exception as e:
logger.exception("An error occurred while running Sunbeam")
sys.exit(1)
stream_logger = StreamToLogger(logger, level=logging.INFO)

with contextlib.redirect_stderr(stream_logger):
snakemake_main(snakemake_args)
finally:
# Show all files in log_file directory
print(list(log_file.parent.glob("*")))
print(log_file)
print(log_file.exists())
with open(log_file, "r") as f:
analyze_run(f.read(), logger, args.ai)


def main_parser():
Expand Down Expand Up @@ -128,6 +183,11 @@ def main_parser():
default=__version__,
help="The tag to use when pulling docker images for the core pipeline environments, defaults to sunbeam's current version, a good alternative is 'latest' for the latest stable release",
)
parser.add_argument(
"--ai",
action="store_true",
help="Use OpenAI to diagnose failures after the run",
)
parser.add_argument(
"--log_file",
default=None,
Expand Down
1 change: 1 addition & 0 deletions sunbeam/workflow/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ from sunbeam.project.post import compile_benchmarks


logger = get_pipeline_logger()
logger.debug("Sunbeam pipeline starting...")

MIN_MEM_MB = int(os.getenv("SUNBEAM_MIN_MEM_MB", 8000))
MIN_RUNTIME = int(os.getenv("SUNBEAM_MIN_RUNTIME", 15))
Expand Down
47 changes: 47 additions & 0 deletions tests/e2e/test_sunbeam_run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
import subprocess as sp
import sys
import types
from sunbeam.scripts.init import main as Init
from sunbeam.scripts.run import main as Run

Expand Down Expand Up @@ -113,3 +115,48 @@ def test_sunbeam_run_with_target_after_exclude(tmp_path, DATA_DIR, capsys):
assert ret.returncode == 0
assert "clean_qc" in ret.stderr.decode("utf-8")
assert "filter_reads" not in ret.stderr.decode("utf-8")


def test_sunbeam_run_ai_option(tmp_path, monkeypatch, DATA_DIR):
project_dir = tmp_path / "test"

called = {"flag": False}

def dummy_create(**kwargs):
called["flag"] = True
return types.SimpleNamespace(
choices=[
types.SimpleNamespace(message=types.SimpleNamespace(content="analysis"))
]
)

fake_openai = types.SimpleNamespace()

fake_openai.OpenAI = lambda *args, **kwargs: types.SimpleNamespace(
chat=types.SimpleNamespace(
completions=types.SimpleNamespace(create=dummy_create)
)
)

monkeypatch.setitem(sys.modules, "openai", fake_openai)
monkeypatch.setenv("OPENAI_API_KEY", "token")

Init(
[
str(project_dir),
"--data_fp",
str(DATA_DIR / "reads"),
]
)

with pytest.raises(SystemExit):
Run(
[
"--profile",
str(project_dir),
"--ai",
"-n",
]
)

assert called["flag"]
Loading