Skip to content

Commit 91dd001

Browse files
committed
Made the config and RemotePathIterator much more robust and safe
1 parent 1353c5e commit 91dd001

File tree

3 files changed

+138
-191
lines changed

3 files changed

+138
-191
lines changed

src/pyremotedata/config.py

+42-51
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,30 @@
44
# Interanl imports
55
from pyremotedata import main_logger, module_logger
66

7-
config = None
8-
97
def ask_user(question, interactive=True):
108
if not interactive:
11-
raise RuntimeError("Cannot ask user for input when interactive=False")
9+
raise RuntimeError("Cannot ask user for input when interactive=False: " + question)
1210
return input(question)
1311

12+
def get_environment_variables(interactive=True):
13+
remote_username = os.getenv('PYREMOTEDATA_REMOTE_USERNAME', None) or ask_user("PYREMOTEDATA_REMOTE_USERNAME not set. Enter your remote name: ", interactive)
14+
remote_uri = os.getenv('PYREMOTEDATA_REMOTE_URI', None) or (ask_user("PYREMOTEDATA_REMOTE_URI not set. Enter your remote URI (leave empty for 'io.erda.au.dk'): ", interactive) or 'io.erda.au.dk')
15+
local_directory = os.getenv('PYREMOTEDATA_LOCAL_DIRECTORY', "")
16+
remote_directory = os.getenv('PYREMOTEDATA_REMOTE_DIRECTORY', None) or ask_user("PYREMOTEDATA_REMOTE_DIRECTORY not set. Enter your remote directory: ", interactive)
17+
18+
return remote_username, remote_uri, local_directory, remote_directory
19+
20+
CONFIG_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'pyremotedata_config.yaml')
21+
22+
def remove_config():
23+
if os.path.exists(CONFIG_PATH):
24+
os.remove(CONFIG_PATH)
25+
module_logger.info("Removed config file at {}".format(CONFIG_PATH))
26+
else:
27+
module_logger.info("No config file found at {}".format(CONFIG_PATH))
28+
1429
def create_default_config(interactive=True):
15-
base_dir = os.path.dirname(os.path.abspath(__file__))
16-
config_path = os.path.join(base_dir, 'pyremotedata_config.yaml')
17-
18-
# Check for environment variables or ask for user input
19-
remote_username = os.getenv('PYREMOTEDATA_REMOTE_USERNAME', None) or ask_user("Enter your remote name: ", interactive)
20-
remote_uri = os.getenv('PYREMOTEDATA_REMOTE_URI', None) or (ask_user("Enter your remote URI (leave empty for 'io.erda.au.dk'): ", interactive) or 'io.erda.au.dk')
21-
local_dir = os.getenv('PYREMOTEDATA_LOCAL_DIR', "")
22-
remote_directory = os.getenv('PYREMOTEDATA_REMOTE_DIRECTORY', None) or ask_user("Enter your remote directory: ", interactive)
23-
if isinstance(local_dir, str) and local_dir != "":
24-
local_dir = f'"{local_dir}"'
25-
if isinstance(remote_directory, str) and remote_directory != "":
26-
remote_directory = f'"{remote_directory}"'
30+
remote_username, remote_uri, local_directory, remote_directory = get_environment_variables(interactive)
2731

2832
# TODO: Remove unnecessary config options!
2933
yaml_content = f"""
@@ -50,8 +54,8 @@ def create_default_config(interactive=True):
5054
# Remote configuration
5155
user: "{remote_username}"
5256
remote: "{remote_uri}"
53-
local_dir: {local_dir} # Leave empty to use the default local directory
54-
default_remote_dir : {remote_directory}
57+
local_dir: "{local_directory}" # Leave empty to use the default local directory
58+
default_remote_dir : "{remote_directory}"
5559
5660
# Lftp configuration (Can be left as-is)
5761
lftp:
@@ -74,45 +78,48 @@ def create_default_config(interactive=True):
7478
7579
"""
7680

77-
with open(config_path, "w") as config_file:
81+
with open(CONFIG_PATH, "w") as config_file:
7882
config_file.write(yaml_content)
7983

80-
module_logger.info("Created default config file at {}".format(config_path))
84+
module_logger.info("Created default config file at {}".format(CONFIG_PATH))
8185
module_logger.info("OBS: It is **strongly** recommended that you **check the config file** and make sure that it is correct before using pyRemoteData.")
8286

83-
def get_config():
84-
base_dir = os.path.dirname(os.path.abspath(__file__))
85-
config_path = os.path.join(base_dir, 'pyremotedata_config.yaml')
86-
87-
if not os.path.exists(config_path):
87+
def get_config():
88+
if not os.path.exists(CONFIG_PATH):
8889
interactive = os.getenv("PYREMOTEDATA_AUTO", "no").lower().strip() != "yes"
8990
if not interactive or ask_user("Config file not found. Create default config file? (y/n): ", interactive).lower().strip() == 'y':
9091
create_default_config(interactive)
9192
else:
92-
raise FileNotFoundError("Config file not found at {}".format(config_path))
93+
raise FileNotFoundError("Config file not found at {}".format(CONFIG_PATH))
9394

94-
with open(config_path, 'r') as stream:
95+
with open(CONFIG_PATH, 'r') as stream:
9596
try:
9697
config_data = yaml.safe_load(stream)
9798
except yaml.YAMLError as exc:
9899
module_logger.error(exc)
99100
return None
100101

101-
return config_data
102+
# Check if environment variables match config (config/cache invalidation)
103+
invalid = False
104+
for k, ek, v in zip(["user", "remote", "local_dir", "default_remote_dir"], ["PYREMOTEDATA_REMOTE_USERNAME", "PYREMOTEDATAA_REMOTE_URI", "PYREMOTEDATA_LOCAL_DIRECTORY", "PYREMOTEDATA_REMOTE_DIRECTORY"], get_environment_variables()):
105+
expected = config_data["implicit_mount"][k]
106+
if expected != v and not (expected is None and v == ""):
107+
module_logger.warning(f"Invalid config detected, auto regenerating from scratch: Expected '{expected}' for '{k}' ({ek}), but got '{v}'.")
108+
invalid = True
109+
if invalid:
110+
remove_config()
111+
return get_config()
102112

103-
config = get_config()
113+
return config_data
104114

105115
def get_this_config(this):
106-
global config
107116
if not isinstance(this, str):
108117
raise TypeError("Expected string, got {}".format(type(this)))
109-
# Check if config is loaded
110-
if config is None:
111-
# Load config
112-
config = get_config()
113-
if this not in config:
118+
# Load config
119+
cfg = get_config()
120+
if this not in cfg:
114121
raise ValueError("Key {} not found in config".format(this))
115-
return config[this]
122+
return cfg[this]
116123

117124
def get_mount_config():
118125
return get_this_config('mount')
@@ -148,19 +155,3 @@ def deparse_args(config, what):
148155

149156
return arg_str
150157

151-
def remove_config():
152-
global config
153-
base_dir = os.path.dirname(os.path.abspath(__file__))
154-
config_path = os.path.join(base_dir, 'pyremotedata_config.yaml')
155-
if os.path.exists(config_path):
156-
os.remove(config_path)
157-
module_logger.info("Removed config file at {}".format(config_path))
158-
else:
159-
module_logger.info("No config file found at {}".format(config_path))
160-
# Reset global config
161-
config = None
162-
163-
def config_path():
164-
base_dir = os.path.dirname(os.path.abspath(__file__))
165-
config_path = os.path.join(base_dir, 'pyremotedata_config.yaml')
166-
return config_path

0 commit comments

Comments
 (0)