diff --git a/src/grelu/data/dataset.py b/src/grelu/data/dataset.py index 16668deb..ffa50c20 100755 --- a/src/grelu/data/dataset.py +++ b/src/grelu/data/dataset.py @@ -30,6 +30,188 @@ from grelu.utils import get_aggfunc, get_transform_func +class LabeledOneHotDataset(Dataset): + """ + A general Dataset class for DNA sequences and labels. All sequences and + labels will be stored in memory. + + Args: + seqs: DNA sequences as one-hot. + labels: A numpy array of shape (B, T, L) containing the labels. + tasks: A list of task names or a pandas dataframe containing task information. + If a dataframe is supplied, the row indices should be the task names. + seq_len: Uniform expected length (in base pairs) for output sequences + genome: The name of the genome from which to read sequences. Only needed if + genomic intervals are supplied. + end: Which end of the sequence to resize if necessary. Supported values are "left", + "right" and "both". + rc: If True, sequences will be augmented by reverse complementation. If False, + they will not be reverse complemented. + max_seq_shift: Maximum number of bases to shift the sequence for augmentation. This + is normally a small value (< 10). If 0, sequences will not be augmented by shifting. + label_len: Uniform expected length (in base pairs) for output labels + max_pair_shift: Maximum number of bases to shift both the sequence and label for + augmentation. If 0, sequence and label pairs will not be augmented by shifting. + label_aggfunc: Function to aggregate the labels over bin_size. + bin_size: Number of bases to aggregate in the label. Only used if label_aggfunc is not None. + If None, it will be taken as equal to label_len. + min_label_clip: Minimum value for label + max_label_clip: Maximum value for label + label_transform_func: Function to transform label values. + seed: Random seed for reproducibility + augment_mode: "random" or "serial" + """ + + def __init__( + self, + seqs: Union[str, Sequence, pd.DataFrame, np.ndarray], + labels: np.ndarray, + tasks: Optional[Union[Sequence, pd.DataFrame]] = None, + seq_len: Optional[int] = None, + genome: Optional[str] = None, + end: str = "both", + rc: bool = False, + max_seq_shift: int = 0, + label_len: Optional[int] = None, + max_pair_shift: int = 0, + label_aggfunc: Optional[Union[str, Callable]] = None, + bin_size: Optional[int] = None, + min_label_clip: Optional[int] = None, + max_label_clip: Optional[int] = None, + label_transform_func: Optional[Union[str, Callable]] = None, + seed: Optional[int] = None, + augment_mode: str = "serial", + ): + super().__init__() + + from grelu.transforms.label_transforms import LabelTransform + + # Save params + self.end = end + self.genome = genome + + # Label transformation params + self.min_label_clip = min_label_clip + self.max_label_clip = max_label_clip + self.label_transform_func = get_transform_func(label_transform_func) + + # Calculate sequence and label length + self.seq_len = seq_len or max(get_lengths(seqs)) + self.label_len = label_len or self.seq_len + + # Calculate bin size + if (bin_size) is None and (label_aggfunc is not None): + bin_size = self.label_len + self.label_aggfunc = get_aggfunc(label_aggfunc) + self.bin_size = bin_size + + # Save augmentation params + self.rc = rc + self.max_seq_shift = max_seq_shift + self.max_pair_shift = max_pair_shift + self.padded_seq_len = ( + self.seq_len + (2 * self.max_seq_shift) + (2 * self.max_pair_shift) + ) + self.padded_label_len = self.label_len + (2 * self.max_pair_shift) + + # Ingest sequences + self._load_seqs(seqs) + self.n_seqs = len(self.seqs) + + # Ingest tasks + self._load_tasks(tasks) + self.n_tasks = len(self.tasks) + + # Ingest labels + self._load_labels(labels) + + # Create label transformer + self.label_transform = LabelTransform( + min_clip=self.min_label_clip, + max_clip=self.max_label_clip, + transform_func=self.label_transform_func, + ) + + # Create augmenter + self.augmenter = Augmenter( + rc=self.rc, + max_seq_shift=self.max_seq_shift, + max_pair_shift=self.max_pair_shift, + seq_len=self.seq_len, + label_len=self.label_len, + seed=seed, + mode=augment_mode, + ) + self.n_augmented = len(self.augmenter) + self.n_alleles = 1 + + # Set mode + self.predict = False + + # Set mode + self.predict = False + + def _load_seqs(self, seqs: Union[str, Sequence, pd.DataFrame, np.ndarray]) -> None: + # seqs = resize(seqs, seq_len=self.padded_seq_len, end=self.end) + self.seqs = seqs + + def _load_tasks(self, tasks: Union[pd.DataFrame, List]) -> None: + if isinstance(tasks, List): + tasks = _create_task_data(tasks) + self.tasks = tasks + + def _load_labels(self, labels: np.ndarray) -> None: + self.labels = labels + + def __len__(self) -> int: + return self.n_seqs * self.n_augmented + + def get_labels(self) -> np.ndarray: + """ + Return the labels as a numpy array of shape (B, T, L). This does not + account for data augmentation. + """ + labels = self.labels + + # Aggregate label + if self.label_aggfunc is not None: + labels = rearrange( + labels, + "batch task (length bin_size) -> batch task length bin_size", + bin_size=self.bin_size, + ) + labels = self.label_aggfunc(labels, axis=-1) + + # Transform label + labels = self.label_transform(labels) + + return labels + + def __getitem__(self, idx: int) -> Union[Tensor, Tuple[Tensor, Tensor]]: + # Get sequence and augmentation indices + seq_idx, augment_idx = _split_overall_idx(idx, (self.n_seqs, self.n_augmented)) + + # Get current sequence and label + seq = self.seqs[seq_idx] + label = self.labels[seq_idx] + + # If using in prediction, return only the sequence + if self.predict: + return seq + + else: + # Aggregate label + if self.label_aggfunc is not None: + label = rearrange(label, "t (l b) -> t l b", b=self.bin_size) + label = self.label_aggfunc(label, axis=-1) + + # Transform label + if self.label_transform is not None: + label = self.label_transform(label) + + return seq, Tensor(label) + + class LabeledSeqDataset(Dataset): """ A general Dataset class for DNA sequences and labels. All sequences and