Skip to content

Commit

Permalink
[AOTI][reland] Emit a CMakeLists.txt when package_cpp_only (pytorch#1…
Browse files Browse the repository at this point in the history
…43680)

Summary: Emit a CMakeLists.txt with compile and link options when package_cpp_only is specified. After unzipping AOTI generated .pt2 package file, user can manually build the generated model code in their local environment.

Pull Request resolved: pytorch#143680
Approved by: https://github.com/huydhn
  • Loading branch information
desertfire authored and pytorchmergebot committed Dec 21, 2024
1 parent b5e1592 commit fecf03f
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 16 deletions.
74 changes: 72 additions & 2 deletions test/inductor/test_aot_inductor_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
import copy
import functools
import io
import os
import shutil
import subprocess
import sys
import tempfile
import unittest
import zipfile
from pathlib import Path
from typing import Callable

from parameterized import parameterized_class
Expand All @@ -14,7 +19,12 @@
from torch._inductor.test_case import TestCase
from torch._inductor.utils import fresh_inductor_cache
from torch.export import Dim
from torch.testing._internal.common_utils import IS_FBCODE, TEST_CUDA
from torch.testing._internal.common_utils import (
IS_FBCODE,
skipIfRocm,
skipIfXpu,
TEST_CUDA,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU


Expand Down Expand Up @@ -176,6 +186,67 @@ def forward(self, x, y):
)
self.check_model(Model(), example_inputs)

@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
@skipIfRocm # build system may be different
@skipIfXpu # build system may be different
def test_compile_after_package(self):
if not self.package_cpp_only:
raise unittest.SkipTest("Only meant to test cpp package")
if shutil.which("cmake") is None:
raise unittest.SkipTest("cmake is not available")
if shutil.which("make") is None:
raise unittest.SkipTest("make is not available")

class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)

def forward(self, x, y):
return x + self.linear(y)

with torch.no_grad():
example_inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
model = Model().to(device=self.device)
expected = model(*example_inputs)

options = {
"aot_inductor.package_cpp_only": self.package_cpp_only,
}
ep = torch.export.export(model, example_inputs)
package_path = torch._inductor.aoti_compile_and_package(
ep, inductor_configs=options
)
with tempfile.TemporaryDirectory() as tmp_dir, zipfile.ZipFile(
package_path, "r"
) as zip_ref:
zip_ref.extractall(tmp_dir)
tmp_path = Path(tmp_dir) / "data" / "aotinductor" / "model"
self.assertTrue(tmp_path.exists())
build_path = tmp_path / "build"
self.assertTrue(not build_path.exists())

# Create a build directory to run cmake
build_path.mkdir()
custom_env = os.environ.copy()
custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent)
subprocess.run(
["cmake", ".."],
cwd=build_path,
env=custom_env,
)
subprocess.run(["make"], cwd=build_path)

# Check if the .so file was build successfully
so_path = build_path / "libaoti_model.so"
self.assertTrue(so_path.exists())
optimized = torch._export.aot_load(str(so_path), self.device)
actual = optimized(*example_inputs)
self.assertTrue(torch.allclose(actual, expected))

def test_metadata(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
Expand Down Expand Up @@ -444,6 +515,5 @@ def forward(self, a):
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

# cpp_extension N/A in fbcode
if HAS_GPU or sys.platform == "darwin":
run_tests(needs="filelock")
39 changes: 26 additions & 13 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,7 +1476,10 @@ def compile(
# is scoped to the 'key', so make sure the consts_s is protected
# by the same lock:
cpp_path_operator = Path(cpp_path)
consts_specified_dir = os.path.join(cpp_path_operator.parent, key)
specified_sub_dir = cpp_path_operator.parent / key
if not specified_sub_dir.exists():
specified_sub_dir.mkdir(exist_ok=True)
cmake_path = str(Path(specified_sub_dir) / "CMakeLists.txt")

def _compile_consts(consts: bytes, platform: str) -> str:
if platform == "linux":
Expand Down Expand Up @@ -1518,7 +1521,7 @@ def _compile_consts(consts: bytes, platform: str) -> str:
_, consts_s = write(
consts_asm,
"S",
specified_dir=consts_specified_dir,
specified_dir=str(specified_sub_dir),
)
consts_s = Path(consts_s)
object_build_options = CppTorchDeviceOptions(
Expand Down Expand Up @@ -1557,6 +1560,10 @@ def _compile_consts(consts: bytes, platform: str) -> str:
while pos < len(consts):
rc = f.write(consts[pos:])
pos += rc

# Remove the .S file to save space
os.remove(consts_s)

return consts_o

from torch.utils._filelock import FileLock
Expand Down Expand Up @@ -1665,22 +1672,25 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
output_o = object_builder.get_target_file_path()

log.debug("aot compilation command: %s", compile_cmd)
if not config.aot_inductor.package_cpp_only:
if fbcode_aot_cpu_re:
output_o = str(cpp_path_operator.with_suffix(".o"))
compile_file(cpp_path, output_o, compile_cmd.split())
os.chmod(output_o, 0o644)
else:
run_command_and_check(compile_cmd)

if config.aot_inductor.package_cpp_only:
# Not doing the actual compilation here
compile_flags = str(
cpp_path_operator.with_name(
f"{cpp_path_operator.stem}_compile_flags.json"
)
)
object_build_options.save_flags_to_file(compile_flags)
object_build_options.save_flags_to_json(compile_flags)
generated_files.append(compile_flags)
object_builder.save_compile_cmd_to_cmake(cmake_path)
object_builder.save_src_to_cmake(cmake_path, cpp_path)
generated_files.append(cmake_path)
else:
if fbcode_aot_cpu_re:
output_o = str(cpp_path_operator.with_suffix(".o"))
compile_file(cpp_path, output_o, compile_cmd.split())
os.chmod(output_o, 0o644)
else:
run_command_and_check(compile_cmd)

if not use_mmap_weights:
aot_constants = serialized_weights
Expand Down Expand Up @@ -1733,7 +1743,7 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
f"{cpp_path_operator.stem}_linker_flags.json"
)
)
so_build_options.save_flags_to_file(linker_flags)
so_build_options.save_flags_to_json(linker_flags)
generated_files.append(linker_flags)

# If we only want to package the cpp, then we need to save the
Expand All @@ -1754,6 +1764,10 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
generated_files.append(consts_o)
generated_files.append(kernels_o)

so_builder.save_src_to_cmake(cmake_path, consts_o)
for kernel_o in kernels_o.split():
so_builder.save_src_to_cmake(cmake_path, kernel_o)
so_builder.save_link_cmd_to_cmake(cmake_path)
else:
if fbcode_aot_cpu_re:
output_so = (
Expand All @@ -1769,7 +1783,6 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
for o_file in [
output_o,
consts_o,
str(Path(consts_o).with_suffix(".S")),
]:
# Remove these as they are not needed anymore
os.remove(o_file)
Expand Down
59 changes: 58 additions & 1 deletion torch/_inductor/cpp_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import subprocess
import sys
import sysconfig
import textwrap
import warnings
from ctypes import cdll
from ctypes.util import find_library
Expand Down Expand Up @@ -437,7 +438,7 @@ def get_use_absolute_path(self) -> bool:
def get_compile_only(self) -> bool:
return self._compile_only

def save_flags_to_file(self, file: str) -> None:
def save_flags_to_json(self, file: str) -> None:
attrs = {
"compiler": self.get_compiler(),
"definitions": self.get_definations(),
Expand Down Expand Up @@ -1400,6 +1401,7 @@ def __init__(
self._name = name

# Code start here, initial self internal veriables firstly.
self._build_option = BuildOption
self._compiler = BuildOption.get_compiler()
self._use_absolute_path = BuildOption.get_use_absolute_path()
self._aot_mode = BuildOption.get_aot_mode()
Expand Down Expand Up @@ -1533,3 +1535,58 @@ def build(self) -> None:
build_cmd = self.get_command_line()
run_compile_cmd(build_cmd, cwd=_build_tmp_dir)
_remove_dir(_build_tmp_dir)

def save_compile_cmd_to_cmake(
self,
cmake_path: str,
) -> None:
definitions = " ".join(self._build_option.get_definations())
contents = textwrap.dedent(
f"""
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
project(aoti_model LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
# May need to point CMAKE_PREFIX_PATH to the right torch location
find_package(Torch REQUIRED)
# Set a shared library target
add_library(aoti_model SHARED)
# Add macro definitions
target_compile_definitions(aoti_model PRIVATE {definitions})
# Add compile flags
target_compile_options(aoti_model PRIVATE {self._cflags_args})
# Backend specific flags
target_compile_options(aoti_model PRIVATE {self._passthrough_parameters_args} -c)
"""
)
with open(cmake_path, "w") as f:
f.write(contents)

def save_src_to_cmake(self, cmake_path: str, src_path: str) -> None:
# Remove the directory part of file_path
src_path = "${CMAKE_CURRENT_SOURCE_DIR}/" + Path(src_path).name
with open(cmake_path, "a") as f:
f.write(f"target_sources(aoti_model PRIVATE {src_path})\n")

def save_link_cmd_to_cmake(self, cmake_path: str) -> None:
lflags = " ".join(self._build_option.get_ldflags())
libs = " ".join(self._build_option.get_libraries())
contents = textwrap.dedent(
f"""
# Add linker flags
target_link_options(aoti_model PRIVATE {lflags})
# Add libraries
target_link_libraries(aoti_model PRIVATE {libs})
"""
)

assert os.path.exists(
cmake_path
), f"save_link_cmd_to_cmakefile expects {cmake_path} to already exist"
with open(cmake_path, "a") as f:
f.write(contents)

0 comments on commit fecf03f

Please sign in to comment.