Skip to content

[ENH] Implementing TimeXer model from thuml #1797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

PranavBhatP
Copy link
Contributor

Description

This PR works on #1793 and aims to align and implement the TimeXer model within PTF's design. It is currently a draft.

Checklist

  • Linked issues (if existing)
  • Amended changelog for large changes (and added myself there as contributor)
  • Added/modified tests
  • Used pre-commit hooks when committing to ensure that code is compliant with hooks. Install hooks with pre-commit install.
    To run hooks independent of commit, execute pre-commit run --all-files

Make sure to have fun coding!

Copy link

codecov bot commented Mar 16, 2025

Codecov Report

Attention: Patch coverage is 0% with 280 lines in your changes missing coverage. Please review.

Please upload report for BASE (main@7d64fce). Learn more about missing BASE report.

Files with missing lines Patch % Lines
pytorch_forecasting/models/timexer/sub_modules.py 0.00% 158 Missing ⚠️
pytorch_forecasting/models/timexer/_timexer.py 0.00% 119 Missing ⚠️
pytorch_forecasting/models/timexer/__init__.py 0.00% 3 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1797   +/-   ##
=======================================
  Coverage        ?   82.40%           
=======================================
  Files           ?       49           
  Lines           ?     5581           
  Branches        ?        0           
=======================================
  Hits            ?     4599           
  Misses          ?      982           
  Partials        ?        0           
Flag Coverage Δ
cpu 82.40% <0.00%> (?)
pytest 82.40% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@PranavBhatP
Copy link
Contributor Author

PranavBhatP commented Mar 21, 2025

@fkiraly the implementation of TimeXer seems to be working fine with PTF in its current design and all the functionalities as highlighed in the thuml paper are working here as well. Will proceed with the PTF v2.0 design suggested by @phoeenniixx and hopefully come up with a interface to integrate more of these models.

@PranavBhatP PranavBhatP marked this pull request as ready for review March 22, 2025 19:00
@fkiraly fkiraly moved this from PR in progress to PR under review in Dec 2024 - Mar 2025 mentee projects Mar 24, 2025
@fkiraly
Copy link
Collaborator

fkiraly commented Apr 4, 2025

FYI @agobbifbk, @benHeid

@agobbifbk
Copy link

A couple of comment:

  • I will add somewhere in the code that this is a porting from THUML (credits?). Maybe there and I did't see it
  • The architecture seems to be aligned to the one reported in THUML.
  • I have some doubts in line 450 of _timexer.py prediction = prediction.repeat(1, 1, 1, num_quantiles) it seems that you replicate the prediction instead of increasing the output shape of the architecture at the beginning
  • I see a not marked tick for multi-output prediction, is this model support multi-output? And muliti-output with quantile loss?
  • I don't have enough knowledge to review the connection between data and model

@PranavBhatP
Copy link
Contributor Author

Hi @agobbifbk thanks for the review!

will add somewhere in the code that this is a porting from THUML (credits?). Maybe there and I did't see it

Yes, I think I forgot to credit the original authors. Will do :).

I have some doubts in line 450 of _timexer.py prediction = prediction.repeat(1, 1, 1, num_quantiles) it seems that you replicate the prediction instead of increasing the output shape of the architecture at the beginning

I had similar questions with how this model would handle quantile predictions, the original architecture doesn't seem to be handling this, so I just decided to patch it up this line of code (which might be a bad approach). I am not very aware of what changes I should make to fix this? Could you help me here?

I see a not marked tick for multi-output prediction, is this model support multi-output? And muliti-output with quantile loss?

It is multi-output indeed, the _forecast_multi method is native to the tslib package. Coming to the aspect of multi-output with quantile loss, as you had mentioned in the previous point, there is some difficulty in handling this, we need to make changes for that.

@agobbifbk
Copy link

I had similar questions with how this model would handle quantile predictions, the original architecture doesn't seem to be handling this, so I just decided to patch it up this line of code (which might be a bad approach). I am not very aware of what changes I should make to fix this? Could you help me here?

Sure, usually it is sufficient to increase the number of output channels. Suppose the model ends with something in the shape BxLxC where B is the batch size, L is the output length and C the number of target variables. You have 2 ways: the first is to inizialize the model with a different value of C: usually the quantiles are 3 (0.05, 0.5, 0.95) --> you force the model to give 3C output channels instead of C. In this case you need to check how the quantile loss is implemented and the shape it expects. The other approach, the one we use in DISPTS, is that all the models produce an output of the shape BxLxCxM where M is 1 in case of standard loss or 3 in case of quantile loss.

My suggestion is to start from the implementation of the quantile loss and check the definition (in DSIPTS there is a multioutput version of it, just summing the contribution of each channel) and then play with the output shape of the model!

Let me know if it is sufficient to finish the job :-)

@PranavBhatP
Copy link
Contributor Author

Thanks for the suggestion! I will try to make these changes.

@PranavBhatP
Copy link
Contributor Author

@agobbifbk , I've made changes to the output layer by adding the quantiles as a separate dimension in the FlattenHead submodule class and handled the final output in the forward method of the TimeXer class. As per your suggestion, I added a 4th dimension to the output tensor as BxLxCxM

@agobbifbk
Copy link

Nice!
Now it should correctly work in case of quantile loss. I suggest to try with a real dataset with the quantile loss and plot the results.
One suggestion:

     enc_out = torch.reshape(
            enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])
        )

this code is hard to understand, I suggest to give some names to the variables (context_length?) so the next time anyone reads this will immediately understand the shapes of those tensors, what do you think?

I still have a a doubt on the quantile loss, this is more a general PTF question.
Now you are returning something with an extra dimension (quantile) do you check if the dimensions are in the correct order?

            prediction = self.transform_output(
                prediction=prediction, target_scale=x["target_scale"]
            )
            return self.to_network_output(prediction=prediction)

Also in this case instead of using range(prediction.size(2)) maybe it will result more readable if you name a variable so devs can undestrand on which dimension you are cycling.

Wrapping up: clean a little bit the code, ensure the output is in a correct form, try to train the model and plot the outputs to see if it is giving correctly the confidence intervals.
THX!!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

3 participants