Skip to content

Commit c45022d

Browse files
authored
Merge pull request #55 from yucongalicechen/input_dir3
input file list
2 parents 495b394 + 13b1e4f commit c45022d

File tree

3 files changed

+83
-45
lines changed

3 files changed

+83
-45
lines changed

src/diffpy/labpdfproc/labpdfprocapp.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
import sys
22
from argparse import ArgumentParser
3-
from pathlib import Path
43

54
from diffpy.labpdfproc.functions import apply_corr, compute_cve
6-
from diffpy.labpdfproc.tools import known_sources, load_user_metadata, set_output_directory, set_wavelength
5+
from diffpy.labpdfproc.tools import (
6+
expand_list_file,
7+
known_sources,
8+
load_user_metadata,
9+
set_input_lists,
10+
set_output_directory,
11+
set_wavelength,
12+
)
713
from diffpy.utils.parsers.loaddata import loadData
814
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object
915

@@ -21,7 +27,7 @@ def get_args(override_cli_inputs=None):
2127
"data-files in that directory will be processed. Examples of valid "
2228
"inputs are 'file.xy', 'data/file.xy', 'file.xy, data/file.xy', "
2329
"'.' (load everything in the current directory), 'data' (load"
24-
"everything in the folder ./data', 'data/file_list.txt' (load"
30+
"everything in the folder ./data), 'data/file_list.txt' (load"
2531
" the list of files contained in the text-file called "
2632
"file_list.txt that can be found in the folder ./data).",
2733
)
@@ -89,45 +95,47 @@ def get_args(override_cli_inputs=None):
8995

9096
def main():
9197
args = get_args()
98+
args = expand_list_file(args)
99+
args = set_input_lists(args)
92100
args.output_directory = set_output_directory(args)
93101
args.wavelength = set_wavelength(args)
94102
args = load_user_metadata(args)
95103

96-
filepath = Path(args.input_file)
97-
outfilestem = filepath.stem + "_corrected"
98-
corrfilestem = filepath.stem + "_cve"
99-
outfile = args.output_directory / (outfilestem + ".chi")
100-
corrfile = args.output_directory / (corrfilestem + ".chi")
104+
for filepath in args.input_directory:
105+
outfilestem = filepath.stem + "_corrected"
106+
corrfilestem = filepath.stem + "_cve"
107+
outfile = args.output_directory / (outfilestem + ".chi")
108+
corrfile = args.output_directory / (corrfilestem + ".chi")
101109

102-
if outfile.exists() and not args.force_overwrite:
103-
sys.exit(
104-
f"Output file {str(outfile)} already exists. Please rerun "
105-
f"specifying -f if you want to overwrite it."
106-
)
107-
if corrfile.exists() and args.output_correction and not args.force_overwrite:
108-
sys.exit(
109-
f"Corrections file {str(corrfile)} was requested and already "
110-
f"exists. Please rerun specifying -f if you want to overwrite it."
111-
)
110+
if outfile.exists() and not args.force_overwrite:
111+
sys.exit(
112+
f"Output file {str(outfile)} already exists. Please rerun "
113+
f"specifying -f if you want to overwrite it."
114+
)
115+
if corrfile.exists() and args.output_correction and not args.force_overwrite:
116+
sys.exit(
117+
f"Corrections file {str(corrfile)} was requested and already "
118+
f"exists. Please rerun specifying -f if you want to overwrite it."
119+
)
112120

113-
input_pattern = Diffraction_object(wavelength=args.wavelength)
114-
xarray, yarray = loadData(args.input_file, unpack=True)
115-
input_pattern.insert_scattering_quantity(
116-
xarray,
117-
yarray,
118-
"tth",
119-
scat_quantity="x-ray",
120-
name=str(args.input_file),
121-
metadata={"muD": args.mud, "anode_type": args.anode_type},
122-
)
121+
input_pattern = Diffraction_object(wavelength=args.wavelength)
122+
xarray, yarray = loadData(args.input_file, unpack=True)
123+
input_pattern.insert_scattering_quantity(
124+
xarray,
125+
yarray,
126+
"tth",
127+
scat_quantity="x-ray",
128+
name=str(args.input_file),
129+
metadata={"muD": args.mud, "anode_type": args.anode_type},
130+
)
123131

124-
absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength)
125-
corrected_data = apply_corr(input_pattern, absorption_correction)
126-
corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}"
127-
corrected_data.dump(f"{outfile}", xtype="tth")
132+
absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength)
133+
corrected_data = apply_corr(input_pattern, absorption_correction)
134+
corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}"
135+
corrected_data.dump(f"{outfile}", xtype="tth")
128136

129-
if args.output_correction:
130-
absorption_correction.dump(f"{corrfile}", xtype="tth")
137+
if args.output_correction:
138+
absorption_correction.dump(f"{corrfile}", xtype="tth")
131139

132140

133141
if __name__ == "__main__":

src/diffpy/labpdfproc/tests/test_tools.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from diffpy.labpdfproc.labpdfprocapp import get_args
88
from diffpy.labpdfproc.tools import (
9+
expand_list_file,
910
known_sources,
1011
load_user_metadata,
1112
set_input_lists,
@@ -49,10 +50,6 @@
4950
"input_dir/binary.pkl",
5051
],
5152
),
52-
( # file_list.txt list of files provided
53-
["input_dir/file_list.txt"],
54-
["good_data.chi", "good_data.xy", "good_data.txt"],
55-
),
5653
( # file_list_example2.txt list of files provided in different directories
5754
["input_dir/file_list_example2.txt"],
5855
["input_dir/good_data.chi", "good_data.xy", "input_dir/good_data.txt"],
@@ -68,8 +65,9 @@ def test_set_input_lists(inputs, expected, user_filesystem):
6865

6966
cli_inputs = ["2.5"] + inputs
7067
actual_args = get_args(cli_inputs)
68+
actual_args = expand_list_file(actual_args)
7169
actual_args = set_input_lists(actual_args)
72-
assert list(actual_args.input_paths).sort() == expected_paths.sort()
70+
assert sorted(actual_args.input_paths) == sorted(expected_paths)
7371

7472

7573
# This test covers non-existing single input file or directory, in this case we raise an error with message
@@ -87,6 +85,10 @@ def test_set_input_lists(inputs, expected, user_filesystem):
8785
["good_data.chi", "good_data.xy", "unreadable_file.txt", "missing_file.txt"],
8886
"Cannot find missing_file.txt. Please specify valid input file(s) or directories.",
8987
),
88+
( # file_list.txt list of files provided (with missing files)
89+
["input_dir/file_list.txt"],
90+
"Cannot find missing_file.txt. Please specify valid input file(s) or directories.",
91+
),
9092
]
9193

9294

@@ -96,6 +98,7 @@ def test_set_input_files_bad(inputs, msg, user_filesystem):
9698
os.chdir(base_dir)
9799
cli_inputs = ["2.5"] + inputs
98100
actual_args = get_args(cli_inputs)
101+
actual_args = expand_list_file(actual_args)
99102
with pytest.raises(FileNotFoundError, match=msg[0]):
100103
actual_args = set_input_lists(actual_args)
101104

src/diffpy/labpdfproc/tools.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,29 @@ def set_output_directory(args):
2828
return output_dir
2929

3030

31+
def expand_list_file(args):
32+
"""
33+
Expands the list of inputs by adding files from file lists and removing the file list.
34+
35+
Parameters
36+
----------
37+
args argparse.Namespace
38+
the arguments from the parser
39+
40+
Returns
41+
-------
42+
the arguments with the modified input list
43+
44+
"""
45+
file_list_inputs = [input_name for input_name in args.input if "file_list" in input_name]
46+
for file_list_input in file_list_inputs:
47+
with open(file_list_input, "r") as f:
48+
file_inputs = [input_name.strip() for input_name in f.readlines()]
49+
args.input.extend(file_inputs)
50+
args.input.remove(file_list_input)
51+
return args
52+
53+
3154
def set_input_lists(args):
3255
"""
3356
Set input directory and files.
@@ -47,20 +70,24 @@ def set_input_lists(args):
4770
"""
4871

4972
input_paths = []
50-
for input in args.input:
51-
input_path = Path(input).resolve()
73+
for input_name in args.input:
74+
input_path = Path(input_name).resolve()
5275
if input_path.exists():
5376
if input_path.is_file():
5477
input_paths.append(input_path)
5578
elif input_path.is_dir():
5679
input_files = input_path.glob("*")
57-
input_files = [file.resolve() for file in input_files if file.is_file()]
80+
input_files = [
81+
file.resolve() for file in input_files if file.is_file() and "file_list" not in file.name
82+
]
5883
input_paths.extend(input_files)
5984
else:
60-
raise FileNotFoundError(f"Cannot find {input}. Please specify valid input file(s) or directories.")
85+
raise FileNotFoundError(
86+
f"Cannot find {input_name}. Please specify valid input file(s) or directories."
87+
)
6188
else:
62-
raise FileNotFoundError(f"Cannot find {input}")
63-
setattr(args, "input_paths", input_paths)
89+
raise FileNotFoundError(f"Cannot find {input_name}")
90+
setattr(args, "input_paths", list(set(input_paths)))
6491
return args
6592

6693

0 commit comments

Comments
 (0)