Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions src/grelu/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,188 @@
from grelu.utils import get_aggfunc, get_transform_func


class LabeledOneHotDataset(Dataset):
"""
A general Dataset class for DNA sequences and labels. All sequences and
labels will be stored in memory.

Args:
seqs: DNA sequences as one-hot.
labels: A numpy array of shape (B, T, L) containing the labels.
tasks: A list of task names or a pandas dataframe containing task information.
If a dataframe is supplied, the row indices should be the task names.
seq_len: Uniform expected length (in base pairs) for output sequences
genome: The name of the genome from which to read sequences. Only needed if
genomic intervals are supplied.
end: Which end of the sequence to resize if necessary. Supported values are "left",
"right" and "both".
rc: If True, sequences will be augmented by reverse complementation. If False,
they will not be reverse complemented.
max_seq_shift: Maximum number of bases to shift the sequence for augmentation. This
is normally a small value (< 10). If 0, sequences will not be augmented by shifting.
label_len: Uniform expected length (in base pairs) for output labels
max_pair_shift: Maximum number of bases to shift both the sequence and label for
augmentation. If 0, sequence and label pairs will not be augmented by shifting.
label_aggfunc: Function to aggregate the labels over bin_size.
bin_size: Number of bases to aggregate in the label. Only used if label_aggfunc is not None.
If None, it will be taken as equal to label_len.
min_label_clip: Minimum value for label
max_label_clip: Maximum value for label
label_transform_func: Function to transform label values.
seed: Random seed for reproducibility
augment_mode: "random" or "serial"
"""

def __init__(
self,
seqs: Union[str, Sequence, pd.DataFrame, np.ndarray],
labels: np.ndarray,
tasks: Optional[Union[Sequence, pd.DataFrame]] = None,
seq_len: Optional[int] = None,
genome: Optional[str] = None,
end: str = "both",
rc: bool = False,
max_seq_shift: int = 0,
label_len: Optional[int] = None,
max_pair_shift: int = 0,
label_aggfunc: Optional[Union[str, Callable]] = None,
bin_size: Optional[int] = None,
min_label_clip: Optional[int] = None,
max_label_clip: Optional[int] = None,
label_transform_func: Optional[Union[str, Callable]] = None,
seed: Optional[int] = None,
augment_mode: str = "serial",
):
super().__init__()

from grelu.transforms.label_transforms import LabelTransform

# Save params
self.end = end
self.genome = genome

# Label transformation params
self.min_label_clip = min_label_clip
self.max_label_clip = max_label_clip
self.label_transform_func = get_transform_func(label_transform_func)

# Calculate sequence and label length
self.seq_len = seq_len or max(get_lengths(seqs))
self.label_len = label_len or self.seq_len

# Calculate bin size
if (bin_size) is None and (label_aggfunc is not None):
bin_size = self.label_len
self.label_aggfunc = get_aggfunc(label_aggfunc)
self.bin_size = bin_size

# Save augmentation params
self.rc = rc
self.max_seq_shift = max_seq_shift
self.max_pair_shift = max_pair_shift
self.padded_seq_len = (
self.seq_len + (2 * self.max_seq_shift) + (2 * self.max_pair_shift)
)
self.padded_label_len = self.label_len + (2 * self.max_pair_shift)

# Ingest sequences
self._load_seqs(seqs)
self.n_seqs = len(self.seqs)

# Ingest tasks
self._load_tasks(tasks)
self.n_tasks = len(self.tasks)

# Ingest labels
self._load_labels(labels)

# Create label transformer
self.label_transform = LabelTransform(
min_clip=self.min_label_clip,
max_clip=self.max_label_clip,
transform_func=self.label_transform_func,
)

# Create augmenter
self.augmenter = Augmenter(
rc=self.rc,
max_seq_shift=self.max_seq_shift,
max_pair_shift=self.max_pair_shift,
seq_len=self.seq_len,
label_len=self.label_len,
seed=seed,
mode=augment_mode,
)
self.n_augmented = len(self.augmenter)
self.n_alleles = 1

# Set mode
self.predict = False

# Set mode
self.predict = False

def _load_seqs(self, seqs: Union[str, Sequence, pd.DataFrame, np.ndarray]) -> None:
# seqs = resize(seqs, seq_len=self.padded_seq_len, end=self.end)
self.seqs = seqs

def _load_tasks(self, tasks: Union[pd.DataFrame, List]) -> None:
if isinstance(tasks, List):
tasks = _create_task_data(tasks)
self.tasks = tasks

def _load_labels(self, labels: np.ndarray) -> None:
self.labels = labels

def __len__(self) -> int:
return self.n_seqs * self.n_augmented

def get_labels(self) -> np.ndarray:
"""
Return the labels as a numpy array of shape (B, T, L). This does not
account for data augmentation.
"""
labels = self.labels

# Aggregate label
if self.label_aggfunc is not None:
labels = rearrange(
labels,
"batch task (length bin_size) -> batch task length bin_size",
bin_size=self.bin_size,
)
labels = self.label_aggfunc(labels, axis=-1)

# Transform label
labels = self.label_transform(labels)

return labels

def __getitem__(self, idx: int) -> Union[Tensor, Tuple[Tensor, Tensor]]:
# Get sequence and augmentation indices
seq_idx, augment_idx = _split_overall_idx(idx, (self.n_seqs, self.n_augmented))

# Get current sequence and label
seq = self.seqs[seq_idx]
label = self.labels[seq_idx]

# If using in prediction, return only the sequence
if self.predict:
return seq

else:
# Aggregate label
if self.label_aggfunc is not None:
label = rearrange(label, "t (l b) -> t l b", b=self.bin_size)
label = self.label_aggfunc(label, axis=-1)

# Transform label
if self.label_transform is not None:
label = self.label_transform(label)

return seq, Tensor(label)


class LabeledSeqDataset(Dataset):
"""
A general Dataset class for DNA sequences and labels. All sequences and
Expand Down