|
| 1 | +from torch.utils.data import IterDataPipe |
| 2 | +from typing import Dict, List, Tuple, Any, Callable, Iterable, Iterator, Union |
| 3 | + |
| 4 | +import os |
| 5 | +import functools |
| 6 | + |
| 7 | + |
| 8 | +# defaut group key is the file pathname without the extension. |
| 9 | +# Assuming the passed in data is a tuple and 1st item is file's pathname. |
| 10 | +def default_group_key_fn(dataitem : Tuple[str, Any]): |
| 11 | + return os.path.splitext(dataitem[0])[0] |
| 12 | + |
| 13 | + |
| 14 | +def default_sort_data_fn(datalist : List[Tuple[str, Any]]): |
| 15 | + txt_ext = ['.json', '.jsn', '.txt', '.text'] |
| 16 | + |
| 17 | + def cmp_fn(a : Tuple[str, Any], b : Tuple[str, Any]): |
| 18 | + a_is_txt = os.path.splitext(a[0])[1] in txt_ext |
| 19 | + b_is_txt = os.path.splitext(b[0])[1] in txt_ext |
| 20 | + |
| 21 | + # if a is txt but b is not, b go front |
| 22 | + if a_is_txt and not b_is_txt: |
| 23 | + return 1 |
| 24 | + # if a is not txt but b is txt, a go front |
| 25 | + if not a_is_txt and b_is_txt: |
| 26 | + return -1 |
| 27 | + # if a and b both are or are not txt, sort in alphabetic order |
| 28 | + if a[0] < b[0]: |
| 29 | + return -1 |
| 30 | + elif a[0] > b[0]: |
| 31 | + return 1 |
| 32 | + return 0 |
| 33 | + |
| 34 | + return sorted(datalist, key=functools.cmp_to_key(cmp_fn)) |
| 35 | + |
| 36 | + |
| 37 | +class GroupByKeyIterDataPipe(IterDataPipe): |
| 38 | + r""" :class:`GroupByKeyIterDataPipe`. |
| 39 | +
|
| 40 | + Iterable datapipe to group data from input iterable by keys which are generated from `group_key_fn`, |
| 41 | + yields a list with `group_size` items in it, each item in the list is a tuple of key and data |
| 42 | +
|
| 43 | + args: |
| 44 | + datapipe: Iterable datapipe that provides data. (typically str key (eg. pathname) and data stream in tuples) |
| 45 | + group_size: the size of group |
| 46 | + max_buffer_size: the max size of stream buffer which is used to store not yet grouped but iterated data |
| 47 | + group_key_fn: a function which is used to generate group key from the data in the input datapipe |
| 48 | + sort_data_fn: a function which is used to sort the grouped data before yielding back |
| 49 | + length: a nominal length of the datapipe |
| 50 | + """ |
| 51 | + |
| 52 | + def __init__( |
| 53 | + self, |
| 54 | + datapipe : Iterable[Tuple[str, Any]], |
| 55 | + *, |
| 56 | + group_size : int, |
| 57 | + max_buffer_size : Union[int, None] = None, |
| 58 | + group_key_fn : Callable = default_group_key_fn, |
| 59 | + sort_data_fn : Callable = default_sort_data_fn, |
| 60 | + length: int = -1): |
| 61 | + super().__init__() |
| 62 | + |
| 63 | + assert group_size > 0 |
| 64 | + self.datapipe : Iterable[Tuple[str, Any]] = datapipe |
| 65 | + self.group_size : int = group_size |
| 66 | + |
| 67 | + # default max buffer size is group_size * 10 |
| 68 | + self.max_buffer_size = max_buffer_size if max_buffer_size is not None else group_size * 10 |
| 69 | + assert self.max_buffer_size >= self.group_size |
| 70 | + |
| 71 | + self.group_key_fn : Callable = group_key_fn |
| 72 | + self.sort_data_fn : Callable = sort_data_fn |
| 73 | + self.curr_buffer_size : int = 0 |
| 74 | + self.stream_buffer : Dict[str, List[Tuple[str, Any]]] = {} |
| 75 | + self.length : int = length |
| 76 | + |
| 77 | + |
| 78 | + def __iter__(self) -> Iterator[list]: |
| 79 | + if self.group_size == 1: |
| 80 | + for data in self.datapipe: |
| 81 | + yield [data] |
| 82 | + else: |
| 83 | + for data in self.datapipe: |
| 84 | + key = self.group_key_fn(data) |
| 85 | + if key not in self.stream_buffer: |
| 86 | + self.stream_buffer[key] = [] |
| 87 | + res = self.stream_buffer[key] |
| 88 | + res.append(data) |
| 89 | + if len(res) == self.group_size: |
| 90 | + yield self.sort_data_fn(res) |
| 91 | + del self.stream_buffer[key] |
| 92 | + self.curr_buffer_size = self.curr_buffer_size - self.group_size + 1 |
| 93 | + else: |
| 94 | + if self.curr_buffer_size == self.max_buffer_size: |
| 95 | + raise OverflowError( |
| 96 | + "stream_buffer is overflow, please adjust the order of data " |
| 97 | + "in the input datapipe or increase the buffer size!") |
| 98 | + self.curr_buffer_size = self.curr_buffer_size + 1 |
| 99 | + |
| 100 | + if self.curr_buffer_size > 0: |
| 101 | + msg = "Not able to group [{}] with group size {}.".format( |
| 102 | + ','.join([v[0] for _, vs in self.stream_buffer.items() for v in vs]), str(self.group_size)) |
| 103 | + raise RuntimeError(msg) |
| 104 | + |
| 105 | + |
| 106 | + def __len__(self): |
| 107 | + if self.length == -1: |
| 108 | + raise NotImplementedError |
| 109 | + return self.length |
0 commit comments