Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,46 @@ pip install -e .
```
</details>

<details>
<summary style="font-size: 1.3em; font-weight: 600;">
Click for Intel GPU (XPU) installation instructions
</summary>

BoltzGen supports Intel Data Center GPUs, discrete GPUs, and integrated GPUs.

### 1 - Set up your Python environment

First, create and activate a Python environment using Miniconda, Conda, or uv (see the "detailed installation instructions" section above for Miniconda setup). For example:

```bash
conda create -n bg python=3.12
conda activate bg
```

### 2 - Install PyTorch with XPU support

Before installing BoltzGen, install PyTorch from Intel's wheel index:

```bash
pip install torch --index-url https://download.pytorch.org/whl/xpu
```

### 3 - Install BoltzGen

```bash
pip install boltzgen
```

> **Note:** The CUDA-specific packages (`cuequivariance`, `nvidia-ml-py`) will be installed but are not used on XPU devices. BoltzGen automatically detects XPU and uses pure PyTorch fallbacks.

### 4 - Verify XPU is detected

```bash
python -c "import torch; print(f'XPU available: {torch.xpu.is_available()}')"
```

</details>

<details>
<summary style="font-size: 1.3em; font-weight: 600;">
Click for optional Docker instructions if you prefer Docker
Expand Down
121 changes: 79 additions & 42 deletions src/boltzgen/cli/boltzgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,35 @@

quiet_startup()

import collections
import huggingface_hub
import argparse
from dataclasses import dataclass
import shlex
import subprocess
import os
import time
import collections
import math
import os
import re
import shlex
import shutil
import subprocess
import sys
import numpy as np
import pandas as pd
import time
from dataclasses import dataclass
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as pkg_version
from pathlib import Path
from typing import Any, Dict, List, Tuple
import yaml

import huggingface_hub
import hydra
import numpy as np
import omegaconf
import pandas as pd
import torch
import yaml

from boltzgen.data import const
from boltzgen.data.mol import load_canonicals
from boltzgen.data.parse.schema import YamlDesignParser
from boltzgen.data.write.mmcif import to_mmcif
from boltzgen.task.task import Task
from importlib.metadata import PackageNotFoundError, version as pkg_version

### Paths and constants ####
# Get the path to the project root (where main.py and configs/ are located)
Expand Down Expand Up @@ -161,6 +163,13 @@ def add_configure_arguments(
type=int,
help="Number of devices to use. Default is all devices available.",
)
p.add_argument(
"--accelerator",
type=str,
choices=["auto", "gpu", "xpu", "cpu", "cuda"],
default="auto",
help="Accelerator to use. Default: %(default)s.",
)
p.add_argument(
"--num_workers",
type=int,
Expand All @@ -170,7 +179,7 @@ def add_configure_arguments(
p.add_argument(
"--config_dir",
type=Path,
help=f"Path to the directory of default config files. Default: %(default)s",
help="Path to the directory of default config files. Default: %(default)s",
default=config_dir,
)
p.add_argument(
Expand Down Expand Up @@ -756,7 +765,7 @@ def execute_command(args: argparse.Namespace) -> None:
)

# Load steps from steps.yaml
with open(steps_yaml_path, "r") as f:
with open(steps_yaml_path) as f:
steps_data = yaml.safe_load(f)

if not isinstance(steps_data, dict) or "steps" not in steps_data:
Expand Down Expand Up @@ -908,17 +917,25 @@ def __init__(self, args: argparse.Namespace, moldir: Path):
)

# Handle use_kernels argument
device_capability = torch.cuda.get_device_capability()
use_kernels = None
if args.use_kernels == "auto":
use_kernels = device_capability[0] >= 8
elif args.use_kernels == "true":
use_kernels = True
elif args.use_kernels == "false":
use_kernels = False
# Kernels are CUDA-specific (cuequivariance), disable on XPU/other devices
if torch.cuda.is_available():
device_capability = torch.cuda.get_device_capability()
use_kernels = None
if args.use_kernels == "auto":
use_kernels = device_capability[0] >= 8
elif args.use_kernels == "true":
use_kernels = True
elif args.use_kernels == "false":
use_kernels = False
else:
raise ValueError(f"Invalid use_kernels value: {args.use_kernels}")
print(
f"Using kernels: {use_kernels} [device capability: {device_capability}]"
)
else:
raise ValueError(f"Invalid use_kernels value: {args.use_kernels}")
print(f"Using kernels: {use_kernels} [device capability: {device_capability}]")
# XPU or other non-CUDA device - kernels not supported
use_kernels = False
print(f"Using kernels: {use_kernels} [non-CUDA device, kernels disabled]")

protocol_config = protocol_configs[protocol]
print(f"Config overrides for protocol {protocol}: {protocol_config}")
Expand All @@ -928,11 +945,29 @@ def __init__(self, args: argparse.Namespace, moldir: Path):
protocol_config, args.config, step_names
)

devices = (
args.devices if args.devices is not None else torch.cuda.device_count()
)
# Detect available devices (CUDA or XPU)
if args.devices is not None:
devices = args.devices
elif torch.cuda.is_available():
devices = torch.cuda.device_count()
elif hasattr(torch, "xpu") and torch.xpu.is_available():
devices = torch.xpu.device_count()
else:
devices = 1 # CPU fallback
print(f"Using {devices} devices")

# Determine accelerator
if args.accelerator == "auto":
if torch.cuda.is_available():
accelerator = "gpu"
elif hasattr(torch, "xpu") and torch.xpu.is_available():
accelerator = "xpu"
else:
accelerator = "cpu"
else:
accelerator = args.accelerator
print(f"Using accelerator: {accelerator}")

self.steps = []

# Design generation
Expand All @@ -956,7 +991,7 @@ def __init__(self, args: argparse.Namespace, moldir: Path):
f"override.checkpoints.first_checkpoint_num_samples={fraction_per_checkpoint}"
)
checkpoint_args.append(
f"override.checkpoints.checkpoint_list=["
"override.checkpoints.checkpoint_list=["
+ ",".join(
f"{{'checkpoint': {{'num_samples': {fraction_per_checkpoint}, 'path': '{get_artifact_path(args, checkpoint)}'}}}}"
for checkpoint in args.design_checkpoints[1:]
Expand All @@ -970,14 +1005,14 @@ def __init__(self, args: argparse.Namespace, moldir: Path):
)
# Also disable the schedule when applying a fixed step scale
design_step_and_noise_scale_args.append(
f"override.step_scale_schedule=null"
"override.step_scale_schedule=null"
)
if args.noise_scale is not None:
design_step_and_noise_scale_args.append(
f"override.diffusion_process_args.noise_scale={args.noise_scale}"
)
design_step_and_noise_scale_args.append(
f"override.noise_scale_schedule=null"
"override.noise_scale_schedule=null"
)

if args.only_inverse_fold:
Expand All @@ -990,6 +1025,7 @@ def __init__(self, args: argparse.Namespace, moldir: Path):
f"output={output_dir}",
f"data.cfg.yaml_path=[{', '.join(str(s) for s in args.design_spec)}]",
f"trainer.devices={devices}",
f"trainer.accelerator={accelerator}",
f"data.cfg.multiplicity={getattr(args, 'inverse_fold_num_sequences', 10)}",
f"override.use_kernels={use_kernels}",
f"checkpoint={get_artifact_path(args, args.inverse_fold_checkpoint)}",
Expand All @@ -1008,6 +1044,7 @@ def __init__(self, args: argparse.Namespace, moldir: Path):
f"output={output_dir}",
f"data.cfg.yaml_path=[{', '.join(str(s) for s in args.design_spec)}]",
f"trainer.devices={devices}",
f"trainer.accelerator={accelerator}",
f"data.num_workers={args.num_workers}",
f"data.cfg.skip_existing={args.reuse}",
f"data.cfg.multiplicity={num_batches}",
Expand Down Expand Up @@ -1055,11 +1092,12 @@ def __init__(self, args: argparse.Namespace, moldir: Path):
f"data.cfg.multiplicity={args.inverse_fold_num_sequences}",
f"data.cfg.num_workers={args.num_workers}",
f"data.skip_existing={args.reuse}",
f"data.skip_existing_kind=inverse_fold",
"data.skip_existing_kind=inverse_fold",
f"override.use_kernels={use_kernels}",
f"checkpoint={get_artifact_path(args, args.inverse_fold_checkpoint)}",
f"data.cfg.moldir={moldir}",
f"trainer.devices={devices}",
f"trainer.accelerator={accelerator}",
f"override.inverse_fold_args.inverse_fold_restriction=[{', '.join(exclude_residues)}]",
]
+ config_args_by_step["inverse_folding"],
Expand All @@ -1076,9 +1114,10 @@ def __init__(self, args: argparse.Namespace, moldir: Path):
f"output={output_dir}",
f"data.design_dir={input_dir}",
f"trainer.devices={devices}",
f"trainer.accelerator={accelerator}",
f"data.cfg.num_workers={args.num_workers}",
f"data.skip_existing={args.reuse}",
f"data.skip_existing_kind=folded",
"data.skip_existing_kind=folded",
f"override.use_kernels={use_kernels}",
f"checkpoint={get_artifact_path(args, args.folding_checkpoint)}",
f"data.cfg.moldir={moldir}",
Expand All @@ -1099,14 +1138,15 @@ def __init__(self, args: argparse.Namespace, moldir: Path):
f"output={output_dir}",
f"data.design_dir={input_dir}",
f"trainer.devices={devices}",
f"trainer.accelerator={accelerator}",
f"data.cfg.num_workers={args.num_workers}",
f"data.skip_existing={args.reuse}",
f"data.skip_existing_kind=design_folded",
"data.skip_existing_kind=design_folded",
f"override.use_kernels={use_kernels}",
f"checkpoint={get_artifact_path(args, args.folding_checkpoint)}",
f"data.cfg.moldir={moldir}",
f"writer.designfolding=True",
f"data.cfg.return_designfolding=True",
"writer.designfolding=True",
"data.cfg.return_designfolding=True",
]
+ config_args_by_step["design_folding"],
)
Expand All @@ -1123,9 +1163,10 @@ def __init__(self, args: argparse.Namespace, moldir: Path):
f"output={output_dir}",
f"data.design_dir={input_dir}",
f"trainer.devices={devices}",
f"trainer.accelerator={accelerator}",
f"data.cfg.num_workers={args.num_workers}",
f"data.skip_existing={args.reuse}",
f"data.skip_existing_kind=affinity",
"data.skip_existing_kind=affinity",
f"override.use_kernels={use_kernels}",
f"checkpoint={get_artifact_path(args, args.affinity_checkpoint)}",
f"data.cfg.moldir={moldir}",
Expand All @@ -1142,7 +1183,7 @@ def __init__(self, args: argparse.Namespace, moldir: Path):
args=[
f"design_dir={input_dir}",
f"data.skip_existing={args.reuse}",
f"data.skip_existing_kind=analyzed",
"data.skip_existing_kind=analyzed",
f"data.cfg.moldir={moldir}",
f"designfolding_metrics={do_design_folding}",
f"delta_sasa_original={args.skip_inverse_folding}",
Expand Down Expand Up @@ -1317,7 +1358,7 @@ def check_design_spec(
output_path = design_spec.stem + ".cif"
with open(output_path, "w") as f:
f.write(mmcif)
print(f"Design specification visualization is written to {str(output_path)}")
print(f"Design specification visualization is written to {output_path!s}")


def get_artifact_path(
Expand Down Expand Up @@ -1462,8 +1503,7 @@ def parse_size_buckets(value_list):
raise ValueError(
f"Invalid size_buckets format: '{item}'. All values must be integers. Use 'min-max:count' format."
)
else:
raise e
raise e
else:
raise ValueError(
f"Invalid size_buckets format: '{item}'. Use 'min-max:count' format."
Expand Down Expand Up @@ -1620,18 +1660,15 @@ def _copy_design_files(
required=False,
)


def _make_new_file_name(original_file: str, new_id: str) -> str:
path = Path(original_file)
suffix = "".join(path.suffixes)
return f"{new_id}{suffix}" if suffix else new_id


def _slugify_run_tag(path: Path, index: int) -> str:
slug = re.sub(r"[^0-9A-Za-z]+", "-", path.name).strip("-").lower()
return slug or f"run{index}"


def _copy_path(src: Path, dst: Path, *, required: bool) -> None:
if src.exists():
dst.parent.mkdir(parents=True, exist_ok=True)
Expand Down
6 changes: 3 additions & 3 deletions src/boltzgen/model/layers/attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
from einops.layers.torch import Rearrange
from torch import Tensor, nn

import boltzgen.model.layers.initialize as init
from boltzgen.model.modules.utils import get_autocast_device_type


class AttentionPairBias(nn.Module):
Expand Down Expand Up @@ -86,7 +88,6 @@ def forward(
The output sequence tensor.

"""

B = s.shape[0]

# Compute projections
Expand Down Expand Up @@ -118,7 +119,7 @@ def forward(
attn_mask = (1 - mask[:, None, None].float()) * -self.inf
attn_mask = attn_mask + bias.float()

with torch.autocast("cuda", enabled=False):
with torch.autocast(get_autocast_device_type(), enabled=False):
# Compute attention weights
o = torch.nn.functional.scaled_dot_product_attention(
q.float(),
Expand All @@ -131,4 +132,3 @@ def forward(
o = o * g
o = self.proj_o(o)
return o

6 changes: 4 additions & 2 deletions src/boltzgen/model/layers/confidence_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from torch import nn

from boltzgen.data import const
from boltzgen.model.modules.utils import get_autocast_device_type


def compute_collinear_mask(v1, v2):
Expand All @@ -22,7 +24,7 @@ def compute_frame_pred(
resolved_mask=None,
inference=False,
):
with torch.amp.autocast("cuda", enabled=False):
with torch.amp.autocast(get_autocast_device_type(), enabled=False):
asym_id_token = feats["asym_id"]
asym_id_atom = torch.bmm(
feats["atom_to_token"].float(), asym_id_token.unsqueeze(-1).float()
Expand Down Expand Up @@ -297,7 +299,7 @@ def compute_ptms(logits, x_preds, feats, multiplicity):
is_chain_design_token.sum() > 0
and (1 - is_chain_design_token.int()).sum() > 0
):
atom_chain_design_mask = (torch.bmm( feats["atom_to_token"].bfloat16(), feats["chain_design_mask"].bfloat16().unsqueeze(-1),).squeeze(-1).bool())
atom_chain_design_mask = (torch.bmm( feats["atom_to_token"].bfloat16(), feats["chain_design_mask"].bfloat16().unsqueeze(-1)).squeeze(-1).bool())
x_target = x_pred_half[~atom_chain_design_mask]
x_design = x_pred_half[atom_chain_design_mask]
dists = torch.cdist(x_target, x_design)
Expand Down
Loading