diff --git a/examples/models/plot_channel_aware_model.py b/examples/models/plot_channel_aware_model.py new file mode 100644 index 0000000..336132d --- /dev/null +++ b/examples/models/plot_channel_aware_model.py @@ -0,0 +1,397 @@ +""" +===================================================================== +Channel-Aware Model Implementation +===================================================================== + +This example demonstrates how to implement and use channel-aware models in Kaira. + +Channel-aware models adapt their processing based on the current state of the +communication channel. The ChannelAwareBaseModel provides standardized handling +of Channel State Information (CSI), ensuring consistent usage across different +model implementations. + +This example shows: +1. How to implement simple channel-aware models +2. How to properly handle CSI in different formats +3. How to use the utility methods provided by the base class +4. How channel-aware models can be composed +""" + +# sphinx_gallery_thumbnail_number = 1 +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn + +from kaira.channels import AWGNChannel +from kaira.models.base import CSIFormat, ChannelAwareBaseModel +from kaira.constraints import PowerNormalization + +# %% +# Basic Channel-Aware Model Implementation +# --------------------------------------- +# +# First, we'll implement a simple channel-aware model that adapts its processing +# based on CSI. This model implements a channel-aware gain factor that compensates +# for channel attenuation. + + +class SimpleChannelAwareModel(ChannelAwareBaseModel): + """A simple channel-aware model that applies adaptive gain based on CSI. + + This model demonstrates basic CSI usage by scaling the input based on the + provided channel state information. It implements adaptive gain to compensate + for channel attenuation. + """ + + def __init__(self): + """Initialize the model.""" + super().__init__( + expected_csi_dims=(1,), # Expected CSI shape: [batch_size, 1] + expected_csi_format=CSIFormat.LINEAR, # Expect linear-scale CSI + csi_min_value=0.001, # Minimum expected CSI value + csi_max_value=10.0, # Maximum expected CSI value + ) + + # Simple layer for processing + self.dense = nn.Linear(10, 10) + + def forward(self, x, csi, *args, **kwargs): + """Apply channel-aware processing to input. + + Args: + x (torch.Tensor): Input tensor of shape [batch_size, 10] + csi (torch.Tensor): Channel state information of shape [batch_size, 1] + + Returns: + torch.Tensor: Processed output tensor + """ + # Validate and normalize CSI + if not self.validate_csi(csi): + csi = self.normalize_csi(csi) + + # Apply adaptive gain based on CSI (compensate for channel attenuation) + # For weak channels (low CSI values), apply higher gain + gain = 1.0 / torch.clamp(csi, min=0.01) + + # Process input + out = self.dense(x) + + # Apply adaptive gain + out = out * gain + + return out + + +# %% +# Channel-Aware Composite Model +# ---------------------------- +# +# Now, we'll implement a more complex model that consists of multiple channel-aware +# components. This demonstrates how to compose channel-aware models and propagate +# CSI to submodules. + + +class ChannelAwareEncoder(ChannelAwareBaseModel): + """A channel-aware encoder model.""" + + def __init__(self, in_dim=10, latent_dim=5): + """Initialize the encoder. + + Args: + in_dim (int): Input dimension + latent_dim (int): Latent dimension + """ + super().__init__(expected_csi_format=CSIFormat.DB) + + self.net = nn.Sequential( + nn.Linear(in_dim, 20), + nn.ReLU(), + nn.Linear(20, latent_dim), + ) + + def forward(self, x, csi, *args, **kwargs): + """Encode the input into a latent representation. + + Args: + x (torch.Tensor): Input tensor + csi (torch.Tensor): Channel state information + + Returns: + torch.Tensor: Encoded representation + """ + # Normalize CSI to dB scale + csi = self.normalize_csi(csi) + + # Adjust encoding based on CSI + # - For poor channels (low SNR/CSI), make encoding more robust + # - For good channels (high SNR/CSI), focus on fidelity + robustness_factor = torch.sigmoid(-csi) # Higher for lower SNR + + # Get basic encoding + z = self.net(x) + + # Apply robustness scaling based on channel conditions + # Limiting the dynamic range for poor channels + z_scaled = torch.tanh(z * (1.0 - robustness_factor)) + + return z_scaled + + +class ChannelAwareDecoder(ChannelAwareBaseModel): + """A channel-aware decoder model.""" + + def __init__(self, latent_dim=5, out_dim=10): + """Initialize the decoder. + + Args: + latent_dim (int): Latent dimension + out_dim (int): Output dimension + """ + super().__init__(expected_csi_format=CSIFormat.LINEAR) + + self.net = nn.Sequential( + nn.Linear(latent_dim, 20), + nn.ReLU(), + nn.Linear(20, out_dim), + ) + + def forward(self, z, csi, *args, **kwargs): + """Decode the latent representation. + + Args: + z (torch.Tensor): Latent representation + csi (torch.Tensor): Channel state information + + Returns: + torch.Tensor: Decoded output + """ + # Normalize CSI to linear scale + csi = self.normalize_csi(csi) + + # Apply adaptive processing based on CSI + # - For poor channels, apply more aggressive denoising + # - For good channels, perform lighter processing + denoise_strength = 1.0 / torch.clamp(csi, min=0.01) + + # Basic decoding + out = self.net(z) + + # Apply denoising based on channel quality (simple illustrative example) + # Clip values more aggressively for poor channels + out_denoised = torch.tanh(out * torch.sigmoid(denoise_strength)) + + return out_denoised + + +class CompositeChannelAwareModel(ChannelAwareBaseModel): + """A composite model that combines encoder, channel, and decoder.""" + + def __init__(self): + """Initialize the composite model.""" + super().__init__() + + self.encoder = ChannelAwareEncoder() + self.decoder = ChannelAwareDecoder() + self.constraint = PowerNormalization() + self.channel = AWGNChannel(snr_db=10) + + def forward(self, x, csi, *args, **kwargs): + """Process input through the full pipeline. + + Args: + x (torch.Tensor): Input tensor + csi (torch.Tensor): Channel state information + + Returns: + torch.Tensor: Reconstructed output + """ + # Format CSI appropriately for encoder (which expects dB format) + encoder_csi = self.format_csi_for_submodules(csi, self.encoder) + + # Encode + z = self.encoder(x, encoder_csi) + + # Apply power constraint + z_constrained = self.constraint(z) + + # Pass through channel + z_received = self.channel(z_constrained) + + # Format CSI appropriately for decoder (which expects linear format) + decoder_csi = self.format_csi_for_submodules(csi, self.decoder) + + # Decode + out = self.decoder(z_received, decoder_csi) + + return out + + +# %% +# Testing and Visualization +# ------------------------ +# +# Now, let's test our models and visualize how they adapt to different channel conditions. + +# Create test data +batch_size = 5 +data = torch.randn(batch_size, 10) + +# Create CSI at different qualities +csi_db_range = torch.tensor([-20, -10, 0, 10, 20]).reshape(batch_size, 1) +csi_linear = 10 ** (csi_db_range / 10) # Convert dB to linear scale + +# Create models +simple_model = SimpleChannelAwareModel() +composite_model = CompositeChannelAwareModel() + +# Process data through models +with torch.no_grad(): + # Process with different CSI formats to demonstrate conversion + simple_output_db = simple_model(data, csi_db_range) + simple_output_linear = simple_model(data, csi_linear) + + composite_output_db = composite_model(data, csi_db_range) + composite_output_linear = composite_model(data, csi_linear) + +# %% +# Visualize how outputs change with CSI quality +plt.figure(figsize=(12, 8)) + +# Plot simple model outputs +plt.subplot(2, 1, 1) +for i in range(batch_size): + plt.plot(simple_output_db[i].numpy(), label=f'SNR: {csi_db_range[i].item()} dB') +plt.title('Simple Channel-Aware Model Output') +plt.xlabel('Feature Index') +plt.ylabel('Output Value') +plt.legend() +plt.grid(True) + +# Plot composite model outputs +plt.subplot(2, 1, 2) +for i in range(batch_size): + plt.plot(composite_output_db[i].numpy(), label=f'SNR: {csi_db_range[i].item()} dB') +plt.title('Composite Channel-Aware Model Output') +plt.xlabel('Feature Index') +plt.ylabel('Output Value') +plt.legend() +plt.grid(True) + +plt.tight_layout() +plt.show() + +# %% +# Channel-Aware Sequential Processing +# --------------------------------- +# +# The ChannelAwareBaseModel provides utility methods to pass CSI through sequential +# modules. Let's demonstrate this functionality. + + +class SequentialChannelAwareModel(ChannelAwareBaseModel): + """A model that demonstrates sequential processing with CSI.""" + + def __init__(self): + """Initialize the model.""" + super().__init__() + + # Create a mix of channel-aware and regular modules + self.ca_module1 = SimpleChannelAwareModel() + self.regular_module = nn.Linear(10, 10) + self.ca_module2 = ChannelAwareEncoder(in_dim=10, latent_dim=10) + + # Collect modules in a list for sequential processing + self.modules_list = [ + self.ca_module1, + self.regular_module, + self.ca_module2, + ] + + def forward(self, x, csi, *args, **kwargs): + """Process input sequentially with CSI. + + Args: + x (torch.Tensor): Input tensor + csi (torch.Tensor): Channel state information + + Returns: + torch.Tensor: Processed output + """ + # Use the utility method to forward through sequential modules + return self.forward_csi_to_sequential(x, self.modules_list, csi, *args, **kwargs) + + +# Create and test sequential model +sequential_model = SequentialChannelAwareModel() +with torch.no_grad(): + sequential_output = sequential_model(data, csi_db_range) + +# %% +# Visualize sequential model output +plt.figure(figsize=(10, 6)) +for i in range(batch_size): + plt.plot(sequential_output[i].numpy(), label=f'SNR: {csi_db_range[i].item()} dB') +plt.title('Sequential Channel-Aware Model Output') +plt.xlabel('Feature Index') +plt.ylabel('Output Value') +plt.legend() +plt.grid(True) +plt.tight_layout() +plt.show() + +# %% +# Extracting CSI from Channel Output +# -------------------------------- +# +# Some channels return CSI along with the processed signal. Let's demonstrate how to +# extract and use CSI from channel outputs. + +# Create data and channel +data = torch.randn(batch_size, 10) +channel = AWGNChannel(snr_db=10) + +# Simulate channel with CSI output +channel_output = { + 'signal': channel(data), + 'snr': torch.tensor([5.0, 7.5, 10.0, 12.5, 15.0]).reshape(batch_size, 1) +} + +# Create and use model with extracted CSI +model = SimpleChannelAwareModel() +with torch.no_grad(): + # Extract CSI from channel output + extracted_csi = model.extract_csi_from_channel_output(channel_output) + + # Use extracted CSI for processing + output = model(channel_output['signal'], extracted_csi) + +# %% +# Visualize output with extracted CSI +plt.figure(figsize=(10, 6)) +for i in range(batch_size): + plt.plot(output[i].numpy(), label=f'SNR: {extracted_csi[i].item()} dB') +plt.title('Model Output with Extracted CSI') +plt.xlabel('Feature Index') +plt.ylabel('Output Value') +plt.legend() +plt.grid(True) +plt.tight_layout() +plt.show() + +# %% +# Conclusion +# --------- +# +# This example demonstrated how to implement and use channel-aware models in Kaira. +# The ChannelAwareBaseModel provides standardized handling of CSI, ensuring consistent +# usage across different model implementations. The key features demonstrated include: +# +# 1. Creating simple and composite channel-aware models +# 2. Handling CSI in different formats (dB, linear) +# 3. Sequential processing with CSI +# 4. Extracting CSI from channel outputs +# +# Channel-aware models are particularly useful in wireless communication systems +# where adaptive processing based on channel conditions can significantly improve +# performance. \ No newline at end of file diff --git a/kaira/models/base.py b/kaira/models/base.py index 543b6d9..01de391 100644 --- a/kaira/models/base.py +++ b/kaira/models/base.py @@ -6,8 +6,10 @@ """ from abc import ABC, abstractmethod -from typing import Any, List +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Union +import torch from torch import nn @@ -52,6 +54,474 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError("Subclasses must implement forward method") +class CSIFormat(Enum): + """Enumeration of different CSI (Channel State Information) formats. + + This enum represents the different formats that CSI data can take in channel-aware models. + """ + + LINEAR = "linear" # Linear scale (e.g., 0.001 to 10.0) + DB = "db" # Decibel scale (e.g., -30 to 10 dB) + NORMALIZED = "normalized" # Normalized to a specific range (e.g., 0.0 to 1.0) + COMPLEX = "complex" # Complex-valued CSI with magnitude and phase + VECTOR = "vector" # Multi-dimensional vector of channel coefficients + MATRIX = "matrix" # Matrix representation (e.g., for MIMO channels) + + +class ChannelAwareBaseModel(BaseModel, ABC): + """Abstract base class for channel-aware models that require CSI. + + This class standardizes how Channel State Information (CSI) is handled + in channel-aware models, ensuring that CSI is explicitly required + rather than being an optional parameter. It provides utilities for + validating, normalizing, and transforming CSI data to ensure consistent + usage across different model implementations. + + Channel-aware models are neural networks that adapt their processing based + on the current state of the communication channel, which may include properties + like signal-to-noise ratio (SNR), fading coefficients, or other channel quality + indicators. + + Attributes: + expected_csi_dims (Tuple[int, ...]): Expected dimensions for CSI tensor + expected_csi_format (CSIFormat): Expected format for CSI values + csi_min_value (float): Minimum expected value for valid CSI + csi_max_value (float): Maximum expected value for valid CSI + auto_normalize_csi (bool): Whether to automatically normalize CSI in forward pass + strict_validation (bool): Whether to raise errors for invalid CSI or silently fix + """ + + def __init__( + self, + expected_csi_dims: Optional[Tuple[int, ...]] = None, + expected_csi_format: CSIFormat = CSIFormat.LINEAR, + csi_min_value: float = float('-inf'), + csi_max_value: float = float('inf'), + auto_normalize_csi: bool = True, + strict_validation: bool = True, + *args: Any, + **kwargs: Any + ): + """Initialize the channel-aware model. + + Args: + expected_csi_dims (Optional[Tuple[int, ...]]): Expected dimensions for CSI tensor. + If None, dimensions are not validated. + expected_csi_format (CSIFormat): Expected format for CSI values. + Defaults to CSIFormat.LINEAR. + csi_min_value (float): Minimum expected value for valid CSI. + Defaults to negative infinity. + csi_max_value (float): Maximum expected value for valid CSI. + Defaults to positive infinity. + auto_normalize_csi (bool): Whether to automatically normalize CSI in forward pass. + Defaults to True. + strict_validation (bool): Whether to raise errors for invalid CSI or silently fix. + Defaults to True (raise errors on invalid CSI). + *args: Variable positional arguments passed to the base class. + **kwargs: Variable keyword arguments passed to the base class. + """ + super().__init__(*args, **kwargs) + self.expected_csi_dims = expected_csi_dims + self.expected_csi_format = expected_csi_format + self.csi_min_value = csi_min_value + self.csi_max_value = csi_max_value + self.auto_normalize_csi = auto_normalize_csi + self.strict_validation = strict_validation + self._last_csi = None # Cache for debugging and visualization + + @abstractmethod + def forward(self, x: torch.Tensor, csi: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + """Transform the input using channel state information. + + Args: + x (torch.Tensor): The input tensor to process + csi (torch.Tensor): Channel state information tensor + *args: Additional positional arguments + **kwargs: Additional keyword arguments + + Returns: + torch.Tensor: Processed output tensor + """ + # Implementations should first normalize CSI if auto_normalize_csi is False: + # if not self.auto_normalize_csi: + # csi = self.get_normalized_csi(csi) + pass + + def _validate_csi_impl(self, csi: torch.Tensor) -> tuple[bool, Optional[str]]: + """Implementation of CSI validation logic. + + Args: + csi (torch.Tensor): The CSI tensor to validate + + Returns: + tuple[bool, Optional[str]]: Validation result and error message if invalid + """ + # Basic type checking + if not isinstance(csi, torch.Tensor): + return False, f"CSI must be a torch.Tensor, got {type(csi)}" + + # Check for NaNs or infinities + if torch.isnan(csi).any(): + return False, "CSI contains NaN values" + + if torch.isinf(csi).any(): + return False, "CSI contains infinite values" + + # Dimension validation + if self.expected_csi_dims is not None: + if len(csi.shape) != len(self.expected_csi_dims): + return False, (f"CSI dimensions mismatch: expected {len(self.expected_csi_dims)} " + f"dimensions, got {len(csi.shape)}") + + # Check if dimensions match, ignoring batch size (first dim) + for i in range(1, len(csi.shape)): + if i < len(self.expected_csi_dims) and self.expected_csi_dims[i] != -1: + if csi.shape[i] != self.expected_csi_dims[i]: + return False, (f"CSI shape mismatch at dimension {i}: expected " + f"{self.expected_csi_dims[i]}, got {csi.shape[i]}") + + # Value range validation + if not torch.is_complex(csi): # Only check range for real-valued CSI + if (csi < self.csi_min_value).any(): + return False, f"CSI contains values below minimum {self.csi_min_value}" + + if (csi > self.csi_max_value).any(): + return False, f"CSI contains values above maximum {self.csi_max_value}" + + return True, None + + def validate_csi(self, csi: torch.Tensor) -> bool: + """Validate that the CSI tensor meets this model's requirements. + + Args: + csi (torch.Tensor): The CSI tensor to validate + + Returns: + bool: True if CSI is valid, False otherwise + + Raises: + ValueError: If CSI is severely invalid and strict_validation is True + """ + valid, error_message = self._validate_csi_impl(csi) + + if not valid and self.strict_validation and error_message: + raise ValueError(f"CSI validation error: {error_message}") + + return valid + + def get_normalized_csi(self, csi: torch.Tensor) -> torch.Tensor: + """Get properly normalized CSI according to model requirements. + + This is a convenience wrapper that validates and normalizes CSI + in a single call, with appropriate error handling based on + the strict_validation setting. + + Args: + csi (torch.Tensor): The input CSI tensor + + Returns: + torch.Tensor: Normalized CSI tensor + + Raises: + ValueError: If CSI validation fails and strict_validation is True + """ + try: + if not self.validate_csi(csi): + return self.normalize_csi(csi) + return csi + except ValueError as e: + if self.strict_validation: + raise + # If not strict, try to recover and normalize anyway + return self.normalize_csi(csi) + + def normalize_csi(self, csi: torch.Tensor) -> torch.Tensor: + """Normalize CSI to the format expected by this model. + + Args: + csi (torch.Tensor): The CSI tensor to normalize + + Returns: + torch.Tensor: Normalized CSI tensor + """ + # Handle non-tensor input gracefully + if not isinstance(csi, torch.Tensor): + try: + csi = torch.tensor(csi, dtype=torch.float32) + except: + raise ValueError(f"Cannot convert CSI of type {type(csi)} to tensor") + + # Handle NaNs and infs by replacing with reasonable values + if torch.isnan(csi).any() or torch.isinf(csi).any(): + csi = torch.nan_to_num(csi, nan=0.0, posinf=self.csi_max_value, neginf=self.csi_min_value) + + # Convert to expected format if needed + if self.expected_csi_format == CSIFormat.LINEAR: + # Detect if input might be in dB based on negative values or range + if csi.min() < 0 or (csi.max() <= 30 and csi.min() >= -30): + csi = self._db_to_linear(csi) + + elif self.expected_csi_format == CSIFormat.DB: + # Detect if input might be in linear scale (all positive, potentially large values) + if csi.min() >= 0 and csi.max() > 30: + csi = self._linear_to_db(csi) + + elif self.expected_csi_format == CSIFormat.NORMALIZED: + # Normalize to [0, 1] range + csi = self._normalize_to_range(csi, 0.0, 1.0) + + elif self.expected_csi_format == CSIFormat.COMPLEX: + # If we have real values but expect complex + if not torch.is_complex(csi): + csi = torch.complex(csi, torch.zeros_like(csi)) + + # Reshape if dimensions don't match expected + if self.expected_csi_dims is not None and len(csi.shape) != len(self.expected_csi_dims): + # Try to reshape to expected dimensions + try: + # Keep batch dimension, reshape rest to match expected + if csi.numel() == csi.shape[0] * torch.prod(torch.tensor(self.expected_csi_dims[1:])): + new_shape = (csi.shape[0],) + self.expected_csi_dims[1:] + csi = csi.reshape(new_shape) + except: + # If reshape fails, leave as is and let validation handle it + pass + + # Clamp values to expected range for real-valued tensors + if not torch.is_complex(csi): + csi = torch.clamp(csi, min=self.csi_min_value, max=self.csi_max_value) + + # Cache the normalized CSI for debugging + self._last_csi = csi + + return csi + + def _db_to_linear(self, csi_db: torch.Tensor) -> torch.Tensor: + """Convert CSI from dB to linear scale. + + Args: + csi_db (torch.Tensor): CSI in decibels + + Returns: + torch.Tensor: CSI in linear scale + """ + return 10.0 ** (csi_db / 10.0) + + def _linear_to_db(self, csi_linear: torch.Tensor) -> torch.Tensor: + """Convert CSI from linear to dB scale. + + Args: + csi_linear (torch.Tensor): CSI in linear scale + + Returns: + torch.Tensor: CSI in decibels + """ + # Add small epsilon to prevent log of zero + return 10.0 * torch.log10(csi_linear + 1e-10) + + def _normalize_to_range( + self, + csi: torch.Tensor, + target_min: float = 0.0, + target_max: float = 1.0 + ) -> torch.Tensor: + """Normalize CSI values to a target range. + + Args: + csi (torch.Tensor): CSI tensor to normalize + target_min (float): Target minimum value + target_max (float): Target maximum value + + Returns: + torch.Tensor: Normalized CSI + """ + csi_min = csi.min() + csi_max = csi.max() + + if csi_min == csi_max: + return torch.ones_like(csi) * target_min + + normalized = (csi - csi_min) / (csi_max - csi_min) + return normalized * (target_max - target_min) + target_min + + def extract_csi_from_channel_output( + self, + channel_output: Union[Dict[str, Any], torch.Tensor] + ) -> torch.Tensor: + """Extract CSI from a channel's output dictionary or tensor. + + This utility helps standardize how CSI is extracted from channel outputs, + which may include the transformed signal along with CSI and other metadata. + + Args: + channel_output (Union[Dict[str, Any], torch.Tensor]): + Output from a channel, either as a dictionary or tensor + + Returns: + torch.Tensor: Extracted CSI tensor + + Raises: + ValueError: If CSI cannot be found in the channel output + """ + # If output is already a tensor, return it as is (assuming it's the CSI) + if isinstance(channel_output, torch.Tensor): + return channel_output + + # Look for CSI in common keys for dictionary output + if isinstance(channel_output, dict): + # Check most common keys first for efficiency + for key in ["csi", "h", "snr", "channel_state", "channel_coefficients", "channel_info"]: + if key in channel_output: + return channel_output[key] + + # Try case-insensitive search as fallback + lowercase_keys = {k.lower(): k for k in channel_output.keys()} + for key in ["csi", "h", "snr", "channel_state"]: + if key in lowercase_keys: + return channel_output[lowercase_keys[key]] + + # If we didn't find any recognized CSI keys + raise ValueError( + "Could not extract CSI from channel output. Channel output must contain " + "one of: 'csi', 'h', 'channel_state', 'channel_coefficients', 'snr', or be a tensor." + ) + + def format_csi_for_submodules( + self, + csi: torch.Tensor, + submodule: nn.Module + ) -> torch.Tensor: + """Format CSI appropriately for a specific submodule. + + Transforms CSI to match the requirements of a particular submodule, + based on its type and expected format. + + Args: + csi (torch.Tensor): The original CSI tensor + submodule (nn.Module): The submodule that will receive the CSI + + Returns: + torch.Tensor: Formatted CSI appropriate for the submodule + """ + # For ChannelAwareBaseModel submodules, use their expected format + if isinstance(submodule, ChannelAwareBaseModel): + if submodule.expected_csi_format != self.expected_csi_format: + if submodule.expected_csi_format == CSIFormat.LINEAR: + return self._db_to_linear(csi) + elif submodule.expected_csi_format == CSIFormat.DB: + return self._linear_to_db(csi) + elif submodule.expected_csi_format == CSIFormat.NORMALIZED: + return self._normalize_to_range(csi, 0.0, 1.0) + elif submodule.expected_csi_format == CSIFormat.COMPLEX and not torch.is_complex(csi): + return torch.complex(csi, torch.zeros_like(csi)) + + # If dimensions don't match but format does, try to reshape + if (submodule.expected_csi_dims is not None and + csi.dim() != len(submodule.expected_csi_dims)): + # Try to adapt dimensions (only if batch size stays the same) + try: + new_shape = [csi.shape[0]] # Keep batch dimension + for dim in submodule.expected_csi_dims[1:]: + new_shape.append(dim if dim != -1 else 1) + return csi.reshape(new_shape) + except: + # If reshape fails, return as is + pass + + # For other modules, return as is + return csi + + def forward_csi_to_sequential( + self, + x: torch.Tensor, + modules: List[nn.Module], + csi: torch.Tensor, + *args: Any, + **kwargs: Any + ) -> torch.Tensor: + """Forward input through a sequence of modules with CSI. + + This utility helps standardize how CSI is forwarded through sequential + modules in a channel-aware model. It handles both channel-aware and + regular modules appropriately. + + Args: + x (torch.Tensor): Input tensor + modules (List[nn.Module]): List of modules to process input sequentially + csi (torch.Tensor): Channel state information + *args: Additional positional arguments + **kwargs: Additional keyword arguments + + Returns: + torch.Tensor: Output after sequential processing + """ + result = x + + # Auto-normalize CSI on first use + if self.auto_normalize_csi: + csi = self.get_normalized_csi(csi) + + # Process through modules + for module in modules: + if isinstance(module, ChannelAwareBaseModel): + # Format CSI for this submodule + formatted_csi = self.format_csi_for_submodules(csi, module) + # Pass explicitly to channel-aware modules + result = module(result, formatted_csi, *args, **kwargs) + elif hasattr(module, 'forward') and 'csi' in module.forward.__code__.co_varnames: + # Module has a forward method with a csi parameter but doesn't inherit from ChannelAwareBaseModel + # This handles legacy or third-party modules + result = module(result, csi, *args, **kwargs) + else: + # For regular modules, just pass the input + result = module(result, *args, **kwargs) + + return result + + def get_last_csi(self) -> Optional[torch.Tensor]: + """Get the most recently used normalized CSI tensor. + + This is useful for debugging and visualization purposes. + + Returns: + Optional[torch.Tensor]: The last normalized CSI tensor used, + or None if no CSI has been processed yet. + """ + return self._last_csi + + @staticmethod + def detect_csi_format(csi: torch.Tensor) -> CSIFormat: + """Detect the most likely format of a CSI tensor. + + Args: + csi (torch.Tensor): The CSI tensor to analyze + + Returns: + CSIFormat: The detected format + """ + if torch.is_complex(csi): + return CSIFormat.COMPLEX + + # Check if likely in dB scale + if csi.min() < 0 and csi.max() < 50: # Typical dB range + return CSIFormat.DB + + # Check if normalized + if csi.min() >= 0 and csi.max() <= 1: + return CSIFormat.NORMALIZED + + # Check for matrix format + if csi.dim() >= 3: + return CSIFormat.MATRIX + + # Check for vector format + if csi.dim() == 2 and csi.shape[1] > 1: + return CSIFormat.VECTOR + + # Default to linear + return CSIFormat.LINEAR + + class ConfigurableModel(BaseModel): """Model that supports dynamically adding and removing steps. diff --git a/tests/models/test_models_channel_aware.py b/tests/models/test_models_channel_aware.py new file mode 100644 index 0000000..4ce33fc --- /dev/null +++ b/tests/models/test_models_channel_aware.py @@ -0,0 +1,255 @@ +import pytest +import torch + +from kaira.models.base import BaseModel, CSIFormat, ChannelAwareBaseModel + + +# Helper classes for testing +class SimpleTestModel(ChannelAwareBaseModel): + """A simple model for testing the ChannelAwareBaseModel functionality.""" + + def __init__(self, expected_format=CSIFormat.LINEAR): + super().__init__(expected_csi_format=expected_format) + + def forward(self, x, csi, *args, **kwargs): + """Simple forward pass that multiplies input by CSI.""" + return x * csi + + +class NestedTestModel(ChannelAwareBaseModel): + """A model that contains another channel-aware model as a component.""" + + def __init__(self): + super().__init__() + self.submodel = SimpleTestModel(expected_format=CSIFormat.DB) + + def forward(self, x, csi, *args, **kwargs): + """Forward pass that formats CSI and passes it to submodel.""" + formatted_csi = self.format_csi_for_submodules(csi, self.submodel) + return self.submodel(x, formatted_csi) + + +# Test initialization and basic properties +def test_channel_aware_base_model_init(): + """Test initializing the ChannelAwareBaseModel with various parameters.""" + # Test with default parameters + model = SimpleTestModel() + assert model.expected_csi_format == CSIFormat.LINEAR + assert model.expected_csi_dims is None + assert model.csi_min_value == float('-inf') + assert model.csi_max_value == float('inf') + + # Test with custom parameters + model = SimpleTestModel(expected_format=CSIFormat.DB) + assert model.expected_csi_format == CSIFormat.DB + + # Test with all parameters specified + model = ChannelAwareBaseModel( + expected_csi_dims=(1, 2), + expected_csi_format=CSIFormat.NORMALIZED, + csi_min_value=0.0, + csi_max_value=1.0 + ) + assert model.expected_csi_dims == (1, 2) + assert model.expected_csi_format == CSIFormat.NORMALIZED + assert model.csi_min_value == 0.0 + assert model.csi_max_value == 1.0 + + +# Test validation functionality +def test_validate_csi(): + """Test the CSI validation functionality.""" + model = ChannelAwareBaseModel( + expected_csi_dims=(1,), + expected_csi_format=CSIFormat.LINEAR, + csi_min_value=0.0, + csi_max_value=10.0 + ) + + # Valid CSI + valid_csi = torch.tensor([[5.0]]) + assert model.validate_csi(valid_csi) + + # Invalid dimension + invalid_dim_csi = torch.tensor([5.0, 6.0]) + assert not model.validate_csi(invalid_dim_csi) + + # Invalid value range + invalid_range_csi = torch.tensor([[-1.0]]) + assert not model.validate_csi(invalid_range_csi) + + # Test with NaN - should raise ValueError + nan_csi = torch.tensor([[float('nan')]]) + with pytest.raises(ValueError): + model.validate_csi(nan_csi) + + # Test with wrong type - should raise ValueError + with pytest.raises(ValueError): + model.validate_csi(5.0) # Not a tensor + + +# Test normalization functionality +def test_normalize_csi(): + """Test the CSI normalization functionality.""" + # Model expecting linear format + linear_model = ChannelAwareBaseModel(expected_csi_format=CSIFormat.LINEAR) + + # Test converting from dB to linear + db_csi = torch.tensor([0.0, 10.0, 20.0]) + normalized_csi = linear_model.normalize_csi(db_csi) + expected_linear = torch.tensor([1.0, 10.0, 100.0]) + assert torch.allclose(normalized_csi, expected_linear) + + # Model expecting dB format + db_model = ChannelAwareBaseModel(expected_csi_format=CSIFormat.DB) + + # Test converting from linear to dB + linear_csi = torch.tensor([1.0, 10.0, 100.0]) + normalized_csi = db_model.normalize_csi(linear_csi) + expected_db = torch.tensor([0.0, 10.0, 20.0]) + assert torch.allclose(normalized_csi, expected_db, atol=1e-5) + + # Model expecting normalized format + norm_model = ChannelAwareBaseModel(expected_csi_format=CSIFormat.NORMALIZED) + + # Test normalizing to [0, 1] range + unnorm_csi = torch.tensor([-10.0, 0.0, 10.0]) + normalized_csi = norm_model.normalize_csi(unnorm_csi) + expected_norm = torch.tensor([0.0, 0.5, 1.0]) + assert torch.allclose(normalized_csi, expected_norm) + + +# Test conversion utility functions +def test_conversion_functions(): + """Test the CSI conversion utility functions.""" + model = ChannelAwareBaseModel() + + # Test dB to linear conversion + db_values = torch.tensor([-10.0, 0.0, 10.0, 20.0]) + linear_values = model._db_to_linear(db_values) + expected_linear = torch.tensor([0.1, 1.0, 10.0, 100.0]) + assert torch.allclose(linear_values, expected_linear) + + # Test linear to dB conversion + linear_values = torch.tensor([0.1, 1.0, 10.0, 100.0]) + db_values = model._linear_to_db(linear_values) + expected_db = torch.tensor([-10.0, 0.0, 10.0, 20.0]) + assert torch.allclose(db_values, expected_db, atol=1e-5) + + # Test normalize to range + values = torch.tensor([-10.0, 0.0, 10.0]) + normalized = model._normalize_to_range(values, 0.0, 1.0) + expected = torch.tensor([0.0, 0.5, 1.0]) + assert torch.allclose(normalized, expected) + + # Test normalize single value + single_value = torch.tensor([5.0]) + normalized = model._normalize_to_range(single_value, 0.0, 1.0) + expected = torch.tensor([0.0]) # When min == max, returns target_min + assert torch.allclose(normalized, expected) + + +# Test the format_csi_for_submodules function +def test_format_csi_for_submodules(): + """Test formatting CSI for different submodules.""" + parent_model = ChannelAwareBaseModel(expected_csi_format=CSIFormat.LINEAR) + + # Format for a channel-aware submodule expecting dB + db_submodel = SimpleTestModel(expected_format=CSIFormat.DB) + linear_csi = torch.tensor([1.0, 10.0, 100.0]) + formatted_csi = parent_model.format_csi_for_submodules(linear_csi, db_submodel) + expected_db = torch.tensor([0.0, 10.0, 20.0]) + assert torch.allclose(formatted_csi, expected_db, atol=1e-5) + + # Format for a channel-aware submodule expecting normalized + norm_submodel = SimpleTestModel(expected_format=CSIFormat.NORMALIZED) + linear_csi = torch.tensor([0.1, 1.0, 10.0]) + formatted_csi = parent_model.format_csi_for_submodules(linear_csi, norm_submodel) + expected_norm = torch.tensor([0.0, 0.1, 1.0]) + assert torch.allclose(formatted_csi, expected_norm, atol=1e-5) + + # Format for a non-channel-aware submodule + regular_submodule = torch.nn.Linear(10, 10) + linear_csi = torch.tensor([1.0, 10.0, 100.0]) + formatted_csi = parent_model.format_csi_for_submodules(linear_csi, regular_submodule) + # Should return the CSI unchanged + assert torch.allclose(formatted_csi, linear_csi) + + +# Test forward_csi_to_sequential +def test_forward_csi_to_sequential(): + """Test forwarding CSI through a sequence of modules.""" + model = ChannelAwareBaseModel() + + # Create a mix of channel-aware and regular modules + ca_module1 = SimpleTestModel(expected_format=CSIFormat.LINEAR) + regular_module = torch.nn.Linear(2, 2) + ca_module2 = SimpleTestModel(expected_format=CSIFormat.DB) + + # Initialize the modules properly + regular_module.weight.data = torch.tensor([[2.0, 0.0], [0.0, 2.0]]) + regular_module.bias.data = torch.zeros(2) + + modules = [ca_module1, regular_module, ca_module2] + + # Input and CSI + x = torch.tensor([[1.0, 2.0]]) + csi = torch.tensor([[10.0]]) # Linear scale + + # Forward through modules + result = model.forward_csi_to_sequential(x, modules, csi) + + # Expected result: + # 1. ca_module1: x * csi = [1, 2] * 10 = [10, 20] + # 2. regular_module: [10, 20] * [[2, 0], [0, 2]] = [20, 40] + # 3. ca_module2: [20, 40] * 10 = [200, 400] + expected = torch.tensor([[200.0, 400.0]]) + + assert torch.allclose(result, expected) + + +# Test extract_csi_from_channel_output +def test_extract_csi_from_channel_output(): + """Test extracting CSI from channel output dictionaries.""" + model = ChannelAwareBaseModel() + + # Test extracting from 'csi' key + output1 = {'signal': torch.tensor([1.0]), 'csi': torch.tensor([10.0])} + csi1 = model.extract_csi_from_channel_output(output1) + assert torch.allclose(csi1, torch.tensor([10.0])) + + # Test extracting from 'snr' key + output2 = {'signal': torch.tensor([1.0]), 'snr': torch.tensor([20.0])} + csi2 = model.extract_csi_from_channel_output(output2) + assert torch.allclose(csi2, torch.tensor([20.0])) + + # Test extracting from 'h' key (channel coefficients) + output3 = {'signal': torch.tensor([1.0]), 'h': torch.tensor([0.5])} + csi3 = model.extract_csi_from_channel_output(output3) + assert torch.allclose(csi3, torch.tensor([0.5])) + + # Test extracting when no CSI key is present + output4 = {'signal': torch.tensor([1.0])} + with pytest.raises(ValueError): + model.extract_csi_from_channel_output(output4) + + +# Test nested models with different CSI format expectations +def test_nested_channel_aware_models(): + """Test nested channel-aware models with different CSI format expectations.""" + # Create outer model expecting linear CSI and inner model expecting dB CSI + model = NestedTestModel() + + # Input and linear CSI + x = torch.tensor([1.0, 2.0]) + linear_csi = torch.tensor([10.0]) # Linear scale + + # Process data + result = model(x, linear_csi) + + # Expected result: + # 1. Convert linear CSI to dB: 10 dB + # 2. Apply to input: [1, 2] * 10 = [10, 20] + expected = torch.tensor([10.0, 20.0]) + + assert torch.allclose(result, expected) \ No newline at end of file