Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile() silently does not run MACE backward pass, cuE v0.2.0 #77

Open
vbharadwaj-bk opened this issue Jan 25, 2025 · 0 comments

Comments

@vbharadwaj-bk
Copy link

vbharadwaj-bk commented Jan 25, 2025

Hello! Thank you for the great kernel package. I am attempting to benchmark MACE with cuE v0.2 (there are some breaking changes related to the constraint that all inputs must have 2 dimensions, but I patched them so that MACE runs as expected without torch.compile).

With torch.compile and cuE involved (with force prediction enabled), the backward pass does not appear to run (silently, without raising an error). The predicted force tensor does not appear to change from zero, and profiling with torch.profile reveals that only kernels expected in the forward pass are dispatched.

Perhaps this is a question for the MACE repository, but I am posting it here since the backward pass runs fine if any of the following conditions are satisfied:

a) The model uses cuE, but is NOT compiled; then the predicted force tensor fills, and I see all kernels in the profiler
b) The model uses the default e3nn backend

which leads me to believe that cuE might possibly have an issue. Code is below, and I can provide the data file if needed.

@ilyes319 for visibility.

import sys, json, time, pathlib

import argparse
import logging
from pathlib import Path

import ase.io
import numpy as np
import torch
from e3nn import o3
from mace import data, modules, tools
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
from mace.tools import torch_geometric
from torch.utils.benchmark import Timer
from mace.calculators import mace_mp
from torch.profiler import profile, record_function, ProfilerActivity

import warnings
warnings.filterwarnings("ignore")

try:
    import cuequivariance as cue  # pylint: disable=unused-import
    CUET_AVAILABLE = True
except ImportError:
    CUET_AVAILABLE = False

def create_model(hidden_irreps, max_ell, device, cueq_config=None):
    table = tools.AtomicNumberTable([8, 82, 53, 55])
    model_config = {
        "r_max": 6.0,
        "num_bessel": 8,
        "num_polynomial_cutoff": 6,
        "max_ell": max_ell,
        "interaction_cls": modules.interaction_classes["RealAgnosticResidualInteractionBlock"],
        "interaction_cls_first": modules.interaction_classes["RealAgnosticResidualInteractionBlock"],
        "num_interactions": 2,
        "num_elements": len(table),
        "hidden_irreps": o3.Irreps(hidden_irreps),
        "MLP_irreps": o3.Irreps("16x0e"),
        "gate": torch.nn.functional.silu,
        "atomic_energies": torch.ones(len(table)),
        "avg_num_neighbors": 8,
        "atomic_numbers": table.zs,
        "correlation": 3,
        "radial_type": "bessel",
        "num_elements": 4,
        "cueq_config": cueq_config,
        "atomic_inter_scale": 1.0,
        "atomic_inter_shift": 0.0,
    }
    return modules.ScaleShiftMACE(**model_config).to(device)

def benchmark_model(model, batch, num_iterations=100, warmup=100, label=None, output_folder=None):
    def run_inference():
        out = model(batch,training=True)
        torch.cuda.synchronize()
        return out

    # Warmup
    for _ in range(warmup):
        run_inference()

    # Benchmark
    timer = Timer(
        stmt="run_inference()",
        globals={
            "run_inference": run_inference,
        },
    )
    warm_up_measurement = timer.timeit(num_iterations)
    measurement = timer.timeit(num_iterations)
    print(run_inference())

    return measurement

def create_model_cueq(hidden_irreps, max_ell, device, cueq_config=None):
    source_model = create_model(hidden_irreps, max_ell, device, cueq_config)
    model_cueq = run_e3nn_to_cueq(source_model)
    return model_cueq.to(device)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("xyz_file", type=str, help="Path to xyz file")
    parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"])
    parser.add_argument("--num_iters", type=int, default=100)
    parser.add_argument("--max_ell", type=int, default=3)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--hidden_irreps", type=str, default="128x0e + 128x1o + 128x2e")
    parser.add_argument("--output_folder", type=str, default=None)
    args = parser.parse_args()

    output_folder = args.output_folder

    if output_folder is None:
        millis_since_epoch = round(time.time() * 1000)
        output_folder = pathlib.Path(f'{package_root}/outputs/{millis_since_epoch}')
    else:
        output_folder = pathlib.Path(output_folder)

    for dtype_str, dtype in [   ("f32", torch.float32),
                                ("f64", torch.float64)
                            ]:
        torch.set_default_dtype(dtype)
        device = torch.device(args.device)
        hidden_irreps = o3.Irreps(args.hidden_irreps)

        # Create dataset
        atoms_list = ase.io.read(args.xyz_file, index=":")
        #table = tools.AtomicNumberTable(list(set(np.concatenate([atoms.numbers for atoms in atoms_list]))))
        table = tools.AtomicNumberTable([6, 82, 53, 55])
        data_loader = torch_geometric.dataloader.DataLoader(
            dataset=[data.AtomicData.from_config(
                data.config_from_atoms(atoms),
                z_table=table,
                cutoff=6.0
            ) for atoms in atoms_list],
            batch_size=min(len(atoms_list), args.batch_size),
            shuffle=False,
            drop_last=False,
        )
        batch = next(iter(data_loader)).to(device)
        batch_dict = batch.to_dict()

        output_folder.mkdir(parents=True, exist_ok=True)

        traces_folder = output_folder / "traces"
        traces_folder.mkdir(parents=True, exist_ok=True) 

        print("\nBenchmarking Configuration:")
        print(f"Number of atoms: {len(atoms_list[0])}")
        print(f"Number of edges: {batch['edge_index'].shape[1]}")
        print(f"Batch size: {min(len(atoms_list), args.batch_size)}")
        print(f"Device: {args.device}")
        print(f"Hidden irreps: {hidden_irreps}")
        print(f"Number of iterations: {args.num_iters}\n")

        # Runs fine
        model_e3nn = create_model(hidden_irreps, args.max_ell, device)
        measurement_e3nn = benchmark_model(model_e3nn, batch_dict, args.num_iters, label=f"e3nn_{dtype_str}", output_folder=output_folder)
        print(f"E3NN Measurement:\n{measurement_e3nn}")
         
        # Runs fine
        model_cueq = create_model_cueq(hidden_irreps, args.max_ell, device)
        measurement_cueq = benchmark_model(model_cueq, batch_dict, args.num_iters, label=f"cueq_{dtype_str}", output_folder=output_folder)
        print(f"\nCUET Measurement:\n{measurement_cueq}")
        print(f"\nSpeedup: {measurement_e3nn.mean / measurement_cueq.mean:.2f}x")

        # Backward pass does not appear to work
        tmp_model = tools.compile.prepare(create_model_cueq)(hidden_irreps, args.max_ell, device)
        model_cueq = torch.compile(tmp_model, mode="default") 
        measurement_cueq = benchmark_model(model_cueq, batch_dict, args.num_iters, label=f"ours_{dtype_str}", output_folder=output_folder)
        print(f"\nCUET Measurement:\n{measurement_cueq}")
        print(f"\nSpeedup: {measurement_e3nn.mean / measurement_cueq.mean:.2f}x")


if __name__ == "__main__":
    main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant