diff --git a/kaira/modulations/base.py b/kaira/modulations/base.py index ad04bab..fbd09ee 100644 --- a/kaira/modulations/base.py +++ b/kaira/modulations/base.py @@ -49,6 +49,29 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ pass + def forward_soft(self, x: torch.Tensor, temp: float = 1.0, *args, **kwargs) -> torch.Tensor: + """Modulate soft bits to symbols in a differentiable manner. + + This method enables differentiability through the modulator using soft bit + probabilities as input. Default implementation calls forward, but subclasses + should override for true differentiability. + + Args: + x: Input tensor of soft bit probabilities with shape (..., K*N), + where K is bits_per_symbol. Values should be in [0, 1] range, + representing P(bit=1). + temp: Temperature parameter for soft decisions (lower = harder) + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + Modulated symbols with shape (..., N) + """ + # Default implementation just calls forward with hard decisions + # Subclasses should override this for true differentiability + hard_bits = (x > 0.5).float() + return self.forward(hard_bits, *args, **kwargs) + def plot_constellation(self, **kwargs): """Plot the constellation diagram. @@ -109,6 +132,32 @@ def forward(self, y: torch.Tensor, noise_var: Optional[Union[float, torch.Tensor """ pass + def forward_soft(self, y: torch.Tensor, noise_var: Union[float, torch.Tensor], temp: float = 1.0, *args, **kwargs) -> torch.Tensor: + """Demodulate symbols to soft bit probabilities in a differentiable manner. + + This method enables differentiability through the demodulator. The default + implementation converts LLRs to probabilities, but subclasses should override + this method if a more efficient implementation is available. + + Args: + y: Received symbols with shape (..., N) + noise_var: Noise variance (required) + temp: Temperature parameter for controlling softness of decisions + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + Soft bit probabilities with shape (..., N*bits_per_symbol) + Values are in [0, 1] range, representing P(bit=1) + """ + # Default implementation converts LLRs to probabilities + # Subclasses can override for more efficient implementations + llrs = self.forward(y, noise_var, *args, **kwargs) + # Convert LLRs to probabilities with temperature scaling + # P(bit=1) = 1 / (1 + exp(LLR / temp)) + probs = torch.sigmoid(-llrs / temp) + return probs + def reset_state(self) -> None: """Reset any stateful components. diff --git a/kaira/modulations/differentiable.py b/kaira/modulations/differentiable.py new file mode 100644 index 0000000..8e49af8 --- /dev/null +++ b/kaira/modulations/differentiable.py @@ -0,0 +1,102 @@ +"""Differentiable operations for modulation schemes. + +This module provides differentiable alternatives to operations commonly used in digital modulation +that are not naturally differentiable, such as bit mapping and constellation symbol selection. +""" + +import torch +import torch.nn.functional as F + + +def soft_symbol_mapping(soft_bits: torch.Tensor, constellation: torch.Tensor, bit_patterns: torch.Tensor) -> torch.Tensor: + """Map soft bit probabilities to a weighted sum of constellation symbols. + + This function provides a differentiable path from soft bit probabilities to symbols + by computing expectations over the constellation. + + Args: + soft_bits: Soft bit probabilities with shape (..., K) where K is bits_per_symbol + Values should be in [0, 1] range, representing P(bit=1) + constellation: Complex tensor of constellation points with shape (M,) + bit_patterns: Binary tensor with shape (M, K) representing the bit patterns + for each constellation point + + Returns: + Complex tensor with shape (...) representing the expected symbol value + """ + # Reshape soft_bits for broadcasting with bit_patterns + soft_bits = soft_bits.unsqueeze(-2) # (..., 1, K) + + # Calculate probabilities of each bit pattern + # For each bit position: + # - If bit pattern is 1, use soft_bit probability + # - If bit pattern is 0, use (1 - soft_bit) probability + probs_when_bit_is_1 = soft_bits # P(bit=1) + probs_when_bit_is_0 = 1 - soft_bits # P(bit=0) + + # Select probabilities based on the bit patterns + # bit_patterns has shape (M, K) + bit_probs = torch.where(bit_patterns.unsqueeze(0).bool(), probs_when_bit_is_1, probs_when_bit_is_0) # (1, M, K) # (..., 1, K) # (..., 1, K) + + # Calculate the joint probability of each constellation point + # by multiplying probabilities of individual bits + symbol_probs = torch.prod(bit_probs, dim=-1) # (..., M) + + # Calculate the expected symbol + expected_symbol = torch.sum(symbol_probs * constellation, dim=-1) # (...) + + return expected_symbol + + +def soft_bits_to_hard_symbols(soft_bits: torch.Tensor, constellation: torch.Tensor, bit_patterns: torch.Tensor, temp: float = 1.0) -> torch.Tensor: + """Convert soft bits to hard symbols with a differentiable approximation. + + Uses a temperature-based softmax approach for approximating the hard decision + while maintaining differentiability. + + Args: + soft_bits: Soft bit probabilities with shape (..., K) where K is bits_per_symbol + Values should be in [0, 1] range, representing P(bit=1) + constellation: Complex tensor of constellation points with shape (M,) + bit_patterns: Binary tensor with shape (M, K) representing the bit patterns + for each constellation point + temp: Temperature parameter for softmax (lower = harder decision) + + Returns: + Complex tensor with shape (...) representing the selected symbol + """ + # Reshape soft_bits for broadcasting + soft_bits = soft_bits.unsqueeze(-2) # (..., 1, K) + + # Calculate log probabilities for each bit pattern + log_probs_when_bit_is_1 = torch.log(soft_bits + 1e-10) + log_probs_when_bit_is_0 = torch.log(1 - soft_bits + 1e-10) + + log_bit_probs = torch.where(bit_patterns.unsqueeze(0).bool(), log_probs_when_bit_is_1, log_probs_when_bit_is_0) + + # Sum log probabilities to get joint log probability + log_symbol_probs = torch.sum(log_bit_probs, dim=-1) # (..., M) + + # Apply temperature scaling and softmax + symbol_weights = F.softmax(log_symbol_probs / temp, dim=-1) # (..., M) + + # Calculate the weighted sum of constellation points + weighted_symbols = torch.sum(symbol_weights * constellation, dim=-1) + + return weighted_symbols + + +def hard_decisions_with_straight_through(soft_values: torch.Tensor) -> torch.Tensor: + """Make hard 0/1 decisions while allowing backpropagation with straight-through estimator. + + Args: + soft_values: Soft values typically in range [0, 1] + + Returns: + Hard binary decisions (0 or 1) with gradients passed through unchanged + """ + # Forward pass: hard thresholding + hard_decisions = (soft_values > 0.5).float() + + # Straight-through estimator: pass gradients through unchanged + return hard_decisions.detach() - soft_values.detach() + soft_values diff --git a/kaira/modulations/psk.py b/kaira/modulations/psk.py index 5928527..535cf33 100644 --- a/kaira/modulations/psk.py +++ b/kaira/modulations/psk.py @@ -36,6 +36,8 @@ def __init__(self, *args, **kwargs) -> None: re_part = torch.tensor([1.0, -1.0]) im_part = torch.tensor([0.0, 0.0]) self.register_buffer("constellation", torch.complex(re_part, im_part)) + # Create bit patterns for each constellation point + self.register_buffer("bit_patterns", torch.tensor([[0.0], [1.0]])) self._bits_per_symbol = 1 # BPSK has 1 bit per symbol def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: @@ -52,6 +54,24 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: # Convert binary 0/1 to 1/-1 return torch.complex(1.0 - 2.0 * x.float(), torch.zeros_like(x.float())) + def forward_soft(self, x: torch.Tensor, temp: float = 1.0, *args, **kwargs) -> torch.Tensor: + """Modulate soft bits to BPSK symbols in a differentiable manner. + + Args: + x: Input tensor of soft bit probabilities with shape (..., N) + Values should be in [0, 1] range, representing P(bit=1) + temp: Temperature parameter for soft decisions + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + Complex tensor of BPSK symbols with shape (..., N) + """ + # For BPSK, we can directly calculate the expected symbol + # P(bit=0) * (+1) + P(bit=1) * (-1) = 1 - 2*P(bit=1) + expected_symbol = torch.complex(1.0 - 2.0 * x.float(), torch.zeros_like(x.float())) + return expected_symbol + def plot_constellation(self, **kwargs) -> plt.Figure: """Plot the BPSK constellation diagram. @@ -111,6 +131,38 @@ def forward(self, y: torch.Tensor, noise_var: Optional[Union[float, torch.Tensor # LLR = log(P(y|b=0)/P(y|b=1)) = log(exp(-(y-1)²/2σ²)/exp(-(y+1)²/2σ²)) = 2y/σ² return -2.0 * y_real / noise_var + def forward_soft(self, y: torch.Tensor, noise_var: Union[float, torch.Tensor], temp: float = 1.0, *args, **kwargs) -> torch.Tensor: + """Demodulate BPSK symbols to soft bit probabilities. + + Args: + y: Received symbols with shape (..., N) + noise_var: Noise variance (required) + temp: Temperature parameter for controlling softness of decisions + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + Soft bit probabilities with shape (..., N) + Values are in [0, 1] range, representing P(bit=1) + """ + # Calculate LLRs + llrs = self.forward(y, noise_var, *args, **kwargs) + + # Convert LLRs to probabilities with temperature scaling + # P(bit=1) = 1 / (1 + exp(LLR/temp)) + return torch.sigmoid(-llrs / temp) + + def plot_constellation(self, **kwargs) -> plt.Figure: + """Plot the BPSK constellation diagram. + + Args: + **kwargs: Additional arguments passed to plot_constellation + + Returns: + Matplotlib figure object + """ + return plot_constellation(self.constellation, labels=["0", "1"], title="BPSK Constellation", **kwargs) + @ModulationRegistry.register_modulator() class QPSKModulator(BaseModulator): @@ -151,15 +203,13 @@ def __init__(self, normalize: bool = True, *args, **kwargs) -> None: [0.0, 1.0], # Fourth quadrant [1.0, 0.0], # Second quadrant [1.0, 1.0], # Third quadrant - ], - dtype=torch.float, + ] ) self.register_buffer("bit_patterns", bit_patterns) - self._bits_per_symbol = 2 # QPSK has 2 bits per symbol def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: - """Modulate bit pairs to QPSK symbols. + """Modulate binary inputs to QPSK symbols. Args: x: Input tensor of bits with shape (..., 2*N) @@ -169,24 +219,51 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: Returns: Complex tensor of QPSK symbols with shape (..., N) """ - # Ensure input length is even batch_shape = x.shape[:-1] - bit_len = x.shape[-1] - if bit_len % 2 != 0: - raise ValueError("Input bit length must be even for QPSK modulation") + num_bits = x.shape[-1] - # Reshape to pairs of bits - x_reshaped = x.reshape(*batch_shape, -1, 2) + if num_bits % 2 != 0: + raise ValueError(f"Number of input bits ({num_bits}) must be even for QPSK modulation") - # Convert bit pairs to indices using Gray coding pattern - indices = x_reshaped[..., 0].to(torch.long) * 2 + x_reshaped[..., 1].to(torch.long) + # Reshape to (..., N, 2) + x_pairs = x.reshape(*batch_shape, -1, 2) - # Handle empty tensor case - if indices.numel() == 0: - return torch.empty((*batch_shape, 0), dtype=torch.complex64, device=x.device) + # Map bit pairs to symbol indices + indices = x_pairs[..., 0] * 2 + x_pairs[..., 1] # Convert bit pairs to indices - # Map indices to symbols - return self.constellation[indices] + # Map indices to constellation symbols + symbols = self.constellation[indices.long()] + + return symbols + + def forward_soft(self, x: torch.Tensor, temp: float = 1.0, *args, **kwargs) -> torch.Tensor: + """Modulate soft bits to QPSK symbols in a differentiable manner. + + Args: + x: Input tensor of soft bit probabilities with shape (..., 2*N) + Values should be in [0, 1] range, representing P(bit=1) + temp: Temperature parameter for soft decisions + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + Complex tensor of QPSK symbols with shape (..., N) + """ + from .differentiable import soft_symbol_mapping + + batch_shape = x.shape[:-1] + num_bits = x.shape[-1] + + if num_bits % 2 != 0: + raise ValueError(f"Number of input bits ({num_bits}) must be even for QPSK modulation") + + # Reshape to (..., N, 2) + x_pairs = x.reshape(*batch_shape, -1, 2) + + # Use differentiable symbol mapping + symbols = soft_symbol_mapping(x_pairs, self.constellation, self.bit_patterns) + + return symbols def plot_constellation(self, **kwargs) -> plt.Figure: """Plot the QPSK constellation diagram. @@ -197,12 +274,7 @@ def plot_constellation(self, **kwargs) -> plt.Figure: Returns: Matplotlib figure object """ - labels = [] - for i in range(4): - bit_pattern = self.bit_patterns[i] - labels.append(f"{int(bit_pattern[0])}{int(bit_pattern[1])}") - - return plot_constellation(self.constellation, labels=labels, title="QPSK Constellation", **kwargs) + return plot_constellation(self.constellation, labels=["00", "01", "10", "11"], title="QPSK Constellation", **kwargs) @ModulationRegistry.register_demodulator() @@ -275,56 +347,60 @@ def forward(self, y: torch.Tensor, noise_var: Optional[Union[float, torch.Tensor llrs = torch.zeros((*batch_shape, symbol_shape, 2), device=y.device) # For each bit position, compute the LLR using max-log approximation - for bit_idx in range(2): - # Separate constellation points for bit=0 and bit=1 - bit_0_mask = self.modulator.bit_patterns[:, bit_idx] == 0 - bit_1_mask = ~bit_0_mask + for bit_idx in range(2): # QPSK has 2 bits per symbol + # Get constellation points corresponding to bit=0 and bit=1 + bit_0_indices = (self.modulator.bit_patterns[:, bit_idx] == 0).nonzero().squeeze(1) + bit_1_indices = (self.modulator.bit_patterns[:, bit_idx] == 1).nonzero().squeeze(1) + + const_bit_0 = self.modulator.constellation[bit_0_indices] # Points with bit=0 + const_bit_1 = self.modulator.constellation[bit_1_indices] # Points with bit=1 + + # Compute min distance to points with bit=0 and bit=1 + dist_bit_0 = torch.min( + torch.abs(y.unsqueeze(-1) - const_bit_0.unsqueeze(0).unsqueeze(0)), + dim=-1, + )[0] + dist_bit_1 = torch.min( + torch.abs(y.unsqueeze(-1) - const_bit_1.unsqueeze(0).unsqueeze(0)), + dim=-1, + )[0] + + # Compute LLR = log(P(y|b=0)/P(y|b=1)) + # Using max-log approximation: LLR ≈ (min_dist_b1^2 - min_dist_b0^2)/(2*noise_var) + llrs[..., bit_idx] = (dist_bit_1**2 - dist_bit_0**2) / (2 * noise_var) - # Get corresponding constellation points - const_bit_0 = self.modulator.constellation[bit_0_mask] - const_bit_1 = self.modulator.constellation[bit_1_mask] - - # Calculate minimum distances - min_dist_0 = self._min_distance_to_points(y, const_bit_0, noise_var) - min_dist_1 = self._min_distance_to_points(y, const_bit_1, noise_var) - - # LLR = log(P(bit=0|y)/P(bit=1|y)) - llrs[..., bit_idx] = min_dist_1 - min_dist_0 - - # Reshape to final sequence + # Reshape to final LLR sequence return llrs.reshape(*batch_shape, -1) - def _min_distance_to_points(self, y: torch.Tensor, points: torch.Tensor, noise_var: torch.Tensor) -> torch.Tensor: - """Calculate minimum (negative) distance to a set of constellation points. - - Uses max-log approximation for computational efficiency. + def forward_soft(self, y: torch.Tensor, noise_var: Union[float, torch.Tensor], temp: float = 1.0, *args, **kwargs) -> torch.Tensor: + """Demodulate QPSK symbols to soft bit probabilities in a differentiable manner. Args: y: Received symbols with shape (..., N) - points: Constellation points to compare against with shape (M,) - noise_var: Noise variance with shape (..., N) + noise_var: Noise variance (required) + temp: Temperature parameter for controlling softness of decisions + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. Returns: - Minimum negative distance for each symbol in y + Soft bit probabilities with shape (..., N*2) + Values are in [0, 1] range, representing P(bit=1) """ batch_shape = y.shape[:-1] symbol_shape = y.shape[-1] - num_points = points.shape[0] - # Reshape inputs for broadcasting - y.unsqueeze(-1) # (..., N, 1) + # Get LLRs + llrs = self.forward(y, noise_var, *args, **kwargs) - # Fix the dimension mismatch by directly calculating distances for each point - distances = torch.zeros((*batch_shape, symbol_shape, num_points), device=y.device) + # Reshape LLRs to (..., N, 2) + llrs = llrs.reshape(*batch_shape, symbol_shape, 2) - for i in range(num_points): - point = points[i] - # Calculate squared distance between each symbol and this point - distances[..., i] = -torch.abs(y - point) ** 2 / noise_var + # Convert LLRs to probabilities with temperature scaling + # P(bit=1) = 1 / (1 + exp(LLR/temp)) + probs = torch.sigmoid(-llrs / temp) - # Return maximum (least negative) value for each symbol - max_values, _ = torch.max(distances, dim=-1) - return max_values + # Reshape to final probability sequence + return probs.reshape(*batch_shape, -1) @ModulationRegistry.register_modulator() diff --git a/kaira/modulations/qam.py b/kaira/modulations/qam.py index 922de99..6fad211 100644 --- a/kaira/modulations/qam.py +++ b/kaira/modulations/qam.py @@ -101,7 +101,7 @@ def _create_constellation(self) -> None: self.register_buffer("bit_patterns", bit_patterns) def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: - """Modulate bit groups to QAM symbols. + """Modulate binary inputs to QAM symbols. Args: x: Input tensor of bits with shape (..., K*N), where K is bits_per_symbol @@ -111,27 +111,52 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: Returns: Complex tensor of QAM symbols with shape (..., N) """ - # Ensure input length is divisible by bits_per_symbol batch_shape = x.shape[:-1] - bit_len = x.shape[-1] - if bit_len % self._bits_per_symbol != 0: - raise ValueError(f"Input bit length must be divisible by {self._bits_per_symbol}") + num_bits = x.shape[-1] - # Reshape to groups of bits_per_symbol - x_reshaped = x.reshape(*batch_shape, -1, self._bits_per_symbol) + if num_bits % self._bits_per_symbol != 0: + raise ValueError(f"Number of input bits ({num_bits}) must be divisible by bits_per_symbol ({self._bits_per_symbol})") - # For each group of bits, find the matching constellation point - symbols = torch.zeros((*batch_shape, x_reshaped.shape[-2]), dtype=torch.complex64, device=x.device) + # Reshape to (..., N, K) + x_groups = x.reshape(*batch_shape, -1, self._bits_per_symbol) - # Search through bit_patterns for each group of bits to find the matching constellation point - for i in range(self.order): - # Create a mask for where the current bit pattern matches the input bits - # Need to compare across the bits_per_symbol dimension - pattern = self.bit_patterns[i].to(x.device) - mask = torch.all(torch.eq(x_reshaped, pattern), dim=-1) + # Map bit groups to symbol indices + indices = torch.zeros((*batch_shape, x_groups.shape[-2]), dtype=torch.long, device=x.device) + for i in range(self._bits_per_symbol): + indices = indices * 2 + x_groups[..., i].long() - # Assign the corresponding constellation point - symbols[mask] = self.constellation[i] + # Map indices to constellation symbols + symbols = self.constellation[indices] + + return symbols + + def forward_soft(self, x: torch.Tensor, temp: float = 1.0, *args, **kwargs) -> torch.Tensor: + """Modulate soft bits to QAM symbols in a differentiable manner. + + Args: + x: Input tensor of soft bit probabilities with shape (..., K*N), + where K is bits_per_symbol. Values should be in [0, 1] range, + representing P(bit=1) + temp: Temperature parameter for soft decisions + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + Complex tensor of QAM symbols with shape (..., N) + """ + from .differentiable import soft_symbol_mapping + + batch_shape = x.shape[:-1] + num_bits = x.shape[-1] + + if num_bits % self._bits_per_symbol != 0: + raise ValueError(f"Number of input bits ({num_bits}) must be divisible by bits_per_symbol ({self._bits_per_symbol})") + + # Reshape to (..., N, K) + x_groups = x.reshape(*batch_shape, -1, self._bits_per_symbol) + + # Use differentiable symbol mapping + symbols = soft_symbol_mapping(x_groups, self.constellation, self.bit_patterns) return symbols @@ -144,12 +169,8 @@ def plot_constellation(self, **kwargs) -> plt.Figure: Returns: Matplotlib figure object """ - labels = [] - for i in range(self.order): - bit_pattern = self.bit_patterns[i] - bit_str = "".join(str(int(bit)) for bit in bit_pattern) - labels.append(bit_str) - + # Format labels as binary strings + labels = [format(i, f"0{self._bits_per_symbol}b") for i in range(self.order)] return plot_constellation(self.constellation, labels=labels, title=f"{self.order}-QAM Constellation", **kwargs) @@ -226,62 +247,65 @@ def forward(self, y: torch.Tensor, noise_var: Optional[Union[float, torch.Tensor noise_var = noise_var.expand(*batch_shape, symbol_shape) # Calculate LLRs for each bit position + bit_patterns = self.modulator.bit_patterns.to(y.device) # (order, bits_per_symbol) llrs = torch.zeros((*batch_shape, symbol_shape, self._bits_per_symbol), device=y.device) - # For each bit position - for bit_idx in range(self._bits_per_symbol): - # Create masks for symbols where bit is 0 or 1 - bit_0_mask = self.modulator.bit_patterns[:, bit_idx] == 0 - bit_1_mask = ~bit_0_mask + # Expand constellation for vectorized calculation of distances + expanded_y = y.unsqueeze(-1) # (..., N, 1) + expanded_const = constellation.expand(*([1] * len(batch_shape)), symbol_shape, self.order) # (..., N, order) + + # Calculate Euclidean distances + # We don't need to square these for the LLR calculation since we'll use them directly + # in the exponential function, and we want to use the true distance + distances = torch.abs(expanded_y - expanded_const) ** 2 # (..., N, order) - # Get constellation points for each bit value - const_bit_0 = constellation[bit_0_mask] - const_bit_1 = constellation[bit_1_mask] + # Apply -dist^2/(2*sigma^2) to get log-likelihoods (up to a constant) + log_likelihoods = -distances / (2 * noise_var.unsqueeze(-1)) # (..., N, order) - # Calculate minimum squared Euclidean distance for each bit value - # For LLR calculation, smaller distance means higher probability - dist_0 = self._min_squared_distance(y, const_bit_0) - dist_1 = self._min_squared_distance(y, const_bit_1) + # For each bit position, calculate LLR = log(P(y|b=0)/P(y|b=1)) + for bit_idx in range(self._bits_per_symbol): + # Get constellation points corresponding to bit=0 and bit=1 + bit_0_mask = bit_patterns[:, bit_idx] == 0 # Binary mask for bit=0 symbols + bit_1_mask = bit_patterns[:, bit_idx] == 1 # Binary mask for bit=1 symbols + + # Apply max-log approximation to compute LLRs + # LLR ≈ max(log(P(y|x_i))) for all i with b_i=1 - max(log(P(y|x_j))) for all j with b_j=0 + # This avoids numerical issues with very large exponents + max_ll_bit_0 = log_likelihoods.masked_fill(~bit_0_mask.unsqueeze(0).unsqueeze(0), -float("inf")).max(dim=-1)[0] + max_ll_bit_1 = log_likelihoods.masked_fill(~bit_1_mask.unsqueeze(0).unsqueeze(0), -float("inf")).max(dim=-1)[0] - # Calculate LLR as log(P(bit=0)/P(bit=1)) - # For AWGN channel: LLR = (dist_1 - dist_0)/(2*noise_var) - # Positive LLR means bit 0 is more likely - llrs[..., bit_idx] = (dist_1 - dist_0) / (2 * noise_var) + # LLR = log(P(b=0|y)/P(b=1|y)) = log(P(y|b=0)/P(y|b=1)) = max_ll_bit_0 - max_ll_bit_1 + llrs[..., bit_idx] = max_ll_bit_0 - max_ll_bit_1 + # Reshape to final bit sequence return llrs.reshape(*batch_shape, -1) - def _min_squared_distance(self, y: torch.Tensor, points: torch.Tensor) -> torch.Tensor: - """Calculate minimum squared Euclidean distance to constellation points. + def forward_soft(self, y: torch.Tensor, noise_var: Union[float, torch.Tensor], temp: float = 1.0, *args, **kwargs) -> torch.Tensor: + """Demodulate QAM symbols to soft bit probabilities in a differentiable manner. Args: y: Received symbols with shape (..., N) - points: Constellation points to compare against with shape (M,) + noise_var: Noise variance (required) + temp: Temperature parameter for controlling softness of decisions + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. Returns: - Minimum squared distance for each symbol in y + Soft bit probabilities with shape (..., N*bits_per_symbol) + Values are in [0, 1] range, representing P(bit=1) """ batch_shape = y.shape[:-1] symbol_shape = y.shape[-1] - num_points = points.shape[0] - # Handle different tensor shapes correctly - if batch_shape: - # Multi-dimensional tensors - y_expanded = y.unsqueeze(-1).expand(*batch_shape, symbol_shape, num_points) + # Get LLRs + llrs = self.forward(y, noise_var, *args, **kwargs) - # Properly reshape points for broadcasting - points_expanded = points.reshape(*([1] * len(batch_shape)), 1, -1) - points_expanded = points_expanded.expand(*batch_shape, symbol_shape, num_points) - else: - # 1D tensors - y_expanded = y.unsqueeze(-1).expand(symbol_shape, num_points) - points_expanded = points.reshape(1, -1).expand(symbol_shape, num_points) - - # Calculate squared Euclidean distances - # For complex numbers: |a - b|^2 = (a - b) * conj(a - b) - diff = y_expanded - points_expanded - squared_distances = torch.real(diff * torch.conj(diff)) - - # Find minimum distance across all points - min_distances, _ = torch.min(squared_distances, dim=-1) - return min_distances + # Reshape LLRs to (..., N, bits_per_symbol) + llrs = llrs.reshape(*batch_shape, symbol_shape, self._bits_per_symbol) + + # Convert LLRs to probabilities with temperature scaling + # P(bit=1) = 1 / (1 + exp(LLR/temp)) + probs = torch.sigmoid(-llrs / temp) + + # Reshape to final probability sequence + return probs.reshape(*batch_shape, -1) diff --git a/tests/modulations/test_differentiable.py b/tests/modulations/test_differentiable.py new file mode 100644 index 0000000..e58b9b6 --- /dev/null +++ b/tests/modulations/test_differentiable.py @@ -0,0 +1,288 @@ +"""Tests for differentiable modulation operations.""" + +import torch + +from kaira.modulations import ( + BPSKDemodulator, + BPSKModulator, + QAMDemodulator, + QAMModulator, + QPSKDemodulator, + QPSKModulator, +) +from kaira.modulations.differentiable import ( + hard_decisions_with_straight_through, + soft_bits_to_hard_symbols, + soft_symbol_mapping, +) + + +class TestDifferentiableOperations: + """Test suite for differentiable modulation operations.""" + + def test_soft_symbol_mapping(self): + """Test soft symbol mapping function.""" + # Create a simple constellation with 2 points + constellation = torch.tensor([1 + 0j, -1 + 0j]) + bit_patterns = torch.tensor([[0.0], [1.0]]) + + # Test with hard probabilities (0 and 1) + soft_bits = torch.tensor([[0.0], [1.0]]) + symbols = soft_symbol_mapping(soft_bits, constellation, bit_patterns) + assert symbols.shape == torch.Size([2]) + assert torch.isclose(symbols[0], torch.tensor(1 + 0j)) + assert torch.isclose(symbols[1], torch.tensor(-1 + 0j)) + + # Test with soft probabilities + soft_bits = torch.tensor([[0.3], [0.7]]) + symbols = soft_symbol_mapping(soft_bits, constellation, bit_patterns) + assert symbols.shape == torch.Size([2]) + # Expected: 0.3 * (-1) + 0.7 * 1 = 0.4 for first symbol + # Expected: 0.7 * (-1) + 0.3 * 1 = -0.4 for second symbol + assert torch.isclose(symbols[0], torch.tensor(0.4 + 0j), atol=1e-6) + assert torch.isclose(symbols[1], torch.tensor(-0.4 + 0j), atol=1e-6) + + def test_soft_bits_to_hard_symbols(self): + """Test soft to hard symbol conversion with differentiability.""" + # Create a simple constellation with 2 points + constellation = torch.tensor([1 + 0j, -1 + 0j]) + bit_patterns = torch.tensor([[0.0], [1.0]]) + + # Test with different temperatures + soft_bits = torch.tensor([[0.3], [0.7]]) + symbols_temp_1 = soft_bits_to_hard_symbols(soft_bits, constellation, bit_patterns, temp=1.0) + symbols_temp_01 = soft_bits_to_hard_symbols(soft_bits, constellation, bit_patterns, temp=0.1) + + # Lower temperature should make decisions harder (closer to constellation points) + assert torch.abs(symbols_temp_01[0] - constellation[0]) < torch.abs(symbols_temp_1[0] - constellation[0]) + assert torch.abs(symbols_temp_01[1] - constellation[1]) < torch.abs(symbols_temp_1[1] - constellation[1]) + + def test_hard_decisions_with_straight_through(self): + """Test hard decision with straight-through estimator.""" + # Create input requiring gradients + soft_values = torch.tensor([0.3, 0.6, 0.8], requires_grad=True) + + # Apply hard decision with straight-through estimator + hard_values = hard_decisions_with_straight_through(soft_values) + + # Check forward pass (hard decisions) + assert hard_values.detach().tolist() == [0.0, 1.0, 1.0] + + # Check that gradients can flow through + loss = hard_values.sum() + loss.backward() + assert soft_values.grad is not None + # The gradient should be all ones since we're using straight-through + assert torch.allclose(soft_values.grad, torch.ones_like(soft_values)) + + +class TestDifferentiableModulators: + """Test suite for differentiable modulators.""" + + def test_bpsk_modulators_diff(self): + """Test differentiable BPSK modulation.""" + modulator = BPSKModulator() + + # Create soft bits with gradients + soft_bits = torch.tensor([0.1, 0.4, 0.6, 0.9], requires_grad=True) + + # Test forward_soft + soft_symbols = modulator.forward_soft(soft_bits) + + # Verify shapes + assert soft_symbols.shape == soft_bits.shape + + # Compute loss and check gradients + loss = soft_symbols.real.sum() + loss.backward() + + # Verify gradients exist + assert soft_bits.grad is not None + # For BPSK the gradient should be -2 for all inputs + assert torch.allclose(soft_bits.grad, torch.tensor([-2.0, -2.0, -2.0, -2.0])) + + def test_qpsk_modulators_diff(self): + """Test differentiable QPSK modulation.""" + modulator = QPSKModulator() + + # Create soft bits with gradients (needs to be even number for QPSK) + soft_bits = torch.tensor([0.1, 0.9, 0.2, 0.8, 0.7, 0.3], requires_grad=True) + + # Test forward_soft + soft_symbols = modulator.forward_soft(soft_bits) + + # Verify shapes - QPSK has 2 bits per symbol + assert soft_symbols.shape == torch.Size([3]) + + # Compute loss and check gradients + loss = soft_symbols.abs().sum() + loss.backward() + + # Verify gradients exist + assert soft_bits.grad is not None + + def test_qam_modulators_diff(self): + """Test differentiable QAM modulation.""" + modulator = QAMModulator(order=16) + + # Create soft bits with gradients (16-QAM has 4 bits per symbol) + soft_bits = torch.rand(8, requires_grad=True) + + # Test forward_soft + soft_symbols = modulator.forward_soft(soft_bits) + + # Verify shapes - 16-QAM has 4 bits per symbol + assert soft_symbols.shape == torch.Size([2]) + + # Compute loss and check gradients + loss = soft_symbols.abs().sum() + loss.backward() + + # Verify gradients exist + assert soft_bits.grad is not None + + +class TestDifferentiableDemodulators: + """Test suite for differentiable demodulators.""" + + def test_bpsk_demodulator_diff(self): + """Test differentiable BPSK demodulation.""" + demodulator = BPSKDemodulator() + + # Create symbols with gradients + symbols = torch.tensor([0.5 + 0j, -0.2 + 0j, 1.5 + 0j], requires_grad=True) + + # Test forward_soft with noise variance + noise_var = 0.1 + soft_bits = demodulator.forward_soft(symbols, noise_var) + + # Verify shapes + assert soft_bits.shape == symbols.shape + + # Verify values are between 0 and 1 (probabilities) + assert (soft_bits >= 0).all() and (soft_bits <= 1).all() + + # Compute loss and check gradients + loss = soft_bits.sum() + loss.backward() + + # Verify gradients exist + assert symbols.grad is not None + + def test_qpsk_demodulator_diff(self): + """Test differentiable QPSK demodulation.""" + demodulator = QPSKDemodulator() + + # Create symbols with gradients + symbols = torch.tensor([0.5 + 0.5j, -0.5 - 0.5j], requires_grad=True) + + # Test forward_soft with noise variance + noise_var = 0.1 + soft_bits = demodulator.forward_soft(symbols, noise_var) + + # Verify shapes - QPSK has 2 bits per symbol + assert soft_bits.shape == torch.Size([4]) + + # Verify values are between 0 and 1 (probabilities) + assert (soft_bits >= 0).all() and (soft_bits <= 1).all() + + # Compute loss and check gradients + loss = soft_bits.sum() + loss.backward() + + # Verify gradients exist + assert symbols.grad is not None + + def test_qam_demodulator_diff(self): + """Test differentiable QAM demodulation.""" + demodulator = QAMDemodulator(order=16) + + # Create symbols with gradients + symbols = torch.tensor([1.0 + 1.0j, -1.0 - 1.0j], requires_grad=True) + + # Test forward_soft with noise variance + noise_var = 0.1 + soft_bits = demodulator.forward_soft(symbols, noise_var) + + # Verify shapes - 16-QAM has 4 bits per symbol + assert soft_bits.shape == torch.Size([8]) + + # Verify values are between 0 and 1 (probabilities) + assert (soft_bits >= 0).all() and (soft_bits <= 1).all() + + # Compute loss and check gradients + loss = soft_bits.sum() + loss.backward() + + # Verify gradients exist + assert symbols.grad is not None + + +class TestEndToEndDifferentiability: + """Test end-to-end differentiability of modulation and demodulation.""" + + def test_bpsk_end_to_end(self): + """Test end-to-end differentiability with BPSK.""" + modulator = BPSKModulator() + demodulator = BPSKDemodulator() + + # Create soft bits with gradients + soft_bits = torch.tensor([0.1, 0.4, 0.6, 0.9], requires_grad=True) + + # Apply modulation + symbols = modulator.forward_soft(soft_bits) + + # Apply demodulation + noise_var = 0.1 + decoded_bits = demodulator.forward_soft(symbols, noise_var) + + # Compute loss between original and decoded bits + loss = torch.nn.functional.binary_cross_entropy(decoded_bits, soft_bits) + loss.backward() + + # Verify gradients exist + assert soft_bits.grad is not None + + def test_qpsk_end_to_end(self): + """Test end-to-end differentiability with QPSK.""" + modulator = QPSKModulator() + demodulator = QPSKDemodulator() + + # Create soft bits with gradients + soft_bits = torch.tensor([0.1, 0.9, 0.2, 0.8], requires_grad=True) + + # Apply modulation + symbols = modulator.forward_soft(soft_bits) + + # Apply demodulation + noise_var = 0.1 + decoded_bits = demodulator.forward_soft(symbols, noise_var) + + # Compute loss between original and decoded bits + loss = torch.nn.functional.binary_cross_entropy(decoded_bits, soft_bits) + loss.backward() + + # Verify gradients exist + assert soft_bits.grad is not None + + def test_qam_end_to_end(self): + """Test end-to-end differentiability with 16-QAM.""" + modulator = QAMModulator(order=16) + demodulator = QAMDemodulator(order=16) + + # Create soft bits with gradients (16-QAM has 4 bits per symbol) + soft_bits = torch.tensor([0.1, 0.9, 0.2, 0.8, 0.7, 0.3, 0.4, 0.6], requires_grad=True) + + # Apply modulation + symbols = modulator.forward_soft(soft_bits) + + # Apply demodulation + noise_var = 0.1 + decoded_bits = demodulator.forward_soft(symbols, noise_var) + + # Compute loss between original and decoded bits + loss = torch.nn.functional.binary_cross_entropy(decoded_bits, soft_bits) + loss.backward() + + # Verify gradients exist + assert soft_bits.grad is not None