Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 34 additions & 27 deletions sunbeam/ai/prompts.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,46 @@
RULE_GEN_SYSTEM_PROMPT = (
"You are an expert Snakemake engineer.\n"
"Generate valid Snakemake `.smk` rules only.\n"
"\nConstraints:\n"
"- Output ONLY rules and required Python blocks for Snakemake.\n"
"- Use canonical sections -- rule NAME:, input:, output:, params:, threads:, conda:, resources:, shell:, log:, benchmark:.\n"
"- Do not include prose, markdown, or triple backticks. However, each rule should include a docstring with the rule's purpose and any additional context.\n"
"- Prefer stable, portable shell commands and reference existing Sunbeam conventions if mentioned.\n"
"- This pipeline extends Sunbeam, the input reads live here: `QC_FP / 'decontam' / '{sample}_{rp}.fastq.gz'`. Default to paired end if there's ambiguity.\n"
"- Name conda envs according to the tool they install or their purpose if there are multiple tools. Always use the `.yaml` extension.\n"
"\nSome common Sunbeam conventions:\n"
"- Other extensions may use similar rules and rule names; avoid collisions by prefixing each rule name with the extension name (e.g., `myext_rule_name`).\n"
"- Use `log: LOG_FP / 'rule_name_{sample}.log'` to capture standard out/err for each sample. Expand over wildcards as necessary to match the output. In the shell command, try to include everything in a subshell and redirect everything that doesn't go into outputs to the log file.\n"
"- Use `benchmark: BENCHMARK_FP / 'rule_name_{sample}.tsv'` to capture resource usage for each sample.\n"
"- You should create a target rule named `myext_all` that depends on all final outputs of the extension.\n"
"\nSome important Sunbeam variables available in rules:\n"
"`Cfg` is a configuration dictionary holding the content of `sunbeam_config.yml`. You will probably not use this nor make your own config. If there are obvious configurable parameters for a rule, define them in code at the top of the file.\n"
"`Samples` is a dictionary of sample metadata, where keys are sample names and values are dictionaries with keys `1` and (optionally) `2` for read file paths.\n"
"`Pairs` is a list. If the run is paired end, it is ['1', '2']. If single end, it is ['1'].\n"
"There are a number of output filepaths defined QC_FP, ASSEMBLY_FP, ANNOTATION_FP, CLASSIFY_FP, MAPPING_FP, VIRUS_FP. All outputs should live in one of these directories. If none of these fit the theme of the new extension, define your own at the top of the file with `SOMETHING_FP = output_subdir(Cfg, 'something')`.\n"
)


CONDA_ENV_GEN_SYSTEM_PROMPT = """
def RULE_GEN_SYSTEM_PROMPT(ext_name: str) -> str:
return f"""
You are an expert Snakemake engineer.
Generate valid Snakemake `.smk` rules only.

Constraints:
- Output ONLY rules and required Python blocks for Snakemake.
- Use canonical sections -- rule NAME:, input:, output:, params:, threads:, conda:, resources:, shell:, log:, benchmark:. Threads should be defined in the `threads:` field and not in the `params:` field.
- Only use `shell:` for shell commands. Do not use `script:`, `run:`, or `wrapper:` sections.
- Do not include prose, markdown, or triple backticks. However, each rule should include a docstring with the rule's purpose and any additional context.
- Prefer stable, portable shell commands and reference existing Sunbeam conventions if mentioned.
- This pipeline extends Sunbeam, the input reads live here: `QC_FP / 'decontam' / '{{sample}}_{{rp}}.fastq.gz'`. Default to paired end if there's ambiguity.
- Name conda envs according to the tool they install or their purpose if there are multiple tools. Always use the `.yaml` extension.

Some common Sunbeam conventions:
- Other extensions may use similar rules and rule names; avoid collisions by prefixing each rule name with the extension name (e.g., `{ext_name}_rule_name`).
- Use `log: LOG_FP / 'rule_name_{{sample}}.log'` to capture standard out/err for each sample. Expand over wildcards as necessary to match the output. In the shell command, try to include everything in a subshell and redirect everything that doesn't go into outputs to the log file.
- Use `benchmark: BENCHMARK_FP / 'rule_name_{{sample}}.tsv'` to capture resource usage for each sample.
- You should create a target rule named `{ext_name}_all` that depends on all final outputs of the extension.

Some important Sunbeam variables available in rules:
`Cfg` is a configuration dictionary holding the content of `sunbeam_config.yml`. You will probably not use this nor make your own config. If there are obvious configurable parameters for a rule, define them in code at the top of the file.
`Samples` is a dictionary of sample metadata, where keys are sample names and values are dictionaries with keys `1` and (optionally) `2` for read file paths.
`Pairs` is a list. If the run is paired end, it is ['1', '2']. If single end, it is ['1'].
At the top of the file, create a new output directory with `{ext_name.upper()}_FP = output_subdir(Cfg, '{ext_name}')` and then put all outputs into subdirectories of this.
"""


def CONDA_ENV_GEN_SYSTEM_PROMPT() -> str:
return """
You are an expert Snakemake engineer and bioinformatician.
Generate a valid conda environment YAML file to satisfy the dependencies of the following Snakemake rule.

Constraints:
- Output ONLY a valid conda environment YAML file.
- Do not include prose, markdown, or triple backticks.
- Include a name field matching the conda environment name used in the rule's `conda:` section.
- Include the `defaults`, `conda-forge`, and `bioconda` channels.
- Use bioconda packages where possible.

Examples:
Examples (backticks are for clarity; do not include them in your output):

```yaml
```
name: blast
channels:
- defaults
Expand All @@ -43,7 +50,7 @@
- blast
```

```yaml
```
name: assembly
channels:
- defaults
Expand Down
41 changes: 28 additions & 13 deletions sunbeam/ai/rule_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _default_model_name() -> str:


def create_rules_from_prompt(
ext_name: str,
prompt: str,
*,
context_files: Optional[Sequence[Path]] = None,
Expand All @@ -58,6 +59,7 @@ def create_rules_from_prompt(
lazily and raises a clear error if they are not installed.

Args:
ext_name: Name of the extension (used for namespacing rules).
prompt: User's high-level request describing desired workflow behavior.
context_files: Optional paths whose contents provide additional context
to the LLM (e.g., existing rules or configs). Read as text.
Expand All @@ -76,27 +78,30 @@ def create_rules_from_prompt(
from langchain_openai import ChatOpenAI
from langchain_core.messages import SystemMessage, HumanMessage

system = RULE_GEN_SYSTEM_PROMPT
system = RULE_GEN_SYSTEM_PROMPT(ext_name)

# Assemble context payloads.
context_blobs: list[str] = []
if context_files:
for p in context_files:
try:
blob = Path(p).read_text()
blob = p.read_text()
except Exception:
continue
context_blobs.append(f"CONTEXT FILE {Path(p).name}:\n{blob}")
context_blobs.append(f"CONTEXT FILE {p.name}:\n{blob}")

context_text = ("\n\n".join(context_blobs)).strip()
context_text = f"Extension name: {ext_name}\n\n{context_text}"
user_content = (
prompt if not context_text else (f"Context:\n{context_text}\n\nTask:\n{prompt}")
)

model_name = model or _default_model_name()
llm = ChatOpenAI(model=model_name, api_key=api_key)
llm = ChatOpenAI(model=model_name, api_key=api_key, cache=False)

# Run rule generation
print("System prompt: ", system)
print("User prompt: ", user_content)
Comment on lines +103 to +104
Copy link

Copilot AI Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug print statements should be removed from production code or replaced with proper logging using the logger module.

Copilot uses AI. Check for mistakes.
resp = llm.invoke(
[
SystemMessage(content=system),
Expand All @@ -116,7 +121,7 @@ def create_rules_from_prompt(
context = f"Generate a conda environment YAML file for the following Snakemake rule:\n\n{rule}"
resp = llm.invoke(
[
SystemMessage(content=CONDA_ENV_GEN_SYSTEM_PROMPT),
SystemMessage(content=CONDA_ENV_GEN_SYSTEM_PROMPT()),
HumanMessage(content=context),
]
)
Expand All @@ -132,9 +137,13 @@ def create_rules_from_prompt(
out_path.write_text(rules_text)
# Write envs
envs_path = out_path.parent / "envs"
print("Out path: ", out_path)
print("Envs path: ", envs_path)
Comment on lines +140 to +141
Copy link

Copilot AI Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug print statements should be removed from production code or replaced with proper logging using the logger module.

Copilot uses AI. Check for mistakes.
envs_path.mkdir(exist_ok=True)
for env_name, env_text in env_texts.items():
(envs_path / f"{env_name}.yaml").write_text(env_text)
env_fp = envs_path / f"{env_name}.yaml"
print("Writing env: ", env_fp)
Copy link

Copilot AI Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug print statements should be removed from production code or replaced with proper logging using the logger module.

Copilot uses AI. Check for mistakes.
env_fp.write_text(env_text)
written_path = out_path

return RuleCreationResult(
Expand Down Expand Up @@ -170,13 +179,19 @@ def get_envs_from_rules(
# Extract env names
envs = {}
for rule in rules:
for line in rule.splitlines():
lines = rule.splitlines()
for i, line in enumerate(lines):
if line.strip().startswith("conda:"):
parts = line.split(":", 1)
if len(parts) == 2:
env_name = parts[1].strip().strip('"').strip("'").strip(".yaml")
envs[env_name] = rule
else:
print("Warning: Malformed conda line:", line)
env_line = lines[i + 1]
env_name = (
env_line.strip()
.replace('"', "")
.replace("'", "")
.replace(".yaml", "")
.replace(",", "")
.split("/")[-1]
)
print(env_name)
Copy link

Copilot AI Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug print statements should be removed from production code or replaced with proper logging using the logger module.

Copilot uses AI. Check for mistakes.
envs[env_name] = rule

return envs
12 changes: 11 additions & 1 deletion sunbeam/scripts/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,17 @@ def main(argv=sys.argv):
raise SystemExit(f"Extension directory {ext_dir} already exists")

rules_path = ext_dir / f"sbx_{ruleset_name}.smk"
result = create_rules_from_prompt(prompt, write_to=rules_path)

# Default context files to include
context_files = [
Path(__file__).parent.parent / "workflow" / "Snakefile",
Path(__file__).parent.parent / "workflow" / "rules" / "qc.smk",
Path(__file__).parent.parent / "workflow" / "rules" / "decontam.smk",
]

result = create_rules_from_prompt(
ruleset_name, prompt, context_files=context_files, write_to=rules_path
)

logger.info(f"Created extension scaffold at {ext_dir}")

Expand Down