This code is based on the STCN repository.
You can find below the installation script:
python -m venv ENV
source ENV/bin/activate
pip install torch torchvision
pip install pyyaml
We evaluated our test-time training strategy on four datasets:
Please see in DATA_PREPARATION to download the datasets.
The test-time training strategy can be run on top of any STCN model trained offline. In our study, we evaluated on top of the three models provided in the original repository.
stcn_s01
, model trained without real videos (model used in the sim2real transfer study)stcn
, model trained with real videos (model used in the corrupted examples study)stcn_s0
, model trained with static videos
For each model, you can run the test-time training strategy using three losses:
- the
tt-mcc
loss, our proposed method, using temporal information through the mask cycle consistency - the
tt-ae
loss, an auto-encoder-based variant - the
tt-ent
loss, an entropy-based variant
For each combination of an offline trained model and a loss, we saved the parameters used to run our test-time training
strategy in a configuration file, saved in the ttt_configs
folder.
You can save your own config by running:
python ttt/config/save_config.py --config_name <config_name> <your parameters list>
We provide below the scripts to run the sim2real transfer study on the four datasets.
python eval_all_datasets_ttt.py --config_name stcn_s01_mcc --dataset_name davis --split val --dataset_dir $DATA_DIR --output_dir $OUTPUT_DIR --seed $SEED
python eval_all_datasets_ttt.py --config_name stcn_s01_mcc --dataset_name youtube --split valid --dataset_dir $DATA_DIR --output_dir $OUTPUT_DIR --seed $SEED
python eval_all_datasets_ttt.py --config_name stcn_s01_mcc --dataset_name mose --split valid --dataset_dir $DATA_DIR --output_dir $OUTPUT_DIR --seed $SEED
We also provide the scripts to run the corrupted examples analysis below.
python eval_all_datasets_ttt.py --config_name stcn_mcc --dataset_name davis --split val --dataset_dir $DATA_DIR --corrupted_image_dir $CORRUPTED_IMAGE_DIR --output_dir $OUTPUT_DIR --seed $SEED
To run test time training, you need to select your configuration (depending on the offline trained STCN model and the test-time training loss used) and the dataset you want to run it for.
Choose your configuration file in the ttt_configs
folder (it can be one of your own as well).
For our main study, we used:
stcn_s01_mcc.yaml
(sim2real)stcn_mcc.yaml
(corrupted examples)
For each dataset, you will need to specify the following parameters:
- DATASET_NAME: between
davis
,youtube
,mose
ordavis-c
- DATASET_DIR: the root directory where your dataset is saved. For example for the DAVIS-2017 dataset, it will be
<your_dataset_folder>/DAVIS/2017
- CORRUPTED_IMAGE_DIR: the root directory containing the RGB frames with corruption. For example, for medium strength of the brighness corruption, it will be <your_dataset_folder>/DAVIS-C/med/brightness.
- SPLIT: the name of the split you want to test on (for our main study we used the
val
split for DAVIS, and thevalid
split for YOUTUBE and MOSE) - OUTPUT_DIR: the name of the output directory where to save the predicted masks
To run multiple model and evaluate the mean and standard deviation, we ran the models for multiple seeds. We define an additional parameter:
- SEED: the seed used (experimented with 1 / 5 / 10)