Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion data_processing/wai_processing/configs/moge/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ root: # needs to be set via argument: root=<wai_dataset_root_path>
model_path: /ai4rl/fsx/xrtech/checkpoints/MoGe/moge-2-vitl-normal/model.pt
model_name: moge2
out_path: moge/v0
device: cuda
device: auto
random_scene_processing_order: true

scene_filters:
Expand Down
6 changes: 4 additions & 2 deletions data_processing/wai_processing/scripts/covisibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
load_scene_data,
project_points_to_views,
sample_depths_at_reprojections,
)
)
from wai_processing.utils.state import SceneProcessLock, set_processing_state

from mapanything.utils.device import get_device
from mapanything.utils.wai.core import load_data, store_data
from mapanything.utils.wai.scene_frame import get_scene_names

Expand Down Expand Up @@ -177,7 +178,8 @@ def compute_covisibility(cfg, scene_name: str, overwrite=False):
if overwrite:
logger.warning("Careful: Overwrite enabled!")

device = cfg.get("device", "cuda")
device_str = cfg.get("device", "cuda")
device = get_device(preferred=device_str if device_str != "auto" else None)
scene_names = get_scene_names(
cfg, shuffle=cfg.get("random_scene_processing_order", True)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@
load_scene_data,
project_points_to_views,
sample_depths_at_reprojections,
)
)
from wai_processing.utils.state import (
set_processing_state,
)
)

from mapanything.utils.device import get_device
from mapanything.utils.wai.core import (
get_frame,
load_data,
nest_modality,
set_frame,
store_data,
)
)
from mapanything.utils.wai.scene_frame import get_scene_names

logger = logging.getLogger("covisibility-confidence")
Expand Down Expand Up @@ -170,7 +171,8 @@ def compute_covisibility_map(
Returns:
None. Results are saved to disk.
"""
device = cfg.get("device", "cuda")
device_str = cfg.get("device", "cuda")
device = get_device(preferred=device_str if device_str != "auto" else None)

# Setup scene data
scene_root = Path(cfg.root) / scene_name
Expand Down
19 changes: 5 additions & 14 deletions mapanything/models/mapanything/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
MapAnything model class defined using UniCeption modules.
"""

import warnings
from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Type, Union

import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin

from mapanything.utils.device import get_amp_dtype, get_autocast_device_type

from mapanything.utils.geometry import (
apply_log_to_norm,
convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
Expand Down Expand Up @@ -2105,18 +2106,7 @@ def infer(
"""
# Determine the mixed precision floating point type
if use_amp:
if amp_dtype == "fp16":
amp_dtype = torch.float16
elif amp_dtype == "bf16":
if torch.cuda.is_bf16_supported():
amp_dtype = torch.bfloat16
else:
warnings.warn(
"bf16 is not supported on this device. Using fp16 instead."
)
amp_dtype = torch.float16
elif amp_dtype == "fp32":
amp_dtype = torch.float32
amp_dtype = get_amp_dtype(self.device, amp_dtype)
else:
amp_dtype = torch.float32

Expand Down Expand Up @@ -2157,7 +2147,8 @@ def infer(
)

# Run the model
with torch.autocast("cuda", enabled=bool(use_amp), dtype=amp_dtype):
device_type = get_autocast_device_type(self.device)
with torch.autocast(device_type, enabled=bool(use_amp), dtype=amp_dtype):
preds = self.forward(
processed_views,
memory_efficient_inference=memory_efficient_inference,
Expand Down
122 changes: 122 additions & 0 deletions mapanything/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,125 @@ def to_cuda(x):
Data with the same structure but with tensors moved to GPU
"""
return to_device(x, "cuda")


def get_device(preferred=None):
"""Auto-detect best available computation device.

Args:
preferred: Optional preferred device type ('cuda', 'mps', 'cpu').
If None, auto-detects in order: CUDA > MPS > CPU.

Returns:
torch.device: The best available device
"""
if preferred is not None:
if isinstance(preferred, str):
return torch.device(preferred)
return preferred

if torch.cuda.is_available():
return torch.device("cuda")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")


def get_device_capabilities(device):
"""Query device-specific capabilities.

Args:
device: torch.device to query

Returns:
dict: Device capabilities including:
- 'bf16_supported': bool - whether bfloat16 is supported
- 'device_type': str - device type for autocast ('cuda', 'mps', 'cpu')
"""
device_type = device.type if hasattr(device, "type") else str(device).split(":")[0]

capabilities = {
"device_type": device_type,
}

if device_type == "cuda":
capabilities["bf16_supported"] = torch.cuda.is_bf16_supported()
elif device_type == "mps":
capabilities["bf16_supported"] = False
else:
capabilities["bf16_supported"] = False

return capabilities


def is_memory_query_supported(device):
"""Check if memory query operations are supported for the device.

Args:
device: torch.device to check

Returns:
bool: True if memory operations like mem_get_info() are supported
"""
device_type = device.type if hasattr(device, "type") else str(device).split(":")[0]
return device_type == "cuda"


def empty_cache(device=None):
"""Clear GPU cache if backend supports it.

Args:
device: torch.device to clear cache for. If None, uses auto-detected device.
"""
if device is None:
device = get_device()

device_type = device.type if hasattr(device, "type") else str(device).split(":")[0]

if device_type == "cuda" and torch.cuda.is_available():
torch.cuda.empty_cache()


def get_amp_dtype(device, requested_dtype="bf16"):
"""Determine the best available dtype for mixed precision.

Args:
device: torch.device to query
requested_dtype: str or torch.dtype - preferred dtype ('bf16', 'fp16', 'fp32')

Returns:
tuple: (dtype_used: torch.dtype, dtype_name: str)
dtype_used is the actual torch.dtype that will be used
dtype_name is the string representation for logging
"""
if isinstance(requested_dtype, str):
requested_dtype = requested_dtype.lower()

if requested_dtype in ["fp32", "float32"]:
return torch.float32, "fp32"

capabilities = get_device_capabilities(device)
if requested_dtype in ["bf16", "bfloat16"]:
if capabilities["bf16_supported"]:
return torch.bfloat16, "bf16"
else:
return torch.float16, "fp16"

if requested_dtype in ["fp16", "float16"]:
return torch.float16, "fp16"

return torch.float32, "fp32"


def get_autocast_device_type(device):
"""Get the device type string for torch.autocast.

Args:
device: torch.device or string

Returns:
str: Device type for autocast ('cuda', 'mps', 'cpu')
"""
if hasattr(device, "type"):
return device.type
return str(device).split(":")[0]
20 changes: 5 additions & 15 deletions mapanything/utils/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
Inference utilities.
"""

import warnings
from typing import Any, Dict, List

import numpy as np
import torch

from mapanything.utils.device import get_amp_dtype, get_autocast_device_type
from mapanything.utils.geometry import (
depth_edge,
get_rays_in_camera_frame,
Expand Down Expand Up @@ -98,25 +98,15 @@ def loss_of_one_batch_multi_view(

# Determine the mixed precision floating point type
if use_amp:
if amp_dtype == "fp16":
amp_dtype = torch.float16
elif amp_dtype == "bf16":
if torch.cuda.is_bf16_supported():
amp_dtype = torch.bfloat16
else:
warnings.warn(
"bf16 is not supported on this device. Using fp16 instead."
)
amp_dtype = torch.float16
elif amp_dtype == "fp32":
amp_dtype = torch.float32
amp_dtype = get_amp_dtype(device, amp_dtype)
else:
amp_dtype = torch.float32

# Run model and compute loss
with torch.autocast("cuda", enabled=bool(use_amp), dtype=amp_dtype):
device_type = get_autocast_device_type(device)
with torch.autocast(device_type, enabled=bool(use_amp), dtype=amp_dtype):
preds = model(batch)
with torch.autocast("cuda", enabled=False):
with torch.autocast(device_type, enabled=False):
loss = criterion(batch, preds) if criterion is not None else None

result = {f"view{i + 1}": view for i, view in enumerate(batch)}
Expand Down
8 changes: 4 additions & 4 deletions scripts/convert_hf_to_benchmark_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch

from mapanything.models import MapAnything
from mapanything.utils.device import get_device


def parse_args():
Expand Down Expand Up @@ -57,13 +58,12 @@ def parse_args():
def main():
args = parse_args()

# Determine device
# Determine device (auto-detects CUDA > MPS > CPU if args.device == "auto")
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
device = get_device()
else:
device = args.device
device = torch.device(args.device)

device = torch.device(device)
print(f"Using device: {device}")

# Use Apache model if requested
Expand Down
9 changes: 5 additions & 4 deletions scripts/demo_colmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@
from mapanything.utils.misc import seed_everything
from mapanything.utils.viz import predictions_to_glb

# Configure CUDA settings
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
# Configure CUDA settings (only if available)
if hasattr(torch.backends, "cudnn"):
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False


def parse_args():
Expand Down
6 changes: 3 additions & 3 deletions scripts/demo_images_only_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import numpy as np
import rerun as rr
import torch

from mapanything.models import MapAnything
from mapanything.utils.device import get_device
from mapanything.utils.geometry import depthmap_to_world_frame
from mapanything.utils.image import load_images
from mapanything.utils.viz import (
Expand Down Expand Up @@ -146,8 +146,8 @@ def main():
) # Options: --headless, --connect, --serve, --addr, --save, --stdout
args = parser.parse_args()

# Get inference device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Get inference device (auto-detects CUDA > MPS > CPU)
device = get_device()
print(f"Using device: {device}")

# Initialize model from HuggingFace
Expand Down
4 changes: 2 additions & 2 deletions scripts/demo_inference_on_colmap_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ def main():
) # Options: --headless, --connect, --serve, --addr, --save, --stdout
args = parser.parse_args()

# Get inference device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Get inference device (auto-detects CUDA > MPS > CPU)
device = get_device()
print(f"Using device: {device}")

# Initialize model from HuggingFace
Expand Down
4 changes: 2 additions & 2 deletions scripts/demo_local_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import numpy as np
import torch

from mapanything.utils.device import get_device
from mapanything.utils.geometry import depthmap_to_world_frame
from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_local
from mapanything.utils.image import load_images
Expand Down Expand Up @@ -87,7 +87,7 @@ def main() -> None:
parser = get_parser()
args = parser.parse_args()

device = "cuda" if torch.cuda.is_available() else "cpu"
device = get_device()
print(f"Using device: {device}")

print(
Expand Down
Loading