Link to the paper: (
Install conda (with python=3.10)
Install pytorch 1.13 cuda version.
Here is an example command of pytorch installation in conda for CUDA 11.7:
conda install pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.7 -c pytorch -c nvidia
Check if torch is able to make use of CUDA
- This can be done using
- This can be done using
Install PyTorch Lightning 1.8 (supports PyTorch 1.10, 1.11, 1.12 and 1.13)
pip install pytorch-lightning=1.8
- Ensure that pytorch lightning version being used is according to the compatibility matrix:
Install hydra
In case there is an issue as shown below:
Install this:
pip install "protobuf==3.20.*”
Install cv2:
conda install -c conda-forge opencv
Install tfrecord:
pip install tfrecord
Download waymo motion data from:
- v1.1 data was used for experiments
The data is partitioned into training, validation, test at the official source using 70/15/15 split.
- Partial data used for current experiments:
- Train: First 66 tf record files
) - Val: First 10 tf record files
- Train: First 66 tf record files
- Partial data used for current experiments:
idx files need to be generated for tf.record files using
- Use
to create a folder of idx files by processing a folder of tf.record files- Generate train idx files
- Generate val idx files
- Use
Visualize a sample of train data using
- Sample motion data:
- Summary of motion dataset:
- TFrecord tutorial:
- Waymo open Dataset paper:
- Run this in order to train/test/validate the modelconf/config.yaml
- Set the configuration of dataset, model and training. Check comments for definitions.- datautil
- Use it to create indices from tfrecord
- Defines the dataset class, dataloader and the collate function.
- model
- Defines the encoder of
- Defines the decoder of
- Defines the training, validation, testing parts. Also contains visualization
- Model/data/training utilities
- tfrecordutils - tf record utilities
- outputs/ - Outputs of train/test are stored here based on the data and time of the run.
Training code loads data as per the dataloader and learns through backpropagation of loss calculated between ground truth and predictions obtained by the model.
- Download the dataset and prepare train, val data and idxs in seperate folders following the procedure shared above.
- In the config.yaml file, set mode to train.
- Set the dataset paths under dataset field of config
- Set the model's parameters under the model field.
- Make sure the gpu's set in config are available for use.
- Run
to start training the model
- The code automatically does a sample validation initially and then proceeds to train the model. Checkpointing of train (and val) loss, metrics, visualizations are done through tf events file. Use
to visualize the relevant plots and validation images. - Training can be resumed from a previous checkpoint by setting the path of checkpoint in resume field of
- Depending on pytorch lightning version installed, there might be an error when resuming training. In that case, comment out the
and resume training.
- Depending on pytorch lightning version installed, there might be an error when resuming training. In that case, comment out the
Test code processes a folder of tf.record data to produce an output folder containing visualizations of each scene.
- Generate idx files of test data as described above using
- In the config.yaml file, set mode to test
- Under the dataset's test field of config file, do the following:
- Set the paths to test tf.record folder and idx folder
- Set the number of seconds into the future model should predict using
- Set the path to model weights using
- Run to start testing the model
- The evaluation metrics (minFDE, minADE, inference time etc) of test data are printed at the end and output visualizations (.png files) are saved in outputs folder.
The following changes have been made in training conditions when compared to training conditions used in the paper:
: 256 -> 64batch_size
: 64 -> 2in_feat_dim
: 7 -> 9- Data augmentation: None (Paper uses data augmentation)
- Partial training and validation data (~7% of original data used in the paper, details shared above in Data Preparation)
- MSE Loss is used instead of multiple losses (Displacement, classification, Laplace, heading losses) used in the paper