From 5b4845ef7eb37de2e29dc90aae6abbca06ad8d11 Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Wed, 25 Mar 2026 10:33:59 +0100 Subject: [PATCH 1/3] refactor: `BaseIO` class and `prefer_env_var` - add `BaseIO` class from which all the external data sources inherit - add `prefer_env_var` init arg to implement data_dir and env var precedence --- swvo/io/__init__.py | 1 + swvo/io/base.py | 118 ++++++++++++++++++++++++++++++++++ swvo/io/dst/omni.py | 12 ---- swvo/io/dst/wdc.py | 18 +----- swvo/io/f10_7/omni.py | 13 ---- swvo/io/f10_7/swpc.py | 22 +------ swvo/io/hp/ensemble.py | 45 ++++++++----- swvo/io/hp/gfz.py | 45 ++++++++----- swvo/io/kp/ensemble.py | 20 +++--- swvo/io/kp/niemegk.py | 17 +---- swvo/io/kp/omni.py | 13 ---- swvo/io/kp/swpc.py | 16 +---- swvo/io/omni/omni_high_res.py | 17 +---- swvo/io/omni/omni_low_res.py | 23 +------ swvo/io/sme/supermag.py | 21 ++---- swvo/io/solar_wind/ace.py | 18 +----- swvo/io/solar_wind/dscovr.py | 18 +----- swvo/io/solar_wind/omni.py | 14 +--- swvo/io/solar_wind/swift.py | 24 ++++--- swvo/io/symh/omni.py | 13 ---- 20 files changed, 232 insertions(+), 256 deletions(-) create mode 100644 swvo/io/base.py diff --git a/swvo/io/__init__.py b/swvo/io/__init__.py index cfd70607..d7cce249 100755 --- a/swvo/io/__init__.py +++ b/swvo/io/__init__.py @@ -13,3 +13,4 @@ solar_wind as solar_wind, sme as sme, ) +from swvo.io.base import BaseIO as BaseIO diff --git a/swvo/io/base.py b/swvo/io/base.py new file mode 100644 index 00000000..1fdffb6b --- /dev/null +++ b/swvo/io/base.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Base class for all IO modules. +""" + +import logging +import os +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional + +import pandas as pd + +logger = logging.getLogger(__name__) + + +class BaseIO(ABC): + """Abstract base class for all IO classes. + + This base class defines the common interface for external data I/O operations, + including initialization, reading, and downloading/processing data. + + Subclasses can implement flexible signatures for `read()` and `download_and_process()` + methods to accommodate different data sources and requirements. + + Parameters + ---------- + data_dir : Path | None + Data directory for storing downloaded/processed data. + If not provided, it will be read from the environment variable + defined by the subclass's `ENV_VAR_NAME`. + + Raises + ------ + ValueError + Raises `ValueError` if necessary environment variable is not set + and `data_dir` is not provided. + """ + + ENV_VAR_NAME: str = "" # Must be set by subclasses + LABEL: str = "" # Must be set by subclasses + + def __init__(self, data_dir: Optional[Path] = None, prefer_env_var: bool = False) -> None: + """Initialize the BaseIO class. + + Parameters + ---------- + data_dir : Path | None + Data directory for storing data. If not provided, it will be read + from the environment variable defined by ENV_VAR_NAME. + prefer_env_var : bool, optional + If True, the environment variable takes precedence over the passed data_dir argument. + If False (default), the passed data_dir is used if provided, otherwise the environment variable is used. + Raises + ------ + ValueError + If data_dir is None and ENV_VAR_NAME is not set in environment, + or if prefer_env_var is True and ENV_VAR_NAME is not set. + """ + if prefer_env_var and self.ENV_VAR_NAME in os.environ: + data_dir = Path(os.environ[self.ENV_VAR_NAME]) + elif data_dir is None: + if not self.ENV_VAR_NAME or self.ENV_VAR_NAME not in os.environ: + raise ValueError(f"Necessary environment variable {self.ENV_VAR_NAME} not set!") + data_dir = Path(os.environ[self.ENV_VAR_NAME]) + + self.data_dir: Path = Path(data_dir) + self.data_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"{self.__class__.__name__} data directory: {self.data_dir}") + + @abstractmethod + def read(self, *args, **kwargs) -> pd.DataFrame | list[pd.DataFrame]: + """Read data. + + Subclasses should implement this method with their specific signature. + Common parameters include: + - start_time: datetime + Start time of the data to read. Must be timezone-aware. + - end_time: datetime + End time of the data to read. Must be timezone-aware. + - download: bool, optional + Download data on the go if not available locally. + - Additional parameters specific to each data source. + + Returns + ------- + pd.DataFrame or list[pd.DataFrame] + Data for the specified parameters. + """ + pass + + @abstractmethod + def download_and_process(self, *args, **kwargs) -> None: + """Download and process data. + + Subclasses should implement this method with their specific signature. + Common parameters include: + - start_time: datetime + Start time of the data to download. Must be timezone-aware. + - end_time: datetime + End time of the data to download. Must be timezone-aware. + - target_date: datetime + Target date for data (for single-day sources). + - request_time: datetime + Request time for data (for streaming sources). + - reprocess_files: bool, optional + If True, re-download and re-process existing files. + - Additional parameters specific to each data source. + + Returns + ------- + None + """ + pass diff --git a/swvo/io/dst/omni.py b/swvo/io/dst/omni.py index 8cc54c4d..4e3b6d4b 100644 --- a/swvo/io/dst/omni.py +++ b/swvo/io/dst/omni.py @@ -12,7 +12,6 @@ import warnings from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Optional import numpy as np import pandas as pd @@ -31,17 +30,6 @@ class DSTOMNI(OMNILowRes): Inherits the `download_and_process`, other private methods and attributes from OMNILowRes. """ - def __init__(self, data_dir: Optional[Path] = None) -> None: - """ - Initialize a DSTOMNI object. - - Parameters - ---------- - data_dir : Path | None - Data directory for the Dst OMNI data. If not provided, it will be read from the environment variable - """ - super().__init__(data_dir=data_dir) - # data is downloaded along with OMNI data, check file name in parent class def read(self, start_time: datetime, end_time: datetime, download: bool = False) -> pd.DataFrame: """ diff --git a/swvo/io/dst/wdc.py b/swvo/io/dst/wdc.py index 6f489a9b..58a3c8d0 100644 --- a/swvo/io/dst/wdc.py +++ b/swvo/io/dst/wdc.py @@ -7,18 +7,18 @@ """ import logging -import os import re import warnings from datetime import datetime, timedelta, timezone from pathlib import Path from shutil import rmtree -from typing import List, Optional, Tuple +from typing import List, Tuple import numpy as np import pandas as pd import requests +from swvo.io.base import BaseIO from swvo.io.utils import enforce_utc_timezone logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ logging.captureWarnings(True) -class DSTWDC: +class DSTWDC(BaseIO): """This is a class for the WDC Dst data. Parameters @@ -50,18 +50,6 @@ class DSTWDC: URL = "https://wdc.kugi.kyoto-u.ac.jp/dst_realtime/YYYYMM/" LABEL = "wdc" - def __init__(self, data_dir: Optional[Path] = None) -> None: - if data_dir is None: - if self.ENV_VAR_NAME not in os.environ: - raise ValueError(f"Necessary environment variable {self.ENV_VAR_NAME} not set!") - - data_dir = os.environ.get(self.ENV_VAR_NAME) # ty: ignore[invalid-assignment] - - self.data_dir: Path = Path(data_dir) # ty:ignore[invalid-argument-type] - self.data_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"WDC Dst data directory: {self.data_dir}") - def download_and_process(self, start_time: datetime, end_time: datetime, reprocess_files: bool = False) -> None: """Download and process WDC Dst data files. diff --git a/swvo/io/f10_7/omni.py b/swvo/io/f10_7/omni.py index 22ff22ca..942e1f53 100644 --- a/swvo/io/f10_7/omni.py +++ b/swvo/io/f10_7/omni.py @@ -10,8 +10,6 @@ import logging from datetime import datetime, timedelta -from pathlib import Path -from typing import Optional import numpy as np import pandas as pd @@ -31,17 +29,6 @@ class F107OMNI(OMNILowRes): Inherits the :func:`download_and_process`, other private methods and attributes from :class:`OMNILowRes`. """ - def __init__(self, data_dir: Optional[Path] = None) -> None: - """ - Initialize a F107OMNI object. - - Parameters - ---------- - data_dir : Path | None - Data directory for the OMNI Kp data. If not provided, it will be read from the environment variable - """ - super().__init__(data_dir=data_dir) - # data is downloaded along with OMNI data, check file name in parent class def read(self, start_time: datetime, end_time: datetime, download: bool = False) -> pd.DataFrame: """ diff --git a/swvo/io/f10_7/swpc.py b/swvo/io/f10_7/swpc.py index 9c5a8962..4428a60f 100644 --- a/swvo/io/f10_7/swpc.py +++ b/swvo/io/f10_7/swpc.py @@ -9,17 +9,16 @@ from __future__ import annotations import logging -import os import shutil import warnings from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Optional import numpy as np import pandas as pd import requests +from swvo.io.base import BaseIO from swvo.io.utils import enforce_utc_timezone logger = logging.getLogger(__name__) @@ -27,14 +26,9 @@ logging.captureWarnings(True) -class F107SWPC: +class F107SWPC(BaseIO): """This is a class for the SWPC F107 data. - Parameters - ---------- - data_dir : Path | None - Data directory for the OMNI Low Resolution data. If not provided, it will be read from the environment variable - Methods ------- download_and_process @@ -52,18 +46,6 @@ class F107SWPC: LABEL = "swpc" - def __init__(self, data_dir: Optional[Path] = None) -> None: - if data_dir is None: - if self.ENV_VAR_NAME not in os.environ: - msg = f"Necessary environment variable {self.ENV_VAR_NAME} not set!" - raise ValueError(msg) - data_dir = os.environ.get(self.ENV_VAR_NAME) # ty: ignore[invalid-assignment] - - self.data_dir: Path = Path(data_dir) # ty:ignore[invalid-argument-type] - self.data_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"SWPC F10.7 data directory: {self.data_dir}") - def _is_within_download_range(self, target_date: datetime) -> bool: """Check if a date is within the last 30 days. diff --git a/swvo/io/hp/ensemble.py b/swvo/io/hp/ensemble.py index 955600c1..2f178f3f 100755 --- a/swvo/io/hp/ensemble.py +++ b/swvo/io/hp/ensemble.py @@ -30,6 +30,9 @@ class HpEnsemble: Hp index Possible options are: hp30, hp60. data_dir : Path | None Data directory for the Hp data. If not provided, it will be read from the environment variable + prefer_env_var : bool, optional + If True, the environment variable takes precedence over the passed data_dir argument. + If False (default), the passed data_dir is used if provided, otherwise the environment variable is used. Methods ------- @@ -43,28 +46,38 @@ class HpEnsemble: Returns `FileNotFoundError` if the data directory does not exist. """ - ENV_VAR_NAME = "PLACEHOLDER; SEE DERIVED CLASSES BELOW" - LABEL = "ensemble" + ENV_VAR_NAME: str = "" # Must be set by subclasses + LABEL: str = "ensemble" - def __init__(self, index: str, data_dir: Optional[Path] = None) -> None: + def __init__(self, index: str, data_dir: Optional[Path] = None, prefer_env_var: bool = False) -> None: + """Initialize HpEnsemble. + + Parameters + ---------- + index : str + Hp index. Possible options are: hp30, hp60. + data_dir : Path | None + Data directory for the Hp data. If not provided, it will be read from the environment variable + prefer_env_var : bool, optional + If True, the environment variable takes precedence over the passed data_dir argument, by default False + """ self.index = index if self.index not in ("hp30", "hp60"): msg = "Encountered invalid index: {self.index}. Possible options are: hp30, hp60!" raise ValueError(msg) + if prefer_env_var and self.ENV_VAR_NAME in os.environ: + data_dir = Path(os.environ[self.ENV_VAR_NAME]) + elif data_dir is None: + if not self.ENV_VAR_NAME or self.ENV_VAR_NAME not in os.environ: + raise ValueError(f"Necessary environment variable {self.ENV_VAR_NAME} not set!") + data_dir = Path(os.environ[self.ENV_VAR_NAME]) - if data_dir is None: - if self.ENV_VAR_NAME not in os.environ: - msg = f"Necessary environment variable {self.ENV_VAR_NAME} not set!" - raise ValueError(msg) - - data_dir = os.environ.get(self.ENV_VAR_NAME) # ty: ignore[invalid-assignment] - - self.data_dir: Path = Path(data_dir) # ty:ignore[invalid-argument-type] + self.data_dir = Path(data_dir) logger.info(f"{self.index.upper()} Ensemble data directory: {self.data_dir}") - if not self.data_dir.exists(): msg = f"Data directory {self.data_dir} does not exist! Impossible to retrive data!" + logger.error(msg) raise FileNotFoundError(msg) self.index_number: int = int(index[2:]) @@ -315,8 +328,8 @@ class Hp30Ensemble(HpEnsemble): ENV_VAR_NAME = "HP30_ENSEMBLE_FORECAST_DIR" - def __init__(self, data_dir: Optional[Path] = None) -> None: - super().__init__("hp30", data_dir) + def __init__(self, data_dir: Optional[Path] = None, prefer_env_var: bool = False) -> None: + super().__init__("hp30", data_dir, prefer_env_var) def read_with_horizon(self, start_time: datetime, end_time: datetime, horizon: float) -> list[pd.DataFrame]: """Read Ensemble Hp30 forecast data for a given time range and forecast horizon. @@ -356,8 +369,8 @@ class Hp60Ensemble(HpEnsemble): ENV_VAR_NAME = "HP60_ENSEMBLE_FORECAST_DIR" - def __init__(self, data_dir: Optional[Path] = None) -> None: - super().__init__("hp60", data_dir) + def __init__(self, data_dir: Optional[Path] = None, prefer_env_var: bool = False) -> None: + super().__init__("hp60", data_dir, prefer_env_var) def read_with_horizon(self, start_time: datetime, end_time: datetime, horizon: int) -> list[pd.DataFrame]: """Read Ensemble Hp60 forecast data for a given time range and forecast horizon. diff --git a/swvo/io/hp/gfz.py b/swvo/io/hp/gfz.py index de861363..9338e997 100755 --- a/swvo/io/hp/gfz.py +++ b/swvo/io/hp/gfz.py @@ -6,7 +6,6 @@ import json import logging -import os from datetime import datetime, timedelta, timezone from pathlib import Path from shutil import rmtree @@ -16,12 +15,13 @@ import pandas as pd import requests +from swvo.io.base import BaseIO from swvo.io.utils import enforce_utc_timezone logger = logging.getLogger(__name__) -class HpGFZ: +class HpGFZ(BaseIO): """This is a base class for HpGFZ data. Parameters @@ -30,6 +30,9 @@ class HpGFZ: Hp index. Possible options are: hp30, hp60. data_dir : Path | None Data directory for the Hp data. If not provided, it will be read from the environment variable + prefer_env_var : bool, optional + If True, the environment variable takes precedence over the passed data_dir argument. + If False (default), the passed data_dir is used if provided, otherwise the environment variable is used. Methods ------- @@ -48,23 +51,27 @@ class HpGFZ: API_URL = "https://kp.gfz.de/app/json/" LABEL = "gfz" - def __init__(self, index: str, data_dir: Optional[Path] = None) -> None: + def __init__(self, index: str, data_dir: Optional[Path] = None, prefer_env_var: bool = False) -> None: + """Initialize HpGFZ. + + Parameters + ---------- + index : str + Hp index. Possible options are: hp30, hp60. + data_dir : Path | None + Data directory for the Hp data. If not provided, it will be read from the environment variable + prefer_env_var : bool, optional + If True, the environment variable takes precedence over the passed data_dir argument, by default False + """ self.index = index if self.index not in ("hp30", "hp60"): msg = f"Encountered invalid index: {self.index}. Possible options are: hp30, hp60!" raise ValueError(msg) - if data_dir is None: - if self.ENV_VAR_NAME not in os.environ: - msg = f"Necessary environment variable {self.ENV_VAR_NAME} not set!" - raise ValueError(msg) - - data_dir = os.environ.get(self.ENV_VAR_NAME) # ty: ignore[invalid-assignment] - - self.data_dir: Path = Path(data_dir) # ty:ignore[invalid-argument-type] - self.data_dir.mkdir(parents=True, exist_ok=True) + super().__init__(data_dir=data_dir, prefer_env_var=prefer_env_var) self.index_number: int = int(index[2:]) + self.data_dir.mkdir(parents=True, exist_ok=True) logger.info(f"{self.index.upper()} GFZ data directory: {self.data_dir}") (self.data_dir / str(self.index)).mkdir(exist_ok=True) @@ -331,11 +338,14 @@ class Hp30GFZ(HpGFZ): ---------- data_dir : str | Path, optional Data directory for the Hp30 data. If not provided, it will be read from the environment variable + prefer_env_var : bool, optional + If True, the environment variable takes precedence over the passed data_dir argument. + If False (default), the passed data_dir is used if provided, otherwise the environment variable is used. """ - def __init__(self, data_dir: Optional[Path] = None) -> None: - super().__init__("hp30", data_dir) + def __init__(self, data_dir: Optional[Path] = None, prefer_env_var: bool = False) -> None: + super().__init__("hp30", data_dir, prefer_env_var=prefer_env_var) class Hp60GFZ(HpGFZ): @@ -345,8 +355,11 @@ class Hp60GFZ(HpGFZ): ---------- data_dir : str | Path, optional Data directory for the Hp30 data. If not provided, it will be read from the environment variable + prefer_env_var : bool, optional + If True, the environment variable takes precedence over the passed data_dir argument. + If False (default), the passed data_dir is used if provided, otherwise the environment variable is used. """ - def __init__(self, data_dir: Optional[Path] = None) -> None: - super().__init__("hp60", data_dir) + def __init__(self, data_dir: Optional[Path] = None, prefer_env_var: bool = False) -> None: + super().__init__("hp60", data_dir, prefer_env_var=prefer_env_var) diff --git a/swvo/io/kp/ensemble.py b/swvo/io/kp/ensemble.py index 2cf78321..4f3aeab0 100755 --- a/swvo/io/kp/ensemble.py +++ b/swvo/io/kp/ensemble.py @@ -31,6 +31,10 @@ class KpEnsemble: data_dir : Path | None Data directory for the Hp data. If not provided, it will be read from the environment variable + prefer_env_var : bool, optional + If True, the environment variable takes precedence over the passed data_dir argument. + If False (default), the passed data_dir is used if provided, otherwise the environment variable is used. + Methods ------- read @@ -46,17 +50,17 @@ class KpEnsemble: ENV_VAR_NAME = "KP_ENSEMBLE_OUTPUT_DIR" LABEL = "ensemble" - def __init__(self, data_dir: Optional[Path] = None) -> None: - if data_dir is None: - if self.ENV_VAR_NAME not in os.environ: + def __init__(self, data_dir: Optional[Path] = None, prefer_env_var: bool = False) -> None: + if prefer_env_var and self.ENV_VAR_NAME in os.environ: + data_dir = Path(os.environ[self.ENV_VAR_NAME]) + elif data_dir is None: + if not self.ENV_VAR_NAME or self.ENV_VAR_NAME not in os.environ: raise ValueError(f"Necessary environment variable {self.ENV_VAR_NAME} not set!") + data_dir = Path(os.environ[self.ENV_VAR_NAME]) - data_dir = os.environ.get(self.ENV_VAR_NAME) # ty: ignore[invalid-assignment] - - self.data_dir: Path = Path(data_dir) # ty:ignore[invalid-argument-type] - - logger.info(f"Kp Ensemble data directory: {self.data_dir}") + self.data_dir = Path(data_dir) + logger.info(f"{self.__class__.__name__} data directory: {self.data_dir}") if not self.data_dir.exists(): msg = f"Data directory {self.data_dir} does not exist! Impossible to retrive data!" logger.error(msg) diff --git a/swvo/io/kp/niemegk.py b/swvo/io/kp/niemegk.py index bad31a83..5dd13db6 100755 --- a/swvo/io/kp/niemegk.py +++ b/swvo/io/kp/niemegk.py @@ -12,12 +12,13 @@ from datetime import datetime, timedelta, timezone from pathlib import Path from shutil import rmtree -from typing import List, Optional, Tuple +from typing import List, Tuple import numpy as np import pandas as pd import requests +from swvo.io.base import BaseIO from swvo.io.utils import enforce_utc_timezone logger = logging.getLogger(__name__) @@ -25,7 +26,7 @@ logging.captureWarnings(True) -class KpNiemegk: +class KpNiemegk(BaseIO): """A class to handle Niemegk Kp data. Parameters @@ -52,18 +53,6 @@ class KpNiemegk: DAYS_TO_SAVE_EACH_FILE = 3 LABEL = "niemegk" - def __init__(self, data_dir: Optional[Path] = None) -> None: - if data_dir is None: - if self.ENV_VAR_NAME not in os.environ: - raise ValueError(f"Necessary environment variable {self.ENV_VAR_NAME} not set!") - - data_dir = os.environ.get(self.ENV_VAR_NAME) # ty: ignore[invalid-assignment] - - self.data_dir: Path = Path(data_dir) # ty:ignore[invalid-argument-type] - self.data_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"Kp Niemegk data directory: {self.data_dir}") - def download_and_process(self, start_time: datetime, end_time: datetime, reprocess_files: bool = False) -> None: """Download and process Niemegk Kp data file. diff --git a/swvo/io/kp/omni.py b/swvo/io/kp/omni.py index 937a1ea4..6bfa827e 100755 --- a/swvo/io/kp/omni.py +++ b/swvo/io/kp/omni.py @@ -10,8 +10,6 @@ import logging from datetime import datetime, timedelta -from pathlib import Path -from typing import Optional import pandas as pd @@ -30,17 +28,6 @@ class KpOMNI(OMNILowRes): Inherits the :func:`download_and_process`, other private methods and attributes from :class:`OMNILowRes`. """ - def __init__(self, data_dir: Optional[Path] = None) -> None: - """ - Initialize a KpOMNI object. - - Parameters - ---------- - data_dir : Path | None - Data directory for the OMNI Kp data. If not provided, it will be read from the environment variable - """ - super().__init__(data_dir=data_dir) - def read(self, start_time: datetime, end_time: datetime, download: bool = False) -> pd.DataFrame: """ Extract Kp data from OMNI Low Resolution files. diff --git a/swvo/io/kp/swpc.py b/swvo/io/kp/swpc.py index faa9d6f6..85a8cd23 100755 --- a/swvo/io/kp/swpc.py +++ b/swvo/io/kp/swpc.py @@ -7,7 +7,6 @@ """ import logging -import os import re import warnings from datetime import datetime, timedelta, timezone @@ -19,6 +18,7 @@ import pandas as pd import requests +from swvo.io.base import BaseIO from swvo.io.utils import enforce_utc_timezone logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ logging.captureWarnings(True) -class KpSWPC: +class KpSWPC(BaseIO): """ A class for handling SWPC Kp data. In SWPC data, the file for current day always contains the forecast for the next 3 days. Keep this in mind when using the `read` and `download_and_process` methods. @@ -54,18 +54,6 @@ class KpSWPC: LABEL = "swpc" - def __init__(self, data_dir: Optional[Path] = None) -> None: - if data_dir is None: - if self.ENV_VAR_NAME not in os.environ: - raise ValueError(f"Necessary environment variable {self.ENV_VAR_NAME} not set!") - - data_dir = os.environ.get(self.ENV_VAR_NAME) # ty: ignore[invalid-assignment] - - self.data_dir: Path = Path(data_dir) # ty:ignore[invalid-argument-type] - self.data_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"Kp SWPC data directory: {self.data_dir}") - def download_and_process(self, target_date: datetime, reprocess_files: bool = False) -> None: """ Download and process SWPC Kp data file. diff --git a/swvo/io/omni/omni_high_res.py b/swvo/io/omni/omni_high_res.py index f0444ad9..5b34ca13 100644 --- a/swvo/io/omni/omni_high_res.py +++ b/swvo/io/omni/omni_high_res.py @@ -8,21 +8,20 @@ import calendar import logging -import os import re from datetime import datetime, timedelta, timezone -from pathlib import Path from typing import List, Optional, Tuple import pandas as pd import requests +from swvo.io.base import BaseIO from swvo.io.utils import enforce_utc_timezone logger = logging.getLogger(__name__) -class OMNIHighRes: +class OMNIHighRes(BaseIO): """This is a class for the OMNI High Resolution data. Parameters @@ -48,18 +47,6 @@ class OMNIHighRes: START_YEAR = 1981 LABEL = "omni" - def __init__(self, data_dir: Optional[Path] = None) -> None: - if data_dir is None: - if self.ENV_VAR_NAME not in os.environ: - raise ValueError(f"Necessary environment variable {self.ENV_VAR_NAME} not set!") - - data_dir = os.environ.get(self.ENV_VAR_NAME) # ty: ignore[invalid-assignment] - - self.data_dir: Path = Path(data_dir) # ty:ignore[invalid-argument-type] - self.data_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"OMNI high resolution data directory: {self.data_dir}") - def download_and_process( self, start_time: datetime, diff --git a/swvo/io/omni/omni_low_res.py b/swvo/io/omni/omni_low_res.py index 8f013850..36d4b342 100755 --- a/swvo/io/omni/omni_low_res.py +++ b/swvo/io/omni/omni_low_res.py @@ -7,17 +7,17 @@ """ import logging -import os import warnings from datetime import datetime, timedelta, timezone from pathlib import Path from shutil import rmtree -from typing import List, Optional, Tuple +from typing import List, Tuple import numpy as np import pandas as pd import requests +from swvo.io.base import BaseIO from swvo.io.utils import enforce_utc_timezone logger = logging.getLogger(__name__) @@ -25,14 +25,9 @@ logging.captureWarnings(True) -class OMNILowRes: +class OMNILowRes(BaseIO): """This is a class for the OMNI Low Resolution data. - Parameters - ---------- - data_dir : Path | None - Data directory for the OMNI Low Resolution data. If not provided, it will be read from the environment variable - Methods ------- download_and_process @@ -107,18 +102,6 @@ class OMNILowRes: "magnetosonic_mach_n", ] - def __init__(self, data_dir: Optional[Path] = None) -> None: - if data_dir is None: - if self.ENV_VAR_NAME not in os.environ: - raise ValueError(f"Necessary environment variable {self.ENV_VAR_NAME} not set!") - - data_dir = os.environ.get(self.ENV_VAR_NAME) # ty: ignore[invalid-assignment] - - self.data_dir: Path = Path(data_dir) # ty:ignore[invalid-argument-type] - self.data_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"OMNI Low Res data directory: {self.data_dir}") - def download_and_process(self, start_time: datetime, end_time: datetime, reprocess_files: bool = False) -> None: """Download and process OMNI Low Resolution data files. diff --git a/swvo/io/sme/supermag.py b/swvo/io/sme/supermag.py index 4f18c4f6..30ace391 100644 --- a/swvo/io/sme/supermag.py +++ b/swvo/io/sme/supermag.py @@ -9,18 +9,18 @@ import json import logging -import os import re import warnings from datetime import datetime, timedelta, timezone from pathlib import Path from shutil import rmtree -from typing import List, Optional, Tuple +from typing import List, Tuple import numpy as np import pandas as pd import requests +from swvo.io.base import BaseIO from swvo.io.utils import enforce_utc_timezone logger = logging.getLogger(__name__) @@ -28,7 +28,7 @@ logging.captureWarnings(True) -class SMESuperMAG: +class SMESuperMAG(BaseIO): """Class for SuperMAG SME data. Parameters @@ -50,21 +50,12 @@ class SMESuperMAG: """ ENV_VAR_NAME = "SUPERMAG_STREAM_DIR" + LABEL = "supermag" - def __init__(self, username: str, data_dir: Optional[Path] = None) -> None: + def __init__(self, username: str, data_dir: Path | None = None) -> None: + super().__init__(data_dir) self.username = username - if data_dir is None: - if self.ENV_VAR_NAME not in os.environ: - msg = f"Necessary environment variable {self.ENV_VAR_NAME} not set!" - raise ValueError(msg) - data_dir = os.environ.get(self.ENV_VAR_NAME) # ty: ignore[invalid-assignment] - - self.data_dir: Path = Path(data_dir) # ty:ignore[invalid-argument-type] - self.data_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"SuperMAG SME data directory: {self.data_dir}") - def download_and_process(self, start_time: datetime, end_time: datetime, reprocess_files: bool = False) -> None: """Download and process SuperMAG SME data files. diff --git a/swvo/io/solar_wind/ace.py b/swvo/io/solar_wind/ace.py index 77c74b4a..08b902d5 100644 --- a/swvo/io/solar_wind/ace.py +++ b/swvo/io/solar_wind/ace.py @@ -7,17 +7,17 @@ """ import logging -import os import warnings from datetime import datetime, timedelta, timezone from pathlib import Path from shutil import rmtree -from typing import List, Optional, Tuple +from typing import List, Tuple import numpy as np import pandas as pd import requests +from swvo.io.base import BaseIO from swvo.io.utils import enforce_utc_timezone, sw_mag_propagation logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ logging.captureWarnings(True) -class SWACE: +class SWACE(BaseIO): """This is a class for the ACE Solar Wind data. Parameters @@ -57,18 +57,6 @@ class SWACE: LABEL = "ace" - def __init__(self, data_dir: Optional[Path] = None) -> None: - if data_dir is None: - if self.ENV_VAR_NAME not in os.environ: - raise ValueError(f"Necessary environment variable {self.ENV_VAR_NAME} not set!") - - data_dir = os.environ.get(self.ENV_VAR_NAME) # ty: ignore[invalid-assignment] - - self.data_dir: Path = Path(data_dir) # ty:ignore[invalid-argument-type] - self.data_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"ACE data directory: {self.data_dir}") - def download_and_process(self, request_time: datetime) -> None: """ Download and process ACE data, splitting data across midnight into appropriate day files. diff --git a/swvo/io/solar_wind/dscovr.py b/swvo/io/solar_wind/dscovr.py index 4a611fbd..88bd8fad 100644 --- a/swvo/io/solar_wind/dscovr.py +++ b/swvo/io/solar_wind/dscovr.py @@ -7,17 +7,17 @@ """ import logging -import os import warnings from datetime import datetime, timedelta, timezone from pathlib import Path from shutil import rmtree -from typing import List, Optional, Tuple +from typing import List, Tuple import numpy as np import pandas as pd import requests +from swvo.io.base import BaseIO from swvo.io.utils import enforce_utc_timezone, sw_mag_propagation logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ logging.captureWarnings(True) -class DSCOVR: +class DSCOVR(BaseIO): """This is a class for the DSCOVR Solar Wind data. Parameters @@ -55,18 +55,6 @@ class DSCOVR: LABEL = "dscovr" - def __init__(self, data_dir: Optional[Path] = None) -> None: - if data_dir is None: - if self.ENV_VAR_NAME not in os.environ: - raise ValueError(f"Necessary environment variable {self.ENV_VAR_NAME} not set!") - - data_dir = os.environ.get(self.ENV_VAR_NAME) # ty: ignore[invalid-assignment] - - self.data_dir: Path = Path(data_dir) # ty:ignore[invalid-argument-type] - self.data_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"DSCOVR data directory: {self.data_dir}") - def download_and_process(self, request_time: datetime) -> None: """ Download and process DSCOVR data, splitting data across midnight into appropriate day files. diff --git a/swvo/io/solar_wind/omni.py b/swvo/io/solar_wind/omni.py index cd9cf066..a2dc34f7 100644 --- a/swvo/io/solar_wind/omni.py +++ b/swvo/io/solar_wind/omni.py @@ -6,9 +6,6 @@ Module handling SW data from OMNI High Resolution files. """ -from pathlib import Path -from typing import Optional - from swvo.io.omni import OMNIHighRes @@ -18,13 +15,4 @@ class SWOMNI(OMNIHighRes): Inherits the :func:`download_and_process`, other private methods and attributes from :class:`OMNIHighRes`. """ - def __init__(self, data_dir: Optional[Path] = None) -> None: - """ - Initialize a SWOMNI object. - - Parameters - ---------- - data_dir : Path | None - Data directory for the OMNI SW data. If not provided, it will be read from the environment variable - """ - super().__init__(data_dir=data_dir) + pass diff --git a/swvo/io/solar_wind/swift.py b/swvo/io/solar_wind/swift.py index 9efe7095..ccdbbd45 100644 --- a/swvo/io/solar_wind/swift.py +++ b/swvo/io/solar_wind/swift.py @@ -33,6 +33,10 @@ class SWSWIFTEnsemble: ---------- data_dir : Path | None Data directory for the SWIFT Ensemble data. If not provided, it will be read from the environment variable + prefer_env_var : bool, optional + If True, the environment variable takes precedence over the passed data_dir argument. + If False (default), the passed data_dir is used if provided, otherwise the environment variable is used. + Methods ------- @@ -51,19 +55,21 @@ class SWSWIFTEnsemble: ENV_VAR_NAME = "SWIFT_ENSEMBLE_OUTPUT_DIR" LABEL = "swift" - def __init__(self, data_dir: Optional[Path] = None) -> None: - if data_dir is None: - if self.ENV_VAR_NAME not in os.environ: - raise ValueError(f"Necessary environment variable {self.ENV_VAR_NAME} not set!") - - data_dir = os.environ.get(self.ENV_VAR_NAME) # ty: ignore[invalid-assignment] + def __init__(self, data_dir: Optional[Path] = None, prefer_env_var: bool = False) -> None: - self.data_dir: Path = Path(data_dir) # ty:ignore[invalid-argument-type] + if prefer_env_var and self.ENV_VAR_NAME in os.environ: + data_dir = Path(os.environ[self.ENV_VAR_NAME]) + elif data_dir is None: + if not self.ENV_VAR_NAME or self.ENV_VAR_NAME not in os.environ: + raise ValueError(f"Necessary environment variable {self.ENV_VAR_NAME} not set!") + data_dir = Path(os.environ[self.ENV_VAR_NAME]) + self.data_dir = Path(data_dir) logger.info(f"SWIFT ensemble data directory: {self.data_dir}") - if not self.data_dir.exists(): - raise FileNotFoundError(f"Data directory {self.data_dir} does not exist! Impossible to retrieve data!") + msg = f"Data directory {self.data_dir} does not exist! Impossible to retrieve data!" + logger.error(msg) + raise FileNotFoundError(msg) def read( self, diff --git a/swvo/io/symh/omni.py b/swvo/io/symh/omni.py index e321784b..4dc2f607 100644 --- a/swvo/io/symh/omni.py +++ b/swvo/io/symh/omni.py @@ -10,8 +10,6 @@ import logging from datetime import datetime, timedelta -from pathlib import Path -from typing import Optional import pandas as pd @@ -29,17 +27,6 @@ class SymhOMNI(OMNIHighRes): Inherits the `download_and_process`, other private methods and attributes from OMNIHighRes. """ - def __init__(self, data_dir: Optional[Path] = None) -> None: - """ - Initialize a SymhOMNI object. - - Parameters - ---------- - data_dir : Path | None - Data directory for the SYM-H OMNI data. If not provided, it will be read from the environment variable - """ - super().__init__(data_dir=data_dir) - def read( self, start_time: datetime, From 8b2c0f3f3a9a2e37e48c73bd2838684afc16d325 Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Wed, 25 Mar 2026 10:36:08 +0100 Subject: [PATCH 2/3] test: add test for new `BaseIO` --- tests/io/test_base.py | 232 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 232 insertions(+) create mode 100644 tests/io/test_base.py diff --git a/tests/io/test_base.py b/tests/io/test_base.py new file mode 100644 index 00000000..3ba52389 --- /dev/null +++ b/tests/io/test_base.py @@ -0,0 +1,232 @@ +# SPDX-FileCopyrightText: 2026 GFZ Helmholtz Centre for Geosciences +# SPDX-FileContributor: Sahil Jhawar +# +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +from pathlib import Path +from unittest.mock import patch + +import pandas as pd +import pytest + +from swvo.io.base import BaseIO + + +class TestBaseIO(BaseIO): + """Implementation of BaseIO for testing.""" + + ENV_VAR_NAME = "TEST_DATA_DIR" + LABEL = "test_label" + + def read(self, *args, **kwargs): + """Implementation of read method.""" + return pd.DataFrame({"data": [1, 2, 3]}) + + def download_and_process(self, *args, **kwargs): + """Implementation of download_and_process method.""" + pass + + +class TestBaseIOInitialization: + """Test BaseIO initialization with various configurations.""" + + def test_initialization_with_data_dir(self, tmp_path): + """Test initialization with explicitly provided data_dir.""" + io = TestBaseIO(data_dir=tmp_path) + assert io.data_dir == tmp_path + assert io.data_dir.exists() + + def test_initialization_creates_directory(self, tmp_path): + """Test that initialization creates data directory if it doesn't exist.""" + data_dir = tmp_path / "new_dir" / "nested" + assert not data_dir.exists() + + io = TestBaseIO(data_dir=data_dir) + + assert data_dir.exists() + assert io.data_dir == data_dir + + def test_initialization_with_env_var(self, tmp_path): + """Test initialization using environment variable.""" + env_dir = tmp_path / "env_data" + with patch.dict(os.environ, {"TEST_DATA_DIR": str(env_dir)}): + io = TestBaseIO() + assert io.data_dir == env_dir + assert env_dir.exists() + + def test_initialization_prefer_env_var_true(self, tmp_path): + """Test that prefer_env_var=True prioritizes environment variable.""" + env_dir = tmp_path / "env_data" + passed_dir = tmp_path / "passed_data" + + with patch.dict(os.environ, {"TEST_DATA_DIR": str(env_dir)}): + io = TestBaseIO(data_dir=passed_dir, prefer_env_var=True) + assert io.data_dir == env_dir + + def test_initialization_prefer_env_var_false(self, tmp_path): + """Test that prefer_env_var=False uses passed data_dir when available.""" + env_dir = tmp_path / "env_data" + passed_dir = tmp_path / "passed_data" + + with patch.dict(os.environ, {"TEST_DATA_DIR": str(env_dir)}): + io = TestBaseIO(data_dir=passed_dir, prefer_env_var=False) + assert io.data_dir == passed_dir + + def test_initialization_no_data_dir_no_env_var(self): + """Test that ValueError is raised when no data_dir and no env_var.""" + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError, match="Necessary environment variable TEST_DATA_DIR not set!"): + TestBaseIO() + + def test_initialization_no_data_dir_missing_env_var(self): + """Test ValueError when data_dir is None and ENV_VAR_NAME not in os.environ.""" + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError, match="Necessary environment variable TEST_DATA_DIR not set!"): + TestBaseIO(data_dir=None) + + def test_initialization_prefer_env_var_without_env_var_set(self, tmp_path): + """Test that passed data_dir is used when prefer_env_var=True but env var not set. + + When prefer_env_var=True and env var is not set, the fallback is to use + the passed data_dir (doesn't raise error if data_dir is provided). + """ + passed_dir = tmp_path / "passed_data" + + with patch.dict(os.environ, {}, clear=True): + io = TestBaseIO(data_dir=passed_dir, prefer_env_var=True) + assert io.data_dir == passed_dir + + def test_initialization_logs_message(self, tmp_path, caplog): + """Test that initialization logs the data directory.""" + with caplog.at_level(logging.INFO): + TestBaseIO(data_dir=tmp_path) + + assert "TestBaseIO data directory:" in caplog.text + assert str(tmp_path) in caplog.text + + +class TestBaseIOAbstractMethods: + """Test abstract method behavior.""" + + def test_baseio_cannot_be_instantiated_directly(self): + """Test that BaseIO cannot be instantiated directly.""" + with pytest.raises(TypeError): + BaseIO() + + def test_baseio_implementation_must_implement_read(self): + """Test that subclass must implement read method.""" + + class IncompleteIO(BaseIO): + ENV_VAR_NAME = "TEST_DATA_DIR" + LABEL = "test" + + def download_and_process(self, *args, **kwargs): + pass + + with pytest.raises(TypeError): + IncompleteIO(data_dir=Path("/tmp")) + + def test_baseio_implementation_must_implement_download_and_process(self): + """Test that subclass must implement download_and_process method.""" + + class IncompleteIO(BaseIO): + ENV_VAR_NAME = "TEST_DATA_DIR" + LABEL = "test" + + def read(self, *args, **kwargs): + return pd.DataFrame() + + with pytest.raises(TypeError): + IncompleteIO(data_dir=Path("/tmp")) + + def test_read_method_can_be_called(self, tmp_path): + """Test that read method can be called with flexible signature.""" + io = TestBaseIO(data_dir=tmp_path) + result = io.read() + assert isinstance(result, pd.DataFrame) + + def test_read_method_with_args_kwargs(self, tmp_path): + """Test that read method accepts arbitrary args and kwargs.""" + io = TestBaseIO(data_dir=tmp_path) + # Should not raise any exception + result = io.read(1, 2, 3, key1="value1", key2="value2") + assert isinstance(result, pd.DataFrame) + + def test_download_and_process_method_can_be_called(self, tmp_path): + """Test that download_and_process method can be called with flexible signature.""" + io = TestBaseIO(data_dir=tmp_path) + # Should not raise any exception + io.download_and_process() + + def test_download_and_process_method_with_args_kwargs(self, tmp_path): + """Test that download_and_process accepts arbitrary args and kwargs.""" + io = TestBaseIO(data_dir=tmp_path) + # Should not raise any exception + io.download_and_process(1, 2, 3, key1="value1", key2="value2") + + +class TestBaseIOAttributes: + """Test BaseIO class and instance attributes.""" + + def test_env_var_name_attribute(self): + """Test that ENV_VAR_NAME class attribute is set correctly.""" + assert TestBaseIO.ENV_VAR_NAME == "TEST_DATA_DIR" + + def test_label_attribute(self): + """Test that LABEL class attribute is set correctly.""" + assert TestBaseIO.LABEL == "test_label" + + def test_data_dir_instance_attribute(self, tmp_path): + """Test that data_dir is an instance attribute.""" + io = TestBaseIO(data_dir=tmp_path) + assert hasattr(io, "data_dir") + assert isinstance(io.data_dir, Path) + + def test_data_dir_is_path_instance(self, tmp_path): + """Test that data_dir is converted to Path instance.""" + io = TestBaseIO(data_dir=str(tmp_path)) + assert isinstance(io.data_dir, Path) + + def test_multiple_instances_have_independent_data_dirs(self, tmp_path): + """Test that multiple instances can have different data_dirs.""" + dir1 = tmp_path / "dir1" + dir2 = tmp_path / "dir2" + + io1 = TestBaseIO(data_dir=dir1) + io2 = TestBaseIO(data_dir=dir2) + + assert io1.data_dir == dir1 + assert io2.data_dir == dir2 + assert io1.data_dir != io2.data_dir + + +class TestBaseIOEnvironmentVariablePrecedence: + """Test environment variable precedence logic.""" + + def test_env_var_not_set_uses_passed_dir(self, tmp_path): + """Test that passed data_dir is used when env var is not set.""" + passed_dir = tmp_path / "passed" + + with patch.dict(os.environ, {}, clear=True): + io = TestBaseIO(data_dir=passed_dir) + assert io.data_dir == passed_dir + + def test_env_var_set_and_data_dir_passed_uses_data_dir(self, tmp_path): + """Test that passed data_dir takes precedence when prefer_env_var is False.""" + env_dir = tmp_path / "env" + passed_dir = tmp_path / "passed" + + with patch.dict(os.environ, {"TEST_DATA_DIR": str(env_dir)}): + io = TestBaseIO(data_dir=passed_dir, prefer_env_var=False) + assert io.data_dir == passed_dir + + def test_prefer_env_var_true_overrides_data_dir(self, tmp_path): + """Test that env var takes precedence when prefer_env_var is True.""" + env_dir = tmp_path / "env" + passed_dir = tmp_path / "passed" + + with patch.dict(os.environ, {"TEST_DATA_DIR": str(env_dir)}): + io = TestBaseIO(data_dir=passed_dir, prefer_env_var=True) + assert io.data_dir == env_dir From 60bdc77bb09e1c71c00baf2bad144834d8fb5a91 Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Wed, 25 Mar 2026 11:35:16 +0100 Subject: [PATCH 3/3] chore: fix typos --- swvo/io/base.py | 11 ++++++----- swvo/io/hp/ensemble.py | 4 ++-- swvo/io/kp/ensemble.py | 2 +- swvo/io/sme/supermag.py | 7 +++++-- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/swvo/io/base.py b/swvo/io/base.py index 1fdffb6b..3a611b43 100644 --- a/swvo/io/base.py +++ b/swvo/io/base.py @@ -50,15 +50,16 @@ def __init__(self, data_dir: Optional[Path] = None, prefer_env_var: bool = False ---------- data_dir : Path | None Data directory for storing data. If not provided, it will be read - from the environment variable defined by ENV_VAR_NAME. + from the environment variable defined by `ENV_VAR_NAME`. prefer_env_var : bool, optional - If True, the environment variable takes precedence over the passed data_dir argument. - If False (default), the passed data_dir is used if provided, otherwise the environment variable is used. + If True, the environment variable takes precedence over the passed `data_dir` argument. + If False (default), the passed `data_dir` is used if provided, otherwise the environment variable is used. + Raises ------ ValueError - If data_dir is None and ENV_VAR_NAME is not set in environment, - or if prefer_env_var is True and ENV_VAR_NAME is not set. + If `data_dir` is None and `ENV_VAR_NAME` is not set in environment, + or if `prefer_env_var` is True and `ENV_VAR_NAME` is not set. """ if prefer_env_var and self.ENV_VAR_NAME in os.environ: data_dir = Path(os.environ[self.ENV_VAR_NAME]) diff --git a/swvo/io/hp/ensemble.py b/swvo/io/hp/ensemble.py index 2f178f3f..7643fa50 100755 --- a/swvo/io/hp/ensemble.py +++ b/swvo/io/hp/ensemble.py @@ -63,7 +63,7 @@ def __init__(self, index: str, data_dir: Optional[Path] = None, prefer_env_var: """ self.index = index if self.index not in ("hp30", "hp60"): - msg = "Encountered invalid index: {self.index}. Possible options are: hp30, hp60!" + msg = f"Encountered invalid index: {self.index}. Possible options are: hp30, hp60!" raise ValueError(msg) if prefer_env_var and self.ENV_VAR_NAME in os.environ: data_dir = Path(os.environ[self.ENV_VAR_NAME]) @@ -76,7 +76,7 @@ def __init__(self, index: str, data_dir: Optional[Path] = None, prefer_env_var: logger.info(f"{self.index.upper()} Ensemble data directory: {self.data_dir}") if not self.data_dir.exists(): - msg = f"Data directory {self.data_dir} does not exist! Impossible to retrive data!" + msg = f"Data directory {self.data_dir} does not exist! Impossible to retrieve data!" logger.error(msg) raise FileNotFoundError(msg) diff --git a/swvo/io/kp/ensemble.py b/swvo/io/kp/ensemble.py index 4f3aeab0..a42370ac 100755 --- a/swvo/io/kp/ensemble.py +++ b/swvo/io/kp/ensemble.py @@ -62,7 +62,7 @@ def __init__(self, data_dir: Optional[Path] = None, prefer_env_var: bool = False logger.info(f"{self.__class__.__name__} data directory: {self.data_dir}") if not self.data_dir.exists(): - msg = f"Data directory {self.data_dir} does not exist! Impossible to retrive data!" + msg = f"Data directory {self.data_dir} does not exist! Impossible to retrieve data!" logger.error(msg) raise FileNotFoundError(msg) diff --git a/swvo/io/sme/supermag.py b/swvo/io/sme/supermag.py index 30ace391..3a6a41d4 100644 --- a/swvo/io/sme/supermag.py +++ b/swvo/io/sme/supermag.py @@ -37,6 +37,9 @@ class SMESuperMAG(BaseIO): SuperMAG username used for authenticated data access (register at the SuperMAG website to obtain one) data_dir : Path | None Data directory for the SuperMAG SME data. If not provided, it will be read from the environment variable + prefer_env_var : bool, optional + If True, the environment variable takes precedence over the passed data_dir argument. + If False (default), the passed data_dir is used if provided, otherwise the environment variable is used. Methods ------- @@ -52,8 +55,8 @@ class SMESuperMAG(BaseIO): ENV_VAR_NAME = "SUPERMAG_STREAM_DIR" LABEL = "supermag" - def __init__(self, username: str, data_dir: Path | None = None) -> None: - super().__init__(data_dir) + def __init__(self, username: str, data_dir: Path | None = None, prefer_env_var: bool = False) -> None: + super().__init__(data_dir, prefer_env_var) self.username = username def download_and_process(self, start_time: datetime, end_time: datetime, reprocess_files: bool = False) -> None: