diff --git a/executors/accelerate/src/hypha/accelerate_executor/dataset.py b/executors/accelerate/src/hypha/accelerate_executor/dataset.py index b7bce441..4731f5b4 100644 --- a/executors/accelerate/src/hypha/accelerate_executor/dataset.py +++ b/executors/accelerate/src/hypha/accelerate_executor/dataset.py @@ -1,3 +1,4 @@ +import os from collections.abc import Iterator from typing import Any @@ -6,35 +7,50 @@ from snappy import uncompress from torch.utils.data import IterableDataset +from .api import fetch +from .utils import get_preprocessor + class IterableStreamDataSet(IterableDataset): # type: ignore[type-arg] def __init__( - self, - data_file_iter: Iterator[str], - batch_size: int, - model_inputs: list[str], - processor_inputs: list[str], - preprocessor: Any | None, + self, socket_path: str, work_dir: str, fetch_path: str, batch_size: int, config: dict[str, Any] ) -> None: super(IterableStreamDataSet).__init__() # type: ignore[misc] - self.data_iter = data_file_iter + self.socket_path = socket_path + self.work_dir = work_dir + self.fetch_path = fetch_path + self.config = config self.batch_size = batch_size - self.model_inputs = model_inputs - self.processor_inputs = processor_inputs - self.processor = preprocessor + self.model_inputs = config["model"]["input-names"] + self.processor_config = config.get("preprocessor", {}) + self.processor_inputs = self.processor_config.get("input-names", []) def __iter__(self): # type: ignore[no-untyped-def] # Don't need sharding each call to data_iter returns a unique instance + socket_path = self.socket_path + data_config = self.config["data"] + work_dir = self.work_dir + + def wrap() -> Iterator[str]: + while True: + tensor_data = fetch(socket_path, data_config) + yield os.path.join(work_dir, tensor_data[0]["path"]) + + data_iter = iter(wrap()) + + processor = None + if self.processor_config: + processor = get_preprocessor(self.processor_config, self.fetch_path) # Holds the "remainder" from the previous file buffer = None - for path in self.data_iter: + for path in data_iter: with open(path, "rb") as file: raw_bytes = uncompress(file.read()) data = load(raw_bytes) - if self.processor: + if processor: processed = { - **{k: v[0] for k, v in self.processor(**{k: data.pop(k) for k in self.processor_inputs}).items()}, + **{k: v[0] for k, v in processor(**{k: data.pop(k) for k in self.processor_inputs}).items()}, **data, } else: diff --git a/executors/accelerate/src/hypha/accelerate_executor/training.py b/executors/accelerate/src/hypha/accelerate_executor/training.py index f52bb644..394900a4 100644 --- a/executors/accelerate/src/hypha/accelerate_executor/training.py +++ b/executors/accelerate/src/hypha/accelerate_executor/training.py @@ -28,9 +28,7 @@ from .model import get_model from .utils import ( extract_gradients, - fetch_data, get_adam, - get_preprocessor, get_scheduler, merge_models, prepare_files, @@ -105,17 +103,13 @@ def main(socket_path: str, work_dir: str, job_json: str) -> None: # noqa: PLR09 model = get_model(local_fetch_path, config["model"]["task"]) optimizer = get_adam(config["optimizer"], model.parameters()) scheduler = get_scheduler(config.get("scheduler"), optimizer) - preprocessor_config = config.get("preprocessor") data_loader = torch.utils.data.DataLoader( - IterableStreamDataSet( - fetch_data(socket_path, config["data"], work_dir), - config["batch_size"], - config["model"]["input-names"], - preprocessor_config["input-names"] if preprocessor_config else [], - get_preprocessor(preprocessor_config, local_fetch_path), - ), + IterableStreamDataSet(socket_path, work_dir, local_fetch_path, config["batch_size"], config), batch_size=None, pin_memory=True, + num_workers=4, + persistent_workers=True, + timeout=600, ) model, optimizer, training_dataloader, scheduler = accelerator.prepare(model, optimizer, data_loader, scheduler) diff --git a/executors/accelerate/src/hypha/accelerate_executor/utils.py b/executors/accelerate/src/hypha/accelerate_executor/utils.py index b1164553..014bd403 100644 --- a/executors/accelerate/src/hypha/accelerate_executor/utils.py +++ b/executors/accelerate/src/hypha/accelerate_executor/utils.py @@ -1,5 +1,4 @@ -import os -from collections.abc import Iterable, Iterator +from collections.abc import Iterable from typing import Any import torch @@ -23,7 +22,7 @@ get_wsd_schedule, ) -from .api import Session, fetch +from .api import Session def prepare_files(config: dict[str, Any], session: Session) -> None: @@ -61,15 +60,6 @@ def get_adam(optimizer: dict[str, Any], parameters: Iterable[torch.Tensor]) -> O return torch.optim.AdamW(parameters, lr=lr) -def fetch_data(socket_path: str, data: str, work_dir: str) -> Iterator[str]: - def wrap() -> Iterator[str]: - while True: - tensor_data = fetch(socket_path, data) - yield os.path.join(work_dir, tensor_data[0]["path"]) - - return iter(wrap()) - - def get_loss_fn(loss_fn: str) -> Module: if loss_fn == "l1": return torch.nn.L1Loss()