diff --git a/LICENSE b/LICENSE
index 7830546..fe463e0 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,21 +1,408 @@
-MIT License
-
-Copyright (c) 2023 Inhwan Bae
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
+Attribution-NonCommercial 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial 4.0 International Public
+License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial 4.0 International Public License ("Public
+License"). To the extent this Public License may be interpreted as a
+contract, You are granted the Licensed Rights in consideration of Your
+acceptance of these terms and conditions, and the Licensor grants You
+such rights in consideration of benefits the Licensor receives from
+making the Licensed Material available under these terms and
+conditions.
+
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+ d. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ f. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ g. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ h. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ i. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ j. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ k. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ l. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ 4. If You Share Adapted Material You produce, the Adapter's
+ License You apply must not prevent recipients of the Adapted
+ Material from complying with this Public License.
+
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material; and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
+
diff --git a/README.md b/README.md
index 1ced6cb..8a8bcd3 100644
--- a/README.md
+++ b/README.md
@@ -2,7 +2,7 @@
Inhwan Bae
·
- Young-Jae Park
+ Young-Jae Park
·
Hae-Gon Jeon
@@ -10,27 +10,125 @@
- Project Page
- CVPR Paper
- Source Code
- Related Works
+ Project Page
+ CVPR Paper
+ Source Code
+ Related Works
-
An overview of our SingularTrajectory framework..
+
An overview of our SingularTrajectory framework.
-[//]: # (
This repository contains the code for the EigenTrajectory(𝔼𝕋) space applied to the 10 traditional Euclidean-based trajectory predictors.)
-[//]: # (
EigenTrajectory-LB-EBM achieves ADE/FDE of 0.21/0.34 while requiring only 1 hour for training! )
+
This repository contains the code for the SingularTrajectory model, designed to handle five different trajectory prediction benchmarks.
+
Our unified framework ensures the general dynamics of human movements across various input modalities and trajectory lengths.
## 1️⃣ SingularTrajectory Model 1️⃣
+* A diffusion-based universal trajectory prediction framework designed to bridge the performance gap across five tasks.
+* A Singular space is constructed to unify various representations of human dynamics in the associated tasks.
+* An adaptive anchor and cascaded denoising process correct initial prototype paths that are placed incorrectly.
+* Our model outperforms on five public benchmarks: Deterministic, Stochastic, Domain Adaptation, Momentary Observation, and Few-Shot.
+
+
+
+## Model Training
+### Setup
+**Environment**
+
All models were trained and tested on Ubuntu 20.04 with Python 3.8 and PyTorch 2.0.1 with CUDA 11.7.
+
+**Dataset**
+
Preprocessed [ETH](https://data.vision.ee.ethz.ch/cvl/aem/ewap_dataset_full.tgz) and [UCY](https://graphics.cs.ucy.ac.cy/research/downloads/crowd-data) datasets are released in this repository.
+The train/validation/test splits are the same as those fond in [Social-GAN](https://github.com/agrimgupta92/sgan).
+
+You can download the dataset by running the following script.
+```bash
+./scripts/download_datasets.sh
+```
+
+### Train SingularTrajectory
+To train our SingularTrajectory on each task using the ETH and UCY datasets simultaneously, we provide a bash script `train.sh` for simplified execution.
+```bash
+./scripts/train.sh -p -t -d -i
+```
+
+**Examples**
+```bash
+# Stochastic prediction task
+./script/train.sh -p stochastic/singulartrajectory -t SingularTrajectory-stochastic
+
+# Deterministic prediction task
+./script/train.sh -p deterministic/singulartrajectory -t SingularTrajectory-deterministic
+
+# Momentary observation task
+./script/train.sh -p momentary/singulartrajectory -t SingularTrajectory-momentary
+
+# Domain adaptation task
+./script/train.sh -p domain/singulartrajectory -t SingularTrajectory-domain -d "A2B A2C A2D A2E B2A B2C B2D B2E C2A C2B C2D C2E D2A D2B D2C D2E E2A E2B E2C E2D" -i "0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4"
+
+# Few-shot task
+./script/train.sh -p fewshot/singulartrajectory -t SingularTrajectory-fewshot
+
+# (Optional) Stochastic domain adaptation task
+./script/train.sh -p domain-stochastic/singulartrajectory -t SingularTrajectory-domain-stochastic -d "A2B A2C A2D A2E B2A B2C B2D B2E C2A C2B C2D C2E D2A D2B D2C D2E E2A E2B E2C E2D" -i "0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4"
+
+# (Optional) All-in-one task
+./script/train.sh -p allinone/singulartrajectory -t SingularTrajectory-allinone -d "A2B A2C A2D A2E B2A B2C B2D B2E C2A C2B C2D C2E D2A D2B D2C D2E E2A E2B E2C E2D" -i "0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4"
+```
+If you want to train the model with custom hyperparameters, use `trainval.py` instead of the script file.
+```bash
+python trainval.py --cfg ./config/{task}/singulartrajectory-transformerdiffusion-{dataset}.json --tag SingularTrajectory-{task} --gpu_id 0
+```
+
+
+
+## Model Evaluation
+### Pretrained Models
+We provide pretrained models in the [**release section**](https://github.com/InhwanBae/SingularTrajectory/releases/tag/v1.0).
+You can download all pretrained models at once by running the script. This will download the 80 SingularTrajectory models.
+```bash
+./scripts/download_pretrained_models.sh
+```
+
+### Evaluate SingularTrajectory
+To evaluate our EigenTrajectory at once, we provide a bash script `test.sh` for a simplified execution.
+```bash
+./scripts/test.sh -p -t -d -i
+```
+
+**Examples**
+
+```bash
+# Stochastic prediction task
+./script/test.sh -p stochastic/singulartrajectory -t SingularTrajectory-stochastic
+
+# Deterministic prediction task
+./script/test.sh -p deterministic/singulartrajectory -t SingularTrajectory-deterministic
+
+# Momentary observation task
+./script/test.sh -p momentary/singulartrajectory -t SingularTrajectory-momentary
-- [x] 3/27 Paper [released](https://arxiv.org/abs/2403.18452)!
-- [ ] Source code will be released shortly. (Please check our [project page](https://inhwanbae.github.io/publication/singulartrajectory/) :)
+# Domain adaptation task
+./script/test.sh -p domain/singulartrajectory -t SingularTrajectory-domain -d "A2B A2C A2D A2E B2A B2C B2D B2E C2A C2B C2D C2E D2A D2B D2C D2E E2A E2B E2C E2D" -i "0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4"
+# Few-shot task
+./script/test.sh -p fewshot/singulartrajectory -t SingularTrajectory-fewshot
+
+# (Optional) Stochastic domain adaptation task
+./script/test.sh -p domain-stochastic/singulartrajectory -t SingularTrajectory-domain-stochastic -d "A2B A2C A2D A2E B2A B2C B2D B2E C2A C2B C2D C2E D2A D2B D2C D2E E2A E2B E2C E2D" -i "0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4"
+
+# (Optional) All-in-one task
+./script/test.sh -p allinone/singulartrajectory -t SingularTrajectory-allinone -d "A2B A2C A2D A2E B2A B2C B2D B2E C2A C2B C2D C2E D2A D2B D2C D2E E2A E2B E2C E2D" -i "0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4"
+```
+
+If you want to evaluate the model individually, you can use `trainval.py` with custom hyperparameters.
+```bash
+python trainval.py --test --cfg ./config/{task}/singulartrajectory-transformerdiffusion-{dataset}.json --tag SingularTrajectory-{task} --gpu_id 0
+```
+
+
## 📖 Citation
If you find this code useful for your research, please cite our trajectory prediction papers :)
@@ -51,7 +149,7 @@ If you find this code useful for your research, please cite our trajectory predi
year={2024}
}
```
-
+
More Information (Click to expand)
```bibtex
@@ -99,4 +197,8 @@ If you find this code useful for your research, please cite our trajectory predi
```
+### Acknowledgement
+Part of our code is borrowed from [EigenTrajectory](https://github.com/InhwanBae/EigenTrajectory) and [LED](https://github.com/MediaBrain-SJTU/LED).
+We thank the authors for releasing their code and models.
+
diff --git a/SingularTrajectory/__init__.py b/SingularTrajectory/__init__.py
new file mode 100644
index 0000000..32cfa27
--- /dev/null
+++ b/SingularTrajectory/__init__.py
@@ -0,0 +1,2 @@
+from .model import SingularTrajectory
+from .normalizer import TrajNorm
diff --git a/SingularTrajectory/anchor.py b/SingularTrajectory/anchor.py
new file mode 100644
index 0000000..bf7893c
--- /dev/null
+++ b/SingularTrajectory/anchor.py
@@ -0,0 +1,142 @@
+import torch
+import torch.nn as nn
+from .kmeans import BatchKMeans
+from sklearn.cluster import KMeans
+import numpy as np
+from .homography import image2world, world2image
+
+
+class AdaptiveAnchor(nn.Module):
+ r"""Adaptive anchor model
+
+ Args:
+ hyper_params (DotDict): The hyper-parameters
+ """
+
+ def __init__(self, hyper_params):
+ super().__init__()
+
+ self.hyper_params = hyper_params
+ self.k = hyper_params.k
+ self.s = hyper_params.num_samples
+ self.dim = hyper_params.traj_dim
+
+ self.C_anchor = nn.Parameter(torch.zeros((self.k, self.s)))
+
+ def to_Singular_space(self, traj, evec):
+ r"""Transform Euclidean trajectories to Singular space coordinates
+
+ Args:
+ traj (torch.Tensor): The trajectory to be transformed
+ evec (torch.Tensor): The Singular space basis vectors
+
+ Returns:
+ C (torch.Tensor): The Singular space coordinates"""
+
+ # Euclidean space -> Singular space
+ tdim = evec.size(0)
+ M = traj.reshape(-1, tdim).T
+ C = evec.T.detach() @ M
+ return C
+
+ def batch_to_Singular_space(self, traj, evec):
+ # Euclidean space -> Singular space
+ tdim = evec.size(0)
+ M = traj.reshape(-1, tdim).transpose(1, 2)
+ C = evec.T.detach() @ M
+ return C
+
+ def to_Euclidean_space(self, C, evec):
+ r"""Transform Singular space coordinates to Euclidean trajectories
+
+ Args:
+ C (torch.Tensor): The Singular space coordinates
+ evec (torch.Tensor): The Singular space basis vectors
+
+ Returns:
+ traj (torch.Tensor): The Euclidean trajectory"""
+
+ # Singular space -> Euclidean
+ t = evec.size(0) // self.dim
+ M = evec.detach() @ C
+ traj = M.T.reshape(-1, t, self.dim)
+ return traj
+
+ def batch_to_Euclidean_space(self, C, evec):
+ # Singular space -> Euclidean
+ b = C.size(0)
+ t = evec.size(0) // self.dim
+ M = evec.detach() @ C
+ traj = M.transpose(1, 2).reshape(b, -1, t, self.dim)
+ return traj
+
+ def anchor_initialization(self, pred_traj_norm, V_pred_trunc):
+ r"""Anchor initialization on Singular space
+
+ Args:
+ pred_traj_norm (torch.Tensor): The normalized predicted trajectory
+ V_pred_trunc (torch.Tensor): The truncated Singular space basis vectors of the predicted trajectory
+
+ Note:
+ This function should be called once before training the model.
+ """
+
+ # Trajectory projection
+ C_pred = self.to_Singular_space(pred_traj_norm, evec=V_pred_trunc).T.detach().numpy()
+ C_anchor = torch.FloatTensor(KMeans(n_clusters=self.s, random_state=0, init='k-means++', n_init=1).fit(C_pred).cluster_centers_.T)
+
+ # Register anchors as model parameters
+ self.C_anchor = nn.Parameter(C_anchor.to(self.C_anchor.device))
+
+ def adaptive_anchor_calculation(self, obs_traj, scene_id, vector_field, homography, space):
+ r"""Adaptive anchor calculation on Singular space"""
+
+ n_ped = obs_traj.size(0)
+ V_trunc = space.V_trunc
+
+ space.traj_normalizer.calculate_params(obs_traj.cuda().detach())
+ init_anchor = self.C_anchor.unsqueeze(dim=0).repeat_interleave(repeats=n_ped, dim=0).detach()
+ init_anchor = init_anchor.permute(2, 1, 0)
+ init_anchor_euclidean = space.batch_to_Euclidean_space(init_anchor, evec=V_trunc)
+ init_anchor_euclidean = space.traj_normalizer.denormalize(init_anchor_euclidean).cpu().numpy()
+ adaptive_anchor_euclidean = init_anchor_euclidean.copy()
+ obs_traj = obs_traj.cpu().numpy()
+
+ for ped_id in range(n_ped):
+ scene_name = scene_id[ped_id]
+ prototype_image = world2image(init_anchor_euclidean[:, ped_id], homography[scene_name])
+ startpoint_image = world2image(obs_traj[ped_id], homography[scene_name])[-1]
+ endpoint_image = prototype_image[:, -1, :]
+ endpoint_image = np.round(endpoint_image).astype(int)
+ size = np.array(vector_field[scene_name].shape[1::-1]) // 2
+ endpoint_image = np.clip(endpoint_image, a_min= -size // 2, a_max=size + size // 2 -1)
+ for s in range(self.s):
+ vector = np.array(vector_field[scene_name][endpoint_image[s, 1] + size[1] // 2, endpoint_image[s, 0] + size[0] // 2])[::-1] - size // 2
+ if vector[0] == endpoint_image[s, 0] and vector[1] == endpoint_image[s, 1]:
+ continue
+ else:
+ nearest_endpoint_image = vector
+ scale_xy = (nearest_endpoint_image - startpoint_image) / (endpoint_image[s] - startpoint_image)
+ prototype_image[s, :, :] = (prototype_image[s, :, :].copy() - startpoint_image) * scale_xy + startpoint_image
+
+ prototype_world = image2world(prototype_image, homography[scene_name])
+ adaptive_anchor_euclidean[:, ped_id] = prototype_world
+
+ adaptive_anchor_euclidean = space.traj_normalizer.normalize(torch.FloatTensor(adaptive_anchor_euclidean).cuda())
+ adaptive_anchor = space.batch_to_Singular_space(adaptive_anchor_euclidean, evec=V_trunc)
+ adaptive_anchor = adaptive_anchor.permute(2, 1, 0).cpu()
+ # If you don't want to use an image, return `init_anchor`.
+ return adaptive_anchor
+
+ def forward(self, C_residual, C_anchor):
+ r"""Anchor refinement on Singular space
+
+ Args:
+ C_residual (torch.Tensor): The predicted Singular space coordinates
+
+ Returns:
+ C_pred_refine (torch.Tensor): The refined Singular space coordinates
+ """
+
+ C_pred_refine = C_anchor.detach() + C_residual
+ return C_pred_refine
diff --git a/SingularTrajectory/homography.py b/SingularTrajectory/homography.py
new file mode 100644
index 0000000..bc36b8d
--- /dev/null
+++ b/SingularTrajectory/homography.py
@@ -0,0 +1,102 @@
+import numpy as np
+import torch
+
+
+def image2world(coord, H):
+ r"""Convert image coordinates to world coordinates.
+
+ Args:
+ coord (np.ndarray or torch.tensor): Image coordinates, shape (..., 2).
+ H (np.ndarray or torch.tensor): Homography matrix, shape (3, 3).
+
+ Returns:
+ np.ndarray: World coordinates.
+ """
+
+ assert coord.shape[-1] == 2
+ assert H.shape == (3, 3)
+ assert type(coord) == type(H)
+
+ shape = coord.shape
+ coord = coord.reshape(-1, 2)
+
+ if isinstance(coord, np.ndarray):
+ x, y = coord[..., 0], coord[..., 1]
+ world = (H @ np.stack([x, y, np.ones_like(x)], axis=-1).T).T
+ world = world / world[..., [2]]
+ world = world[..., :2]
+
+ elif isinstance(coord, torch.Tensor):
+ x, y = coord[..., 0], coord[..., 1]
+ world = (H @ torch.stack([x, y, torch.ones_like(x)], dim=-1).T).T
+ world = world / world[..., [2]]
+ world = world[..., :2]
+
+
+ else:
+ raise NotImplementedError
+
+ return world.reshape(shape)
+
+
+def world2image(coord, H, transpose=False):
+ r"""Convert world coordinates to image coordinates.
+
+ Args:
+ coord (np.ndarray or torch.tensor): World coordinates, shape (..., 2).
+ H (np.ndarray or torch.tensor): Homography matrix, shape (3, 3).
+
+ Returns:
+ np.ndarray: Image coordinates.
+ """
+
+ assert coord.shape[-1] == 2
+ assert H.shape == (3, 3)
+ assert type(coord) == type(H)
+
+ shape = coord.shape
+ coord = coord.reshape(-1, 2)
+
+ if isinstance(coord, np.ndarray):
+ x, y = coord[..., 0], coord[..., 1]
+ image = (np.linalg.inv(H) @ np.stack([x, y, np.ones_like(x)], axis=-1).T).T
+ image = image / image[..., [2]]
+ image = image[..., :2]
+
+ elif isinstance(coord, torch.Tensor):
+ x, y = coord[..., 0], coord[..., 1]
+ image = (torch.linalg.inv(H) @ torch.stack([x, y, torch.ones_like(x)], dim=-1).T).T
+ image = image / image[..., [2]]
+ image = image[..., :2]
+
+ else:
+ raise NotImplementedError
+
+ return image.reshape(shape)
+
+
+def generate_homography(shift_w: float=0, shift_h: float=0, rotate: float=0, scale: float=1):
+ r"""Generate a homography matrix.
+
+ Args:
+ shift (float): Shift in x and y direction.
+ rotate (float): Rotation angle in radian.
+ scale (float): Scale factor.
+
+ Returns:
+ np.ndarray: Homography matrix, shape (3, 3).
+ """
+
+ H = np.eye(3)
+ H[0, 2] = shift_w
+ H[1, 2] = shift_h
+ H[2, 2] = scale
+
+ if rotate != 0:
+ # rotation matrix
+ R = np.array([[np.cos(rotate), -np.sin(rotate), 0],
+ [np.sin(rotate), np.cos(rotate), 0],
+ [0, 0, 1]])
+ H = H @ R
+
+ return H
diff --git a/SingularTrajectory/kmeans.py b/SingularTrajectory/kmeans.py
new file mode 100644
index 0000000..7076452
--- /dev/null
+++ b/SingularTrajectory/kmeans.py
@@ -0,0 +1,279 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from time import time
+
+
+class BatchKMeans(nn.Module):
+ r"""Run multiple independent K-means algorithms in parallel.
+
+ Args:
+ n_clusters (int): Number of clusters
+ max_iter (int): Maximum number of iterations (default: 100)
+ tol (float): Tolerance (default: 0.0001)
+ n_redo (int): Number of time k-means will be run with differently initialized centroids.
+ the centroids with the lowest inertia will be selected as a final result. (default: 1)
+ init_mode (str): Initialization method.
+ 'random': randomly chose initial centroids from input data.
+ 'kmeans++': use k-means++ algorithm to initialize centroids. (default: 'kmeans++')
+ """
+
+ def __init__(self, n_clusters, n_redo=1, max_iter=100, tol=1e-4, init_mode="kmeans++", verbose=False):
+ super(BatchKMeans, self).__init__()
+ self.n_redo = n_redo
+ self.n_clusters = n_clusters
+ self.max_iter = max_iter
+ self.tol = tol
+ self.init_mode = init_mode
+ self.verbose = verbose
+
+ self.register_buffer("centroids", None)
+
+ def load_state_dict(self, state_dict, **kwargs):
+ r"""Override the default load_state_dict() to load custom buffers."""
+
+ for k, v in state_dict.items():
+ if "." not in k:
+ assert hasattr(self, k), f"attribute {k} does not exist"
+ delattr(self, k)
+ self.register_buffer(k, v)
+
+ for name, module in self.named_children():
+ sd = {k.replace(name + ".", ""): v for k, v in state_dict.items() if k.startswith(name + ".")}
+ module.load_state_dict(sd)
+
+ @staticmethod
+ def calculate_error(a, b):
+ r"""Compute L2 error between a and b"""
+
+ diff = a - b
+ diff.pow_(2)
+ return diff.sum()
+
+ @staticmethod
+ def calculate_inertia(a):
+ r"""Compute inertia of a"""
+
+ return (-a).mean()
+
+ @staticmethod
+ def euc_sim(a, b):
+ r"""Compute batched negative squared Euclidean distance between 'a' and 'b'
+
+ Args:
+ a (torch.Tensor): Vector of shape (..., d_vector, m)
+ b (torch.Tensor): Vector of shape (..., d_vector, n)
+
+ Returns:
+ y (torch.Tensor): Vector of shape (..., m, n)
+ """
+
+ y = a.transpose(-2, -1) @ b
+ y.mul_(2)
+ y.sub_(a.pow(2).sum(dim=-2)[..., :, None])
+ y.sub_(b.pow(2).sum(dim=-2)[..., None, :])
+
+ return y
+
+ def kmeanspp(self, data):
+ r"""Initialize centroids with k-means++ algorithm
+
+ Args:
+ data (torch.Tensor): Vector of shape (..., d_vector, n_data)
+
+ Returns:
+ centroids (torch.Tensor): Vector of shape (..., d_vector, n_clusters)
+ """
+
+ d_vector, n_data = data.shape[-2:]
+ centroids = torch.zeros(*data.shape[:-2], d_vector, self.n_clusters, device=data.device, dtype=data.dtype)
+
+ # Select initial centroid
+ centroids[..., 0] = data[..., np.random.randint(n_data)]
+ for i in range(1, self.n_clusters):
+ current_centroids = centroids[..., :i].contiguous()
+ sims = self.euc_sim(data, current_centroids)
+ max_sims_v, max_sims_i = sims.max(dim=-1)
+ index = max_sims_v.argmin(dim=-1) # (batch,)
+
+ if data.dim() == 2:
+ new_centroid = data[:, index]
+ elif data.dim() == 3:
+ arange = torch.arange(data.size(0), device=data.device)
+ new_centroid = data[arange, :, index] # (batch, d_vector)
+ elif data.dim() == 4:
+ arange_w = torch.arange(data.size(0), device=data.device).unsqueeze(dim=1)
+ arange_h = torch.arange(data.size(1), device=data.device).unsqueeze(dim=0)
+ new_centroid = data[arange_w, arange_h, :, index]
+ else:
+ raise NotImplementedError
+
+ centroids[..., i] = new_centroid
+ return centroids
+
+ def initialize_centroids(self, data):
+ r"""
+ Initialize centroids with init_method specified in __init__
+
+ Args:
+ data (torch.Tensor) Vector of shape (..., d_vector, n_data)
+
+ Returns:
+ centroids (torch.Tensor) Vector of shape (..., d_vector, n_clusters)
+ """
+ n_data = data.size(-1)
+ if self.init_mode == "random":
+ random_index = np.random.choice(n_data, size=[self.n_clusters], replace=False)
+ centroids = data[:, :, random_index].clone()
+
+ if self.verbose:
+ print("centroids are randomly initialized.")
+
+ elif self.init_mode == "kmeans++":
+ centroids = self.kmeanspp(data).clone()
+
+ if self.verbose:
+ print("centroids are initialized with kmeans++.")
+
+ else:
+ raise NotImplementedError
+
+ return centroids
+
+ def get_labels(self, data, centroids):
+ r"""Compute labels of data
+
+ Args:
+ data (torch.Tensor): Vector of shape (..., d_vector, n_data)
+ centroids (torch.Tensor): Vector of shape (..., d_vector, n_clusters)
+
+ Returns:
+ maxsims (torch.Tensor): Vector of shape (..., n_data)
+ labels (torch.Tensor): Vector of shape (..., n_data)
+ """
+
+ sims = self.euc_sim(data, centroids)
+ maxsims, labels = sims.max(dim=-1)
+
+ return maxsims, labels
+
+ def compute_centroids_loop(self, data, labels):
+ r"""Compute centroids of data
+
+ Args:
+ data (torch.Tensor): Vector of shape (..., d_vector, n_data)
+ labels (torch.Tensor): Vector of shape (..., n_data)
+
+ Returns:
+ centroids (torch.Tensor): Vector of shape (..., d_vector, n_clusters)
+ """
+
+ ### Naive method with loop ###
+ # l, d, m = data.shape
+ # centroids = torch.zeros(l, d, self.n_clusters, device=data.device, dtype=data.dtype)
+ # for j in range(l):
+ # unique_labels, counts = labels[j].unique(return_counts=True)
+ # for i, count in zip(unique_labels, counts):
+ # centroids[j, :, i] = data[j, :, labels[j] == i].sum(dim=1) / count
+
+ ### Fastest method ###
+ mask = [labels == i for i in range(self.n_clusters)]
+ mask = torch.stack(mask, dim=-1) # (..., d_vector, n_clusters)
+ centroids = (data.unsqueeze(dim=-1) * mask.unsqueeze(dim=-3)).sum(dim=-2) / mask.sum(dim=-2, keepdim=True)
+
+ return centroids
+
+ def compute_centroids(self, data, labels):
+ r"""Compute centroids of data
+
+ Args:
+ data (torch.Tensor): Vector of shape (..., d_vector, n_data)
+ labels (torch.Tensor): Vector of shape (..., n_data)
+
+ Returns:
+ centroids (torch.Tensor): Vector of shape (..., d_vector, n_clusters)
+ """
+
+ centroids = self.compute_centroids_loop(data, labels)
+ return centroids
+
+ def fit(self, data, centroids=None):
+ r"""Perform K-means clustering, and return final labels
+
+ Args:
+ data (torch.Tensor): data to be clustered, shape (l, d_vector, n_data)
+ centroids (torch.Tensor): initial centroids, shape (l, d_vector, n_clusters)
+
+ Returns:
+ best_labels (torch.Tensor): final labels, shape (l, n_data)
+ """
+
+ assert data.is_contiguous(), "use .contiguous()"
+
+ best_centroids = None
+ best_error = 1e32
+ best_labels = None
+ best_inertia = 1e32
+
+ if self.verbose:
+ tm = time()
+
+ for i in range(self.n_redo):
+ if self.verbose:
+ tm_i = time()
+
+ if centroids is None:
+ centroids = self.initialize_centroids(data)
+
+ for j in range(self.max_iter):
+ # clustering iteration
+ maxsims, labels = self.get_labels(data, centroids)
+ new_centroids = self.compute_centroids(data, labels)
+ error = self.calculate_error(centroids, new_centroids)
+ centroids = new_centroids
+ inertia = self.calculate_inertia(maxsims)
+
+ if self.verbose:
+ print(f"----iteration {j} of {i}th redo, error={error.item()}, inertia={inertia.item()}")
+
+ if error <= self.tol:
+ break
+
+ if inertia < best_inertia:
+ best_centroids = centroids
+ best_error = error
+ best_labels = labels
+ best_inertia = inertia
+
+ centroids = None
+
+ if self.verbose:
+ print(
+ f"--{i}th redo finished, error: {error.item()}, inertia: {inertia.item()}time spent:{round(time() - tm_i, 4)} sec")
+
+ self.register_buffer("centroids", best_centroids)
+
+ if self.verbose:
+ print(f"finished {self.n_redo} redos in {round(time() - tm, 4)} sec, final_inertia: {best_inertia}")
+
+ return best_labels
+
+ def predict(self, query):
+ r"""Predict the closest cluster center each sample in query belongs to.
+
+ Args:
+ query (torch.Tensor): Vector of shape (l, d_vector, n_query)
+
+ Returns:
+ labels (torch.Tensor): Vector of shape (l, n_query)
+ """
+
+ _, labels = self.get_labels(query, self.centroids)
+ return labels
+
+
+if __name__ == "__main__":
+ x = torch.randn(13, 29, 2, 1000).cuda()
+ multi_k_means = BatchKMeans(n_clusters=20, n_redo=1)
+ multi_k_means.fit(x)
+ print(multi_k_means.centroids.shape)
diff --git a/SingularTrajectory/model.py b/SingularTrajectory/model.py
new file mode 100644
index 0000000..8d8f542
--- /dev/null
+++ b/SingularTrajectory/model.py
@@ -0,0 +1,153 @@
+import torch
+import torch.nn as nn
+from .anchor import AdaptiveAnchor
+from .space import SingularSpace
+
+
+class SingularTrajectory(nn.Module):
+ r"""The SingularTrajectory model
+
+ Args:
+ baseline_model (nn.Module): The baseline model
+ hook_func (dict): The bridge functions for the baseline model
+ hyper_params (DotDict): The hyper-parameters
+ """
+
+ def __init__(self, baseline_model, hook_func, hyper_params):
+ super().__init__()
+
+ self.baseline_model = baseline_model
+ self.hook_func = hook_func
+ self.hyper_params = hyper_params
+ self.t_obs, self.t_pred = hyper_params.obs_len, hyper_params.pred_len
+ self.obs_svd, self.pred_svd = hyper_params.obs_svd, hyper_params.pred_svd
+ self.k = hyper_params.k
+ self.s = hyper_params.num_samples
+ self.dim = hyper_params.traj_dim
+ self.static_dist = hyper_params.static_dist
+
+ self.Singular_space_m = SingularSpace(hyper_params=hyper_params, norm_sca=True)
+ self.Singular_space_s = SingularSpace(hyper_params=hyper_params, norm_sca=False)
+ self.adaptive_anchor_m = AdaptiveAnchor(hyper_params=hyper_params)
+ self.adaptive_anchor_s = AdaptiveAnchor(hyper_params=hyper_params)
+
+ def calculate_parameters(self, obs_traj, pred_traj):
+ r"""Generate the Sinuglar space of the SingularTrajectory model
+
+ Args:
+ obs_traj (torch.Tensor): The observed trajectory
+ pred_traj (torch.Tensor): The predicted trajectory
+
+ Note:
+ This function should be called once before training the model.
+ """
+
+ # Mask out static trajectory
+ mask = self.calculate_mask(obs_traj)
+ obs_m_traj, pred_m_traj = obs_traj[mask], pred_traj[mask]
+ obs_s_traj, pred_s_traj = obs_traj[~mask], pred_traj[~mask]
+
+ # Descriptor initialization
+ data_m = self.Singular_space_m.parameter_initialization(obs_m_traj, pred_m_traj)
+ data_s = self.Singular_space_s.parameter_initialization(obs_s_traj, pred_s_traj)
+
+ # Anchor initialization
+ self.adaptive_anchor_m.anchor_initialization(*data_m)
+ self.adaptive_anchor_s.anchor_initialization(*data_s)
+
+ def calculate_adaptive_anchor(self, dataset):
+ obs_traj, pred_traj = dataset.obs_traj, dataset.pred_traj
+ scene_id = dataset.scene_id
+ vector_field = dataset.vector_field
+ homography = dataset.homography
+
+ # Mask out static trajectory
+ mask = self.calculate_mask(obs_traj)
+ obs_m_traj, scene_id_m = obs_traj[mask], scene_id[mask]
+ obs_s_traj, scene_id_s = obs_traj[~mask], scene_id[~mask]
+
+ n_ped = pred_traj.size(0)
+ anchor = torch.zeros((n_ped, self.k, self.s), dtype=torch.float)
+ anchor[mask] = self.adaptive_anchor_m.adaptive_anchor_calculation(obs_m_traj, scene_id_m, vector_field, homography, self.Singular_space_m)
+ anchor[~mask] = self.adaptive_anchor_s.adaptive_anchor_calculation(obs_s_traj, scene_id_s, vector_field, homography, self.Singular_space_s)
+
+ return anchor
+
+ def calculate_mask(self, obs_traj):
+ if obs_traj.size(1) <= 2:
+ mask = (obs_traj[:, -1] - obs_traj[:, -2]).div(1).norm(p=2, dim=-1) > self.static_dist
+ else:
+ mask = (obs_traj[:, -1] - obs_traj[:, -3]).div(2).norm(p=2, dim=-1) > self.static_dist
+ return mask
+
+ def forward(self, obs_traj, adaptive_anchor, pred_traj=None, addl_info=None):
+ r"""The forward function of the SingularTrajectory model
+
+ Args:
+ obs_traj (torch.Tensor): The observed trajectory
+ pred_traj (torch.Tensor): The predicted trajectory (optional, for training only)
+ addl_info (dict): The additional information (optional, if baseline model requires)
+
+ Returns:
+ output (dict): The output of the model (recon_traj, loss, etc.)
+ """
+
+ n_ped = obs_traj.size(0)
+
+ # Filter out static trajectory
+ mask = self.calculate_mask(obs_traj)
+ obs_m_traj = obs_traj[mask]
+ obs_s_traj = obs_traj[~mask]
+ pred_m_traj_gt = pred_traj[mask] if pred_traj is not None else None
+ pred_s_traj_gt = pred_traj[~mask] if pred_traj is not None else None
+
+ # Projection
+ C_m_obs, C_m_pred_gt = self.Singular_space_m.projection(obs_m_traj, pred_m_traj_gt)
+ C_s_obs, C_s_pred_gt = self.Singular_space_s.projection(obs_s_traj, pred_s_traj_gt)
+ C_obs = torch.zeros((self.k, n_ped), dtype=torch.float, device=obs_traj.device)
+ C_obs[:, mask], C_obs[:, ~mask] = C_m_obs, C_s_obs
+
+ # Absolute coordinate
+ obs_m_ori = self.Singular_space_m.traj_normalizer.traj_ori.squeeze(dim=1).T
+ obs_s_ori = self.Singular_space_s.traj_normalizer.traj_ori.squeeze(dim=1).T
+ obs_ori = torch.zeros((2, n_ped), dtype=torch.float, device=obs_traj.device)
+ obs_ori[:, mask], obs_ori[:, ~mask] = obs_m_ori, obs_s_ori
+ obs_ori -= obs_ori.mean(dim=1, keepdim=True)
+
+ ### Adaptive anchor per agent
+ C_anchor = adaptive_anchor.permute(1, 0, 2)
+ addl_info["anchor"] = C_anchor.clone()
+
+ # Trajectory prediction
+ input_data = self.hook_func.model_forward_pre_hook(C_obs, obs_ori, addl_info)
+ output_data = self.hook_func.model_forward(input_data, self.baseline_model)
+ C_pred_refine = self.hook_func.model_forward_post_hook(output_data, addl_info) * 0.1
+
+ C_m_pred = self.adaptive_anchor_m(C_pred_refine[:, mask], C_anchor[:, mask])
+ C_s_pred = self.adaptive_anchor_s(C_pred_refine[:, ~mask], C_anchor[:, ~mask])
+
+ # Reconstruction
+ pred_m_traj_recon = self.Singular_space_m.reconstruction(C_m_pred)
+ pred_s_traj_recon = self.Singular_space_s.reconstruction(C_s_pred)
+ pred_traj_recon = torch.zeros((self.s, n_ped, self.t_pred, self.dim), dtype=torch.float, device=obs_traj.device)
+ pred_traj_recon[:, mask], pred_traj_recon[:, ~mask] = pred_m_traj_recon, pred_s_traj_recon
+
+ output = {"recon_traj": pred_traj_recon}
+
+ if pred_traj is not None:
+ C_pred = torch.zeros((self.k, n_ped, self.s), dtype=torch.float, device=obs_traj.device)
+ C_pred[:, mask], C_pred[:, ~mask] = C_m_pred, C_s_pred
+
+ # Low-rank approximation for gt trajectory
+ C_pred_gt = torch.zeros((self.k, n_ped), dtype=torch.float, device=obs_traj.device)
+ C_pred_gt[:, mask], C_pred_gt[:, ~mask] = C_m_pred_gt, C_s_pred_gt
+ C_pred_gt = C_pred_gt.detach()
+
+ # Loss calculation
+ error_coefficient = (C_pred - C_pred_gt.unsqueeze(dim=-1)).norm(p=2, dim=0)
+ error_displacement = (pred_traj_recon - pred_traj.unsqueeze(dim=0)).norm(p=2, dim=-1)
+ output["loss_eigentraj"] = error_coefficient.min(dim=-1)[0].mean()
+ output["loss_euclidean_ade"] = error_displacement.mean(dim=-1).min(dim=0)[0].mean()
+ output["loss_euclidean_fde"] = error_displacement[:, :, -1].min(dim=0)[0].mean()
+
+ return output
diff --git a/SingularTrajectory/normalizer.py b/SingularTrajectory/normalizer.py
new file mode 100644
index 0000000..c1a13c8
--- /dev/null
+++ b/SingularTrajectory/normalizer.py
@@ -0,0 +1,68 @@
+import torch
+
+
+class TrajNorm:
+ r"""Normalize trajectory with shape (num_peds, length_of_time, 2)
+
+ Args:
+ ori (bool): Whether to normalize the trajectory with the origin
+ rot (bool): Whether to normalize the trajectory with the rotation
+ sca (bool): Whether to normalize the trajectory with the scale
+ """
+
+ def __init__(self, ori=True, rot=True, sca=True):
+ self.ori, self.rot, self.sca = ori, rot, sca
+ self.traj_ori, self.traj_rot, self.traj_sca = None, None, None
+
+ def calculate_params(self, traj):
+ r"""Calculate the normalization parameters"""
+
+ if self.ori:
+ self.traj_ori = traj[:, [-1]]
+ if self.rot:
+ if traj.size(1) <= 2:
+ dir = traj[:, -1] - traj[:, -2]
+ else:
+ dir = traj[:, -1] - traj[:, -3]
+ rot = torch.atan2(dir[:, 1], dir[:, 0])
+ self.traj_rot = torch.stack([torch.stack([rot.cos(), -rot.sin()], dim=1),
+ torch.stack([rot.sin(), rot.cos()], dim=1)], dim=1)
+ if self.sca:
+ if traj.size(1) <= 2:
+ self.traj_sca = 1. / (traj[:, -1] - traj[:, -2]).norm(p=2, dim=-1)[:, None, None] * 1
+ else:
+ self.traj_sca = 1. / (traj[:, -1] - traj[:, -3]).norm(p=2, dim=-1)[:, None, None] * 2
+ # self.traj_sca[self.traj_sca.isnan() | self.traj_sca.isinf()] = 1e2
+
+ def get_params(self):
+ r"""Get the normalization parameters"""
+
+ return self.ori, self.rot, self.sca, self.traj_ori, self.traj_rot, self.traj_sca
+
+ def set_params(self, ori, rot, sca, traj_ori, traj_rot, traj_sca):
+ r"""Set the normalization parameters"""
+
+ self.ori, self.rot, self.sca = ori, rot, sca
+ self.traj_ori, self.traj_rot, self.traj_sca = traj_ori, traj_rot, traj_sca
+
+ def normalize(self, traj):
+ r"""Normalize the trajectory"""
+
+ if self.ori:
+ traj = traj - self.traj_ori
+ if self.rot:
+ traj = traj @ self.traj_rot
+ if self.sca:
+ traj = traj * self.traj_sca
+ return traj
+
+ def denormalize(self, traj):
+ r"""Denormalize the trajectory"""
+
+ if self.sca:
+ traj = traj / self.traj_sca
+ if self.rot:
+ traj = traj @ self.traj_rot.transpose(-1, -2)
+ if self.ori:
+ traj = traj + self.traj_ori
+ return traj
diff --git a/SingularTrajectory/space.py b/SingularTrajectory/space.py
new file mode 100644
index 0000000..51d3c6f
--- /dev/null
+++ b/SingularTrajectory/space.py
@@ -0,0 +1,214 @@
+import torch
+import torch.nn as nn
+from .normalizer import TrajNorm
+import numpy as np
+from sklearn.cluster import KMeans
+from scipy.interpolate import BSpline
+
+
+class SingularSpace(nn.Module):
+ r"""Singular space model
+
+ Args:
+ hyper_params (DotDict): The hyper-parameters
+ norm_ori (bool): Whether to normalize the trajectory with the origin
+ norm_rot (bool): Whether to normalize the trajectory with the rotation
+ norm_sca (bool): Whether to normalize the trajectory with the scale"""
+
+ def __init__(self, hyper_params, norm_ori=True, norm_rot=True, norm_sca=True):
+ super().__init__()
+
+ self.hyper_params = hyper_params
+ self.t_obs, self.t_pred = hyper_params.obs_len, hyper_params.pred_len
+ self.obs_svd, self.pred_svd = hyper_params.obs_svd, hyper_params.pred_svd
+ self.k = hyper_params.k
+ self.s = hyper_params.num_samples
+ self.dim = hyper_params.traj_dim
+ self.traj_normalizer = TrajNorm(ori=norm_ori, rot=norm_rot, sca=norm_sca)
+
+ self.V_trunc = nn.Parameter(torch.zeros((self.t_pred * self.dim, self.k)))
+ self.V_obs_trunc = nn.Parameter(torch.zeros((self.t_obs * self.dim, self.k)))
+ self.V_pred_trunc = nn.Parameter(torch.zeros((self.t_pred * self.dim, self.k)))
+
+ def normalize_trajectory(self, obs_traj, pred_traj=None):
+ r"""Trajectory normalization
+
+ Args:
+ obs_traj (torch.Tensor): The observed trajectory
+ pred_traj (torch.Tensor): The predicted trajectory (Optional, for training only)
+
+ Returns:
+ obs_traj_norm (torch.Tensor): The normalized observed trajectory
+ pred_traj_norm (torch.Tensor): The normalized predicted trajectory
+ """
+
+ self.traj_normalizer.calculate_params(obs_traj)
+ obs_traj_norm = self.traj_normalizer.normalize(obs_traj)
+ pred_traj_norm = self.traj_normalizer.normalize(pred_traj) if pred_traj is not None else None
+ return obs_traj_norm, pred_traj_norm
+
+ def denormalize_trajectory(self, traj_norm):
+ r"""Trajectory denormalization
+
+ Args:
+ traj_norm (torch.Tensor): The trajectory to be denormalized
+
+ Returns:
+ traj (torch.Tensor): The denormalized trajectory
+ """
+
+ traj = self.traj_normalizer.denormalize(traj_norm)
+ return traj
+
+ def to_Singular_space(self, traj, evec):
+ r"""Transform Euclidean trajectories to Singular space coordinates
+
+ Args:
+ traj (torch.Tensor): The trajectory to be transformed
+ evec (torch.Tensor): The Singular space basis vectors
+
+ Returns:
+ C (torch.Tensor): The Singular space coordinates"""
+
+ # Euclidean space -> Singular space
+ tdim = evec.size(0)
+ M = traj.reshape(-1, tdim).T
+ C = evec.T.detach() @ M
+ return C
+
+ def batch_to_Singular_space(self, traj, evec):
+ # Euclidean space -> Singular space
+ tdim = evec.size(0)
+ M = traj.reshape(traj.size(0), traj.size(1), tdim).transpose(1, 2)
+ C = evec.T.detach() @ M
+ return C
+
+ def to_Euclidean_space(self, C, evec):
+ r"""Transform Singular space coordinates to Euclidean trajectories
+
+ Args:
+ C (torch.Tensor): The Singular space coordinates
+ evec (torch.Tensor): The Singular space basis vectors
+
+ Returns:
+ traj (torch.Tensor): The Euclidean trajectory"""
+
+ # Singular space -> Euclidean
+ t = evec.size(0) // self.dim
+ M = evec.detach() @ C
+ traj = M.T.reshape(-1, t, self.dim)
+ return traj
+
+ def batch_to_Euclidean_space(self, C, evec):
+ # Singular space -> Euclidean
+ b = C.size(0)
+ t = evec.size(0) // self.dim
+ M = evec.detach() @ C
+ traj = M.transpose(1, 2).reshape(b, -1, t, self.dim)
+ return traj
+
+ def truncated_SVD(self, traj, k=None, full_matrices=False):
+ r"""Truncated Singular Value Decomposition
+
+ Args:
+ traj (torch.Tensor): The trajectory to be decomposed
+ k (int): The number of singular values and vectors to be computed
+ full_matrices (bool): Whether to compute full-sized matrices
+
+ Returns:
+ U_trunc (torch.Tensor): The truncated left singular vectors
+ S_trunc (torch.Tensor): The truncated singular values
+ Vt_trunc (torch.Tensor): The truncated right singular vectors
+ """
+
+ assert traj.size(2) == self.dim # NTC
+ k = self.k if k is None else k
+
+ # Singular Value Decomposition
+ M = traj.reshape(-1, traj.size(1) * self.dim).T
+ U, S, Vt = torch.linalg.svd(M, full_matrices=full_matrices)
+
+ # Truncated SVD
+ U_trunc, S_trunc, Vt_trunc = U[:, :k], S[:k], Vt[:k, :]
+ return U_trunc, S_trunc, Vt_trunc.T
+
+ def parameter_initialization(self, obs_traj, pred_traj):
+ r"""Initialize the Singular space basis vectors parameters (for training only)
+
+ Args:
+ obs_traj (torch.Tensor): The observed trajectory
+ pred_traj (torch.Tensor): The predicted trajectory
+
+ Returns:
+ pred_traj_norm (torch.Tensor): The normalized predicted trajectory
+ V_pred_trunc (torch.Tensor): The truncated eigenvectors of the predicted trajectory
+
+ Note:
+ This function should be called once before training the model."""
+
+ # Normalize trajectory
+ obs_traj_norm, pred_traj_norm = self.normalize_trajectory(obs_traj, pred_traj)
+ V_trunc, _, _ = self.truncated_SVD(pred_traj_norm)
+
+ # Pre-calculate the transformation matrix
+ # Here, we use Irwin–Hall polynomial function
+ degree=2
+ twot_win = self.dim * self.t_pred
+ twot_hist=self.dim * self.t_obs
+ steps = np.linspace(0., 1., twot_hist)
+ knot = twot_win - degree + 1
+ knots_qu = np.concatenate([np.zeros(degree), np.linspace(0, 1, knot), np.ones(degree)])
+ C_hist = np.zeros([twot_hist, twot_win])
+ for i in range(twot_win):
+ C_hist[:, i] = BSpline(knots_qu, (np.arange(twot_win) == i).astype(float), degree, extrapolate=False)(steps)
+ C_hist = torch.FloatTensor(C_hist)
+
+ V_obs_trunc = C_hist @ V_trunc
+ V_pred_trunc = V_trunc
+
+ # Register basis vectors as model parameters
+ self.V_trunc = nn.Parameter(V_trunc.to(self.V_trunc.device))
+ self.V_obs_trunc = nn.Parameter(V_obs_trunc.to(self.V_obs_trunc.device))
+ self.V_pred_trunc = nn.Parameter(V_pred_trunc.to(self.V_pred_trunc.device))
+
+ # Reuse values for anchor generation
+ return pred_traj_norm, V_pred_trunc
+
+ def projection(self, obs_traj, pred_traj=None):
+ r"""Trajectory projection to the Singular space
+
+ Args:
+ obs_traj (torch.Tensor): The observed trajectory
+ pred_traj (torch.Tensor): The predicted trajectory (optional, for training only)
+
+ Returns:
+ C_obs (torch.Tensor): The observed trajectory in the Singular space
+ C_pred (torch.Tensor): The predicted trajectory in the Singular space (optional, for training only)
+ """
+
+ # Trajectory Projection
+ obs_traj_norm, pred_traj_norm = self.normalize_trajectory(obs_traj, pred_traj)
+ C_obs = self.to_Singular_space(obs_traj_norm, evec=self.V_obs_trunc).detach()
+ C_pred = self.to_Singular_space(pred_traj_norm, evec=self.V_pred_trunc).detach() if pred_traj is not None else None
+ return C_obs, C_pred
+
+ def reconstruction(self, C_pred):
+ r"""Trajectory reconstruction from the Singular space
+
+ Args:
+ C_pred (torch.Tensor): The predicted trajectory in the Singular space
+
+ Returns:
+ pred_traj (torch.Tensor): The predicted trajectory in the Euclidean space
+ """
+
+ C_pred = C_pred.permute(2, 0, 1)
+ pred_traj = self.batch_to_Euclidean_space(C_pred, evec=self.V_pred_trunc)
+ pred_traj = self.denormalize_trajectory(pred_traj)
+
+ return pred_traj
+
+ def forward(self, C_pred):
+ r"""Alias for reconstruction"""
+
+ return self.reconstruction(C_pred)
diff --git a/baseline/__init__.py b/baseline/__init__.py
new file mode 100644
index 0000000..25d94d0
--- /dev/null
+++ b/baseline/__init__.py
@@ -0,0 +1 @@
+from . import transformerdiffusion
diff --git a/baseline/transformerdiffusion/__init__.py b/baseline/transformerdiffusion/__init__.py
new file mode 100644
index 0000000..1732dcb
--- /dev/null
+++ b/baseline/transformerdiffusion/__init__.py
@@ -0,0 +1,2 @@
+from .model import DiffusionModel as TrajectoryPredictor
+from .bridge import model_forward_pre_hook, model_forward, model_forward_post_hook
diff --git a/baseline/transformerdiffusion/bridge.py b/baseline/transformerdiffusion/bridge.py
new file mode 100644
index 0000000..552c18a
--- /dev/null
+++ b/baseline/transformerdiffusion/bridge.py
@@ -0,0 +1,33 @@
+from collections import defaultdict
+import torch
+
+
+def model_forward_pre_hook(obs_data, obs_ori, addl_info):
+ # Pre-process input data for the baseline model
+ if obs_ori is not None:
+ obs_data = torch.cat([obs_data, obs_ori], dim=0)
+
+ scene_mask = addl_info["scene_mask"]
+ num_samples = addl_info["num_samples"]
+ anchor = addl_info["anchor"]
+
+ obs_data = torch.cat([obs_data.transpose(1, 0), anchor.permute(1, 2, 0).flatten(start_dim=1)], dim=1)
+ obs_data = obs_data.unsqueeze(dim=-1)
+
+ loc = anchor.permute(1, 2, 0).unsqueeze(dim=-1)
+ input_data = [obs_data, scene_mask, loc]
+
+ return input_data
+
+
+def model_forward(input_data, baseline_model):
+ # Forward the baseline model with input data
+ output_data = baseline_model(*input_data)
+ return output_data
+
+
+def model_forward_post_hook(output_data, addl_info=None):
+ # Post-process output data of the baseline model
+ pred_data = output_data.squeeze(dim=-1).permute(2, 0, 1)
+
+ return pred_data
diff --git a/baseline/transformerdiffusion/layers.py b/baseline/transformerdiffusion/layers.py
new file mode 100644
index 0000000..28ead28
--- /dev/null
+++ b/baseline/transformerdiffusion/layers.py
@@ -0,0 +1,143 @@
+import math
+import torch
+import torch.nn as nn
+from torch.nn import Module, Linear
+
+
+class PositionalEncoding(nn.Module):
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
+ super().__init__()
+
+ self.dropout = nn.Dropout(p=dropout)
+ pe = torch.zeros(max_len, d_model)
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0).transpose(0, 1)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x):
+ x = x + self.pe[: x.size(0), :]
+ return self.dropout(x)
+
+
+class ConcatSquashLinear(Module):
+ def __init__(self, dim_in, dim_out, dim_ctx):
+ super(ConcatSquashLinear, self).__init__()
+ self._layer = Linear(dim_in, dim_out)
+ self._hyper_bias = Linear(dim_ctx, dim_out, bias=False)
+ self._hyper_gate = Linear(dim_ctx, dim_out)
+
+ def forward(self, ctx, x):
+ gate = torch.sigmoid(self._hyper_gate(ctx))
+ bias = self._hyper_bias(ctx)
+ ret = self._layer(x) * gate + bias
+ return ret
+
+ def batch_generate(self, ctx, x):
+ gate = torch.sigmoid(self._hyper_gate(ctx))
+ bias = self._hyper_bias(ctx)
+ ret = self._layer(x) * gate + bias
+ return ret
+
+
+class GAT(nn.Module):
+ def __init__(self, in_feat=2, out_feat=64, n_head=4, dropout=0.1, skip=True):
+ super(GAT, self).__init__()
+ self.in_feat = in_feat
+ self.out_feat = out_feat
+ self.n_head = n_head
+ self.skip = skip
+ self.w = nn.Parameter(torch.Tensor(n_head, in_feat, out_feat))
+ self.a_src = nn.Parameter(torch.Tensor(n_head, out_feat, 1))
+ self.a_dst = nn.Parameter(torch.Tensor(n_head, out_feat, 1))
+ self.bias = nn.Parameter(torch.Tensor(out_feat))
+
+ self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
+ self.softmax = nn.Softmax(dim=-1)
+ self.dropout = nn.Dropout(dropout)
+
+ nn.init.xavier_uniform_(self.w, gain=1.414)
+ nn.init.xavier_uniform_(self.a_src, gain=1.414)
+ nn.init.xavier_uniform_(self.a_dst, gain=1.414)
+ nn.init.constant_(self.bias, 0)
+
+ def forward(self, h, mask):
+ h_prime = h.unsqueeze(1) @ self.w
+ attn_src = h_prime @ self.a_src
+ attn_dst = h_prime @ self.a_dst
+ attn = attn_src @ attn_dst.permute(0, 1, 3, 2)
+ attn = self.leaky_relu(attn)
+ attn = self.softmax(attn)
+ attn = self.dropout(attn)
+ attn = attn * mask if mask is not None else attn
+ out = (attn @ h_prime).sum(dim=1) + self.bias
+ if self.skip:
+ out += h_prime.sum(dim=1)
+ return out, attn
+
+
+class MLP(nn.Module):
+ def __init__(self, in_feat, out_feat, hid_feat=(1024, 512), activation=None, dropout=-1):
+ super(MLP, self).__init__()
+ dims = (in_feat, ) + hid_feat + (out_feat, )
+
+ self.layers = nn.ModuleList()
+ for i in range(len(dims) - 1):
+ self.layers.append(nn.Linear(dims[i], dims[i + 1]))
+
+ self.activation = activation if activation is not None else lambda x: x
+ self.dropout = nn.Dropout(dropout) if dropout != -1 else lambda x: x
+
+ def forward(self, x):
+ for i in range(len(self.layers)):
+ x = self.activation(x)
+ x = self.dropout(x)
+ x = self.layers[i](x)
+ return x
+
+
+class social_transformer(nn.Module):
+ def __init__(self, past_len):
+ super(social_transformer, self).__init__()
+ self.encode_past = nn.Linear(past_len*6, 256, bias=False)
+ self.layer = nn.TransformerEncoderLayer(d_model=256, nhead=2, dim_feedforward=256)
+ self.transformer_encoder = nn.TransformerEncoder(self.layer, num_layers=2)
+
+ def forward(self, h, mask):
+ h_feat = self.encode_past(h.reshape(h.size(0), -1)).unsqueeze(1)
+ h_feat_ = self.transformer_encoder(h_feat, mask)
+ h_feat = h_feat + h_feat_
+
+ return h_feat
+
+
+class st_encoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+ channel_in = 6
+ channel_out = 32
+ dim_kernel = 3
+ self.dim_embedding_key = 256
+ self.spatial_conv = nn.Conv1d(channel_in, channel_out, dim_kernel, stride=1, padding=1)
+ self.temporal_encoder = nn.GRU(channel_out, self.dim_embedding_key, 1, batch_first=True)
+ self.relu = nn.ReLU()
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.kaiming_normal_(self.spatial_conv.weight)
+ nn.init.kaiming_normal_(self.temporal_encoder.weight_ih_l0)
+ nn.init.kaiming_normal_(self.temporal_encoder.weight_hh_l0)
+ nn.init.zeros_(self.spatial_conv.bias)
+ nn.init.zeros_(self.temporal_encoder.bias_ih_l0)
+ nn.init.zeros_(self.temporal_encoder.bias_hh_l0)
+
+ def forward(self, X):
+ X_t = torch.transpose(X, 1, 2)
+ X_after_spatial = self.relu(self.spatial_conv(X_t))
+ X_embed = torch.transpose(X_after_spatial, 1, 2)
+ output_x, state_x = self.temporal_encoder(X_embed)
+ state_x = state_x.squeeze(0)
+ return state_x
+
diff --git a/baseline/transformerdiffusion/model.py b/baseline/transformerdiffusion/model.py
new file mode 100644
index 0000000..054c45c
--- /dev/null
+++ b/baseline/transformerdiffusion/model.py
@@ -0,0 +1,173 @@
+import math
+import torch
+import torch.nn as nn
+from torch.nn import Module, Linear
+import numpy as np
+from .layers import PositionalEncoding, ConcatSquashLinear
+
+
+class st_encoder(nn.Module):
+ """Transformer Denoising Model
+ codebase borrowed from https://github.com/MediaBrain-SJTU/LED"""
+ def __init__(self):
+ super().__init__()
+ channel_in = 2
+ channel_out = 32
+ dim_kernel = 3
+ self.dim_embedding_key = 256
+ self.spatial_conv = nn.Conv1d(channel_in, channel_out, dim_kernel, stride=1, padding=1)
+ self.temporal_encoder = nn.GRU(channel_out, self.dim_embedding_key, 1, batch_first=True)
+ self.relu = nn.ReLU()
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.kaiming_normal_(self.spatial_conv.weight)
+ nn.init.kaiming_normal_(self.temporal_encoder.weight_ih_l0)
+ nn.init.kaiming_normal_(self.temporal_encoder.weight_hh_l0)
+ nn.init.zeros_(self.spatial_conv.bias)
+ nn.init.zeros_(self.temporal_encoder.bias_ih_l0)
+ nn.init.zeros_(self.temporal_encoder.bias_hh_l0)
+
+ def forward(self, X):
+ X_t = torch.transpose(X, 1, 2)
+ X_after_spatial = self.relu(self.spatial_conv(X_t))
+ X_embed = torch.transpose(X_after_spatial, 1, 2)
+ output_x, state_x = self.temporal_encoder(X_embed)
+ state_x = state_x.squeeze(0)
+ return state_x
+
+
+class social_transformer(nn.Module):
+ """Transformer Denoising Model
+ codebase borrowed from https://github.com/MediaBrain-SJTU/LED"""
+ def __init__(self, cfg):
+ super(social_transformer, self).__init__()
+ self.encode_past = nn.Linear(cfg.k*cfg.s+6, 256, bias=False)
+ self.layer = nn.TransformerEncoderLayer(d_model=256, nhead=2, dim_feedforward=256)
+ self.transformer_encoder = nn.TransformerEncoder(self.layer, num_layers=2)
+
+ def forward(self, h, mask):
+ h_feat = self.encode_past(h.reshape(h.size(0), -1)).unsqueeze(1)
+ h_feat_ = self.transformer_encoder(h_feat, mask)
+ h_feat = h_feat + h_feat_
+
+ return h_feat
+
+
+class TransformerDenoisingModel(Module):
+ """Transformer Denoising Model
+ codebase borrowed from https://github.com/MediaBrain-SJTU/LED"""
+ def __init__(self, context_dim=256, cfg=None):
+ super().__init__()
+ self.context_dim = context_dim
+ self.spatial_dim = 1
+ self.temporal_dim = cfg.k
+ self.n_samples = cfg.s
+ self.encoder_context = social_transformer(cfg)
+ self.pos_emb = PositionalEncoding(d_model=2*context_dim, dropout=0.1, max_len=24)
+ self.concat1 = ConcatSquashLinear(self.n_samples*self.spatial_dim*self.temporal_dim, 2*context_dim, context_dim+3)
+ self.concat3 = ConcatSquashLinear(2*context_dim,context_dim,context_dim+3)
+ self.concat4 = ConcatSquashLinear(context_dim,context_dim//2,context_dim+3)
+ self.linear = ConcatSquashLinear(context_dim//2, self.n_samples*self.spatial_dim*self.temporal_dim, context_dim+3)
+
+ def forward(self, x, beta, context, mask):
+ batch_size = x.size(0)
+ beta = beta.view(batch_size, 1, 1)
+ time_emb = torch.cat([beta, torch.sin(beta), torch.cos(beta)], dim=-1)
+ ctx_emb = torch.cat([time_emb, context], dim=-1)
+ x = self.concat1(ctx_emb, x)
+ final_emb = x.permute(1,0,2)
+ final_emb = self.pos_emb(final_emb)
+ trans = self.transformer_encoder(final_emb).permute(1,0,2)
+ trans = self.concat3(ctx_emb, trans)
+ trans = self.concat4(ctx_emb, trans)
+ return self.linear(ctx_emb, trans)
+
+ def encode_context(self, context, mask):
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
+ context = self.encoder_context(context, mask)
+ return context
+
+ def generate_accelerate(self, x, beta, context, mask):
+ beta = beta.view(beta.size(0), 1)
+ time_emb = torch.cat([beta, torch.sin(beta), torch.cos(beta)], dim=-1)
+ ctx_emb = torch.cat([time_emb, context.view(-1, self.context_dim*self.spatial_dim)], dim=-1)
+
+ trans = self.concat1.batch_generate(ctx_emb, x.view(-1, self.n_samples*self.temporal_dim*self.spatial_dim))
+ trans = self.concat3.batch_generate(ctx_emb, trans)
+ trans = self.concat4.batch_generate(ctx_emb, trans)
+ return self.linear.batch_generate(ctx_emb, trans).view(-1, self.n_samples, self.temporal_dim, self.spatial_dim)
+
+
+class DiffusionModel(Module):
+ """Transformer Denoising Model
+ codebase borrowed from https://github.com/MediaBrain-SJTU/LED"""
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self.model = TransformerDenoisingModel(context_dim=256, cfg=cfg)
+
+ self.betas = self.make_beta_schedule(
+ schedule=self.cfg.beta_schedule, n_timesteps=self.cfg.steps,
+ start=self.cfg.beta_start, end=self.cfg.beta_end).cuda()
+
+ self.alphas = 1 - self.betas
+ self.alphas_prod = torch.cumprod(self.alphas, 0)
+ self.alphas_bar_sqrt = torch.sqrt(self.alphas_prod)
+ self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - self.alphas_prod)
+
+ def make_beta_schedule(self, schedule: str = 'linear',
+ n_timesteps: int = 1000,
+ start: float = 1e-5, end: float = 1e-2) -> torch.Tensor:
+ if schedule == 'linear':
+ betas = torch.linspace(start, end, n_timesteps)
+ elif schedule == "quad":
+ betas = torch.linspace(start ** 0.5, end ** 0.5, n_timesteps) ** 2
+ elif schedule == "sigmoid":
+ betas = torch.linspace(-6, 6, n_timesteps)
+ betas = torch.sigmoid(betas) * (end - start) + start
+ return betas
+
+ def extract(self, input, t, x):
+ shape = x.shape
+ out = torch.gather(input, 0, t.to(input.device))
+ reshape = [t.shape[0]] + [1] * (len(shape) - 1)
+ return out.reshape(*reshape)
+
+ def forward(self, past_traj, traj_mask, loc):
+ pred_traj = self.p_sample_forward(past_traj, traj_mask, loc)
+ return pred_traj
+
+ def p_sample(self, x, mask, cur_y, t, context):
+ t = torch.tensor([t]).cuda()
+ beta = self.extract(self.betas, t.repeat(x.shape[0]), cur_y)
+ eps_theta = self.model.generate_accelerate(cur_y, beta, context, mask)
+ eps_factor = ((1 - self.extract(self.alphas, t, cur_y)) / self.extract(self.one_minus_alphas_bar_sqrt, t, cur_y))
+ mean = (1 / self.extract(self.alphas, t, cur_y).sqrt()) * (cur_y - (eps_factor * eps_theta))
+
+ # Fix the random seed for reproducibility
+ if False:
+ z = torch.randn_like(cur_y).to(x.device)
+ else:
+ rng = np.random.default_rng(seed=0)
+ z = torch.Tensor(rng.normal(loc=0, scale=1.0, size=cur_y.shape)).cuda()
+
+ sigma_t = self.extract(self.betas, t, cur_y).sqrt()
+ sample = mean + sigma_t * z * 0.00001
+ return (sample)
+
+ def p_sample_forward(self, x, mask, loc):
+ prediction_total = torch.Tensor().cuda()
+
+ # Fix the random seed for reproducibility
+ if False:
+ cur_y = torch.randn_like(loc)
+ else:
+ rng = np.random.default_rng(seed=0)
+ cur_y = torch.Tensor(rng.normal(loc=0, scale=1.0, size=loc.shape)).cuda()
+
+ context = self.model.encode_context(x, mask)
+ for i in reversed(range(self.cfg.steps)):
+ cur_y = self.p_sample(x, mask, cur_y, i, context)
+ prediction_total = cur_y
+ return prediction_total
diff --git a/checkpoints/.gitkeep b/checkpoints/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/config/config_example.json b/config/config_example.json
new file mode 100644
index 0000000..d50de14
--- /dev/null
+++ b/config/config_example.json
@@ -0,0 +1,29 @@
+{
+ "dataset_dir": "./datasets/",
+ "checkpoint_dir": "./checkpoints/",
+ "task": "stochastic",
+
+ "dataset": "zara2",
+ "traj_dim": 2,
+ "obs_len": 8,
+ "obs_step": 10,
+ "pred_len": 12,
+ "pred_step": 10,
+ "skip": 1,
+
+ "k": 4,
+ "static_dist": 0.3,
+ "num_samples": 20,
+ "obs_svd": true,
+ "pred_svd": true,
+ "baseline": "transformerdiffusion",
+
+ "batch_size": 512,
+ "num_epochs": 256,
+ "lr": 0.001,
+ "weight_decay": 0.0001,
+ "clip_grad": 10,
+ "lr_schd": true,
+ "lr_schd_step": 64,
+ "lr_schd_gamma": 0.5
+}
diff --git a/datasets/.gitkeep b/datasets/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/script/download_datasets.sh b/script/download_datasets.sh
new file mode 100644
index 0000000..96bca78
--- /dev/null
+++ b/script/download_datasets.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+echo "Download ETH-UCY datasets"
+
+wget -O datasets.zip 'https://www.dropbox.com/s/8n02xqv3l9q18r1/datasets.zip?dl=0'
+unzip -q datasets.zip
+rm -rf datasets.zip
+
+wget -O datasets-domainextension.zip https://github.com/InhwanBae/SingularTrajectory/releases/download/v1.0/datasets-domainextension.zip
+unzip -q datasets-domainextension.zip
+rm -rf datasets-domainextension.zip
+
+wget -O datasets-imageextension.zip https://github.com/InhwanBae/SingularTrajectory/releases/download/v1.0/datasets-imageextension.zip
+unzip -q datasets-imageextension.zip
+rm -rf datasets-imageextension.zip
+
+echo "Done."
diff --git a/script/download_pretrained_models.sh b/script/download_pretrained_models.sh
new file mode 100644
index 0000000..fd8e8eb
--- /dev/null
+++ b/script/download_pretrained_models.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+
+task_array=("stochastic" "deterministic" "momentary" "domain" "fewshot" "domain-stochastic" "allinone")
+
+for (( i=0; i<${#task_array[@]}; i++ ))
+do
+ echo "Download pre-trained model for ${task_array[$i]} task."
+ wget -O ${task_array[$i]}.zip https://github.com/InhwanBae/SingularTrajectory/releases/download/v1.0/SingularTrajectory-${task_array[$i]}-pretrained.zip
+ unzip -q ${task_array[$i]}.zip
+ rm -rf ${task_array[$i]}.zip
+done
+
+echo "Done."
diff --git a/script/generate_vector_field.py b/script/generate_vector_field.py
new file mode 100644
index 0000000..bc7f2c5
--- /dev/null
+++ b/script/generate_vector_field.py
@@ -0,0 +1,83 @@
+import numpy as np
+from tqdm import tqdm
+
+
+def check_nonzero(a, x, y):
+ try:
+ if 0 <= x < a.shape[0] and 0 <= y < a.shape[1]:
+ return a[x, y] == 1
+ return False
+ except IndexError:
+ return False
+
+
+def nearest_nonzero_idx(a, x, y):
+ try:
+ if 0 <= x < a.shape[0] and 0 <= y < a.shape[1]:
+ if a[x, y] != 0:
+ return x, y
+ except IndexError:
+ pass
+
+ r,c = np.nonzero(a)
+ min_idx = ((r - x)**2 + (c - y)**2).argmin()
+ return r[min_idx], c[min_idx]
+
+
+def main(id):
+ # IMAGE_SCALE_DOWN = 8
+ img_file_list = ['seq_eth', 'seq_hotel', 'students003', 'crowds_zara01', 'crowds_zara02'][id]
+ print(img_file_list)
+
+ import PIL.Image as Image
+ img = Image.open(f'./datasets/image/{img_file_list}_map.png')
+ img = img.convert('RGB')
+ img = np.array(img)
+ # img = img[::IMAGE_SCALE_DOWN, ::IMAGE_SCALE_DOWN, :]
+ img = img[:, :, 0]
+ img = img > 0.5
+ img = img.astype(np.int32)
+
+ img_padded = np.pad(img, ((img.shape[0] // 2,) * 2, (img.shape[1] // 2,) * 2), 'constant', constant_values=0)
+ print(img.shape, img_padded.shape)
+ img = img_padded
+
+ vector_field = np.zeros(img.shape + (2,))
+ pbar = tqdm(total=img.shape[0] * img.shape[1])
+ for x in range(img.shape[0]):
+ for y in range(img.shape[1]):
+ vector_field[x, y] = np.array(nearest_nonzero_idx(img, x, y))
+ pbar.update(1)
+ pbar.close()
+
+ # Faster version with ProcessPoolExecutor()
+ # import concurrent.futures
+ # def nearest_nonzero_idx_wrapper(args):
+ # return nearest_nonzero_idx(img, args[0], args[1])
+
+ # vector_field_fast = np.zeros(img.shape + (2,))
+ # pbar = tqdm(total=img.shape[0] * img.shape[1])
+ # with concurrent.futures.ProcessPoolExecutor(max_workers=64) as executor:
+ # coords = []
+ # for x in range(img.shape[0]):
+ # for y in range(img.shape[1]):
+ # print(x, y)
+ # coords.append((x, y))
+ # # future = executor.submit(nearest_nonzero_idx_wrapper, (img, x, y))
+ # # vector_field_fast[x, y] = future.result()
+
+ # for coord, vector in zip(coords, executor.map(nearest_nonzero_idx_wrapper, coords)):
+ # vector_field_fast[coord[0], coord[1]] = vector
+ # pbar.update(1)
+
+ # print("allcolse:", np.allclose(vector_field, vector_field_fast))
+ # pbar.close()
+
+ np.save(f"./datasets/vectorfield/{img_file_list}_vector_field.npy", vector_field)
+ # np.savetxt(f"./datasets/vectorfield/{img_file_list}_vector_field_x.txt", vector_field[:, :, 0], fmt='%d')
+ # np.savetxt(f"./datasets/vectorfield/{img_file_list}_vector_field_y.txt", vector_field[:, :, 1], fmt='%d')
+
+
+if "__main__" == __name__:
+ for i in range(5):
+ main(id=i)
diff --git a/script/test.sh b/script/test.sh
new file mode 100644
index 0000000..7883b39
--- /dev/null
+++ b/script/test.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+echo "Start evaluation task queues"
+
+# Hyperparameters
+dataset_array=("eth" "hotel" "univ" "zara1" "zara2")
+device_id_array=(0 1 2 3 4)
+tag="SingularTrajectory-stochastic"
+config_path="./config/"
+config_prefix="stochastic/singulartrajectory"
+baseline="transformerdiffusion"
+
+# Arguments
+while getopts t:b:c:p:d:i: flag
+do
+ case "${flag}" in
+ t) tag=${OPTARG};;
+ b) baseline=${OPTARG};;
+ c) config_path=${OPTARG};;
+ p) config_prefix=${OPTARG};;
+ d) dataset_array=(${OPTARG});;
+ i) device_id_array=(${OPTARG});;
+ *) echo "usage: $0 [-t TAG] [-b BASELINE] [-c CONFIG_PATH] [-p CONFIG_PREFIX] [-d \"eth hotel univ zara1 zara2\"] [-i \"0 1 2 3 4\"]" >&2
+ exit 1 ;;
+ esac
+done
+
+if [ ${#dataset_array[@]} -ne ${#device_id_array[@]} ]
+then
+ printf "Arrays must all be same length. "
+ printf "len(dataset_array)=${#dataset_array[@]} and len(device_id_array)=${#device_id_array[@]}\n"
+ exit 1
+fi
+
+# Start test tasks
+for (( i=0; i<${#dataset_array[@]}; i++ ))
+do
+ printf "Evaluate ${dataset_array[$i]}"
+ CUDA_VISIBLE_DEVICES=${device_id_array[$i]} python3 trainval.py \
+ --cfg "${config_path}""${config_prefix}"-"${baseline}"-"${dataset_array[$i]}".json \
+ --tag "${tag}" --gpu_id ${device_id_array[$i]} --test
+done
+
+echo "Done."
\ No newline at end of file
diff --git a/script/train.sh b/script/train.sh
new file mode 100644
index 0000000..3f9908a
--- /dev/null
+++ b/script/train.sh
@@ -0,0 +1,66 @@
+#!/bin/bash
+echo "Start training task queues"
+
+# Hyperparameters
+dataset_array=("eth" "hotel" "univ" "zara1" "zara2")
+device_id_array=(0 1 2 3 4)
+tag="SingularTrajectory-stochastic"
+config_path="./config/"
+config_prefix="stochastic/singulartrajectory"
+baseline="transformerdiffusion"
+
+# Arguments
+while getopts t:b:c:p:d:i: flag
+do
+ case "${flag}" in
+ t) tag=${OPTARG};;
+ b) baseline=${OPTARG};;
+ c) config_path=${OPTARG};;
+ p) config_prefix=${OPTARG};;
+ d) dataset_array=(${OPTARG});;
+ i) device_id_array=(${OPTARG});;
+ *) echo "usage: $0 [-t TAG] [-b BASELINE] [-c CONFIG_PATH] [-p CONFIG_PREFIX] [-d \"eth hotel univ zara1 zara2\"] [-i \"0 1 2 3 4\"]" >&2
+ exit 1 ;;
+ esac
+done
+
+if [ ${#dataset_array[@]} -ne ${#device_id_array[@]} ]
+then
+ printf "Arrays must all be same length. "
+ printf "len(dataset_array)=${#dataset_array[@]} and len(device_id_array)=${#device_id_array[@]}\n"
+ exit 1
+fi
+
+# Signal handler
+pid_array=()
+
+sighdl ()
+{
+ echo "Kill training processes"
+ for (( i=0; i<${#dataset_array[@]}; i++ ))
+ do
+ kill ${pid_array[$i]}
+ done
+ echo "Done."
+ exit 0
+}
+
+trap sighdl SIGINT SIGTERM
+
+# Start training tasks
+for (( i=0; i<${#dataset_array[@]}; i++ ))
+do
+ printf "Evaluate ${dataset_array[$i]}"
+ CUDA_VISIBLE_DEVICES=${device_id_array[$i]} python3 trainval.py \
+ --cfg "${config_path}""${config_prefix}"-"${baseline}"-"${dataset_array[$i]}".json \
+ --tag "${tag}" --gpu_id ${device_id_array[$i]} &
+ pid_array[$i]=$!
+ printf " job ${#pid_array[@]} pid ${pid_array[$i]}\n"
+done
+
+for (( i=0; i<${#dataset_array[@]}; i++ ))
+do
+ wait ${pid_array[$i]}
+done
+
+echo "Done."
\ No newline at end of file
diff --git a/trainval.py b/trainval.py
new file mode 100644
index 0000000..c28f234
--- /dev/null
+++ b/trainval.py
@@ -0,0 +1,39 @@
+import os
+import argparse
+import baseline
+from SingularTrajectory import *
+from utils import *
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', default="./config/singulartrajectory-transformerdiffusion-zara1.json", type=str, help="config file path")
+ parser.add_argument('--tag', default="SingularTrajectory-TEMP", type=str, help="personal tag for the model")
+ parser.add_argument('--gpu_id', default="0", type=str, help="gpu id for the model")
+ parser.add_argument('--test', default=False, action='store_true', help="evaluation mode")
+ args = parser.parse_args()
+
+ print("===== Arguments =====")
+ print_arguments(vars(args))
+
+ print("===== Configs =====")
+ hyper_params = get_exp_config(args.cfg)
+ print_arguments(hyper_params)
+
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
+ PredictorModel = getattr(baseline, hyper_params.baseline).TrajectoryPredictor
+ hook_func = DotDict({"model_forward_pre_hook": getattr(baseline, hyper_params.baseline).model_forward_pre_hook,
+ "model_forward": getattr(baseline, hyper_params.baseline).model_forward,
+ "model_forward_post_hook": getattr(baseline, hyper_params.baseline).model_forward_post_hook})
+ ModelTrainer = getattr(trainer, *[s for s in trainer.__dict__.keys() if hyper_params.baseline in s.lower()])
+ trainer = ModelTrainer(base_model=PredictorModel, model=SingularTrajectory, hook_func=hook_func,
+ args=args, hyper_params=hyper_params)
+
+ if not args.test:
+ trainer.init_descriptor()
+ trainer.fit()
+ else:
+ trainer.load_model()
+ print("Testing...", end=' ')
+ results = trainer.test()
+ print(f"Scene: {hyper_params.dataset}", *[f"{meter}: {value:.8f}" for meter, value in results.items()])
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..f312a75
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1,5 @@
+from .dataloader import get_dataloader, TrajectoryDataset, TrajBatchSampler, traj_collate_fn
+from .metrics import compute_batch_ade, compute_batch_fde, compute_batch_tcc, compute_batch_col, AverageMeter
+from .utils import reproducibility_settings, get_exp_config, DotDict, print_arguments, augment_trajectory
+from .trainer import STTransformerDiffusionTrainer
+from .homography import image2world, world2image, generate_homography
diff --git a/utils/dataloader.py b/utils/dataloader.py
new file mode 100644
index 0000000..cac10eb
--- /dev/null
+++ b/utils/dataloader.py
@@ -0,0 +1,290 @@
+import os
+import math
+import torch
+import numpy as np
+from torch.utils.data import Dataset
+from torch.utils.data.sampler import Sampler
+from torch.utils.data.dataloader import DataLoader
+from .homography import generate_homography
+from PIL import Image
+
+
+def get_dataloader(data_dir, phase, obs_len, pred_len, batch_size, skip=1):
+ r"""Get dataloader for a specific phase
+
+ Args:
+ data_dir (str): path to the dataset directory
+ phase (str): phase of the data, one of 'train', 'val', 'test'
+ obs_len (int): length of observed trajectory
+ pred_len (int): length of predicted trajectory
+ batch_size (int): batch size
+
+ Returns:
+ loader_phase (torch.utils.data.DataLoader): dataloader for the specific phase
+ """
+
+ assert phase in ['train', 'val', 'test']
+
+ data_set = data_dir + '/' + phase + '/'
+ shuffle = True if phase == 'train' else False
+ drop_last = True if phase == 'train' else False
+
+ dataset_phase = TrajectoryDataset(data_set, obs_len=obs_len, pred_len=pred_len, skip=skip)
+ sampler_phase = None
+ if batch_size > 1:
+ sampler_phase = TrajBatchSampler(dataset_phase, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
+ loader_phase = DataLoader(dataset_phase, collate_fn=traj_collate_fn, batch_sampler=sampler_phase, pin_memory=True)
+ return loader_phase
+
+
+def traj_collate_fn(data):
+ r"""Collate function for the dataloader
+
+ Args:
+ data (list): list of tuples of (obs_seq, pred_seq, non_linear_ped, loss_mask, seq_start_end, scene_id)
+
+ Returns:
+ obs_seq_list (torch.Tensor): (num_ped, obs_len, 2)
+ pred_seq_list (torch.Tensor): (num_ped, pred_len, 2)
+ non_linear_ped_list (torch.Tensor): (num_ped,)
+ loss_mask_list (torch.Tensor): (num_ped, obs_len + pred_len)
+ scene_mask (torch.Tensor): (num_ped, num_ped)
+ seq_start_end (torch.Tensor): (num_ped, 2)
+ scene_id
+ """
+
+ data_collated = {}
+ for k in data[0].keys():
+ data_collated[k] = [d[k] for d in data]
+
+ _len = [len(seq) for seq in data_collated["obs_traj"]]
+ cum_start_idx = [0] + np.cumsum(_len).tolist()
+ seq_start_end = [[start, end] for start, end in zip(cum_start_idx, cum_start_idx[1:])]
+ seq_start_end = torch.LongTensor(seq_start_end)
+ scene_mask = torch.zeros(sum(_len), sum(_len), dtype=torch.bool)
+ for idx, (start, end) in enumerate(seq_start_end):
+ scene_mask[start:end, start:end] = 1
+
+ data_collated["obs_traj"] = torch.cat(data_collated["obs_traj"], dim=0)
+ data_collated["pred_traj"] = torch.cat(data_collated["pred_traj"], dim=0)
+ data_collated["anchor"] = torch.cat(data_collated["anchor"], dim=0)
+ data_collated["non_linear_ped"] = torch.cat(data_collated["non_linear_ped"], dim=0)
+ data_collated["loss_mask"] = torch.cat(data_collated["loss_mask"], dim=0)
+ data_collated["scene_mask"] = scene_mask
+ data_collated["seq_start_end"] = seq_start_end
+ data_collated["frame"] = torch.cat(data_collated["frame"], dim=0)
+ data_collated["scene_id"] = np.concatenate(data_collated["scene_id"], axis=0)
+
+ return data_collated
+
+
+class TrajBatchSampler(Sampler):
+ r"""Samples batched elements by yielding a mini-batch of indices.
+ Args:
+ data_source (Dataset): dataset to sample from
+ batch_size (int): Size of mini-batch.
+ shuffle (bool, optional): set to ``True`` to have the data reshuffled
+ at every epoch (default: ``False``).
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``
+ generator (Generator): Generator used in sampling.
+ """
+
+ def __init__(self, data_source, batch_size=64, shuffle=False, drop_last=False, generator=None):
+ self.data_source = data_source
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+ self.drop_last = drop_last
+ self.generator = generator
+
+ def __iter__(self):
+ assert len(self.data_source) == len(self.data_source.num_peds_in_seq)
+
+ if self.shuffle:
+ if self.generator is None:
+ generator = torch.Generator()
+ generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
+ else:
+ generator = self.generator
+ indices = torch.randperm(len(self.data_source), generator=generator).tolist()
+ else:
+ indices = list(range(len(self.data_source)))
+ num_peds_indices = self.data_source.num_peds_in_seq[indices]
+
+ batch = []
+ total_num_peds = 0
+ for idx, num_peds in zip(indices, num_peds_indices):
+ batch.append(idx)
+ total_num_peds += num_peds
+ if total_num_peds >= self.batch_size:
+ yield batch
+ batch = []
+ total_num_peds = 0
+ if len(batch) > 0 and not self.drop_last:
+ yield batch
+
+ def __len__(self):
+ # Approximated number of batches.
+ # The order of trajectories can be shuffled, so this number can vary from run to run.
+ if self.drop_last:
+ return sum(self.data_source.num_peds_in_seq) // self.batch_size
+ else:
+ return (sum(self.data_source.num_peds_in_seq) + self.batch_size - 1) // self.batch_size
+
+
+def read_file(_path, delim='\t'):
+ data = []
+ if delim == 'tab':
+ delim = '\t'
+ elif delim == 'space':
+ delim = ' '
+ with open(_path, 'r') as f:
+ for line in f:
+ line = line.strip().split(delim)
+ line = [float(i) for i in line]
+ data.append(line)
+ return np.asarray(data)
+
+
+def poly_fit(traj, traj_len, threshold):
+ """
+ Input:
+ - traj: Numpy array of shape (2, traj_len)
+ - traj_len: Len of trajectory
+ - threshold: Minimum error to be considered for non-linear traj
+ Output:
+ - int: 1 -> Non Linear 0-> Linear
+ """
+ t = np.linspace(0, traj_len - 1, traj_len)
+ res_x = np.polyfit(t, traj[0, -traj_len:], 2, full=True)[1]
+ res_y = np.polyfit(t, traj[1, -traj_len:], 2, full=True)[1]
+ if res_x + res_y >= threshold:
+ return 1.0
+ else:
+ return 0.0
+
+
+class TrajectoryDataset(Dataset):
+ """Dataloder for the Trajectory datasets"""
+
+ def __init__(self, data_dir, obs_len=8, pred_len=12, skip=1, threshold=0.02, min_ped=1, delim='\t'):
+ """
+ Args:
+ - data_dir: Directory containing dataset files in the format
+ - obs_len: Number of time-steps in input trajectories
+ - pred_len: Number of time-steps in output trajectories
+ - skip: Number of frames to skip while making the dataset
+ - threshold: Minimum error to be considered for non-linear traj when using a linear predictor
+ - min_ped: Minimum number of pedestrians that should be in a sequence
+ - delim: Delimiter in the dataset files
+ """
+ super(TrajectoryDataset, self).__init__()
+
+ self.data_dir = data_dir
+ self.obs_len = obs_len
+ self.pred_len = pred_len
+ self.skip = skip
+ self.seq_len = self.obs_len + self.pred_len
+ self.delim = delim
+
+ all_files = sorted(os.listdir(self.data_dir))
+ all_files = [os.path.join(self.data_dir, _path) for _path in all_files]
+ num_peds_in_seq = []
+ seq_list = []
+ loss_mask_list = []
+ non_linear_ped = []
+ frame_list = []
+ scene_id = []
+ self.homography = {}
+ self.vector_field = {}
+ scene_img_map = {'biwi_eth': 'seq_eth', 'biwi_hotel': 'seq_hotel',
+ 'students001': 'students003', 'students003': 'students003', 'uni_examples': 'students003',
+ 'crowds_zara01': 'crowds_zara01', 'crowds_zara02': 'crowds_zara02', 'crowds_zara03': 'crowds_zara02'}
+
+ for path in all_files:
+ parent_dir, scene_name = os.path.split(path)
+ parent_dir, phase = os.path.split(parent_dir)
+ parent_dir, dataset_name = os.path.split(parent_dir)
+ scene_name, _ = os.path.splitext(scene_name)
+ scene_name = scene_name.replace('_' + phase, '')
+ self.vector_field[scene_name] = np.load(os.path.join(parent_dir, "vectorfield", scene_img_map[scene_name] + "_vector_field.npy"))
+
+ if dataset_name in ["eth", "hotel", "univ", "zara1", "zara2"]:
+ homography_file = os.path.join(parent_dir, "homography", scene_name + "_H.txt")
+ self.homography[scene_name] = np.loadtxt(homography_file)
+ elif dataset_name in [aa + '2' + bb for aa in ['A', 'B', 'C', 'D', 'E'] for bb in ['A', 'B', 'C', 'D', 'E'] if aa != bb]:
+ homography_file = os.path.join(parent_dir, "homography", scene_name + "_H.txt")
+ self.homography[scene_name] = np.loadtxt(homography_file)
+
+ # Load data
+ data = read_file(path, delim)
+ frames = np.unique(data[:, 0]).tolist()
+ frame_data = []
+ for frame in frames:
+ frame_data.append(data[frame == data[:, 0], :])
+ num_sequences = int(math.ceil((len(frames) - self.seq_len + 1) / skip))
+
+ for idx in range(0, num_sequences * self.skip + 1, skip):
+ curr_seq_data = np.concatenate(frame_data[idx:idx + self.seq_len], axis=0)
+ peds_in_curr_seq = np.unique(curr_seq_data[:, 1])
+ curr_seq = np.zeros((len(peds_in_curr_seq), 2, self.seq_len))
+ curr_loss_mask = np.zeros((len(peds_in_curr_seq), self.seq_len))
+ num_peds_considered = 0
+ _non_linear_ped = []
+ for _, ped_id in enumerate(peds_in_curr_seq):
+ curr_ped_seq = curr_seq_data[curr_seq_data[:, 1] == ped_id, :]
+ curr_ped_seq = np.around(curr_ped_seq, decimals=4)
+ pad_front = frames.index(curr_ped_seq[0, 0]) - idx
+ pad_end = frames.index(curr_ped_seq[-1, 0]) - idx + 1
+ if pad_end - pad_front != self.seq_len:
+ continue
+ curr_ped_seq = np.transpose(curr_ped_seq[:, 2:])
+ curr_ped_seq = curr_ped_seq
+ _idx = num_peds_considered
+ curr_seq[_idx, :, pad_front:pad_end] = curr_ped_seq
+ # Linear vs Non-Linear Trajectory
+ _non_linear_ped.append(poly_fit(curr_ped_seq, pred_len, threshold))
+ curr_loss_mask[_idx, pad_front:pad_end] = 1
+ num_peds_considered += 1
+
+ if num_peds_considered > min_ped:
+ non_linear_ped += _non_linear_ped
+ num_peds_in_seq.append(num_peds_considered)
+ loss_mask_list.append(curr_loss_mask[:num_peds_considered])
+ seq_list.append(curr_seq[:num_peds_considered])
+ frame_list.extend([frames[idx]] * num_peds_considered)
+ scene_id.extend([scene_name] * num_peds_considered)
+
+ self.num_seq = len(seq_list)
+ seq_list = np.concatenate(seq_list, axis=0)
+ loss_mask_list = np.concatenate(loss_mask_list, axis=0)
+ non_linear_ped = np.asarray(non_linear_ped)
+ self.num_peds_in_seq = np.array(num_peds_in_seq)
+ self.frame_list = np.array(frame_list, dtype=np.int32)
+ self.scene_id = np.array(scene_id)
+
+ # Convert numpy -> Torch Tensor
+ self.obs_traj = torch.from_numpy(seq_list[:, :, :self.obs_len]).type(torch.float).permute(0, 2, 1) # NTC
+ self.pred_traj = torch.from_numpy(seq_list[:, :, self.obs_len:]).type(torch.float).permute(0, 2, 1) # NTC
+ self.loss_mask = torch.from_numpy(loss_mask_list).type(torch.float).gt(0.5)
+ self.non_linear_ped = torch.from_numpy(non_linear_ped).type(torch.float).gt(0.5)
+ cum_start_idx = [0] + np.cumsum(num_peds_in_seq).tolist()
+ self.seq_start_end = [(start, end) for start, end in zip(cum_start_idx, cum_start_idx[1:])]
+ self.frame_list = torch.from_numpy(self.frame_list).type(torch.long)
+ self.anchor = None
+
+ def __len__(self):
+ return self.num_seq
+
+ def __getitem__(self, index):
+ start, end = self.seq_start_end[index]
+ out = {"obs_traj": self.obs_traj[start:end],
+ "pred_traj": self.pred_traj[start:end],
+ "anchor": self.anchor[start:end],
+ "non_linear_ped": self.non_linear_ped[start:end],
+ "loss_mask": self.loss_mask[start:end],
+ "scene_mask": None,
+ "seq_start_end": [[0, end - start]],
+ "frame": self.frame_list[start:end],
+ "scene_id": self.scene_id[start:end]}
+ return out
diff --git a/utils/homography.py b/utils/homography.py
new file mode 100644
index 0000000..9ece40b
--- /dev/null
+++ b/utils/homography.py
@@ -0,0 +1,101 @@
+import numpy as np
+import torch
+
+
+def image2world(coord, H):
+ r"""Convert image coordinates to world coordinates.
+
+ Args:
+ coord (np.ndarray or torch.tensor): Image coordinates, shape (..., 2).
+ H (np.ndarray or torch.tensor): Homography matrix, shape (3, 3).
+
+ Returns:
+ np.ndarray: World coordinates.
+ """
+
+ assert coord.shape[-1] == 2
+ assert H.shape == (3, 3)
+ assert type(coord) == type(H)
+
+ shape = coord.shape
+ coord = coord.reshape(-1, 2)
+
+ if isinstance(coord, np.ndarray):
+ x, y = coord[..., 0], coord[..., 1]
+ world = (H @ np.stack([x, y, np.ones_like(x)], axis=-1).T).T
+ world = world / world[..., [2]]
+ world = world[..., :2]
+
+ elif isinstance(coord, torch.Tensor):
+ x, y = coord[..., 0], coord[..., 1]
+ world = (H @ torch.stack([x, y, torch.ones_like(x)], dim=-1).T).T
+ world = world / world[..., [2]]
+ world = world[..., :2]
+
+ else:
+ raise NotImplementedError
+
+ return world.reshape(shape)
+
+
+def world2image(coord, H, transpose=False):
+ r"""Convert world coordinates to image coordinates.
+
+ Args:
+ coord (np.ndarray or torch.tensor): World coordinates, shape (..., 2).
+ H (np.ndarray or torch.tensor): Homography matrix, shape (3, 3).
+
+ Returns:
+ np.ndarray: Image coordinates.
+ """
+
+ assert coord.shape[-1] == 2
+ assert H.shape == (3, 3)
+ assert type(coord) == type(H)
+
+ shape = coord.shape
+ coord = coord.reshape(-1, 2)
+
+ if isinstance(coord, np.ndarray):
+ x, y = coord[..., 0], coord[..., 1]
+ image = (np.linalg.inv(H) @ np.stack([x, y, np.ones_like(x)], axis=-1).T).T
+ image = image / image[..., [2]]
+ image = image[..., :2]
+
+ elif isinstance(coord, torch.Tensor):
+ x, y = coord[..., 0], coord[..., 1]
+ image = (torch.linalg.inv(H) @ torch.stack([x, y, torch.ones_like(x)], dim=-1).T).T
+ image = image / image[..., [2]]
+ image = image[..., :2]
+
+ else:
+ raise NotImplementedError
+
+ return image.reshape(shape)
+
+
+def generate_homography(shift_w: float=0, shift_h: float=0, rotate: float=0, scale: float=1):
+ r"""Generate a homography matrix.
+
+ Args:
+ shift (float): Shift in x and y direction.
+ rotate (float): Rotation angle in radian.
+ scale (float): Scale factor.
+
+ Returns:
+ np.ndarray: Homography matrix, shape (3, 3).
+ """
+
+ H = np.eye(3)
+ H[0, 2] = shift_w
+ H[1, 2] = shift_h
+ H[2, 2] = scale
+
+ if rotate != 0:
+ # rotation matrix
+ R = np.array([[np.cos(rotate), -np.sin(rotate), 0],
+ [np.sin(rotate), np.cos(rotate), 0],
+ [0, 0, 1]])
+ H = H @ R
+
+ return H
diff --git a/utils/metrics.py b/utils/metrics.py
new file mode 100644
index 0000000..1195271
--- /dev/null
+++ b/utils/metrics.py
@@ -0,0 +1,110 @@
+import numpy as np
+import torch
+
+
+class AverageMeter(object):
+ r"""Stores the results of a metric and computes its average"""
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.data = list()
+
+ def append(self, value):
+ self.data.append([value])
+
+ def extend(self, values):
+ self.data.append(values)
+
+ def mean(self):
+ return np.concatenate(self.data, axis=0).mean()
+
+ def sum(self):
+ return np.concatenate(self.data, axis=0).sum()
+
+ def __len__(self):
+ return np.concatenate(self.data, axis=0).shape[0]
+
+
+def compute_batch_ade(pred, gt):
+ r"""Compute ADE(average displacement error) scores for each pedestrian
+
+ Args:
+ pred (torch.Tensor): (num_samples, num_ped, seq_len, 2)
+ gt (torch.Tensor): (1, num_ped, seq_len, 2) or (num_ped, seq_len, 2)
+
+ Returns:
+ ADEs (np.ndarray): (num_ped,)
+ """
+ temp = (pred - gt).norm(p=2, dim=-1)
+ ADEs = temp.mean(dim=2).min(dim=0)[0]
+ return ADEs.detach().cpu().numpy()
+
+
+def compute_batch_fde(pred, gt):
+ r"""Compute FDE(final displacement error) scores for each pedestrian
+
+ Args:
+ pred (torch.Tensor): (num_samples, num_ped, seq_len, 2)
+ gt (torch.Tensor): (1, num_ped, seq_len, 2) or (num_ped, seq_len, 2)
+
+ Returns:
+ FDEs (np.ndarray): (num_ped,)
+ """
+ temp = (pred - gt).norm(p=2, dim=-1)
+ FDEs = temp[:, :, -1].min(dim=0)[0]
+ return FDEs.detach().cpu().numpy()
+
+
+def compute_batch_tcc(pred, gt):
+ r"""Compute TCC(temporal correlation coefficient) scores for each pedestrian
+
+ Args:
+ pred (torch.Tensor): (num_samples, num_ped, seq_len, 2)
+ gt (torch.Tensor): (1, num_ped, seq_len, 2) or (num_ped, seq_len, 2)
+
+ Returns:
+ TCCs (np.ndarray): (num_ped,)
+ """
+
+ gt = gt.squeeze(dim=0) if gt.dim() == 4 else gt
+ temp = (pred - gt).norm(p=2, dim=-1)
+ pred_best = pred[temp[:, :, -1].argmin(dim=0), range(pred.size(1)), :, :]
+ pred_gt_stack = torch.stack([pred_best, gt], dim=0)
+ pred_gt_stack = pred_gt_stack.permute(3, 1, 0, 2)
+ covariance = pred_gt_stack - pred_gt_stack.mean(dim=-1, keepdim=True)
+ factor = 1 / (covariance.shape[-1] - 1)
+ covariance = factor * covariance @ covariance.transpose(-1, -2)
+ variance = covariance.diagonal(offset=0, dim1=-2, dim2=-1)
+ stddev = variance.sqrt()
+ corrcoef = covariance / stddev.unsqueeze(-1) / stddev.unsqueeze(-2)
+ corrcoef = corrcoef.clamp(-1, 1)
+ corrcoef[torch.isnan(corrcoef)] = 0
+ TCCs = corrcoef[:, :, 0, 1].mean(dim=0)
+ return TCCs.detach().cpu().numpy()
+
+
+def compute_batch_col(pred, gt):
+ r"""Compute COL(collision rate) scores for each pedestrian
+
+ Args:
+ pred (torch.Tensor): (num_samples, num_ped, seq_len, 2)
+ gt (torch.Tensor): (1, num_ped, seq_len, 2) or (num_ped, seq_len, 2)
+
+ Returns:
+ COLs (np.ndarray): (num_ped,)
+ """
+
+ pred = pred.permute(0, 2, 1, 3)
+ num_interp, thres = 4, 0.2
+ pred_fp = pred[:, [0], :, :]
+ pred_rel = pred[:, 1:] - pred[:, :-1]
+ pred_rel_dense = pred_rel.div(num_interp).unsqueeze(dim=2).repeat_interleave(repeats=num_interp, dim=2).contiguous()
+ pred_rel_dense = pred_rel_dense.reshape(pred.size(0), num_interp * (pred.size(1) - 1), pred.size(2), pred.size(3))
+ pred_dense = torch.cat([pred_fp, pred_rel_dense], dim=1).cumsum(dim=1)
+ col_mask = pred_dense[:, :3 * num_interp + 2].unsqueeze(dim=2).repeat_interleave(repeats=pred.size(2), dim=2)
+ col_mask = (col_mask - col_mask.transpose(2, 3)).norm(p=2, dim=-1)
+ col_mask = col_mask.add(torch.eye(n=pred.size(2), device=pred.device)[None, None, :, :]).min(dim=1)[0].lt(thres)
+ COLs = col_mask.sum(dim=1).gt(0).type(pred.type()).mean(dim=0).mul(100)
+ return COLs.detach().cpu().numpy()
diff --git a/utils/trainer.py b/utils/trainer.py
new file mode 100644
index 0000000..10059a0
--- /dev/null
+++ b/utils/trainer.py
@@ -0,0 +1,340 @@
+import os
+import pickle
+import torch
+import numpy as np
+from tqdm import tqdm
+from . import *
+
+
+class STTrainer:
+ r"""Base class for all Trainers"""
+
+ def __init__(self, args, hyper_params):
+ print("Trainer initiating...")
+
+ # Reproducibility
+ reproducibility_settings(seed=0)
+
+ self.args, self.hyper_params = args, hyper_params
+ self.model, self.optimizer, self.scheduler = None, None, None
+ self.loader_train, self.loader_val, self.loader_test = None, None, None
+ self.dataset_dir = hyper_params.dataset_dir + hyper_params.dataset + '/'
+ self.checkpoint_dir = hyper_params.checkpoint_dir + '/' + args.tag + '/' + hyper_params.dataset + '/'
+ print("Checkpoint dir:", self.checkpoint_dir)
+ self.log = {'train_loss': [], 'val_loss': []}
+ self.stats_func, self.stats_meter = None, None
+ self.reset_metric()
+
+ if not args.test:
+ # Save arguments and configs
+ if not os.path.exists(self.checkpoint_dir):
+ os.makedirs(self.checkpoint_dir)
+
+ with open(self.checkpoint_dir + 'args.pkl', 'wb') as fp:
+ pickle.dump(args, fp)
+
+ with open(self.checkpoint_dir + 'config.pkl', 'wb') as fp:
+ pickle.dump(hyper_params, fp)
+
+ def init_descriptor(self):
+ # Singular space initialization
+ print("Singular space initialization...")
+ obs_traj, pred_traj = self.loader_train.dataset.obs_traj, self.loader_train.dataset.pred_traj
+ obs_traj, pred_traj = augment_trajectory(obs_traj, pred_traj)
+ self.model.calculate_parameters(obs_traj, pred_traj)
+ print("Anchor generation...")
+
+ def init_adaptive_anchor(self, dataset):
+ print("Adaptive anchor initialization...")
+ dataset.anchor = self.model.calculate_adaptive_anchor(dataset)
+
+ def train(self, epoch):
+ raise NotImplementedError
+
+ @torch.no_grad()
+ def valid(self, epoch):
+ raise NotImplementedError
+
+ @torch.no_grad()
+ def test(self):
+ raise NotImplementedError
+
+ def fit(self):
+ print("Training started...")
+ for epoch in range(self.hyper_params.num_epochs):
+ self.train(epoch)
+ self.valid(epoch)
+
+ if self.hyper_params.lr_schd:
+ self.scheduler.step()
+
+ # Save the best model
+ if epoch == 0 or self.log['val_loss'][-1] < min(self.log['val_loss'][:-1]):
+ self.save_model()
+
+ print(" ")
+ print("Dataset: {0}, Epoch: {1}".format(self.hyper_params.dataset, epoch))
+ print("Train_loss: {0:.8f}, Val_los: {1:.8f}".format(self.log['train_loss'][-1], self.log['val_loss'][-1]))
+ print("Min_val_epoch: {0}, Min_val_loss: {1:.8f}".format(np.array(self.log['val_loss']).argmin(),
+ np.array(self.log['val_loss']).min()))
+ print(" ")
+ print("Done.")
+
+ def reset_metric(self):
+ self.stats_func = {'ADE': compute_batch_ade, 'FDE': compute_batch_fde}
+ self.stats_meter = {x: AverageMeter() for x in self.stats_func.keys()}
+
+ def get_metric(self):
+ return self.stats_meter
+
+ def load_model(self, filename='model_best.pth'):
+ model_path = self.checkpoint_dir + filename
+ self.model.load_state_dict(torch.load(model_path))
+
+ def save_model(self, filename='model_best.pth'):
+ if not os.path.exists(self.checkpoint_dir):
+ os.makedirs(self.checkpoint_dir)
+ model_path = self.checkpoint_dir + filename
+ torch.save(self.model.state_dict(), model_path)
+
+
+class STSequencedMiniBatchTrainer(STTrainer):
+ r"""Base class using sequenced mini-batch training strategy"""
+
+ def __init__(self, args, hyper_params):
+ super().__init__(args, hyper_params)
+
+ # Dataset preprocessing
+ obs_len, pred_len, skip = hyper_params.obs_len, hyper_params.pred_len, hyper_params.skip
+ self.loader_train = get_dataloader(self.dataset_dir, 'train', obs_len, pred_len, batch_size=1, skip=skip)
+ self.loader_val = get_dataloader(self.dataset_dir, 'val', obs_len, pred_len, batch_size=1)
+ self.loader_test = get_dataloader(self.dataset_dir, 'test', obs_len, pred_len, batch_size=1)
+
+ def train(self, epoch):
+ self.model.train()
+ loss_batch = 0
+ is_first_loss = True
+
+ for cnt, batch in enumerate(tqdm(self.loader_train, desc=f'Train Epoch {epoch}', mininterval=1)):
+ obs_traj, pred_traj = [tensor.cuda(non_blocking=True) for tensor in batch[:2]]
+
+ self.optimizer.zero_grad()
+
+ output = self.model(obs_traj, pred_traj)
+
+ loss = output["loss_euclidean_ade"]
+ loss[torch.isnan(loss)] = 0
+
+ if (cnt + 1) % self.hyper_params.batch_size != 0 and (cnt + 1) != len(self.loader_train):
+ if is_first_loss:
+ is_first_loss = False
+ loss_cum = loss
+ else:
+ loss_cum += loss
+
+ else:
+ is_first_loss = True
+ loss_cum += loss
+ loss_cum /= self.hyper_params.batch_size
+ loss_cum.backward()
+
+ if self.hyper_params.clip_grad is not None:
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.hyper_params.clip_grad)
+
+ self.optimizer.step()
+ loss_batch += loss_cum.item()
+
+ self.log['train_loss'].append(loss_batch / len(self.loader_train))
+
+ @torch.no_grad()
+ def valid(self, epoch):
+ self.model.eval()
+ loss_batch = 0
+
+ for cnt, batch in enumerate(tqdm(self.loader_val, desc=f'Valid Epoch {epoch}', mininterval=1)):
+ obs_traj, pred_traj = [tensor.cuda(non_blocking=True) for tensor in batch[:2]]
+
+ output = self.model(obs_traj, pred_traj)
+
+ recon_loss = output["loss_euclidean_fde"] * obs_traj.size(0)
+ loss_batch += recon_loss.item()
+
+ num_ped = sum(self.loader_val.dataset.num_peds_in_seq)
+ self.log['val_loss'].append(loss_batch / num_ped)
+
+ @torch.no_grad()
+ def test(self):
+ self.model.eval()
+ self.reset_metric()
+
+ for batch in tqdm(self.loader_test, desc=f"Test {self.hyper_params.dataset.upper()} scene"):
+ obs_traj, pred_traj = [tensor.cuda(non_blocking=True) for tensor in batch[:2]]
+
+ output = self.model(obs_traj)
+
+ # Evaluate trajectories
+ for metric in self.stats_func.keys():
+ value = self.stats_func[metric](output["recon_traj"], pred_traj)
+ self.stats_meter[metric].extend(value)
+
+ return {x: self.stats_meter[x].mean() for x in self.stats_meter.keys()}
+
+
+class STCollatedMiniBatchTrainer(STTrainer):
+ r"""Base class using collated mini-batch training strategy"""
+
+ def __init__(self, args, hyper_params):
+ super().__init__(args, hyper_params)
+
+ # Dataset preprocessing
+ batch_size = hyper_params.batch_size
+ obs_len, pred_len, skip = hyper_params.obs_len, hyper_params.pred_len, hyper_params.skip
+ self.loader_train = get_dataloader(self.dataset_dir, 'train', obs_len, pred_len, batch_size=batch_size, skip=skip)
+ self.loader_val = get_dataloader(self.dataset_dir, 'val', obs_len, pred_len, batch_size=batch_size)
+ self.loader_test = get_dataloader(self.dataset_dir, 'test', obs_len, pred_len, batch_size=1)
+
+ def train(self, epoch):
+ self.model.train()
+ loss_batch = 0
+
+ for cnt, batch in enumerate(tqdm(self.loader_train, desc=f'Train Epoch {epoch}', mininterval=1)):
+ obs_traj, pred_traj = [tensor.cuda(non_blocking=True) for tensor in batch[:2]]
+
+ self.optimizer.zero_grad()
+
+ output = self.model(obs_traj, pred_traj)
+
+ loss = output["loss_euclidean_ade"]
+ loss[torch.isnan(loss)] = 0
+ loss_batch += loss.item()
+
+ loss.backward()
+ if self.hyper_params.clip_grad is not None:
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.hyper_params.clip_grad)
+ self.optimizer.step()
+
+ self.log['train_loss'].append(loss_batch / len(self.loader_train))
+
+ @torch.no_grad()
+ def valid(self, epoch):
+ self.model.eval()
+ loss_batch = 0
+
+ for cnt, batch in enumerate(tqdm(self.loader_val, desc=f'Valid Epoch {epoch}', mininterval=1)):
+ obs_traj, pred_traj = [tensor.cuda(non_blocking=True) for tensor in batch[:2]]
+
+ output = self.model(obs_traj, pred_traj)
+
+ recon_loss = output["loss_euclidean_fde"] * obs_traj.size(0)
+ loss_batch += recon_loss.item()
+
+ num_ped = sum(self.loader_val.dataset.num_peds_in_seq)
+ self.log['val_loss'].append(loss_batch / num_ped)
+
+ @torch.no_grad()
+ def test(self):
+ self.model.eval()
+ self.reset_metric()
+
+ for batch in tqdm(self.loader_test, desc=f"Test {self.hyper_params.dataset.upper()} scene"):
+ obs_traj, pred_traj = [tensor.cuda(non_blocking=True) for tensor in batch[:2]]
+
+ output = self.model(obs_traj)
+
+ # Evaluate trajectories
+ for metric in self.stats_func.keys():
+ value = self.stats_func[metric](output["recon_traj"], pred_traj)
+ self.stats_meter[metric].extend(value)
+
+ return {x: self.stats_meter[x].mean() for x in self.stats_meter.keys()}
+
+
+class STTransformerDiffusionTrainer(STCollatedMiniBatchTrainer):
+ r"""SingularTrajectory model trainer"""
+
+ def __init__(self, base_model, model, hook_func, args, hyper_params):
+ super().__init__(args, hyper_params)
+ cfg = DotDict({'scheduler': 'ddim', 'steps': 10, 'beta_start': 1.e-4, 'beta_end': 5.e-2, 'beta_schedule': 'linear',
+ 'k': hyper_params.k, 's': hyper_params.num_samples})
+ predictor_model = base_model(cfg).cuda()
+ eigentraj_model = model(baseline_model=predictor_model, hook_func=hook_func, hyper_params=hyper_params).cuda()
+ self.model = eigentraj_model
+ self.optimizer = torch.optim.AdamW(params=self.model.parameters(), lr=hyper_params.lr,
+ weight_decay=hyper_params.weight_decay)
+
+ if hyper_params.lr_schd:
+ self.scheduler = torch.optim.lr_scheduler.StepLR(optimizer=self.optimizer,
+ step_size=hyper_params.lr_schd_step,
+ gamma=hyper_params.lr_schd_gamma)
+
+ def train(self, epoch):
+ self.model.train()
+ loss_batch = 0
+
+ if self.loader_train.dataset.anchor is None:
+ self.init_adaptive_anchor(self.loader_train.dataset)
+
+ for cnt, batch in enumerate(tqdm(self.loader_train, desc=f'Train Epoch {epoch}', mininterval=1)):
+ obs_traj, pred_traj = batch["obs_traj"].cuda(non_blocking=True), batch["pred_traj"].cuda(non_blocking=True)
+ adaptive_anchor = batch["anchor"].cuda(non_blocking=True)
+ scene_mask, seq_start_end = batch["scene_mask"].cuda(non_blocking=True), batch["seq_start_end"].cuda(non_blocking=True)
+
+ self.optimizer.zero_grad()
+
+ additional_information = {"scene_mask": scene_mask, "num_samples": self.hyper_params.num_samples}
+ output = self.model(obs_traj, adaptive_anchor, pred_traj, addl_info=additional_information)
+
+ loss = output["loss_euclidean_ade"]
+ loss[torch.isnan(loss)] = 0
+ loss_batch += loss.item()
+
+ loss.backward()
+ if self.hyper_params.clip_grad is not None:
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.hyper_params.clip_grad)
+ self.optimizer.step()
+
+ self.log['train_loss'].append(loss_batch / len(self.loader_train))
+
+ @torch.no_grad()
+ def valid(self, epoch):
+ self.model.eval()
+ loss_batch = 0
+
+ if self.loader_val.dataset.anchor is None:
+ self.init_adaptive_anchor(self.loader_val.dataset)
+
+ for cnt, batch in enumerate(tqdm(self.loader_val, desc=f'Valid Epoch {epoch}', mininterval=1)):
+ obs_traj, pred_traj = batch["obs_traj"].cuda(non_blocking=True), batch["pred_traj"].cuda(non_blocking=True)
+ adaptive_anchor = batch["anchor"].cuda(non_blocking=True)
+ scene_mask, seq_start_end = batch["scene_mask"].cuda(non_blocking=True), batch["seq_start_end"].cuda(non_blocking=True)
+
+ additional_information = {"scene_mask": scene_mask, "num_samples": self.hyper_params.num_samples}
+ output = self.model(obs_traj, adaptive_anchor, pred_traj, addl_info=additional_information)
+
+ recon_loss = output["loss_euclidean_fde"] * obs_traj.size(0)
+ loss_batch += recon_loss.item()
+
+ num_ped = sum(self.loader_val.dataset.num_peds_in_seq)
+ self.log['val_loss'].append(loss_batch / num_ped)
+
+ @torch.no_grad()
+ def test(self):
+ self.model.eval()
+ self.reset_metric()
+
+ if self.loader_test.dataset.anchor is None:
+ self.init_adaptive_anchor(self.loader_test.dataset)
+
+ for cnt, batch in enumerate(tqdm(self.loader_test, desc=f"Test {self.hyper_params.dataset.upper()} scene")):
+ obs_traj, pred_traj = batch["obs_traj"].cuda(non_blocking=True), batch["pred_traj"].cuda(non_blocking=True)
+ adaptive_anchor = batch["anchor"].cuda(non_blocking=True)
+ scene_mask, seq_start_end = batch["scene_mask"].cuda(non_blocking=True), batch["seq_start_end"].cuda(non_blocking=True)
+
+ additional_information = {"scene_mask": scene_mask, "num_samples": self.hyper_params.num_samples}
+ output = self.model(obs_traj, adaptive_anchor, addl_info=additional_information)
+
+ for metric in self.stats_func.keys():
+ value = self.stats_func[metric](output["recon_traj"], pred_traj)
+ self.stats_meter[metric].extend(value)
+
+ return {x: self.stats_meter[x].mean() for x in self.stats_meter.keys()}
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100644
index 0000000..857a95c
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,91 @@
+import os
+import json
+import random
+import numpy as np
+import torch
+
+
+def reproducibility_settings(seed: int = 0):
+ r"""Set the random seed for reproducibility"""
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cuda.matmul.allow_tf32 = False # Settings for 3090
+ torch.backends.cudnn.allow_tf32 = False # Settings for 3090
+
+
+def get_exp_config(file: str):
+ r"""Load the configuration files"""
+
+ assert os.path.exists(file), f"File {file} does not exist!"
+ file = open(file)
+ config = json.load(file)
+ for k in config.keys():
+ if type(config[k]) == dict:
+ config[k] = DotDict(config[k])
+ return DotDict(config)
+
+
+class DotDict(dict):
+ r"""dot.notation access to dictionary attributes"""
+
+ __getattr__ = dict.get
+ __setattr__ = dict.__setitem__
+ __delattr__ = dict.__delitem__
+ __getstate__ = dict
+ __setstate__ = dict.update
+
+
+def print_arguments(args, length=100, sep=': ', delim=' | '):
+ r"""Print the arguments in a nice format
+
+ Args:
+ args (dict): arguments
+ length (int): maximum length of each line
+ sep (str): separator between key and value
+ delim (str): delimiter between lines
+ """
+
+ text = []
+ for key in args.keys():
+ text.append('{}{}{}'.format(key, sep, args[key]))
+
+ cl = 0
+ for n, line in enumerate(text):
+ if cl + len(line) > length:
+ print('')
+ cl = 0
+ print(line, end='')
+ cl += len(line)
+ if n != len(text) - 1:
+ print(delim, end='')
+ cl += len(delim)
+ print('')
+
+
+def augment_trajectory(obs_traj, pred_traj, flip=True, reverse=True):
+ r"""Flip and reverse the trajectory
+
+ Args:
+ obs_traj (torch.Tensor): observed trajectory with shape (num_peds, obs_len, 2)
+ pred_traj (torch.Tensor): predicted trajectory with shape (num_peds, pred_len, 2)
+ flip (bool): whether to flip the trajectory
+ reverse (bool): whether to reverse the trajectory
+ """
+
+ if flip:
+ obs_traj = torch.cat([obs_traj, obs_traj * torch.FloatTensor([[[1, -1]]])], dim=0)
+ pred_traj = torch.cat([pred_traj, pred_traj * torch.FloatTensor([[[1, -1]]])], dim=0)
+ elif reverse:
+ full_traj = torch.cat([obs_traj, pred_traj], dim=1) # NTC
+ obs_traj = torch.cat([obs_traj, full_traj.flip(1)[:, :obs_traj.size(1)]], dim=0)
+ pred_traj = torch.cat([pred_traj, full_traj.flip(1)[:, obs_traj.size(1):]], dim=0)
+ return obs_traj, pred_traj
+
+
+if __name__ == '__main__':
+ cfg = get_exp_config("../config/stochastic/singulartrajectory-transformerdiffusion-eth.json")
+ print(cfg)