3
3
import os
4
4
import sys
5
5
import re
6
- import shutil
6
+ import platform
7
+ import sysconfig
7
8
from distutils import ccompiler
8
9
from setuptools import setup
9
10
from setuptools .extension import Extension
10
- from numpy import get_include
11
+ import numpy
11
12
12
13
cwd = os .path .abspath (os .path .dirname (__file__ ))
13
14
fftwdir = os .path .join (cwd , 'mpi4py_fft' , 'fftw' )
14
15
prec_map = {'float' : 'f' , 'double' : '' , 'long double' : 'l' }
16
+ triplet = sysconfig .get_config_var ('MULTIARCH' ) or ''
17
+ bits = platform .architecture ()[0 ][:- 3 ]
15
18
16
- def get_library_dirs ():
17
- lib_dirs = [os .path .join (sys .prefix , 'lib' )]
18
- for f in ('FFTW_ROOT' , 'FFTW_DIR' ):
19
- if f in os .environ :
20
- lib_dirs .append (os .path .join (os .environ [f ], 'lib' ))
21
- return lib_dirs
19
+ def append (dirlist , * args ):
20
+ entry = os .path .join (* args )
21
+ entry = os .path .normpath (entry )
22
+ if os .path .isdir (entry ):
23
+ if entry not in dirlist :
24
+ dirlist .append (entry )
25
+
26
+ def get_prefix_dirs ():
27
+ dirs = []
28
+ for envvar in ('FFTW_ROOT' , 'FFTW_DIR' ):
29
+ if envvar in os .environ :
30
+ prefix = os .environ [envvar ]
31
+ append (dirs , prefix )
32
+ append (dirs , sys .prefix )
33
+ return dirs
22
34
23
35
def get_include_dirs ():
24
- inc_dirs = [get_include (), os .path .join (sys .prefix , 'include' )]
25
- for f in ('FFTW_ROOT' , 'FFTW_DIR' ):
26
- if f in os .environ :
27
- inc_dirs .append (os .path .join (os .environ [f ], 'include' ))
28
- return inc_dirs
36
+ dirs = []
37
+ if 'FFTW_INCLUDE_DIR' in os .environ :
38
+ entry = os .environ ['FFTW_INCLUDE_DIR' ]
39
+ append (dirs , entry )
40
+ for prefix in get_prefix_dirs ():
41
+ append (dirs , prefix , 'include' , triplet )
42
+ append (dirs , prefix , 'include' )
43
+ dirs .append (numpy .get_include ())
44
+ return dirs
45
+
46
+ def get_library_dirs ():
47
+ dirs = []
48
+ if 'FFTW_LIBRARY_DIR' in os .environ :
49
+ entry = os .environ ['FFTW_LIBRARY_DIR' ]
50
+ append (dirs , entry )
51
+ for prefix in get_prefix_dirs ():
52
+ append (dirs , prefix , 'lib' + bits )
53
+ append (dirs , prefix , 'lib' , triplet )
54
+ append (dirs , prefix , 'lib' )
55
+ return dirs
29
56
30
57
def get_fftw_libs ():
31
58
"""Return FFTW libraries"""
@@ -34,14 +61,14 @@ def get_fftw_libs():
34
61
libs = {}
35
62
for d in ('float' , 'double' , 'long double' ):
36
63
lib = 'fftw3' + prec_map [d ]
64
+ tlib = lib + '_threads'
37
65
if compiler .find_library_file (library_dirs , lib ):
38
66
libs [d ] = [lib ]
39
- tlib = '_' .join ((lib , 'threads' ))
40
67
if compiler .find_library_file (library_dirs , tlib ):
41
68
libs [d ].append (tlib )
42
- if sys . platform in ( 'unix' , 'darwin' ) :
69
+ if os . name == 'posix' :
43
70
libs [d ].append ('m' )
44
- assert len (libs ) > 0 , "No FFTW libraries found in {} {} " .format (library_dirs , sys . prefix )
71
+ assert len (libs ) > 0 , "No FFTW libraries found in {}" .format (library_dirs )
45
72
return libs
46
73
47
74
def generate_extensions (fftwlibs ):
@@ -50,12 +77,20 @@ def generate_extensions(fftwlibs):
50
77
if d == 'double' :
51
78
continue
52
79
p = 'fftw' + prec_map [d ]+ '_'
53
- for fl in ('fftw_planxfftn.h' , 'fftw_planxfftn.c' , 'fftw_xfftn.pyx' , 'fftw_xfftn.pxd' ):
54
- fp = fl .replace ('fftw_' , p )
55
- shutil .copy (os .path .join (fftwdir , fl ), os .path .join (fftwdir , fp ))
56
- sedcmd = "sed -i ''" if sys .platform == 'darwin' else "sed -i''"
57
- os .system (sedcmd + " 's/fftw_/{0}/g' {1}" .format (p , os .path .join (fftwdir , fp )))
58
- os .system (sedcmd + " 's/double/{0}/g' {1}" .format (d , os .path .join (fftwdir , fp )))
80
+ for fname in (
81
+ 'fftw_planxfftn.h' ,
82
+ 'fftw_planxfftn.c' ,
83
+ 'fftw_xfftn.pyx' ,
84
+ 'fftw_xfftn.pxd' ,
85
+ ):
86
+ src = os .path .join (fftwdir , fname )
87
+ dst = os .path .join (fftwdir , fname .replace ('fftw_' , p ))
88
+ with open (src , 'r' ) as fin :
89
+ code = fin .read ()
90
+ code = re .sub ('fftw_' , p , code )
91
+ code = re .sub ('double' , d , code )
92
+ with open (dst , 'w' ) as fout :
93
+ fout .write (code )
59
94
60
95
def get_extensions (fftwlibs ):
61
96
"""Return list of extension modules"""
0 commit comments