Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for PyTorch 2.* and Python>=3.8 #1571

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions avalanche/benchmarks/classic/openloris.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
a number of configuration parameters."""

from pathlib import Path
from typing import Union, Any, Optional
from typing_extensions import Literal
from typing import Union, Any, Optional, Literal

from avalanche.benchmarks.classic.classic_benchmarks_utils import (
check_vision_benchmark,
Expand Down
4 changes: 1 addition & 3 deletions avalanche/benchmarks/classic/stream51.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
# Website: www.continualai.org #
################################################################################
from pathlib import Path
from typing import List, Optional, Union

from typing_extensions import Literal
from typing import List, Optional, Union, Literal

from avalanche.benchmarks.datasets import Stream51
from avalanche.benchmarks.scenarios.deprecated.generic_benchmark_creation import (
Expand Down
3 changes: 1 addition & 2 deletions avalanche/benchmarks/datasets/lvis_dataset/lvis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
""" LVIS PyTorch Object Detection Dataset """

from pathlib import Path
from typing import Optional, Union, List, Sequence
from typing import Optional, Union, List, Sequence, TypedDict

import torch
from PIL import Image
from torchvision.datasets.folder import default_loader
from torchvision.transforms import ToTensor
from typing_extensions import TypedDict

from avalanche.benchmarks.datasets import (
DownloadableDataset,
Expand Down
3 changes: 1 addition & 2 deletions avalanche/benchmarks/datasets/mini_imagenet/mini_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,9 @@
import csv
import glob
from pathlib import Path
from typing import Union, List, Tuple, Dict
from typing import Union, List, Tuple, Dict, Literal

from torchvision.datasets.folder import default_loader
from typing_extensions import Literal

import PIL
import numpy as np
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
Iterator,
TypeVar,
Union,
overload,
)
from typing_extensions import overload
from avalanche.benchmarks.utils.data import AvalancheDataset
from avalanche.benchmarks.utils.dataset_utils import manage_advanced_indexing

Expand Down
2 changes: 1 addition & 1 deletion avalanche/benchmarks/scenarios/generic_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
Union,
Generic,
overload,
final,
)
from typing_extensions import final

import numpy as np

Expand Down
2 changes: 1 addition & 1 deletion avalanche/benchmarks/scenarios/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
TypeVar,
Union,
Protocol,
Literal,
)
from typing_extensions import Literal
import warnings
from avalanche.benchmarks.utils.data import AvalancheDataset
from avalanche.benchmarks.utils.utils import concat_datasets
Expand Down
4 changes: 0 additions & 4 deletions avalanche/benchmarks/utils/collate_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,19 @@
################################################################################

import itertools
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import (
List,
TypeVar,
Generic,
Sequence,
Tuple,
Dict,
Union,
overload,
)
from typing_extensions import TypeAlias

import torch
from torch import Tensor
from torch.utils.data import default_collate

BatchT = TypeVar("BatchT")
ExampleT = TypeVar("ExampleT")
Expand Down
3 changes: 1 addition & 2 deletions avalanche/benchmarks/utils/dataset_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
# Website: avalanche.continualai.org #
################################################################################

from typing import TypeVar, SupportsInt, Sequence
from typing import TypeVar, SupportsInt, Sequence, Protocol

from torch.utils.data.dataset import Dataset
from typing_extensions import Protocol

T_co = TypeVar("T_co", covariant=True)
TTargetType = TypeVar("TTargetType")
Expand Down
3 changes: 1 addition & 2 deletions avalanche/benchmarks/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from abc import ABC, abstractmethod
import bisect
import copy
from typing import Iterator, overload
from typing_extensions import final
from typing import Iterator, overload, final
import numpy as np
from numpy import ndarray
from torch import Tensor
Expand Down
2 changes: 1 addition & 1 deletion avalanche/benchmarks/utils/transform_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
Union,
Callable,
Sequence,
Protocol,
)
from typing_extensions import Protocol

from avalanche.benchmarks.utils.transforms import (
MultiParamCompose,
Expand Down
3 changes: 1 addition & 2 deletions avalanche/distributed/distributed_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import pickle
import warnings
from io import BytesIO
from typing import ContextManager, Optional, List, Any, Iterable, Dict, TypeVar
from typing import ContextManager, Optional, List, Any, Iterable, Dict, TypeVar, Literal

import torch
from torch import Tensor
from torch.nn.modules import Module
from torch.nn.parallel import DistributedDataParallel
from typing_extensions import Literal
from torch.distributed import init_process_group, broadcast_object_list


Expand Down
3 changes: 2 additions & 1 deletion avalanche/evaluation/metric_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
List,
Union,
overload,
Literal,
Protocol,
)
from typing_extensions import Literal, Protocol
from .metric_results import MetricValue, MetricType, AlternativeValues
from .metric_utils import (
get_metric_name,
Expand Down
2 changes: 1 addition & 1 deletion avalanche/evaluation/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
################################################################################
from matplotlib.figure import Figure
from numpy import arange
from typing_extensions import Literal
from typing import (
Any,
Callable,
Expand All @@ -19,6 +18,7 @@
Optional,
TYPE_CHECKING,
List,
Literal,
)

import wandb
Expand Down
2 changes: 1 addition & 1 deletion avalanche/evaluation/metrics/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Callable,
Sequence,
Optional,
Protocol,
)

from avalanche.benchmarks.utils.data import AvalancheDataset
Expand All @@ -41,7 +42,6 @@
from json import JSONEncoder

from torch.utils.data import Subset, ConcatDataset
from typing_extensions import Protocol

from avalanche.evaluation import PluginMetric
from avalanche.evaluation.metric_results import MetricValue
Expand Down
4 changes: 1 addition & 3 deletions avalanche/evaluation/metrics/images_samples.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, TYPE_CHECKING, Tuple
from typing import List, TYPE_CHECKING, Tuple, Literal

from torch import Tensor
from torch.utils.data import DataLoader
Expand All @@ -15,8 +15,6 @@
)
from avalanche.evaluation.metric_utils import get_metric_name

from typing_extensions import Literal


if TYPE_CHECKING:
from avalanche.training.templates import SupervisedTemplate
Expand Down
3 changes: 1 addition & 2 deletions avalanche/evaluation/metrics/labels_repartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
List,
Counter,
overload,
Literal,
)

from matplotlib.figure import Figure
Expand All @@ -19,8 +20,6 @@
default_history_repartition_image_creator,
)

from typing_extensions import Literal


if TYPE_CHECKING:
from avalanche.training.templates import SupervisedTemplate
Expand Down
4 changes: 1 addition & 3 deletions avalanche/evaluation/metrics/mean_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"""
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Callable, Dict, Set, TYPE_CHECKING, List, Optional, TypeVar
from typing import Callable, Dict, Set, TYPE_CHECKING, List, Optional, TypeVar, Literal

import torch
from matplotlib.axes import Axes
Expand All @@ -26,8 +26,6 @@
from avalanche.evaluation.metric_results import MetricValue, AlternativeValues


from typing_extensions import Literal

if TYPE_CHECKING:
from avalanche.training.templates import SupervisedTemplate
from avalanche.evaluation.metric_results import MetricResult
Expand Down
4 changes: 1 addition & 3 deletions avalanche/training/plugins/lr_scheduling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import warnings
from typing import TYPE_CHECKING

from typing_extensions import Literal
from typing import TYPE_CHECKING, Literal

from avalanche.evaluation.metrics import Mean
from avalanche.training.plugins import SupervisedPlugin
Expand Down
3 changes: 1 addition & 2 deletions avalanche/training/templates/strategy_mixin_protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Iterable, List, Optional, Sequence, Tuple, TypeVar
from typing_extensions import Protocol
from typing import Iterable, List, Optional, TypeVar, Protocol

from torch import Tensor
import torch
Expand Down
3 changes: 1 addition & 2 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ dependencies:
- conda-forge::sphinx-autoapi
- conda-forge::sphinx-copybutton
- pip:
- typing-extensions==4.4.0
- typing-extensions
- pytorchcv
- gdown
- ctrl-benchmark
- higher
- gym
- lvis
Expand Down
3 changes: 1 addition & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ dependencies:
- conda-forge::pycocotools
- conda-forge::torchmetrics
- pip:
- typing-extensions==4.4.0
- typing-extensions
- pytorchcv
- gdown
- ctrl-benchmark
- gym
- higher
- lvis
Expand Down