diff --git a/core/custom_training_loop.py b/core/custom_training_loop.py index 0241145..57aba03 100644 --- a/core/custom_training_loop.py +++ b/core/custom_training_loop.py @@ -8,14 +8,22 @@ - go/dataset-service 0-copy integration """ -import datetime import os -from typing import Callable, Dict, Iterable, List, Mapping, Optional +import datetime +from typing import ( + Callable, + Dict, + Iterable, + List, + Mapping, + Optional +) + from tml.common import log_weights -import tml.common.checkpointing.snapshot as snapshot_lib from tml.core.losses import get_global_loss_detached +import tml.common.checkpointing.snapshot as snapshot_lib from tml.ml_logging.torch_logging import logging # type: ignore[attr-defined] from tml.core.train_pipeline import TrainPipelineSparseDist diff --git a/core/metrics.py b/core/metrics.py index 2384e4d..cdbc50c 100644 --- a/core/metrics.py +++ b/core/metrics.py @@ -5,7 +5,11 @@ """ from typing import Any, Dict -from tml.core.metric_mixin import MetricMixin, StratifyMixin, TaskMixin +from tml.core.metric_mixin import ( + MetricMixin, + StratifyMixin, + TaskMixin +) import torch import torchmetrics as tm