@@ -35,6 +35,23 @@ def main_cli():
3535 break
3636 current_path = current_path .parent
3737
38+ parent_parser = argparse .ArgumentParser (
39+ add_help = False , description = "Common arguments for all subcommands"
40+ )
41+ parent_parser .add_argument (
42+ "--header-path" ,
43+ help = "Path of C headers to analyze (default: %(default)s)" ,
44+ default = str (default_c_header_path ),
45+ )
46+ parent_parser .add_argument (
47+ "--include-file" ,
48+ help = "root header file to examine (default: %(default)s)" ,
49+ default = "cuvs/core/all.h" ,
50+ )
51+ parent_parser .add_argument (
52+ "--dlpack-include-path" , help = "path of dlpack header include"
53+ )
54+
3855 parser = argparse .ArgumentParser (
3956 description = "Analyze C headers for breaking ABI changes"
4057 )
@@ -43,44 +60,28 @@ def main_cli():
4360 )
4461
4562 parser_extract = subparsers .add_parser (
46- "extract" , help = "Extract the ABI from a set of header files"
63+ "extract" ,
64+ parents = [parent_parser ],
65+ help = "Extract the ABI from a set of header files" ,
4766 )
4867 parser_extract .add_argument (
4968 "--output-file" ,
5069 type = str ,
5170 help = "The file to output the ABI into (default: %(default)s)" ,
5271 default = "c_abi.json.gz" ,
5372 )
54- parser_extract .add_argument (
55- "--header-path" ,
56- help = "Path of C headers to extract the ABI from (default: %(default)s)" ,
57- default = str (default_c_header_path ),
58- )
59- parser_extract .add_argument (
60- "--include-file" ,
61- help = "root header file to examine (default: %(default)s)" ,
62- default = "cuvs/core/all.h" ,
63- )
6473
6574 parser_analyze = subparsers .add_parser (
66- "analyze" , help = "Analyze a set of header files for breaking changes"
75+ "analyze" ,
76+ parents = [parent_parser ],
77+ help = "Analyze a set of header files for breaking changes" ,
6778 )
6879 parser_analyze .add_argument (
6980 "--abi-file" ,
7081 type = str ,
7182 help = "The extracted ABI file to compare against (default: %(default)s)" ,
7283 default = "c_abi.json.gz" ,
7384 )
74- parser_analyze .add_argument (
75- "--header-path" ,
76- help = "Path of C headers to analyze (default: %(default)s)" ,
77- default = str (default_c_header_path ),
78- )
79- parser_analyze .add_argument (
80- "--include-file" ,
81- help = "root header file to examine (default: %(default)s)" ,
82- default = "cuvs/core/all.h" ,
83- )
8485
8586 args = parser .parse_args ()
8687 if not args .command :
@@ -89,23 +90,34 @@ def main_cli():
8990
9091 header_path = pathlib .Path (args .header_path )
9192
92- # TODO: better way of specifying the dlpack header source, since missing the dlpack.h
93- # header means that we all dlpack types get treated as 'int' which could be misleading
94- # when looking for differences in the ABI (like if we change a field from `DLDataType` to
95- # `int` without specifying the dlpack include directory, we won't know that the type has
96- # changed)
97- dlpack_header_path = (
98- header_path .parent .parent
99- / "cpp"
100- / "build"
101- / "_deps"
102- / "dlpack-src"
103- / "include"
104- )
105- if not dlpack_header_path .is_dir ():
106- raise ValueError (f"dlpack header { dlpack_header_path } not found" )
93+ if args .dlpack_include_path :
94+ dlpack_include_path = pathlib .Path (args .dlpack_include_path ).resolve ()
95+
96+ else :
97+ # try getting from the cmake build directory dependencies if we
98+ # haven't specified the include directory
99+ dlpack_include_path = (
100+ header_path .parent .parent
101+ / "cpp"
102+ / "build"
103+ / "_deps"
104+ / "dlpack-src"
105+ / "include"
106+ )
107+
108+ if not dlpack_include_path .is_dir ():
109+ raise ValueError (
110+ f"dlpack header path '{ dlpack_include_path } ' not found"
111+ )
112+
113+ if not (dlpack_include_path / "dlpack" / "dlpack.h" ).is_file ():
114+ raise ValueError (
115+ f"dlpack header 'dlpack/dlpack.h' not found in '{ dlpack_include_path } '"
116+ )
117+
118+ print (f"using dlpack from { dlpack_include_path } " )
107119
108- extra_clang_args = [f"-I{ str (dlpack_header_path )} " ]
120+ extra_clang_args = [f"-I{ str (dlpack_include_path )} " ]
109121
110122 if args .command == "extract" :
111123 abi = Abi .from_include_path (
0 commit comments