-
Notifications
You must be signed in to change notification settings - Fork 44
AI rule generation #576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
AI rule generation #576
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,6 +45,7 @@ def _default_model_name() -> str: | |
|
|
||
|
|
||
| def create_rules_from_prompt( | ||
| ext_name: str, | ||
| prompt: str, | ||
| *, | ||
| context_files: Optional[Sequence[Path]] = None, | ||
|
|
@@ -58,6 +59,7 @@ def create_rules_from_prompt( | |
| lazily and raises a clear error if they are not installed. | ||
|
|
||
| Args: | ||
| ext_name: Name of the extension (used for namespacing rules). | ||
| prompt: User's high-level request describing desired workflow behavior. | ||
| context_files: Optional paths whose contents provide additional context | ||
| to the LLM (e.g., existing rules or configs). Read as text. | ||
|
|
@@ -76,27 +78,30 @@ def create_rules_from_prompt( | |
| from langchain_openai import ChatOpenAI | ||
| from langchain_core.messages import SystemMessage, HumanMessage | ||
|
|
||
| system = RULE_GEN_SYSTEM_PROMPT | ||
| system = RULE_GEN_SYSTEM_PROMPT(ext_name) | ||
|
|
||
| # Assemble context payloads. | ||
| context_blobs: list[str] = [] | ||
| if context_files: | ||
| for p in context_files: | ||
| try: | ||
| blob = Path(p).read_text() | ||
| blob = p.read_text() | ||
| except Exception: | ||
| continue | ||
| context_blobs.append(f"CONTEXT FILE {Path(p).name}:\n{blob}") | ||
| context_blobs.append(f"CONTEXT FILE {p.name}:\n{blob}") | ||
|
|
||
| context_text = ("\n\n".join(context_blobs)).strip() | ||
| context_text = f"Extension name: {ext_name}\n\n{context_text}" | ||
| user_content = ( | ||
| prompt if not context_text else (f"Context:\n{context_text}\n\nTask:\n{prompt}") | ||
| ) | ||
|
|
||
| model_name = model or _default_model_name() | ||
| llm = ChatOpenAI(model=model_name, api_key=api_key) | ||
| llm = ChatOpenAI(model=model_name, api_key=api_key, cache=False) | ||
|
|
||
| # Run rule generation | ||
| print("System prompt: ", system) | ||
| print("User prompt: ", user_content) | ||
| resp = llm.invoke( | ||
| [ | ||
| SystemMessage(content=system), | ||
|
|
@@ -116,7 +121,7 @@ def create_rules_from_prompt( | |
| context = f"Generate a conda environment YAML file for the following Snakemake rule:\n\n{rule}" | ||
| resp = llm.invoke( | ||
| [ | ||
| SystemMessage(content=CONDA_ENV_GEN_SYSTEM_PROMPT), | ||
| SystemMessage(content=CONDA_ENV_GEN_SYSTEM_PROMPT()), | ||
| HumanMessage(content=context), | ||
| ] | ||
| ) | ||
|
|
@@ -132,9 +137,13 @@ def create_rules_from_prompt( | |
| out_path.write_text(rules_text) | ||
| # Write envs | ||
| envs_path = out_path.parent / "envs" | ||
| print("Out path: ", out_path) | ||
| print("Envs path: ", envs_path) | ||
|
Comment on lines
+140
to
+141
|
||
| envs_path.mkdir(exist_ok=True) | ||
| for env_name, env_text in env_texts.items(): | ||
| (envs_path / f"{env_name}.yaml").write_text(env_text) | ||
| env_fp = envs_path / f"{env_name}.yaml" | ||
| print("Writing env: ", env_fp) | ||
|
||
| env_fp.write_text(env_text) | ||
| written_path = out_path | ||
|
|
||
| return RuleCreationResult( | ||
|
|
@@ -170,13 +179,19 @@ def get_envs_from_rules( | |
| # Extract env names | ||
| envs = {} | ||
| for rule in rules: | ||
| for line in rule.splitlines(): | ||
| lines = rule.splitlines() | ||
| for i, line in enumerate(lines): | ||
| if line.strip().startswith("conda:"): | ||
| parts = line.split(":", 1) | ||
| if len(parts) == 2: | ||
| env_name = parts[1].strip().strip('"').strip("'").strip(".yaml") | ||
| envs[env_name] = rule | ||
| else: | ||
| print("Warning: Malformed conda line:", line) | ||
| env_line = lines[i + 1] | ||
| env_name = ( | ||
| env_line.strip() | ||
| .replace('"', "") | ||
| .replace("'", "") | ||
| .replace(".yaml", "") | ||
| .replace(",", "") | ||
| .split("/")[-1] | ||
| ) | ||
| print(env_name) | ||
|
||
| envs[env_name] = rule | ||
|
|
||
| return envs | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debug print statements should be removed from production code or replaced with proper logging using the logger module.