Skip to content

Commit fecf03f

Browse files
desertfirepytorchmergebot
authored andcommitted
[AOTI][reland] Emit a CMakeLists.txt when package_cpp_only (pytorch#143680)
Summary: Emit a CMakeLists.txt with compile and link options when package_cpp_only is specified. After unzipping AOTI generated .pt2 package file, user can manually build the generated model code in their local environment. Pull Request resolved: pytorch#143680 Approved by: https://github.com/huydhn
1 parent b5e1592 commit fecf03f

File tree

3 files changed

+156
-16
lines changed

3 files changed

+156
-16
lines changed

test/inductor/test_aot_inductor_package.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@
22
import copy
33
import functools
44
import io
5+
import os
6+
import shutil
7+
import subprocess
58
import sys
69
import tempfile
710
import unittest
11+
import zipfile
12+
from pathlib import Path
813
from typing import Callable
914

1015
from parameterized import parameterized_class
@@ -14,7 +19,12 @@
1419
from torch._inductor.test_case import TestCase
1520
from torch._inductor.utils import fresh_inductor_cache
1621
from torch.export import Dim
17-
from torch.testing._internal.common_utils import IS_FBCODE, TEST_CUDA
22+
from torch.testing._internal.common_utils import (
23+
IS_FBCODE,
24+
skipIfRocm,
25+
skipIfXpu,
26+
TEST_CUDA,
27+
)
1828
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
1929

2030

@@ -176,6 +186,67 @@ def forward(self, x, y):
176186
)
177187
self.check_model(Model(), example_inputs)
178188

189+
@unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode")
190+
@skipIfRocm # build system may be different
191+
@skipIfXpu # build system may be different
192+
def test_compile_after_package(self):
193+
if not self.package_cpp_only:
194+
raise unittest.SkipTest("Only meant to test cpp package")
195+
if shutil.which("cmake") is None:
196+
raise unittest.SkipTest("cmake is not available")
197+
if shutil.which("make") is None:
198+
raise unittest.SkipTest("make is not available")
199+
200+
class Model(torch.nn.Module):
201+
def __init__(self) -> None:
202+
super().__init__()
203+
self.linear = torch.nn.Linear(10, 10)
204+
205+
def forward(self, x, y):
206+
return x + self.linear(y)
207+
208+
with torch.no_grad():
209+
example_inputs = (
210+
torch.randn(10, 10, device=self.device),
211+
torch.randn(10, 10, device=self.device),
212+
)
213+
model = Model().to(device=self.device)
214+
expected = model(*example_inputs)
215+
216+
options = {
217+
"aot_inductor.package_cpp_only": self.package_cpp_only,
218+
}
219+
ep = torch.export.export(model, example_inputs)
220+
package_path = torch._inductor.aoti_compile_and_package(
221+
ep, inductor_configs=options
222+
)
223+
with tempfile.TemporaryDirectory() as tmp_dir, zipfile.ZipFile(
224+
package_path, "r"
225+
) as zip_ref:
226+
zip_ref.extractall(tmp_dir)
227+
tmp_path = Path(tmp_dir) / "data" / "aotinductor" / "model"
228+
self.assertTrue(tmp_path.exists())
229+
build_path = tmp_path / "build"
230+
self.assertTrue(not build_path.exists())
231+
232+
# Create a build directory to run cmake
233+
build_path.mkdir()
234+
custom_env = os.environ.copy()
235+
custom_env["CMAKE_PREFIX_PATH"] = str(Path(torch.__file__).parent)
236+
subprocess.run(
237+
["cmake", ".."],
238+
cwd=build_path,
239+
env=custom_env,
240+
)
241+
subprocess.run(["make"], cwd=build_path)
242+
243+
# Check if the .so file was build successfully
244+
so_path = build_path / "libaoti_model.so"
245+
self.assertTrue(so_path.exists())
246+
optimized = torch._export.aot_load(str(so_path), self.device)
247+
actual = optimized(*example_inputs)
248+
self.assertTrue(torch.allclose(actual, expected))
249+
179250
def test_metadata(self):
180251
class Model(torch.nn.Module):
181252
def __init__(self) -> None:
@@ -444,6 +515,5 @@ def forward(self, a):
444515
if __name__ == "__main__":
445516
from torch._inductor.test_case import run_tests
446517

447-
# cpp_extension N/A in fbcode
448518
if HAS_GPU or sys.platform == "darwin":
449519
run_tests(needs="filelock")

torch/_inductor/codecache.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,7 +1476,10 @@ def compile(
14761476
# is scoped to the 'key', so make sure the consts_s is protected
14771477
# by the same lock:
14781478
cpp_path_operator = Path(cpp_path)
1479-
consts_specified_dir = os.path.join(cpp_path_operator.parent, key)
1479+
specified_sub_dir = cpp_path_operator.parent / key
1480+
if not specified_sub_dir.exists():
1481+
specified_sub_dir.mkdir(exist_ok=True)
1482+
cmake_path = str(Path(specified_sub_dir) / "CMakeLists.txt")
14801483

14811484
def _compile_consts(consts: bytes, platform: str) -> str:
14821485
if platform == "linux":
@@ -1518,7 +1521,7 @@ def _compile_consts(consts: bytes, platform: str) -> str:
15181521
_, consts_s = write(
15191522
consts_asm,
15201523
"S",
1521-
specified_dir=consts_specified_dir,
1524+
specified_dir=str(specified_sub_dir),
15221525
)
15231526
consts_s = Path(consts_s)
15241527
object_build_options = CppTorchDeviceOptions(
@@ -1557,6 +1560,10 @@ def _compile_consts(consts: bytes, platform: str) -> str:
15571560
while pos < len(consts):
15581561
rc = f.write(consts[pos:])
15591562
pos += rc
1563+
1564+
# Remove the .S file to save space
1565+
os.remove(consts_s)
1566+
15601567
return consts_o
15611568

15621569
from torch.utils._filelock import FileLock
@@ -1665,22 +1672,25 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
16651672
output_o = object_builder.get_target_file_path()
16661673

16671674
log.debug("aot compilation command: %s", compile_cmd)
1668-
if not config.aot_inductor.package_cpp_only:
1669-
if fbcode_aot_cpu_re:
1670-
output_o = str(cpp_path_operator.with_suffix(".o"))
1671-
compile_file(cpp_path, output_o, compile_cmd.split())
1672-
os.chmod(output_o, 0o644)
1673-
else:
1674-
run_command_and_check(compile_cmd)
1675-
16761675
if config.aot_inductor.package_cpp_only:
1676+
# Not doing the actual compilation here
16771677
compile_flags = str(
16781678
cpp_path_operator.with_name(
16791679
f"{cpp_path_operator.stem}_compile_flags.json"
16801680
)
16811681
)
1682-
object_build_options.save_flags_to_file(compile_flags)
1682+
object_build_options.save_flags_to_json(compile_flags)
16831683
generated_files.append(compile_flags)
1684+
object_builder.save_compile_cmd_to_cmake(cmake_path)
1685+
object_builder.save_src_to_cmake(cmake_path, cpp_path)
1686+
generated_files.append(cmake_path)
1687+
else:
1688+
if fbcode_aot_cpu_re:
1689+
output_o = str(cpp_path_operator.with_suffix(".o"))
1690+
compile_file(cpp_path, output_o, compile_cmd.split())
1691+
os.chmod(output_o, 0o644)
1692+
else:
1693+
run_command_and_check(compile_cmd)
16841694

16851695
if not use_mmap_weights:
16861696
aot_constants = serialized_weights
@@ -1733,7 +1743,7 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
17331743
f"{cpp_path_operator.stem}_linker_flags.json"
17341744
)
17351745
)
1736-
so_build_options.save_flags_to_file(linker_flags)
1746+
so_build_options.save_flags_to_json(linker_flags)
17371747
generated_files.append(linker_flags)
17381748

17391749
# If we only want to package the cpp, then we need to save the
@@ -1754,6 +1764,10 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
17541764
generated_files.append(consts_o)
17551765
generated_files.append(kernels_o)
17561766

1767+
so_builder.save_src_to_cmake(cmake_path, consts_o)
1768+
for kernel_o in kernels_o.split():
1769+
so_builder.save_src_to_cmake(cmake_path, kernel_o)
1770+
so_builder.save_link_cmd_to_cmake(cmake_path)
17571771
else:
17581772
if fbcode_aot_cpu_re:
17591773
output_so = (
@@ -1769,7 +1783,6 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes:
17691783
for o_file in [
17701784
output_o,
17711785
consts_o,
1772-
str(Path(consts_o).with_suffix(".S")),
17731786
]:
17741787
# Remove these as they are not needed anymore
17751788
os.remove(o_file)

torch/_inductor/cpp_builder.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import subprocess
1515
import sys
1616
import sysconfig
17+
import textwrap
1718
import warnings
1819
from ctypes import cdll
1920
from ctypes.util import find_library
@@ -437,7 +438,7 @@ def get_use_absolute_path(self) -> bool:
437438
def get_compile_only(self) -> bool:
438439
return self._compile_only
439440

440-
def save_flags_to_file(self, file: str) -> None:
441+
def save_flags_to_json(self, file: str) -> None:
441442
attrs = {
442443
"compiler": self.get_compiler(),
443444
"definitions": self.get_definations(),
@@ -1400,6 +1401,7 @@ def __init__(
14001401
self._name = name
14011402

14021403
# Code start here, initial self internal veriables firstly.
1404+
self._build_option = BuildOption
14031405
self._compiler = BuildOption.get_compiler()
14041406
self._use_absolute_path = BuildOption.get_use_absolute_path()
14051407
self._aot_mode = BuildOption.get_aot_mode()
@@ -1533,3 +1535,58 @@ def build(self) -> None:
15331535
build_cmd = self.get_command_line()
15341536
run_compile_cmd(build_cmd, cwd=_build_tmp_dir)
15351537
_remove_dir(_build_tmp_dir)
1538+
1539+
def save_compile_cmd_to_cmake(
1540+
self,
1541+
cmake_path: str,
1542+
) -> None:
1543+
definitions = " ".join(self._build_option.get_definations())
1544+
contents = textwrap.dedent(
1545+
f"""
1546+
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
1547+
project(aoti_model LANGUAGES CXX)
1548+
set(CMAKE_CXX_STANDARD 17)
1549+
1550+
# May need to point CMAKE_PREFIX_PATH to the right torch location
1551+
find_package(Torch REQUIRED)
1552+
1553+
# Set a shared library target
1554+
add_library(aoti_model SHARED)
1555+
1556+
# Add macro definitions
1557+
target_compile_definitions(aoti_model PRIVATE {definitions})
1558+
1559+
# Add compile flags
1560+
target_compile_options(aoti_model PRIVATE {self._cflags_args})
1561+
# Backend specific flags
1562+
target_compile_options(aoti_model PRIVATE {self._passthrough_parameters_args} -c)
1563+
1564+
"""
1565+
)
1566+
with open(cmake_path, "w") as f:
1567+
f.write(contents)
1568+
1569+
def save_src_to_cmake(self, cmake_path: str, src_path: str) -> None:
1570+
# Remove the directory part of file_path
1571+
src_path = "${CMAKE_CURRENT_SOURCE_DIR}/" + Path(src_path).name
1572+
with open(cmake_path, "a") as f:
1573+
f.write(f"target_sources(aoti_model PRIVATE {src_path})\n")
1574+
1575+
def save_link_cmd_to_cmake(self, cmake_path: str) -> None:
1576+
lflags = " ".join(self._build_option.get_ldflags())
1577+
libs = " ".join(self._build_option.get_libraries())
1578+
contents = textwrap.dedent(
1579+
f"""
1580+
# Add linker flags
1581+
target_link_options(aoti_model PRIVATE {lflags})
1582+
1583+
# Add libraries
1584+
target_link_libraries(aoti_model PRIVATE {libs})
1585+
"""
1586+
)
1587+
1588+
assert os.path.exists(
1589+
cmake_path
1590+
), f"save_link_cmd_to_cmakefile expects {cmake_path} to already exist"
1591+
with open(cmake_path, "a") as f:
1592+
f.write(contents)

0 commit comments

Comments
 (0)