diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..b9b8926b --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,15 @@ +# Changelog + +## Unreleased + +### Added +- Support for loading and operating external PyTorch `torch.utils.data.Dataset` objects (#382) + - New `load_external_dataset` function in `data_processing.py` + - Updated `perform_training` in `baler.py` to handle external datasets + - Updated `train` function in `training.py` to accept DataLoaders directly + - Added documentation and examples for using external datasets + - Added unit tests for the new functionality + +### Changed + +### Fixed \ No newline at end of file diff --git a/baler/baler.py b/baler/baler.py index f7bcfb2e..54ac1730 100644 --- a/baler/baler.py +++ b/baler/baler.py @@ -15,8 +15,12 @@ import os import time from math import ceil +import importlib +from typing import Optional, Union import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader from .modules import helper import gzip @@ -95,6 +99,86 @@ def perform_training(output_path, config, verbose: bool): Raises: NameError: Baler currently only supports 1D (e.g. HEP) or 2D (e.g. CFD) data as inputs. """ + # Check if an external dataset is provided + if hasattr(config, "external_dataset") and config.external_dataset is not None: + # Load the external dataset module + if verbose: + print(f"Using external dataset from {config.external_dataset}") + + try: + # Import the external dataset + if isinstance(config.external_dataset, str): + # Assuming external_dataset is a module path like "mymodule.mydataset" + module_path, class_name = config.external_dataset.rsplit(".", 1) + module = importlib.import_module(module_path) + dataset_class = getattr(module, class_name) + + # Initialize the dataset + if hasattr(config, "dataset_args") and config.dataset_args is not None: + external_dataset = dataset_class(**config.dataset_args) + else: + external_dataset = dataset_class() + else: + # Assuming external_dataset is already a Dataset instance + external_dataset = config.external_dataset + + if not isinstance(external_dataset, Dataset): + raise ValueError("The provided external_dataset is not a PyTorch Dataset instance") + + # Create DataLoaders + from .modules import data_processing + train_loader, val_loader = data_processing.load_external_dataset( + dataset=external_dataset, + test_size=config.test_size, + batch_size=config.batch_size, + shuffle=True, + random_state=42 if config.deterministic_algorithm else None, + deterministic=config.deterministic_algorithm + ) + + # Initialize model + model_object = helper.model_init(config.model_name) + + # Get an example batch to determine feature size + device = helper.get_device() + example_batch = next(iter(train_loader)).to(device) + + # Determine input dimensions based on the model type and data dimensions + if config.data_dimension == 2: + if config.model_type == "dense": + n_features = example_batch.shape[1] * example_batch.shape[2] + elif config.model_type == "convolutional": + # Get the flattened size from convolutional features + n_features = example_batch.shape[1] * example_batch.shape[2] * example_batch.shape[3] + else: # 1D data + n_features = example_batch.shape[1] + + # Calculate latent space size based on compression ratio + z_dim = int(n_features * config.compression_ratio) + + variables = {"n_features": n_features, "z_dim": z_dim} + + if verbose: + print(f"Input features: {n_features}, Latent dimension: {z_dim}") + + model = model_object(n_features, z_dim) + + # Train the model with the external dataset + from .modules import training + trained_model = training.train( + model=model, + variables=variables, + train_loader=train_loader, + val_loader=val_loader, + project_path=output_path, + config=config + ) + + return + except (ImportError, AttributeError, ValueError) as e: + raise ValueError(f"Failed to load external dataset: {str(e)}") + + # Original code for numpy array-based datasets ( train_set_norm, test_set_norm, diff --git a/baler/modules/data_processing.py b/baler/modules/data_processing.py index d14d9958..a4434ec4 100644 --- a/baler/modules/data_processing.py +++ b/baler/modules/data_processing.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +from typing import List, Tuple, Optional, Union import numpy as np import torch from numpy import ndarray from sklearn.model_selection import train_test_split +from torch.utils.data import Dataset, DataLoader +import random from ..modules import helper from ..modules import models @@ -201,3 +203,90 @@ def renormalize_func(norm_data: ndarray, min_list: List, range_list: List) -> nd min_list = np.array(min_list) range_list = np.array(range_list) return norm_data * range_list + min_list + + +def load_external_dataset(dataset: Dataset, test_size: float = 0.2, + batch_size: int = 128, shuffle: bool = True, + random_state: int = 42, + deterministic: bool = False) -> Tuple[DataLoader, DataLoader]: + """Load an external PyTorch Dataset and split it into training and validation sets. + + Args: + dataset (Dataset): An instance of a PyTorch Dataset. + test_size (float): Proportion of the dataset to include in the validation split. + batch_size (int): How many samples per batch to load. + shuffle (bool): Whether to shuffle the data before splitting and in the DataLoader. + random_state (int): Controls the shuffling applied to the data before splitting. + deterministic (bool): If True, sets the random seed for reproducibility. + + Returns: + Tuple[DataLoader, DataLoader]: Tuple containing training and validation DataLoaders. + """ + # Set the seed for reproducibility if deterministic is True + if deterministic: + torch.manual_seed(random_state) + np.random.seed(random_state) + random.seed(random_state) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(random_state) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # Handle edge cases for test_size + total_size = len(dataset) + if test_size <= 0.0: + # All data goes to train set + train_indices = list(range(total_size)) + val_indices = [] + elif test_size >= 1.0: + # All data goes to validation set + train_indices = [] + val_indices = list(range(total_size)) + else: + # Regular split using train_test_split + indices = list(range(total_size)) + train_indices, val_indices = train_test_split( + indices, test_size=test_size, random_state=random_state if shuffle else None, shuffle=shuffle + ) + + # Create subset datasets + train_subset = torch.utils.data.Subset(dataset, train_indices) + val_subset = torch.utils.data.Subset(dataset, val_indices) + + # Define worker_init_fn for reproducibility + worker_init_fn = seed_worker if deterministic else None + generator = torch.Generator().manual_seed(random_state) if deterministic else None + + # Create DataLoaders with appropriate batch size + train_batch_size = min(batch_size, len(train_subset)) if train_indices else 1 + val_batch_size = min(batch_size, len(val_subset)) if val_indices else 1 + + # Create DataLoaders, ensuring we don't try to shuffle empty datasets + train_loader = DataLoader( + train_subset, + batch_size=train_batch_size, + shuffle=shuffle and len(train_subset) > 0, + worker_init_fn=worker_init_fn, + generator=generator + ) + + val_loader = DataLoader( + val_subset, + batch_size=val_batch_size, + shuffle=False, # Usually validation data is not shuffled + worker_init_fn=worker_init_fn, + generator=generator + ) + + return train_loader, val_loader + + +def seed_worker(worker_id): + """Function to seed DataLoader workers for reproducibility. + + Args: + worker_id: The ID of the worker + """ + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/baler/modules/training.py b/baler/modules/training.py index ca189def..c4800d2f 100644 --- a/baler/modules/training.py +++ b/baler/modules/training.py @@ -147,19 +147,28 @@ def seed_worker(worker_id): random.seed(worker_seed) -def train(model, variables, train_data, test_data, project_path, config): +def train(model, variables, train_data=None, test_data=None, project_path=None, config=None, + train_loader=None, val_loader=None): """Does the entire training loop by calling the `fit()` and `validate()`. Appart from this, this is the main function where the data is converted to the correct type for it to be trained, via `torch.Tensor()`. Furthermore, the batching is also done here, based on `config.batch_size`, and it is the `torch.utils.data.DataLoader` doing the splitting. Applying either `EarlyStopping` or `LR Scheduler` is also done here, all based on their respective `config` arguments. For reproducibility, the seeds can also be fixed in this function. + + The function now supports both traditional numpy arrays and PyTorch DataLoaders directly: + - If train_loader and val_loader are provided, they will be used directly + - Otherwise, train_data and test_data arrays will be converted to tensors and DataLoaders + Args: model (modelObject): The model you wish to train - variables (_type_): _description_ - train_set (ndarray): Array consisting of the train set - test_set (ndarray): Array consisting of the test set + variables (dict): Dictionary containing model parameters + train_data (ndarray, optional): Array consisting of the train set. Required if train_loader is None. + test_data (ndarray, optional): Array consisting of the test set. Required if val_loader is None. project_path (string): Path to the project directory config (dataClass): Base class selecting user inputs + train_loader (DataLoader, optional): Pre-configured DataLoader for training data. If provided, train_data is ignored. + val_loader (DataLoader, optional): Pre-configured DataLoader for validation data. If provided, test_data is ignored. + Returns: modelObject: fully trained model ready to perform compression and decompression """ @@ -180,7 +189,7 @@ def train(model, variables, train_data, test_data, project_path, config): rho = config.RHO l1 = config.l1 epochs = config.epochs - latent_space_size = config.latent_space_size + latent_space_size = config.latent_space_size if hasattr(config, "latent_space_size") else variables["z_dim"] intermittent_model_saving = config.intermittent_model_saving intermittent_saving_patience = config.intermittent_saving_patience @@ -190,77 +199,80 @@ def train(model, variables, train_data, test_data, project_path, config): device = helper.get_device() model = model.to(device) - # Converting data to tensors - if config.data_dimension == 2: - if config.model_type == "dense": - # print(train_data.shape) - # print(test_data.shape) - # sys.exit() - train_ds = torch.tensor( - train_data, dtype=torch.float32, device=device - ).view(train_data.shape[0], train_data.shape[1] * train_data.shape[2]) - valid_ds = torch.tensor(test_data, dtype=torch.float32, device=device).view( - test_data.shape[0], test_data.shape[1] * test_data.shape[2] + # Check if DataLoaders are provided directly + if train_loader is not None and val_loader is not None: + # Use the provided DataLoaders + train_dl = train_loader + valid_dl = val_loader + else: + # Converting data to tensors + if config.data_dimension == 2: + if config.model_type == "dense": + train_ds = torch.tensor( + train_data, dtype=torch.float32, device=device + ).view(train_data.shape[0], train_data.shape[1] * train_data.shape[2]) + valid_ds = torch.tensor(test_data, dtype=torch.float32, device=device).view( + test_data.shape[0], test_data.shape[1] * test_data.shape[2] + ) + elif config.model_type == "convolutional" and config.model_name == "Conv_AE_3D": + train_ds = torch.tensor( + train_data, dtype=torch.float32, device=device + ).view( + train_data.shape[0] // bs, + 1, + bs, + train_data.shape[1], + train_data.shape[2], + ) + valid_ds = torch.tensor(test_data, dtype=torch.float32, device=device).view( + train_data.shape[0] // bs, + 1, + bs, + train_data.shape[1], + train_data.shape[2], + ) + elif config.model_type == "convolutional": + train_ds = torch.tensor( + train_data, dtype=torch.float32, device=device + ).view(train_data.shape[0], 1, train_data.shape[1], train_data.shape[2]) + valid_ds = torch.tensor(test_data, dtype=torch.float32, device=device).view( + test_data.shape[0], 1, test_data.shape[1], test_data.shape[2] + ) + elif config.data_dimension == 1: + train_ds = torch.tensor(train_data, dtype=torch.float64, device=device) + valid_ds = torch.tensor(test_data, dtype=torch.float64, device=device) + + # Pushing input data into the torch-DataLoader object and combines into one DataLoader object (a basic wrapper + # around several DataLoader objects). + + if config.deterministic_algorithm: + train_dl = DataLoader( + train_ds, + batch_size=bs, + shuffle=False, + worker_init_fn=seed_worker, + generator=g, + drop_last=False, ) - elif config.model_type == "convolutional" and config.model_name == "Conv_AE_3D": - train_ds = torch.tensor( - train_data, dtype=torch.float32, device=device - ).view( - train_data.shape[0] // bs, - 1, - bs, - train_data.shape[1], - train_data.shape[2], + valid_dl = DataLoader( + valid_ds, + batch_size=bs, + worker_init_fn=seed_worker, + generator=g, + drop_last=False, ) - valid_ds = torch.tensor(test_data, dtype=torch.float32, device=device).view( - train_data.shape[0] // bs, - 1, - bs, - train_data.shape[1], - train_data.shape[2], + else: + train_dl = DataLoader( + train_ds, + batch_size=bs, + shuffle=False, + drop_last=False, ) - elif config.model_type == "convolutional": - train_ds = torch.tensor( - train_data, dtype=torch.float32, device=device - ).view(train_data.shape[0], 1, train_data.shape[1], train_data.shape[2]) - valid_ds = torch.tensor(test_data, dtype=torch.float32, device=device).view( - train_data.shape[0], 1, train_data.shape[1], train_data.shape[2] + valid_dl = DataLoader( + valid_ds, + batch_size=bs, + drop_last=False, ) - elif config.data_dimension == 1: - train_ds = torch.tensor(train_data, dtype=torch.float64, device=device) - valid_ds = torch.tensor(test_data, dtype=torch.float64, device=device) - - # Pushing input data into the torch-DataLoader object and combines into one DataLoader object (a basic wrapper - # around several DataLoader objects). - - if config.deterministic_algorithm: - train_dl = DataLoader( - train_ds, - batch_size=bs, - shuffle=False, - worker_init_fn=seed_worker, - generator=g, - drop_last=False, - ) - valid_dl = DataLoader( - valid_ds, - batch_size=bs, - worker_init_fn=seed_worker, - generator=g, - drop_last=False, - ) - else: - train_dl = DataLoader( - train_ds, - batch_size=bs, - shuffle=False, - drop_last=False, - ) - valid_dl = DataLoader( - valid_ds, - batch_size=bs, - drop_last=False, - ) # Select Optimizer optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) diff --git a/docs/guides/external_dataset.md b/docs/guides/external_dataset.md new file mode 100644 index 00000000..8f9347d6 --- /dev/null +++ b/docs/guides/external_dataset.md @@ -0,0 +1,138 @@ +# Using External PyTorch Datasets with Baler + +Baler now supports using external PyTorch Dataset objects directly. This feature allows you to: + +- Use custom dataset implementations +- Utilize pre-existing PyTorch datasets +- Work with datasets that don't easily fit into NumPy arrays +- Apply custom transformations to your data during loading + +## Requirements + +To use this feature, you need: + +1. A class that implements the PyTorch `torch.utils.data.Dataset` interface +2. Understanding of how your dataset's dimensions map to model inputs + +## Using External Datasets + +### Option 1: Providing a Dataset Instance + +You can directly provide a PyTorch Dataset instance to Baler by adding it to your config: + +```python +import torch +from torch.utils.data import Dataset +from baler.modules.helper import Config + +# Example custom dataset +class MyCustomDataset(Dataset): + def __init__(self, data_path): + # Load your data here + self.data = ... # Your data loading logic + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + # Return a tensor representing your data item + return self.data[idx] + +# Create your dataset +my_dataset = MyCustomDataset("path/to/data") + +# Create a Baler config +config = Config(...) # Your other config parameters + +# Assign the dataset to the config +config.external_dataset = my_dataset +``` + +### Option 2: Providing a Dataset Class Path + +You can also specify the module path to your Dataset class: + +```python +# Create a Baler config +config = Config(...) # Your other config parameters + +# Specify the path to your dataset class +config.external_dataset = "mymodule.mydataset.MyCustomDataset" + +# Optionally provide arguments for dataset initialization +config.dataset_args = { + "data_path": "path/to/data", + "transform": None +} +``` + +## Important Considerations + +1. **Dataset Format**: Your dataset's `__getitem__` method should return tensors that match the expected input format of your chosen model architecture. + +2. **Dimensions**: Make sure to set `config.data_dimension` correctly (1 or 2) based on your data. + +3. **Model Type**: Set `config.model_type` to either "dense" or "convolutional" based on the architecture you want to use. + +4. **Batch Structure**: + - For 1D data: Each item should be a 1D tensor of features + - For 2D data with dense models: Each item should be a 2D tensor + - For 2D data with convolutional models: Each item should be a 2D tensor (will be reshaped to include channels) + +## Example + +Here's a complete example using the MNIST dataset: + +```python +import torch +from torchvision import datasets, transforms +from baler.modules.helper import Config + +# Create MNIST dataset +transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) +]) +mnist_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) + +# Create a Baler config +config = Config( + input_path="", # Not used with external dataset + compression_ratio=0.1, + epochs=10, + early_stopping=True, + early_stopping_patience=5, + lr_scheduler=True, + lr_scheduler_patience=2, + min_delta=0.001, + model_name="Dense_AE", + model_type="dense", + custom_norm=False, + l1=True, + reg_param=0.001, + RHO=0.05, + lr=0.001, + batch_size=64, + test_size=0.2, + data_dimension=2, + intermittent_model_saving=False, + separate_model_saving=False, + intermittent_saving_patience=10, + mse_avg=False, + mse_sum=True, + emd=False, + deterministic_algorithm=True, + apply_normalization=False # Dataset already normalized by transform +) + +# Assign the dataset +config.external_dataset = mnist_dataset +``` + +## Limitations + +1. Currently, the compression and decompression phases still require NumPy array inputs. External datasets are only supported for the training phase. + +2. Your dataset items must be directly usable by the model without additional preprocessing (beyond what your Dataset class already does). + +3. If your dataset returns tuples (e.g., data and labels), you'll need to create a wrapper that only returns the data portion. \ No newline at end of file diff --git a/docs/guides/external_dataset_example.py b/docs/guides/external_dataset_example.py new file mode 100644 index 00000000..3b257dd4 --- /dev/null +++ b/docs/guides/external_dataset_example.py @@ -0,0 +1,109 @@ +""" +Example script showing how to use an external PyTorch Dataset with Baler. +This example uses the MNIST dataset from torchvision. +""" + +import os +import numpy as np +import torch +from torch.utils.data import Dataset +from torchvision import datasets, transforms + +# Baler imports +from baler.modules.helper import Config + + +# Create a simple wrapper for MNIST that ensures it only returns data (not labels) +class MNISTDatasetWrapper(Dataset): + def __init__(self, root='./data', train=True, transform=None, download=True): + """Initialize the MNIST dataset wrapper. + + Args: + root (str): Root directory for the dataset. + train (bool): If True, use the training set, otherwise use the test set. + transform (callable, optional): Optional transform to apply to the data. + download (bool): If True, download the dataset if needed. + """ + self.mnist = datasets.MNIST( + root=root, + train=train, + transform=transform, + download=download + ) + + def __len__(self): + """Return the length of the dataset.""" + return len(self.mnist) + + def __getitem__(self, idx): + """Return the data at the specified index (without the label).""" + data, _ = self.mnist[idx] # Ignore the label + # MNIST returns 1x28x28 tensors, but we need to reshape for Baler + # For dense models, flatten the data + return data.flatten() # Flatten to a 1D tensor of 784 values + + +def main(): + """Main function to demonstrate using an external dataset with Baler.""" + # Create the dataset + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + mnist_dataset = MNISTDatasetWrapper( + root='./data', + transform=transform, + download=True + ) + + # Create the Baler config + # Note: We have to create this manually since we're not using the CLI + config = Config( + input_path="", # Not used with external dataset + compression_ratio=0.1, # Compress to 10% of original size + epochs=10, + early_stopping=True, + early_stoppin_patience=5, + lr_scheduler=True, + lr_scheduler_patience=2, + min_delta=0.001, + model_name="Dense_AE", # Using a dense autoencoder for flattened images + model_type="dense", + custom_norm=False, + l1=True, + reg_param=0.001, + RHO=0.05, + lr=0.001, + batch_size=64, + test_size=0.2, # 20% of data used for validation + data_dimension=1, # We flattened the images to 1D + intermittent_model_saving=False, + separate_model_saving=False, + intermittent_saving_patience=10, + mse_avg=False, + mse_sum=True, + emd=False, + deterministic_algorithm=True, + apply_normalization=False, # Dataset already normalized by transform + activation_extraction=False, + ) + + # Set the external dataset + config.external_dataset = mnist_dataset + + # Create output directory structure + project_path = "workspaces/examples/mnist_external_dataset" + os.makedirs(project_path, exist_ok=True) + os.makedirs(os.path.join(project_path, "output"), exist_ok=True) + + # Import and call the training function + from baler import perform_training + perform_training(os.path.join(project_path, "output"), config, verbose=True) + + print("Training completed successfully!") + print(f"Model saved in {project_path}/output") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/test_external_dataset.py b/tests/test_external_dataset.py new file mode 100644 index 00000000..76c03223 --- /dev/null +++ b/tests/test_external_dataset.py @@ -0,0 +1,247 @@ +""" +Tests for external PyTorch Dataset support. +""" + +import unittest +import torch +import numpy as np +from torch.utils.data import Dataset, DataLoader +from sklearn.model_selection import train_test_split + +from baler.modules import data_processing + + +class SyntheticFeatureDataset(Dataset): + """A simple dataset that returns random tensors.""" + + def __init__(self, size=1000, feature_dim=10): + """Initialize with random data. + + Args: + size (int): Number of samples + feature_dim (int): Dimension of each sample + """ + self.data = torch.randn(size, feature_dim) + + def __len__(self): + """Return the size of the dataset.""" + return self.data.shape[0] + + def __getitem__(self, idx): + """Return the item at the specified index.""" + return self.data[idx] + + +class SequentialFeatureDataset(Dataset): + """A dataset that returns sequential values as features.""" + + def __init__(self, size=100): + """Initialize with sequential data. + + Args: + size (int): Number of samples + """ + self.data = torch.arange(size).unsqueeze(1).float() # Shape (size, 1) + + def __len__(self): + """Return the size of the dataset.""" + return len(self.data) + + def __getitem__(self, idx): + """Return the item at the specified index.""" + return self.data[idx] + + +class TestExternalDataset(unittest.TestCase): + """Test cases for external dataset support.""" + + def test_load_external_dataset(self): + """Test loading an external dataset.""" + # Create a simple dataset + dataset = SyntheticFeatureDataset(size=1000, feature_dim=10) + + # Load the dataset with data_processing.load_external_dataset + train_loader, val_loader = data_processing.load_external_dataset( + dataset=dataset, + test_size=0.2, + batch_size=32, + shuffle=True, + random_state=42, + deterministic=True + ) + + # Check that the DataLoaders have the expected lengths + self.assertEqual(len(train_loader.dataset), 800) + self.assertEqual(len(val_loader.dataset), 200) + + # Check that the batch size is correct + for batch in train_loader: + self.assertEqual(batch.shape[0], 32) # Batch size + self.assertEqual(batch.shape[1], 10) # Feature dimension + break + + # Test with different parameters + train_loader, val_loader = data_processing.load_external_dataset( + dataset=dataset, + test_size=0.5, + batch_size=64, + shuffle=False, + deterministic=False + ) + + # Check that the DataLoaders have the expected lengths + self.assertEqual(len(train_loader.dataset), 500) + self.assertEqual(len(val_loader.dataset), 500) + + def test_seed_worker(self): + """Test the seed_worker function for reproducibility.""" + # Create a simple dataset + dataset = SyntheticFeatureDataset(size=1000, feature_dim=10) + + # Create two DataLoaders with the same seed + train_loader1, _ = data_processing.load_external_dataset( + dataset=dataset, + test_size=0.2, + batch_size=32, + shuffle=True, + random_state=42, + deterministic=True + ) + + train_loader2, _ = data_processing.load_external_dataset( + dataset=dataset, + test_size=0.2, + batch_size=32, + shuffle=True, + random_state=42, + deterministic=True + ) + + # Check that the batches are identical due to the same seed + for batch1, batch2 in zip(train_loader1, train_loader2): + torch.testing.assert_close(batch1, batch2) + break + + def test_load_external_dataset_edge_cases(self): + """Test loading with edge case test_size values.""" + dataset = SyntheticFeatureDataset(size=100, feature_dim=5) + + # Test with test_size = 0.0 + train_loader, val_loader = data_processing.load_external_dataset( + dataset=dataset, + test_size=0.0, + batch_size=10, + deterministic=True + ) + self.assertEqual(len(train_loader.dataset), 100) + self.assertEqual(len(val_loader.dataset), 0) + # Check if val_loader is empty + self.assertEqual(len(list(val_loader)), 0) + + # Test with test_size = 1.0 + train_loader, val_loader = data_processing.load_external_dataset( + dataset=dataset, + test_size=1.0, + batch_size=10, + deterministic=True + ) + self.assertEqual(len(train_loader.dataset), 0) + self.assertEqual(len(val_loader.dataset), 100) + # Check if train_loader is empty + self.assertEqual(len(list(train_loader)), 0) + + def test_load_external_dataset_no_shuffle(self): + """Test loading without shuffling to ensure order is preserved.""" + # Use an identifiable dataset with sequentially increasing values + size = 100 + dataset = SequentialFeatureDataset(size=size) + + # First get the indices that train_test_split would select + indices = list(range(size)) + train_indices, val_indices = train_test_split( + indices, test_size=0.2, shuffle=False, random_state=42 + ) + + # Now we can properly test that each loader preserves the order of its subset + expected_train_data = dataset.data[train_indices] + expected_val_data = dataset.data[val_indices] + + train_loader, val_loader = data_processing.load_external_dataset( + dataset=dataset, + test_size=0.2, + batch_size=10, + shuffle=False, # Important: No shuffling + random_state=42, + deterministic=True + ) + + # Collect data from loaders + train_data_loaded = torch.cat([batch for batch in train_loader], dim=0) + val_data_loaded = torch.cat([batch for batch in val_loader], dim=0) + + # Test that the ordering is preserved within each subset + torch.testing.assert_close(sorted(train_data_loaded), sorted(expected_train_data)) + torch.testing.assert_close(sorted(val_data_loaded), sorted(expected_val_data)) + + # Additionally test that original order is preserved (values appear in original ascending order) + self.assertTrue(torch.all(train_data_loaded[:-1] <= train_data_loaded[1:])) + self.assertTrue(torch.all(val_data_loaded[:-1] <= val_data_loaded[1:])) + + +class FeatureLabelDictDataset(Dataset): + """A dataset that returns dictionaries with features and labels.""" + def __init__(self, size=100, feature_dim=5): + self.features = torch.randn(size, feature_dim) + self.labels = torch.randint(0, 2, (size,)) + + def __len__(self): + return self.features.shape[0] + + def __getitem__(self, idx): + return {'features': self.features[idx], 'labels': self.labels[idx]} + + +class TestComplexDatasetLoading(unittest.TestCase): + """Test cases for datasets returning complex types.""" + + def test_load_dict_dataset(self): + """Test loading a dataset that returns dictionaries.""" + dataset = FeatureLabelDictDataset(size=100, feature_dim=5) + + train_loader, val_loader = data_processing.load_external_dataset( + dataset=dataset, + test_size=0.3, + batch_size=16, + shuffle=True, + random_state=42, + deterministic=True + ) + + # Check loader sizes + self.assertEqual(len(train_loader.dataset), 70) + self.assertEqual(len(val_loader.dataset), 30) + + # Check batch structure and content type + for batch in train_loader: + self.assertIsInstance(batch, dict) + self.assertIn('features', batch) + self.assertIn('labels', batch) + self.assertIsInstance(batch['features'], torch.Tensor) + self.assertIsInstance(batch['labels'], torch.Tensor) + self.assertEqual(batch['features'].shape[0], 16) # Batch size + self.assertEqual(batch['features'].shape[1], 5) # Feature dim + self.assertEqual(batch['labels'].shape[0], 16) # Batch size + break # Only check first batch + + for batch in val_loader: + self.assertIsInstance(batch, dict) + self.assertIn('features', batch) + self.assertIn('labels', batch) + # Validation batch size might be smaller for the last batch + self.assertLessEqual(batch['features'].shape[0], 16) + self.assertEqual(batch['features'].shape[1], 5) + break # Only check first batch + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file