Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Changelog

## Unreleased

### Added
- Support for loading and operating external PyTorch `torch.utils.data.Dataset` objects (#382)
- New `load_external_dataset` function in `data_processing.py`
- Updated `perform_training` in `baler.py` to handle external datasets
- Updated `train` function in `training.py` to accept DataLoaders directly
- Added documentation and examples for using external datasets
- Added unit tests for the new functionality

### Changed

### Fixed
84 changes: 84 additions & 0 deletions baler/baler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
import os
import time
from math import ceil
import importlib
from typing import Optional, Union

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

from .modules import helper
import gzip
Expand Down Expand Up @@ -95,6 +99,86 @@ def perform_training(output_path, config, verbose: bool):
Raises:
NameError: Baler currently only supports 1D (e.g. HEP) or 2D (e.g. CFD) data as inputs.
"""
# Check if an external dataset is provided
if hasattr(config, "external_dataset") and config.external_dataset is not None:
# Load the external dataset module
if verbose:
print(f"Using external dataset from {config.external_dataset}")

try:
# Import the external dataset
if isinstance(config.external_dataset, str):
# Assuming external_dataset is a module path like "mymodule.mydataset"
module_path, class_name = config.external_dataset.rsplit(".", 1)
module = importlib.import_module(module_path)
dataset_class = getattr(module, class_name)

# Initialize the dataset
if hasattr(config, "dataset_args") and config.dataset_args is not None:
external_dataset = dataset_class(**config.dataset_args)
else:
external_dataset = dataset_class()
else:
# Assuming external_dataset is already a Dataset instance
external_dataset = config.external_dataset

if not isinstance(external_dataset, Dataset):
raise ValueError("The provided external_dataset is not a PyTorch Dataset instance")

# Create DataLoaders
from .modules import data_processing
train_loader, val_loader = data_processing.load_external_dataset(
dataset=external_dataset,
test_size=config.test_size,
batch_size=config.batch_size,
shuffle=True,
random_state=42 if config.deterministic_algorithm else None,
deterministic=config.deterministic_algorithm
)

# Initialize model
model_object = helper.model_init(config.model_name)

# Get an example batch to determine feature size
device = helper.get_device()
example_batch = next(iter(train_loader)).to(device)

# Determine input dimensions based on the model type and data dimensions
if config.data_dimension == 2:
if config.model_type == "dense":
n_features = example_batch.shape[1] * example_batch.shape[2]
elif config.model_type == "convolutional":
# Get the flattened size from convolutional features
n_features = example_batch.shape[1] * example_batch.shape[2] * example_batch.shape[3]
else: # 1D data
n_features = example_batch.shape[1]

# Calculate latent space size based on compression ratio
z_dim = int(n_features * config.compression_ratio)

variables = {"n_features": n_features, "z_dim": z_dim}

if verbose:
print(f"Input features: {n_features}, Latent dimension: {z_dim}")

model = model_object(n_features, z_dim)

# Train the model with the external dataset
from .modules import training
trained_model = training.train(
model=model,
variables=variables,
train_loader=train_loader,
val_loader=val_loader,
project_path=output_path,
config=config
)

return
except (ImportError, AttributeError, ValueError) as e:
raise ValueError(f"Failed to load external dataset: {str(e)}")

# Original code for numpy array-based datasets
(
train_set_norm,
test_set_norm,
Expand Down
91 changes: 90 additions & 1 deletion baler/modules/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple
from typing import List, Tuple, Optional, Union

import numpy as np
import torch
from numpy import ndarray
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import random

from ..modules import helper
from ..modules import models
Expand Down Expand Up @@ -201,3 +203,90 @@ def renormalize_func(norm_data: ndarray, min_list: List, range_list: List) -> nd
min_list = np.array(min_list)
range_list = np.array(range_list)
return norm_data * range_list + min_list


def load_external_dataset(dataset: Dataset, test_size: float = 0.2,
batch_size: int = 128, shuffle: bool = True,
random_state: int = 42,
deterministic: bool = False) -> Tuple[DataLoader, DataLoader]:
"""Load an external PyTorch Dataset and split it into training and validation sets.

Args:
dataset (Dataset): An instance of a PyTorch Dataset.
test_size (float): Proportion of the dataset to include in the validation split.
batch_size (int): How many samples per batch to load.
shuffle (bool): Whether to shuffle the data before splitting and in the DataLoader.
random_state (int): Controls the shuffling applied to the data before splitting.
deterministic (bool): If True, sets the random seed for reproducibility.

Returns:
Tuple[DataLoader, DataLoader]: Tuple containing training and validation DataLoaders.
"""
# Set the seed for reproducibility if deterministic is True
if deterministic:
torch.manual_seed(random_state)
np.random.seed(random_state)
random.seed(random_state)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(random_state)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Handle edge cases for test_size
total_size = len(dataset)
if test_size <= 0.0:
# All data goes to train set
train_indices = list(range(total_size))
val_indices = []
elif test_size >= 1.0:
# All data goes to validation set
train_indices = []
val_indices = list(range(total_size))
else:
# Regular split using train_test_split
indices = list(range(total_size))
train_indices, val_indices = train_test_split(
indices, test_size=test_size, random_state=random_state if shuffle else None, shuffle=shuffle
)

# Create subset datasets
train_subset = torch.utils.data.Subset(dataset, train_indices)
val_subset = torch.utils.data.Subset(dataset, val_indices)

# Define worker_init_fn for reproducibility
worker_init_fn = seed_worker if deterministic else None
generator = torch.Generator().manual_seed(random_state) if deterministic else None

# Create DataLoaders with appropriate batch size
train_batch_size = min(batch_size, len(train_subset)) if train_indices else 1
val_batch_size = min(batch_size, len(val_subset)) if val_indices else 1

# Create DataLoaders, ensuring we don't try to shuffle empty datasets
train_loader = DataLoader(
train_subset,
batch_size=train_batch_size,
shuffle=shuffle and len(train_subset) > 0,
worker_init_fn=worker_init_fn,
generator=generator
)

val_loader = DataLoader(
val_subset,
batch_size=val_batch_size,
shuffle=False, # Usually validation data is not shuffled
worker_init_fn=worker_init_fn,
generator=generator
)

return train_loader, val_loader


def seed_worker(worker_id):
"""Function to seed DataLoader workers for reproducibility.

Args:
worker_id: The ID of the worker
"""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
Loading