Skip to content
Merged
239 changes: 238 additions & 1 deletion swvo/io/plasmasphere/read_plasmasphere.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,106 @@
# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences
# SPDX-FileContributor: Stefano Bianco
# SPDX-FileContributor: Sahil Jhawar
#
# SPDX-License-Identifier: Apache-2.0

import logging
import os
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd

logger = logging.getLogger(__name__)


@dataclass(frozen=True, repr=False)
class PlasmasphereDensityCube:
"""A structured container for plasmaspheric electron density data.

Attributes
-------
time : np.ndarray[datetime]
Array of Python datetime values.
l : np.ndarray
Array of L-values.
mlt : np.ndarray
Array of MLT-values.
l_grid : np.ndarray
Grid of L-values. L x MLT shape.
mlt_grid : np.ndarray
Grid of MLT-values. L x MLT shape.
density_grid : np.ndarray or list[np.ndarray]
Grid of electron density values. time x L x MLT shape if `density_column` is a single column, otherwise a list of such arrays (one per density column).
density_column : str or list[str]
Name(s) of the column(s) containing electron density data.
"""

time: np.ndarray[datetime] # ty: ignore[invalid-type-arguments]
l: np.ndarray # noqa: E741
mlt: np.ndarray
l_grid: np.ndarray
mlt_grid: np.ndarray
density_grid: np.ndarray | list[np.ndarray]
density_column: str | list[str]

def __repr__(self) -> str:
return self.__str__()

def __str__(self) -> str:
"""Readable summary for logging and printing."""
num_times = len(self.time)

l_range = f"[{self.l.min():.2f}, {self.l.max():.2f}]"
mlt_range = f"[{self.mlt.min():.2f}, {self.mlt.max():.2f}]"

summary = [
"--- Plasmasphere Density Cube ---",
f"Temporal Span : {num_times} steps ({self.time[0]} to {self.time[-1]})",
f"Spatial L-Bins: {len(self.l)} {l_range}",
f"Spatial MLT-Bins: {len(self.mlt)} {mlt_range}",
f"Density Grid Geometry per Time Step : {self.density_grid[0].shape if isinstance(self.density_grid, list) else self.density_grid.shape} (Time x L x MLT)",
Comment thread
sahiljhawar marked this conversation as resolved.
f"Data Columns : {self.density_column}",
"----------------------------------",
]
return "\n".join(summary)

def __eq__(self, other: object) -> bool:
if not isinstance(other, PlasmasphereDensityCube):
return NotImplemented

if not np.array_equal(self.time, other.time):
return False
if not np.array_equal(self.l, other.l):
return False
if not np.array_equal(self.mlt, other.mlt):
return False
if not np.array_equal(self.l_grid, other.l_grid):
return False
if not np.array_equal(self.mlt_grid, other.mlt_grid):
return False

if isinstance(self.density_grid, np.ndarray) and isinstance(other.density_grid, np.ndarray):
if not np.array_equal(self.density_grid, other.density_grid):
return False
elif isinstance(self.density_grid, list) and isinstance(other.density_grid, list):
if len(self.density_grid) != len(other.density_grid):
return False
for grid_self, grid_other in zip(self.density_grid, other.density_grid):
if not np.array_equal(grid_self, grid_other):
return False
else:
return False

if self.density_column != other.density_column:
return False

return True


class PlasmaspherePredictionReader:
"""Reads one of the available PAGER plasmasphere density prediction.

Expand Down Expand Up @@ -73,7 +161,7 @@ def read(self, requested_date: datetime | None = None) -> pd.DataFrame | None:

requested_date = requested_date.replace(minute=0, second=0, microsecond=0)

file_name = f"plasmasphere_density_{requested_date.year}{str(requested_date.month).zfill(2)}{str(requested_date.day).zfill(2)}T{str(requested_date.hour).zfill(2)}00.csv"
file_name = f"plasmasphere_density_{requested_date.strftime('%Y%m%dT%H00')}.csv"

file_path = os.path.join(self.data_dir, file_name)
logger.info(f"Looking for file {file_path} for date {requested_date}")
Expand All @@ -86,3 +174,152 @@ def read(self, requested_date: datetime | None = None) -> pd.DataFrame | None:
data["t"] = data["date"]
data.drop(labels=["date"], axis=1, inplace=True)
return data

def _validate_data(self, data: pd.DataFrame) -> None:
if not isinstance(data, pd.DataFrame):
msg = f"data must be an instance of a pandas dataframe, instead it is of type {type(data)}"
logger.error(msg)
raise TypeError(msg)

required_columns = ["L", "MLT", "t"]
for column in required_columns:
if column not in data.columns:
msg = f"column {column} is missing"
logger.error(msg)
raise ValueError(msg)

if data.empty:
msg = "data dataframe is empty"
logger.error(msg)
raise ValueError(msg)

if not pd.api.types.is_datetime64_any_dtype(data["t"]):
msg = "values of date column must be datetime objects"
logger.error(msg)
raise TypeError(msg)

def _get_density_columns(self, data: pd.DataFrame) -> list[str]:
density_columns = [column for column in data.columns if "predicted_densities" in column]
if not density_columns:
msg = "no columns matching 'predicted_densities' were found"
logger.error(msg)
raise ValueError(msg)
return density_columns

def _resolve_density_column(self, data: pd.DataFrame, density_column: str | None) -> str:
density_columns = self._get_density_columns(data)
if density_column is None:
return density_columns[0]
if density_column not in density_columns:
msg = f"density_column '{density_column}' is not valid. Available columns: {density_columns}"
logger.error(msg)
raise ValueError(msg)
return density_column

def _legacy_reshape_2d(self, df_date: pd.DataFrame, density_column: str) -> tuple:
l_values = df_date["L"].to_numpy()
mlt_values = df_date["MLT"].to_numpy()
density_values = df_date[density_column].to_numpy(dtype=float)

l_axis = np.unique(l_values)
mlt_axis = np.unique(mlt_values)

expected_points = len(l_axis) * len(mlt_axis)
if len(df_date) != expected_points:
msg = "data for a single timestamp does not form a complete L-MLT grid. Expected n_L * n_MLT rows."
logger.error(msg)
raise ValueError(msg)

l_grid = np.reshape(l_values, (len(l_axis), len(mlt_axis)), order="F")
mlt_grid = np.reshape(mlt_values, (len(l_axis), len(mlt_axis)), order="F")
density_2d = np.reshape(density_values, (len(l_axis), len(mlt_axis)), order="F")

return l_axis, mlt_axis, l_grid, mlt_grid, density_2d

def build_density_cube(
self,
requested_date: datetime | None = None,
density_column: str | None = None,
) -> Optional[PlasmasphereDensityCube]:
Comment thread
sahiljhawar marked this conversation as resolved.
"""
Build density tensor with shape time x L x MLT.

Parameters
----------
requested_date : datetime.datetime or None
Date of plasma density prediction thar we want to read up to hour precision.
Comment thread
sahiljhawar marked this conversation as resolved.
Outdated

Returns
-------
PlasmasphereDensityCube or None
If `density_column` is provided, `density_grid` has shape
(n_time, n_L, n_MLT). If `density_column` is None, `density_grid`
is a list of arrays with that same shape (one per density column).

If no data is available for the requested date, returns None.
"""
data = self.read(requested_date=requested_date)
if data is None:
msg = f"No data available for the requested date {requested_date}"
logger.error(msg)
Comment thread
sahiljhawar marked this conversation as resolved.
Outdated
return None
self._validate_data(data)

if density_column is None:
resolved_density_columns = self._get_density_columns(data)
else:
resolved_density_columns = [self._resolve_density_column(data, density_column)]

dates = pd.to_datetime(data["t"].unique())
dates_to_return = dates.to_pydatetime()
dates = np.sort(dates)
Comment thread
sahiljhawar marked this conversation as resolved.
Outdated
density_slices_by_column = {column: [] for column in resolved_density_columns}

l_axis_ref = None
mlt_axis_ref = None
l_grid_ref = None
mlt_grid_ref = None

for date in dates:
df_date = data[pd.to_datetime(data["t"]) == date]
for column in resolved_density_columns:
Comment thread
sahiljhawar marked this conversation as resolved.
l_axis, mlt_axis, l_grid, mlt_grid, density_2d = self._legacy_reshape_2d(df_date, column)

if l_axis_ref is None:
l_axis_ref = l_axis
mlt_axis_ref = mlt_axis
l_grid_ref = l_grid
mlt_grid_ref = mlt_grid
else:
assert mlt_axis_ref is not None
if not np.array_equal(l_axis_ref, l_axis) or not np.array_equal(mlt_axis_ref, mlt_axis):
msg = "Inconsistent L/MLT axes across timestamps."
logger.error(msg)
raise ValueError(msg)

density_slices_by_column[column].append(density_2d)

if l_axis_ref is None or mlt_axis_ref is None or l_grid_ref is None or mlt_grid_ref is None:
msg = "Unable to build density cube axes from input data."
logger.error(msg)
raise RuntimeError(msg)

if len(resolved_density_columns) == 1:
resolved_density_column: str | list[str] = resolved_density_columns[0]
density_grid: np.ndarray | list[np.ndarray] = np.stack(
density_slices_by_column[resolved_density_columns[0]],
axis=0,
)
else:
resolved_density_column = resolved_density_columns
density_grid = [np.stack(density_slices_by_column[column], axis=0) for column in resolved_density_columns]

return PlasmasphereDensityCube(
time=dates_to_return,
l=l_axis_ref,
mlt=mlt_axis_ref,
l_grid=l_grid_ref,
mlt_grid=mlt_grid_ref,
density_grid=density_grid,
density_column=resolved_density_column,
)
Loading