forked from facebookresearch/torchdim
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsetup.py
64 lines (55 loc) · 1.89 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from setuptools import setup
import os.path
import os.path
from torch.utils.cpp_extension import (
CppExtension,
BuildExtension
)
from subprocess import run
import glob
build_functorch = True
srcs = [
'torchdim/csrc/dim.cpp',
]
extra_libraries=[]
if build_functorch:
cwd = 'third_party/functorch'
if '#if 0' not in open('third_party/functorch/functorch/csrc/init.cpp', 'r').read():
print("PATCHING FUNCTORCH")
run(['git', 'apply', '../../functorch.diff'], cwd=cwd)
this_dir = os.path.dirname(os.path.abspath(__file__))
ft_home = os.path.join(this_dir, "third_party", "functorch")
extensions_dir = os.path.join(ft_home, "functorch", "csrc")
extension_sources = set(
os.path.join(extensions_dir, p)
for p in glob.glob(os.path.join(extensions_dir, "*.cpp"))
)
srcs.extend(extension_sources)
else:
import functorch._C
ft_home = os.path.dirname(os.path.dirname(os.path.abspath(functorch.__file__)))
extra_libraries.append(functorch._C.__file__)
mintorch_C = CppExtension(
'torchdim._C',
srcs,
include_dirs = [os.path.dirname(os.path.abspath(__file__)), ft_home],
extra_compile_args = { "cxx": ["-Wno-write-strings", "-Wno-sign-compare"] },
extra_link_args = extra_libraries
)
setup(name='torchdim',
version='1.0',
description='first class dimensions',
author='',
author_email='',
url='',
packages=['torchdim'],
ext_modules=[mintorch_C],
cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)}
)
with open('compile_commands.json', 'w') as cc:
run(['ninja', '-C', 'build/temp.linux-x86_64-3.8', '-t', 'compdb'], stdout=cc)