Skip to content

Commit f9c9575

Browse files
committed
Cleanup
added requirements.txt and startet using pip only removed version from dgs.__init__ file and use importlib to obtain it Signed-off-by: Martin <[email protected]>
1 parent a60032f commit f9c9575

34 files changed

+143
-44
lines changed

dgs/__init__.py

-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
11
"""
22
Tracking via Dynamically Gated Similarities
3-
4-
TODO Provide more package information
53
"""
6-
7-
__version__ = "0.0.2"
8-
__author__ = "Martin Steinborn"
9-
__homepage__ = "https://bmmtstb.github.io/dynamically-gated-similarities/"
10-
__description__ = 'Code for Paper "Tracking with Dynamically Gated Similarities"'
11-
__url__ = "https://github.com/bmmtstb/dynamically-gated-similarities"

dgs/default_config.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
44
These values are used, iff the given config does not set own values.
55
"""
6+
67
import torch
78
from easydict import EasyDict
89

@@ -12,7 +13,7 @@
1213
# General #
1314
# ####### #
1415

15-
# cfg.name = "DEFAULT" # shouldn't be set, to force user to give it a name
16+
cfg.name = "DEFAULT"
1617
cfg.print_prio = "normal"
1718
cfg.working_memory_size = 30
1819

dgs/models/engine/engine.py

-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ class EngineModule(BaseModule):
121121
lr_sched: list[optim.lr_scheduler.LRScheduler]
122122
"""The learning-rate sheduler(s) can be changed by setting ``engine.lr_scheduler = [..., ...]``."""
123123

124-
@torch.enable_grad
125124
def __init__(
126125
self,
127126
config: Config,

dgs/models/engine/visual_sim_engine.py

+3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class VisualSimilarityEngine(EngineModule):
3434
val_dl: TorchDataLoader
3535
"""The torch DataLoader containing the validation (query) data."""
3636

37+
# The heart of the project might get a little larger...
38+
# pylint: disable=too-many-arguments,too-many-locals
39+
3740
def __init__(
3841
self,
3942
config: Config,

dgs/models/loss.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Functions to load and manage torch loss functions.
33
"""
4+
45
from typing import Type, Union
56

67
from torch import nn

dgs/models/metric.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""
22
Methods for handling the computation of distances and other metrics.
33
"""
4+
45
import warnings
56
from typing import Type, Union
67

78
import torch
89
from torch import nn
10+
from torch.linalg import vector_norm
911

1012
from dgs.utils.types import Metric
1113

@@ -71,8 +73,8 @@ def compute_cmc(
7173
def custom_cosine_similarity(input1: torch.Tensor, input2: torch.Tensor, dim: int, eps: float) -> torch.Tensor:
7274
"""See https://github.com/pytorch/pytorch/issues/104564#issuecomment-1625348908"""
7375
# get normalization value
74-
t1_div = torch.linalg.vector_norm(input1, dim=dim, keepdims=True)
75-
t2_div = torch.linalg.vector_norm(input2, dim=dim, keepdims=True)
76+
t1_div = vector_norm(input1, dim=dim, keepdim=True) # pylint: disable=not-callable
77+
t2_div = vector_norm(input2, dim=dim, keepdim=True) # pylint: disable=not-callable
7678

7779
t1_div = t1_div.clone()
7880
t2_div = t2_div.clone()

dgs/models/module.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Base model class as lowest building block for dynamic modules
33
"""
4+
45
import inspect
56
from abc import ABC, abstractmethod
67
from functools import wraps
@@ -17,7 +18,7 @@
1718
module_validations: Validations = {
1819
"name": ["str", ("longer", 2)],
1920
"print_prio": [("in", PRINT_PRIORITY)],
20-
"device": ["str", ("or", (("in", ["cuda", "cpu"]), ("instance", torch.device)))],
21+
"device": [("or", (("in", ["cuda", "cpu"]), ("type", torch.device)))],
2122
"gpus": ["optional", lambda gpus: isinstance(gpus, list) and all(isinstance(gpu, int) for gpu in gpus)],
2223
"num_workers": ["optional", "int", ("gte", 0)],
2324
"sp": [("instance", bool)],
@@ -97,7 +98,7 @@ class BaseModule(ABC):
9798
"""
9899

99100
@enable_keyboard_interrupt
100-
def __init__(self, config: Config, path: NodePath, validate_base: bool = False):
101+
def __init__(self, config: Config, path: NodePath):
101102
self.config: Config = config
102103
self.params: Config = get_sub_config(config, path)
103104
self._path: NodePath = path
@@ -106,16 +107,13 @@ def __init__(self, config: Config, path: NodePath, validate_base: bool = False):
106107
if not self.config["gpus"]:
107108
self.config["gpus"] = [-1]
108109
elif isinstance(self.config["gpus"], str):
109-
self.config["gpus"] = (
110-
[int(i) for i in self.config["gpus"].split(",")] if torch.cuda.device_count() >= 1 else [-1]
111-
)
110+
self.config["gpus"] = [int(i) for i in self.config["gpus"].split(",")]
111+
112112
# set default value of num_workers
113113
if not self.config["num_workers"]:
114114
self.config["num_workers"] = 0
115115

116-
# validate config when calling BaseModule class and flag is True
117-
if validate_base:
118-
self.validate_params(module_validations, "config")
116+
self.validate_params(module_validations, "config")
119117

120118
def validate_params(self, validations: Validations, attrib_name: str = "params") -> None:
121119
"""Given per key validations, validate this module's parameters.
@@ -173,7 +171,9 @@ def validate_params(self, validations: Validations, attrib_name: str = "params")
173171
if param_name not in getattr(self, attrib_name):
174172
if "optional" in list_of_validations:
175173
continue # value is optional and does not exist, skip validation
176-
raise InvalidParameterException(f"{param_name} is expected to be in module {self.__class__.__name__}")
174+
raise InvalidParameterException(
175+
f"'{param_name}' is expected to be in module '{self.__class__.__name__}'"
176+
)
177177

178178
# it is now safe to get the value
179179
value = getattr(self, attrib_name)[param_name]
@@ -202,13 +202,13 @@ def validate_params(self, validations: Validations, attrib_name: str = "params")
202202
if validate_value(value=value, data=data, validation=validation_name):
203203
continue
204204
raise InvalidParameterException(
205-
f"In module {self.__class__.__name__}, parameter {param_name} is not valid. "
206-
f"Value is {value} and is expected to have validation(s) {list_of_validations}."
205+
f"In module '{self.__class__.__name__}', parameter '{param_name}' is not valid. "
206+
f"Value is '{value}' and is expected to have validation(s) '{list_of_validations}'."
207207
)
208208
# no other case was true
209209
raise ValidationException(
210-
f"Validation is expected to be callable or tuple, but is {type(validation)}. "
211-
f"Current module: {self.__class__.__name__}, Parameter: {param_name}"
210+
f"Validation is expected to be callable or tuple, but is '{type(validation)}'. "
211+
f"Current module: '{self.__class__.__name__}', Parameter: '{param_name}'"
212212
)
213213

214214
@abstractmethod

dgs/models/optimizer.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Functions to load and manage torch optimizers.
33
"""
4+
45
from typing import Type, Union
56

67
from torch import optim

dgs/models/pose_warping/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Module to warp a given pose, pose-state -
33
or more generally to predict the next pose of a person given previous time steps.
44
"""
5+
56
from typing import Type
67

78
from dgs.utils.exceptions import InvalidParameterException

dgs/models/pose_warping/kalman.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Implementation if kalman filter for basic pose warping
33
"""
4+
45
import torch
56

67
from dgs.models.pose_warping.pose_warping import PoseWarpingModule

dgs/models/pose_warping/pose_warping.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Helpers and models for warping the pose-state of a track into the next time frame.
33
"""
4+
45
from abc import abstractmethod
56

67
import torch

dgs/models/similarity/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Modules for handling similarity functions or other models that return similarity scores between two (or more) inputs.
33
"""
4+
45
from typing import Type
56

67
from dgs.utils.exceptions import InvalidParameterException

dgs/models/similarity/combined.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Models that combine the results of two or more similarity matrices.
33
"""
4+
45
from abc import abstractmethod
56

67
import torch

dgs/models/similarity/pose_similarity.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Modules for computing the similarity between two poses.
33
"""
4+
45
import torch
56

67
from dgs.models.similarity.similarity import SimilarityModule

dgs/models/similarity/similarity.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
44
Similarity functions compute a similarity score "likeness" between two equally sized inputs.
55
"""
6+
67
from typing import Callable
78

89
import torch

dgs/models/states.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
definitions and helpers for pose-state(s)
33
"""
4+
45
from collections import UserDict
56
from typing import Union
67

dgs/utils/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
44
Contains functions for validating configuration and parameter of modules.
55
"""
6+
67
from copy import deepcopy
78
from typing import Union
89

dgs/utils/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Predefined constants that will not change and might be used at different places.
33
"""
4+
45
import os
56

67
import torch

dgs/utils/files.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Contains helper functions for loading and interacting with files and paths.
33
"""
4+
45
import json
56
import os
67

dgs/utils/image.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
RGB Images in cv2 have a shape of ``[h x w x C]`` and the channels are in order GBR.
1515
Grayscale Images in cv2 have a shape of ``[h x w]``.
1616
"""
17+
1718
from typing import Iterable, Union
1819

1920
import torch

dgs/utils/timer.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Models, functions and helpers for timing operations.
33
"""
4+
45
import time
56
from collections import UserList
67
from datetime import timedelta

dgs/utils/torchtools.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Tools for handling recurring torch tasks. Mostly taken from the `torchreid package
33
<https://kaiyangzhou.github.io/deep-person-reid/_modules/torchreid/utils/torchtools.html#load_pretrained_weights>`_
44
"""
5+
56
import os
67
import pickle
78
import shutil

dgs/utils/types.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
definition of regularly used types
33
"""
4+
45
from typing import Callable, Union
56

67
import torch

dgs/utils/validation.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Utilities for validating recurring data types.
33
"""
4+
45
import os
56
from collections.abc import Iterable, Sized
67
from typing import Union
@@ -30,6 +31,7 @@
3031
"callable": (lambda x, _: callable(x)),
3132
"instance": isinstance, # alias
3233
"isinstance": isinstance,
34+
"type": (lambda x, d: isinstance(d, type) and isinstance(x, d)),
3335
"iterable": (lambda x, _: isinstance(x, Iterable)),
3436
"sized": (lambda x, _: isinstance(x, Sized)),
3537
# number

dgs/utils/visualization.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Matplotlib uses a different order for the images: `[B x H x W x C]`.
77
At least, the channel for matplotlib is RGB too.
88
"""
9+
910
from typing import Union
1011

1112
import numpy as np

docs/conf.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import os
1414
import sys
15+
from importlib.metadata import PackageNotFoundError, version as ilib_version
1516

1617
sys.path.insert(0, os.path.abspath(".."))
1718

@@ -23,10 +24,11 @@
2324
copyright = "2023, Martin Steinborn"
2425
author = "Martin Steinborn"
2526

26-
version_file = "../dgs/__init__.py"
27-
with open(version_file, "r") as f:
28-
exec(compile(f.read(), version_file, "exec"))
29-
__version__ = locals()["__version__"]
27+
try:
28+
__version__ = ilib_version("dynamically_gated_similarities")
29+
except PackageNotFoundError:
30+
__version__ = "0.0.0"
31+
3032
# The short X.Y version
3133
version = __version__[: __version__.find(".", __version__.find(".") + 1)]
3234
# The full version, including alpha/beta/rc tags

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "dynamically_gated_similarities"
3-
version = "0.0.2"
3+
version = "0.0.3"
44
authors = [
55
{ name = "Martin Steinborn", email = "[email protected]" },
66
]

0 commit comments

Comments
 (0)