Skip to content
Merged
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
cef11f3
fix imports
dxoigmn Jun 15, 2023
9522075
Move _call_with_args_ and _return_as_dict_ functionality into CallWith
dxoigmn Jun 15, 2023
ba58264
Allow overwriting _call_with_args_ and _return_as_dict_ in CallWith.f…
dxoigmn Jun 15, 2023
60fc2ad
Add _train_mode_ and _inference_mode_ to CallWith
dxoigmn Jun 15, 2023
2e51c66
Revert "Add _train_mode_ and _inference_mode_ to CallWith"
dxoigmn Jun 15, 2023
1671755
Revert "Revert "Add _train_mode_ and _inference_mode_ to CallWith""
dxoigmn Jun 15, 2023
731b23d
cleanup
dxoigmn Jun 15, 2023
ac942c7
Merge branch 'better_sequentialdict3' into better_sequentialdict4
dxoigmn Jun 15, 2023
8e664f1
cleanup
dxoigmn Jun 15, 2023
ae1b836
Fix configs
dxoigmn Jun 15, 2023
39f9aaa
cleanup
dxoigmn Jun 15, 2023
0973933
bugfix
dxoigmn Jun 15, 2023
6d3f42a
Merge branch 'better_sequentialdict3' into better_sequentialdict4
dxoigmn Jun 15, 2023
2ec4e49
Only set train mode and inference mode on Modules
dxoigmn Jun 15, 2023
741d282
bugfix
dxoigmn Jun 15, 2023
15c5a5f
CallWith is not a Module
dxoigmn Jun 15, 2023
cb55a31
Merge branch 'better_sequentialdict3' into better_sequentialdict4
dxoigmn Jun 15, 2023
40dd262
cleanup
dxoigmn Jun 22, 2023
a218bb2
Merge branch 'main' into better_sequentialdict3
dxoigmn Jun 22, 2023
0716be3
Merge branch 'main' into better_sequentialdict3
dxoigmn Jun 22, 2023
ceaa744
bugfix
dxoigmn Jun 22, 2023
d607fdf
Merge branch 'main' into better_sequentialdict3
dxoigmn Jun 22, 2023
b8d473b
fix merge error
dxoigmn Jun 22, 2023
e9cf67b
Change call special arg names
dxoigmn Jun 22, 2023
3f3e302
Merge branch 'better_sequentialdict3' into better_sequentialdict4
dxoigmn Jun 22, 2023
9c7b05e
bugfix
dxoigmn Jun 23, 2023
82e66f5
Merge branch 'main' into better_sequentialdict3
dxoigmn Jun 23, 2023
14bde3e
Merge branch 'better_sequentialdict3' into better_sequentialdict4
dxoigmn Jun 23, 2023
fab9763
fix merge error
dxoigmn Jun 23, 2023
f9b43eb
cleanup
dxoigmn Jun 23, 2023
471757d
cleanup
dxoigmn Jun 23, 2023
368b87e
style
dxoigmn Jun 23, 2023
c5c734a
Merge remote-tracking branch 'origin/better_sequentialdict3' into bet…
dxoigmn Jun 27, 2023
8d44bb2
Merge branch 'main' into better_sequentialdict4
dxoigmn Jun 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions mart/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import logging
from collections import OrderedDict
from contextlib import nullcontext
from typing import Callable, Iterable

import torch
Expand Down Expand Up @@ -126,6 +127,8 @@ def __init__(
module: Callable,
_call_with_args_: Iterable[str] | None = None,
_return_as_dict_: Iterable[str] | None = None,
_train_mode_: bool | None = None,
_inference_mode_: bool | None = None,
**kwarg_keys,
) -> None:
super().__init__()
Expand All @@ -134,18 +137,24 @@ def __init__(
self.arg_keys = _call_with_args_
self.kwarg_keys = kwarg_keys
self.return_keys = _return_as_dict_
self.train_mode = _train_mode_
self.inference_mode = _inference_mode_

def __call__(
self,
*args,
_args_: Iterable[str] | None = None,
_return_keys_: Iterable[str] | None = None,
_train_mode_: bool | None = None,
_inference_mode_: bool | None = None,
**kwargs,
):
module_name = self.module.__class__.__name__

arg_keys = _args_ or self.arg_keys
kwarg_keys = self.kwarg_keys
_train_mode_ = _train_mode_ or self.train_mode
_inference_mode_ = _inference_mode_ or self.inference_mode

# Change and replace args and kwargs that we call module with
if arg_keys is not None or len(kwarg_keys) > 0:
Expand Down Expand Up @@ -183,8 +192,24 @@ def __call__(
f"{module_name} only received kwargs: {', '.join(kwargs.keys())}."
) from ex

# FIXME: Add better error message
ret = self.module(*args, **kwargs)
# Apply train mode and inference mode, if necessary, and call module with args and kwargs
context = nullcontext()
if isinstance(self.module, torch.nn.Module):
old_train_mode = self.module.training

if _train_mode_ is not None:
self.module.train(_train_mode_)

if _inference_mode_ is not None:
context = torch.inference_mode(mode=_inference_mode_)

with context:
# FIXME: Add better error message
ret = self.module(*args, **kwargs)

if isinstance(self.module, torch.nn.Module):
if _train_mode_ is not None:
self.module.train(old_train_mode)

# Change returned values into dictionary, if necessary
return_keys = _return_keys_ or self.return_keys
Expand Down