Skip to content

Commit

Permalink
Initial public release
Browse files Browse the repository at this point in the history
  • Loading branch information
InhwanBae committed Jun 10, 2024
1 parent b0f7515 commit 45660cd
Show file tree
Hide file tree
Showing 29 changed files with 3,060 additions and 32 deletions.
429 changes: 408 additions & 21 deletions LICENSE

Large diffs are not rendered by default.

124 changes: 113 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,133 @@
<p align="center">
<a href="https://InhwanBae.github.io/"><strong>Inhwan Bae</strong></a>
·
<a href="https://sites.google.com/view/yjparkcv/"><strong>Young-Jae Park</strong></a>
<a href="https://www.youngjaepark.com/"><strong>Young-Jae Park</strong></a>
·
<a href="https://scholar.google.com/citations?user=Ei00xroAAAAJ"><strong>Hae-Gon Jeon</strong></a>
<br>
CVPR 2024
</p>

<p align="center">
<a href="https://inhwanbae.github.io/publication/singulartrajectory/"><strong><code>Project Page</code></strong></a>
<a href="https://arxiv.org/abs/2403.18452"><strong><code>CVPR Paper</code></strong></a>
<a href="https://github.com/InhwanBae/SingularTrajectory"><strong><code>Source Code</code></strong></a>
<a href="#-citation"><strong><code>Related Works</code></strong></a>
<a href="https://inhwanbae.github.io/publication/singulartrajectory/"><strong><code>Project Page</code></strong></a>
<a href="https://arxiv.org/abs/2403.18452"><strong><code>CVPR Paper</code></strong></a>
<a href="https://github.com/InhwanBae/SingularTrajectory"><strong><code>Source Code</code></strong></a>
<a href="#-citation"><strong><code>Related Works</code></strong></a>
</p>

<div align='center'>
<br><img src="img/singulartrajectory-model.png" width=70%>
<br>An overview of our SingularTrajectory framework..
<br>An overview of our SingularTrajectory framework.
</div>

[//]: # (<br>This repository contains the code for the EigenTrajectory&#40;𝔼𝕋&#41; space applied to the 10 traditional Euclidean-based trajectory predictors.)
[//]: # (<br>EigenTrajectory-LB-EBM achieves ADE/FDE of 0.21/0.34 while requiring only 1 hour for training! )
<br>This repository contains the code for the SingularTrajectory model, designed to handle five different trajectory prediction benchmarks.
<br>Our unified framework ensures the general dynamics of human movements across various input modalities and trajectory lengths.

<br>

## 1️⃣ SingularTrajectory Model 1️⃣
* A diffusion-based universal trajectory prediction framework designed to bridge the performance gap across five tasks.
* A Singular space is constructed to unify various representations of human dynamics in the associated tasks.
* An adaptive anchor and cascaded denoising process correct initial prototype paths that are placed incorrectly.
* Our model outperforms on five public benchmarks: Deterministic, Stochastic, Domain Adaptation, Momentary Observation, and Few-Shot.

<br>

## Model Training
### Setup
**Environment**
<br>All models were trained and tested on Ubuntu 20.04 with Python 3.8 and PyTorch 2.0.1 with CUDA 11.7.

**Dataset**
<br>Preprocessed [ETH](https://data.vision.ee.ethz.ch/cvl/aem/ewap_dataset_full.tgz) and [UCY](https://graphics.cs.ucy.ac.cy/research/downloads/crowd-data) datasets are released in this repository.
The train/validation/test splits are the same as those fond in [Social-GAN](https://github.com/agrimgupta92/sgan).

You can download the dataset by running the following script.
```bash
./scripts/download_datasets.sh
```

### Train SingularTrajectory
To train our SingularTrajectory on each task using the ETH and UCY datasets simultaneously, we provide a bash script `train.sh` for simplified execution.
```bash
./scripts/train.sh -p <config_path> -t <experiment_tag> -d <space_seperated_dataset_string> -i <space_seperated_gpu_id_string>
```

**Examples**
```bash
# Stochastic prediction task
./script/train.sh -p stochastic/singulartrajectory -t SingularTrajectory-stochastic

# Deterministic prediction task
./script/train.sh -p deterministic/singulartrajectory -t SingularTrajectory-deterministic

# Momentary observation task
./script/train.sh -p momentary/singulartrajectory -t SingularTrajectory-momentary

# Domain adaptation task
./script/train.sh -p domain/singulartrajectory -t SingularTrajectory-domain -d "A2B A2C A2D A2E B2A B2C B2D B2E C2A C2B C2D C2E D2A D2B D2C D2E E2A E2B E2C E2D" -i "0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4"

# Few-shot task
./script/train.sh -p fewshot/singulartrajectory -t SingularTrajectory-fewshot

# (Optional) Stochastic domain adaptation task
./script/train.sh -p domain-stochastic/singulartrajectory -t SingularTrajectory-domain-stochastic -d "A2B A2C A2D A2E B2A B2C B2D B2E C2A C2B C2D C2E D2A D2B D2C D2E E2A E2B E2C E2D" -i "0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4"

# (Optional) All-in-one task
./script/train.sh -p allinone/singulartrajectory -t SingularTrajectory-allinone -d "A2B A2C A2D A2E B2A B2C B2D B2E C2A C2B C2D C2E D2A D2B D2C D2E E2A E2B E2C E2D" -i "0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4"
```
If you want to train the model with custom hyperparameters, use `trainval.py` instead of the script file.
```bash
python trainval.py --cfg ./config/{task}/singulartrajectory-transformerdiffusion-{dataset}.json --tag SingularTrajectory-{task} --gpu_id 0
```

<br>

## Model Evaluation
### Pretrained Models
We provide pretrained models in the [**release section**](https://github.com/InhwanBae/SingularTrajectory/releases/tag/v1.0).
You can download all pretrained models at once by running the script. This will download the 80 SingularTrajectory models.
```bash
./scripts/download_pretrained_models.sh
```

### Evaluate SingularTrajectory
To evaluate our EigenTrajectory at once, we provide a bash script `test.sh` for a simplified execution.
```bash
./scripts/test.sh -p <config_path> -t <experiment_tag> -d <space_seperated_dataset_string> -i <space_seperated_gpu_id_string>
```

**Examples**

```bash
# Stochastic prediction task
./script/test.sh -p stochastic/singulartrajectory -t SingularTrajectory-stochastic

# Deterministic prediction task
./script/test.sh -p deterministic/singulartrajectory -t SingularTrajectory-deterministic

# Momentary observation task
./script/test.sh -p momentary/singulartrajectory -t SingularTrajectory-momentary

- [x] 3/27 Paper [released](https://arxiv.org/abs/2403.18452)!
- [ ] Source code will be released shortly. (Please check our [project page](https://inhwanbae.github.io/publication/singulartrajectory/) :)
# Domain adaptation task
./script/test.sh -p domain/singulartrajectory -t SingularTrajectory-domain -d "A2B A2C A2D A2E B2A B2C B2D B2E C2A C2B C2D C2E D2A D2B D2C D2E E2A E2B E2C E2D" -i "0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4"

# Few-shot task
./script/test.sh -p fewshot/singulartrajectory -t SingularTrajectory-fewshot

# (Optional) Stochastic domain adaptation task
./script/test.sh -p domain-stochastic/singulartrajectory -t SingularTrajectory-domain-stochastic -d "A2B A2C A2D A2E B2A B2C B2D B2E C2A C2B C2D C2E D2A D2B D2C D2E E2A E2B E2C E2D" -i "0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4"

# (Optional) All-in-one task
./script/test.sh -p allinone/singulartrajectory -t SingularTrajectory-allinone -d "A2B A2C A2D A2E B2A B2C B2D B2E C2A C2B C2D C2E D2A D2B D2C D2E E2A E2B E2C E2D" -i "0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4"
```

If you want to evaluate the model individually, you can use `trainval.py` with custom hyperparameters.
```bash
python trainval.py --test --cfg ./config/{task}/singulartrajectory-transformerdiffusion-{dataset}.json --tag SingularTrajectory-{task} --gpu_id 0
```

<br>

## 📖 Citation
If you find this code useful for your research, please cite our trajectory prediction papers :)
Expand All @@ -51,7 +149,7 @@ If you find this code useful for your research, please cite our trajectory predi
year={2024}
}
```
<details open>
<details>
<summary>More Information (Click to expand)</summary>

```bibtex
Expand Down Expand Up @@ -99,4 +197,8 @@ If you find this code useful for your research, please cite our trajectory predi
```
</details>

### Acknowledgement
Part of our code is borrowed from [EigenTrajectory](https://github.com/InhwanBae/EigenTrajectory) and [LED](https://github.com/MediaBrain-SJTU/LED).
We thank the authors for releasing their code and models.

<br>
2 changes: 2 additions & 0 deletions SingularTrajectory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .model import SingularTrajectory
from .normalizer import TrajNorm
142 changes: 142 additions & 0 deletions SingularTrajectory/anchor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import torch
import torch.nn as nn
from .kmeans import BatchKMeans
from sklearn.cluster import KMeans
import numpy as np
from .homography import image2world, world2image


class AdaptiveAnchor(nn.Module):
r"""Adaptive anchor model
Args:
hyper_params (DotDict): The hyper-parameters
"""

def __init__(self, hyper_params):
super().__init__()

self.hyper_params = hyper_params
self.k = hyper_params.k
self.s = hyper_params.num_samples
self.dim = hyper_params.traj_dim

self.C_anchor = nn.Parameter(torch.zeros((self.k, self.s)))

def to_Singular_space(self, traj, evec):
r"""Transform Euclidean trajectories to Singular space coordinates
Args:
traj (torch.Tensor): The trajectory to be transformed
evec (torch.Tensor): The Singular space basis vectors
Returns:
C (torch.Tensor): The Singular space coordinates"""

# Euclidean space -> Singular space
tdim = evec.size(0)
M = traj.reshape(-1, tdim).T
C = evec.T.detach() @ M
return C

def batch_to_Singular_space(self, traj, evec):
# Euclidean space -> Singular space
tdim = evec.size(0)
M = traj.reshape(-1, tdim).transpose(1, 2)
C = evec.T.detach() @ M
return C

def to_Euclidean_space(self, C, evec):
r"""Transform Singular space coordinates to Euclidean trajectories
Args:
C (torch.Tensor): The Singular space coordinates
evec (torch.Tensor): The Singular space basis vectors
Returns:
traj (torch.Tensor): The Euclidean trajectory"""

# Singular space -> Euclidean
t = evec.size(0) // self.dim
M = evec.detach() @ C
traj = M.T.reshape(-1, t, self.dim)
return traj

def batch_to_Euclidean_space(self, C, evec):
# Singular space -> Euclidean
b = C.size(0)
t = evec.size(0) // self.dim
M = evec.detach() @ C
traj = M.transpose(1, 2).reshape(b, -1, t, self.dim)
return traj

def anchor_initialization(self, pred_traj_norm, V_pred_trunc):
r"""Anchor initialization on Singular space
Args:
pred_traj_norm (torch.Tensor): The normalized predicted trajectory
V_pred_trunc (torch.Tensor): The truncated Singular space basis vectors of the predicted trajectory
Note:
This function should be called once before training the model.
"""

# Trajectory projection
C_pred = self.to_Singular_space(pred_traj_norm, evec=V_pred_trunc).T.detach().numpy()
C_anchor = torch.FloatTensor(KMeans(n_clusters=self.s, random_state=0, init='k-means++', n_init=1).fit(C_pred).cluster_centers_.T)

# Register anchors as model parameters
self.C_anchor = nn.Parameter(C_anchor.to(self.C_anchor.device))

def adaptive_anchor_calculation(self, obs_traj, scene_id, vector_field, homography, space):
r"""Adaptive anchor calculation on Singular space"""

n_ped = obs_traj.size(0)
V_trunc = space.V_trunc

space.traj_normalizer.calculate_params(obs_traj.cuda().detach())
init_anchor = self.C_anchor.unsqueeze(dim=0).repeat_interleave(repeats=n_ped, dim=0).detach()
init_anchor = init_anchor.permute(2, 1, 0)
init_anchor_euclidean = space.batch_to_Euclidean_space(init_anchor, evec=V_trunc)
init_anchor_euclidean = space.traj_normalizer.denormalize(init_anchor_euclidean).cpu().numpy()
adaptive_anchor_euclidean = init_anchor_euclidean.copy()
obs_traj = obs_traj.cpu().numpy()

for ped_id in range(n_ped):
scene_name = scene_id[ped_id]
prototype_image = world2image(init_anchor_euclidean[:, ped_id], homography[scene_name])
startpoint_image = world2image(obs_traj[ped_id], homography[scene_name])[-1]
endpoint_image = prototype_image[:, -1, :]
endpoint_image = np.round(endpoint_image).astype(int)
size = np.array(vector_field[scene_name].shape[1::-1]) // 2
endpoint_image = np.clip(endpoint_image, a_min= -size // 2, a_max=size + size // 2 -1)
for s in range(self.s):
vector = np.array(vector_field[scene_name][endpoint_image[s, 1] + size[1] // 2, endpoint_image[s, 0] + size[0] // 2])[::-1] - size // 2
if vector[0] == endpoint_image[s, 0] and vector[1] == endpoint_image[s, 1]:
continue
else:
nearest_endpoint_image = vector
scale_xy = (nearest_endpoint_image - startpoint_image) / (endpoint_image[s] - startpoint_image)
prototype_image[s, :, :] = (prototype_image[s, :, :].copy() - startpoint_image) * scale_xy + startpoint_image

prototype_world = image2world(prototype_image, homography[scene_name])
adaptive_anchor_euclidean[:, ped_id] = prototype_world

adaptive_anchor_euclidean = space.traj_normalizer.normalize(torch.FloatTensor(adaptive_anchor_euclidean).cuda())
adaptive_anchor = space.batch_to_Singular_space(adaptive_anchor_euclidean, evec=V_trunc)
adaptive_anchor = adaptive_anchor.permute(2, 1, 0).cpu()
# If you don't want to use an image, return `init_anchor`.
return adaptive_anchor

def forward(self, C_residual, C_anchor):
r"""Anchor refinement on Singular space
Args:
C_residual (torch.Tensor): The predicted Singular space coordinates
Returns:
C_pred_refine (torch.Tensor): The refined Singular space coordinates
"""

C_pred_refine = C_anchor.detach() + C_residual
return C_pred_refine
Loading

0 comments on commit 45660cd

Please sign in to comment.