Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stylistic fixes and addition of multi-modal scaffolding #810

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 45 additions & 34 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion flood_forecast/basic/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def simple_decode(model: Type[torch.nn.Module],
else:
# residual = output_len if max_seq_len - output_len - i >= 0 else max_seq_len % output_len
if output_len != out.shape[1]:
raise ValueError("Output length should laways equal the output shape")
raise ValueError("Output length should always equal the output shape")
real_target2[:, i:i + residual, 0:multi_targets] = out[:, :residual]
src = torch.cat((src[:, residual:, :], real_target2[:, i:i + residual, :]), 1)
ys = torch.cat((ys, real_target2[:, i:i + residual, :]), 1)
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Description:
This module contains functions for evaluating models. The basic logic flow is as follows:
1. `evaluate_model` is called from `trainer.py` at the end of training. It calls `infer_on_torch_model` which does the actual inference. # noqa
2. `infer_on_torch_model` calls `generate_predictions` which calls `generate_decoded_predictions` or `generate_predictions_non_decoded` depending on whether the model uses a decoder or not.
2. `infer_on_torch_mode` calls `generate_predictions` which calls `generate_decoded_predictions` or `generate_predictions_non_decoded` depending on whether the model uses a decoder or not.
3. `generate_decoded_predictions` calls `decoding_functions` which calls `greedy_decode` or `beam_decode` depending on the decoder function specified in the config file.
4. The returned value from `generate_decoded_predictions` is then used to calculate the evaluation metrics in `run_evaluation`.
5. `run_evaluation` returns the evaluation metrics to `evaluate_model` which returns them to `trainer.py`.
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/model_dict_function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from flood_forecast.multi_models.crossvivit import RoCrossViViT
from flood_forecast.multimodal_models.crossvivit import RoCrossViViT
from flood_forecast.transformer_xl.multi_head_base import MultiAttnHeadSimple
from flood_forecast.transformer_xl.transformer_basic import SimpleTransformer, CustomTransformerDecoder
from flood_forecast.transformer_xl.informer import Informer
Expand Down
Loading
Loading