Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions executors/accelerate/src/hypha/accelerate_executor/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from collections.abc import Iterator
from typing import Any

Expand All @@ -6,35 +7,50 @@
from snappy import uncompress
from torch.utils.data import IterableDataset

from .api import fetch
from .utils import get_preprocessor


class IterableStreamDataSet(IterableDataset): # type: ignore[type-arg]
def __init__(
self,
data_file_iter: Iterator[str],
batch_size: int,
model_inputs: list[str],
processor_inputs: list[str],
preprocessor: Any | None,
self, socket_path: str, work_dir: str, fetch_path: str, batch_size: int, config: dict[str, Any]
) -> None:
super(IterableStreamDataSet).__init__() # type: ignore[misc]
self.data_iter = data_file_iter
self.socket_path = socket_path
self.work_dir = work_dir
self.fetch_path = fetch_path
self.config = config
self.batch_size = batch_size
self.model_inputs = model_inputs
self.processor_inputs = processor_inputs
self.processor = preprocessor
self.model_inputs = config["model"]["input-names"]
self.processor_config = config.get("preprocessor", {})
self.processor_inputs = self.processor_config.get("input-names", [])

def __iter__(self): # type: ignore[no-untyped-def]
# Don't need sharding each call to data_iter returns a unique instance
socket_path = self.socket_path
data_config = self.config["data"]
work_dir = self.work_dir

def wrap() -> Iterator[str]:
while True:
tensor_data = fetch(socket_path, data_config)
yield os.path.join(work_dir, tensor_data[0]["path"])

data_iter = iter(wrap())

processor = None
if self.processor_config:
processor = get_preprocessor(self.processor_config, self.fetch_path)

# Holds the "remainder" from the previous file
buffer = None
for path in self.data_iter:
for path in data_iter:
with open(path, "rb") as file:
raw_bytes = uncompress(file.read())
data = load(raw_bytes)
if self.processor:
if processor:
processed = {
**{k: v[0] for k, v in self.processor(**{k: data.pop(k) for k in self.processor_inputs}).items()},
**{k: v[0] for k, v in processor(**{k: data.pop(k) for k in self.processor_inputs}).items()},
**data,
}
else:
Expand Down
14 changes: 4 additions & 10 deletions executors/accelerate/src/hypha/accelerate_executor/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
from .model import get_model
from .utils import (
extract_gradients,
fetch_data,
get_adam,
get_preprocessor,
get_scheduler,
merge_models,
prepare_files,
Expand Down Expand Up @@ -105,17 +103,13 @@ def main(socket_path: str, work_dir: str, job_json: str) -> None: # noqa: PLR09
model = get_model(local_fetch_path, config["model"]["task"])
optimizer = get_adam(config["optimizer"], model.parameters())
scheduler = get_scheduler(config.get("scheduler"), optimizer)
preprocessor_config = config.get("preprocessor")
data_loader = torch.utils.data.DataLoader(
IterableStreamDataSet(
fetch_data(socket_path, config["data"], work_dir),
config["batch_size"],
config["model"]["input-names"],
preprocessor_config["input-names"] if preprocessor_config else [],
get_preprocessor(preprocessor_config, local_fetch_path),
),
IterableStreamDataSet(socket_path, work_dir, local_fetch_path, config["batch_size"], config),
batch_size=None,
pin_memory=True,
num_workers=4,
persistent_workers=True,
timeout=600,
)

model, optimizer, training_dataloader, scheduler = accelerator.prepare(model, optimizer, data_loader, scheduler)
Expand Down
14 changes: 2 additions & 12 deletions executors/accelerate/src/hypha/accelerate_executor/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from collections.abc import Iterable, Iterator
from collections.abc import Iterable
from typing import Any

import torch
Expand All @@ -23,7 +22,7 @@
get_wsd_schedule,
)

from .api import Session, fetch
from .api import Session


def prepare_files(config: dict[str, Any], session: Session) -> None:
Expand Down Expand Up @@ -61,15 +60,6 @@ def get_adam(optimizer: dict[str, Any], parameters: Iterable[torch.Tensor]) -> O
return torch.optim.AdamW(parameters, lr=lr)


def fetch_data(socket_path: str, data: str, work_dir: str) -> Iterator[str]:
def wrap() -> Iterator[str]:
while True:
tensor_data = fetch(socket_path, data)
yield os.path.join(work_dir, tensor_data[0]["path"])

return iter(wrap())


def get_loss_fn(loss_fn: str) -> Module:
if loss_fn == "l1":
return torch.nn.L1Loss()
Expand Down