diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..905ee7a --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +venv/ +*.pyc +.idea/ +build/ +__pycache__/ diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..16b6923 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,97 @@ +# License Guide + +When you create a new open-source repo you need to be attent to the right way to create or update a LICENSE file + +In the copyright year you need to write the year of the begining of your work. + +The copyright notice should include the year in which you finished preparing the release (so if you finished it in 1998 but didn't post it until 1999, use 1998). You should add the proper year for each release; for example, “Copyright 1998, 1999 Terry Jones” if some versions were finished in 1998 and some were finished in 1999. If several people helped write the code, use all their names. + +For software with several releases over multiple years, it's okay to use a range (“2008-2010”) instead of listing individual years (“2008, 2009, 2010”) if and only if every year in the range, inclusive, really is a “copyrightable” year that would be listed individually; and you make an explicit statement in your documentation about this usage. + +If you made changes that year, do include the year in a comma-separated list in your copyright notice. + +If did not make copyrightable changes that year, do not include that year in your copyright notice. + +### Useful links + +- [What's the right license for me?](http://choosealicense.com) +- [Is renewal of MIT license needed on github at the beginning of each year?](http://programmers.stackexchange.com/a/210491) + +### Examples + +``` +The MIT License (MIT) + +Copyright (c) 2016 Pagar.me Pagamentos S/A + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +``` + +> Example of a MIT license for projects in constantly active development (like our libraries) + +``` +The MIT License (MIT) + +Copyright (c) 2013-present Pagar.me Pagamentos S/A + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +``` + +> Example of a MIT license when you're a new maintainer of a deprecated ou discontinued project. + +``` +The MIT License (MIT) + +Copyright: (c) 2016 Pagar.me Pagamentos S/A + (c) 2010 John Doe + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +``` diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..1e3cd58 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,7 @@ +include gazenet/configs/infer_configs/*.json +include gazenet/configs/train_configs/*.json +include gazenet/readers/visualization/assets/*.css +global-include *.sh +exclude *.zip *.7z *.tar.gz *.ptb *.ptb.tar *.npy *.npz *.hd5 *.txt *.jpg *.png *.avi *.gif *.mp4 *.wav *.mp3 +include datasets/processed/center_bias.jpg +include datasets/processed/center_bias_bw.jpg diff --git a/README.md b/README.md new file mode 100644 index 0000000..9ada318 --- /dev/null +++ b/README.md @@ -0,0 +1,220 @@ +# GASP: Gated Attention for Saliency Prediction + +[\[Project Page: KT\]](http://software.knowledge-technology.info/#gasp) | [\[Abstract\]](https://www.ijcai.org/proceedings/2021/81) | [\[Paper\]](https://www.ijcai.org/proceedings/2021/0081.pdf) | [\[BibTeX\]](https://www.ijcai.org/proceedings/2021/bibtex/81) + +This is the official [GASP](http://software.knowledge-technology.info/#gasp) code for our paper, presented at +[IJCAI 2021](https://ijcai-21.org). If you find this work useful, please cite our [paper](https://www2alt.informatik.uni-hamburg.de/wtm/publications/2021/AWW21/index.php): + +``` +@inproceedings{abawi2021gasp, + title={{GASP: Gated Attention for Saliency Prediction}}, + author={Abawi, Fares and Weber, Tom and Wermter, Stefan}, + booktitle={Proceedings of the Thirtieth International Joint Conference on Artificial Intelligence (IJCAI)}, + pages={584--591}, + year={2021}, + doi={10.24963/ijcai.2021/81}, + publisher={IJCAI Organization} +} +``` + +## Architecture Overview + +![GASP model architecture showing the feature extraction pipeline in the first stage followed by the feature integration in the second stage](showcase/multimodalsaliency.png) + +## Environment Variables and Preparation + +Configuring the training and inference pipelines is done through python scripts to maintain flexibility +in adding functionality within the configuration itself. Configuration can also be added externally in the +form of `json` files found in [infer_configs](gazenet/configs/infer_configs) and [train_configs](gazenet/configs/train_configs). +In general, all the configurations for repeating the paper experiments +could be found in [infer_config.py](gazenet/configs/infer_config.py) and [train_config.py](gazenet/configs/train_config.py). + +We recommend using [comet.ml](https://comet.ml) for tracking your experiments during training. You can opt to use `tensorboard` as well, +but the logger needs to be specified as an argument `--logger_name tensorboard` to `gasp_train` or changed in the training configuration itself + +` +logger = 'tensorboard' # logger = '' -> does not log the experiment +` + +When choosing to use comet.ml by specifying `--logger_name comet`, set the following environment variables: + +``` +export COMET_WORKSPACE= +export COMET_KEY= +``` + +replacing `` with your workspace name and `` with your comet.ml API key. + +It is recommended to create a separate working space outside of this repository. This can be done by setting: + +``` +export GASP_CWD= +# CAN BE SET TO $(pwd) TO INSTALL IN THE CODE REPO: +# export GASP_CWD=$(pwd) +``` + +## Setup + +The following need to be installed: + +``` +sudo apt-get install pv jq libportaudio2 ffmpeg +``` + +GASP is implemented in Pytorch and trained using the PytorchLightning library. To install GASP requirements, create a virtual environment and run: + +` +python3 setup.py install +` + +### Preprocessing and Generating Social Cues + Saliency Prediction Representations (SCD) + +| Saliency Prediction (DAVE) + Video | Gaze Direction Estimation (Gaze360) | Gaze Following (VideoGaze) | Facial Expression Recognition (ESR9) | +| ------------------------- | ------------------------- | ------------------------- |:------------------------- | +| | | | | + +**Download** + +You can download the preprocessed spatiotemporal maps directly without the need to process the training data locally: + +``` +gasp_download_manager --working_dir $GASP_CWD \ + --datasets processed/Grouped_frames/coutrot1 \ + processed/Grouped_frames/coutrot2 \ + processed/Grouped_frames/diem +``` + + +**Preprocess Locally** + +Alternatively, generate the modality representations using the provided scripts. +Note that this might take upwards of a day depending on your CPU and/or GPU. + +1. Download the datasets and pretrained social cue parameters directly by running the following script + (shells bash scripts and has been tested on Ubuntu 20.4): + + ``` + gasp_download_manager --working_dir $GASP_CWD \ + --datasets ave/database1 ave/database2 ave/diem stavis_preprocessed \ + --models emotion_recognition/esr9/checkpoints/pretrained_esr9_orig \ + gaze_estimation/gaze360/checkpoints/pretrained_gaze360_orig \ + gaze_following/videogaze/checkpoints/pretrained_videogaze_orig \ + saliency_prediction/dave/checkpoints/pretrained_dave_orig + ``` + + *Note*: You can instead navigate to `datasets/` to download individual datasets and run the corresponding `download_dataset.sh` bash file directly. + +2. On download completion, the dataset can be generated (run from within the `--working_dir $GASP_CWD` specified in the previous step): + + ``` + gasp_infer --infer_config InferGeneratorAllModelsCoutrot1 + gasp_infer --infer_config InferGeneratorAllModelsCoutrot2 + gasp_infer --infer_config InferGeneratorAllModelsDIEM + ``` + +3. Finally, you could choose to replace the ground-truth fixation density maps and fixation points by the preprocessed + maps generated by [STAViS: Tsiami et al.](https://github.com/atsiami/STAViS) Note that we use these ground-truth maps in all our experiments: + + ``` + gasp_scripts --working_dir $GASP_CWD --scripts postprocess_get_from_stavis + ``` + +## Training + +To train the best achieving sequential model **(DAM + LARGMU; Context Size = 10)**, invoke the script configuration on the social event subset of the [AVE dataset \[Tavakoli et al.\]](https://hrtavakoli.github.io/AVE/): + +``` +gasp_train --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_10Norm \ + --infer_configs InferMetricsGASPTrain \ + --checkpoint_save_every_n_epoch 50 --checkpoint_save_n_top 5 --check_val_every_n_epoch 49 --max_epochs 2000 \ + --gpus "0," --logger_name "comet" --val_store_image_samples --compute_metrics + +``` + +or specify a json configuration file: + +``` +gasp_train --train_config_file $GASP_CWD/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_10Norm.json \ + --infer_config_files $GASP_CWD/gazenet/configs/infer_configs/InferMetricsGASPTrain.json \ + --checkpoint_save_every_n_epoch 50 --checkpoint_save_n_top 5 --check_val_every_n_epoch 49 --max_epochs 2000 \ + --gpus "0," --logger_name "comet" --val_store_image_samples --compute_metrics + +``` + +The `--compute_metrics` argument will run the inference script on completion and store the metrics results in [logs\metrics](logs\metrics) in the working directory. + +*Note*: We treat a single peek (covering the context size of GASP) into each of the videos in the dataset as an epoch since we visualize validation samples at short intervals rather than entire epochs. This should not be misconstrued as 2000 epochs over the dataset. + +## Inference + +![coutrot2 clip13 with sequential GASP DAM + LARGMU (Context Size: 10) overlaid on a video of 4 men in a meeting scenario](showcase/coutrot2_clip13_compressed.gif) + +**WARNING: Always run** `gasp_scripts --working_dir $GASP_CWD --scripts clean_temp` **before executing any inference script when changing dataset splits. +As a precaution, always delete temporary files before executing inference scripts if it doesn't take too long to process.** + +The inferer can run and visualize all integrated models as well as the GASP variants. To download all GASP variants: + +``` +gasp_download_manager --working_dir $GASP_CWD \ + --models "saliency_prediction/gasp/<...>" +``` + +To run the GASP inference, select a configuration class or json file and execute the inference script: + +``` +gasp_infer --infer_config InferVisualizeGASPSeqDAMALSTMGMU1x1Conv_10Norm --gpu 0 +``` + +*Note*: Remove `--gpu 0` argument to run on CPU. + +*Tip*: To try out specific videos, create a new split (.csv files found in [datasets/processed](datasets/processed)) e.g.,: + +``` +video_id,fps,scene_type,dataset +clip_11,25,Other,coutrot1 +``` + +and in the configuration file e.g., [InferVisualizeGASPSeqDAMALSTMGMU1x1Conv_10Norm](gazenet/configs/infer_configs/InferVisualizeGASPSeqDAMALSTMGMU1x1Conv_10Norm.json) replace: + +``` +"datasplitter_properties": { + "train_csv_file": "datasets/processed/test_ave.csv", + "val_csv_file": null, + "test_csv_file": null + }, +``` + +by the new split's name: + +``` + "train_csv_file": "datasets/processed/.csv" +``` + +## TODOs + +- [ ] Support parallelizing inference models on multiple GPUs +- [ ] Support parallelizing inference models on multiple machines using middleware +- [ ] Support realtime inference for GASP (currently works for selected models) +- [ ] Restructure configuration files for more consistency +- [ ] Support intermediate invocation of external applications within the inference model pipeline + +## Attribution + +This work relies on several packages and code-bases which we have modified to fit our framework. +If any attributions are missing, please notify us by [Email](mailto:fares.abawi@uni-hamburg.de?subject=[GitHub]%20Missing%20GASP%20Attribution). +The following is a list of repositories which have a substantial portion of their content included in this work: + +* [STAViS: Spatio-Temporal AudioVisual Saliency Network](https://github.com/atsiami/STAViS) +* [Unified Image and Video Saliency Modeling](https://github.com/rdroste/unisal) +* [TASED-Net: Temporally-Aggregating Spatial Encoder-Decoder Network for Video Saliency Detection](https://github.com/MichiganCOG/TASED-Net) +* [ViNet: Pushing the limits of Visual Modality for Audio-Visual Saliency Prediction](https://github.com/samyak0210/ViNet) +* [DAVE: A Deep Audio-Visual Embedding for Dynamic Saliency Prediction](https://github.com/hrtavakoli/DAVE) +* [Gaze360: Physically Unconstrained Gaze Estimation in the Wild Dataset](https://github.com/erkil1452/gaze360) +* [Following Gaze in Video](https://github.com/recasens/Gaze-Following) +* [Efficient Facial Feature Learning with Wide Ensemble-based Convolutional Neural Networks](https://github.com/siqueira-hc/Efficient-Facial-Feature-Learning-with-Wide-Ensemble-based-Convolutional-Neural-Networks) +* [Saliency Metrics](https://github.com/tarunsharma1/saliency_metrics) + +## Acknowledgement + +This work was supported by the German +Research Foundation DFG under project [CML (TRR 169)](https://www.crossmodal-learning.org/). diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/datasets/ave/database1/download_dataset.sh b/datasets/ave/database1/download_dataset.sh new file mode 100644 index 0000000..b10f03d --- /dev/null +++ b/datasets/ave/database1/download_dataset.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +./megadown "https://mega.nz/#!At8DWR7L!yf5k0jVwL961-jI4FJ2DGUAUqAu-yNbq3s3i6b52M2I" +wget -O coutrot_database1.mat "http://antoinecoutrot.magix.net/public/assets/coutrot_database1.mat" +7za x "ERB3_Stimuli.zip" +rm "ERB3_Stimuli.zip" \ No newline at end of file diff --git a/datasets/ave/database1/megadown b/datasets/ave/database1/megadown new file mode 100644 index 0000000..731a828 --- /dev/null +++ b/datasets/ave/database1/megadown @@ -0,0 +1,704 @@ +#!/bin/bash + +VERSION="1.9.46" + +# code from: https://github.com/tonikelope/megadown +# need to install: sudo apt-get install pv + +HERE=$(dirname "$0") +SCRIPT=$(readlink -f "$0") +BASH_MIN_VER="3" + +MEGA_API_URL="https://g.api.mega.co.nz" +OPENSSL_AES_CTR_128_DEC="openssl enc -d -aes-128-ctr" +OPENSSL_AES_CBC_128_DEC="openssl enc -a -A -d -aes-128-cbc" +OPENSSL_AES_CBC_256_DEC="openssl enc -a -A -d -aes-256-cbc" +OPENSSL_MD5="openssl md5" + +if [ ! -d ".megadown" ]; then + mkdir ".megadown" +fi + +# 1:message_error +function showError { + echo -e "\n$1\n" 1>&2 + exit 1 +} + +function showHelp { + echo -e "\nmegadown $VERSION - https://github.com/tonikelope/megadown" + echo -e "\ncli downloader for mega.nz and megacrypter" + echo -e "\nSingle url mode: megadown [OPTION]... 'URL'\n" + echo -e "\tOptions:" + echo -e "\t-o,\t--output FILE_NAME Store file with this name." + echo -e "\t-s,\t--speed SPEED Download speed limit (integer values: 500B, K, 2M)." + echo -e "\t-p,\t--password PASSWORD Password for MegaCrypter links." + echo -e "\t-q,\t--quiet Quiet mode." + echo -e "\t-m,\t--metadata Prints file metadata in JSON format and exits." + echo -e "\n\nMulti url mode: megadown [OPTION]... -l|--list FILE\n" + echo -e "\tOptions:" + echo -e "\t-s,\t--speed SPEED Download speed limit (integer values: 500B, 500K, 2M)." + echo -e "\t-p,\t--password PASSWORD Password for MegaCrypter links (same for every link in a list)." + echo -e "\t-q,\t--quiet Quiet mode." + echo -e "\t-m,\t--metadata Prints file metadata in JSON format and exits." + echo -e "\tFile line format: URL [optional_file_name]\n" +} + +function check_deps { + + local dep_error=0 + + if [ -n "$(command -v curl 2>&1)" ]; then + DL_COM="curl --fail -s" + DL_COM_PDATA="--data" + elif [ -n "$(command -v wget 2>&1)" ]; then + DL_COM="wget -q -O -" + DL_COM_PDATA="--post-data" + else + echo "wget OR curl is required and it's not installed" + dep_error=1 + fi + + for i in openssl pv jq; do + + if [ -z "$(command -v "$i" 2>&1)" ]; then + + echo "[$i] is required and it's not installed" + dep_error=1 + + else + + case "$i" in + + openssl) + + openssl_sup=$(openssl enc -ciphers 2>&1) + + for i in "aes-128-ctr" "aes-128-cbc" "aes-256-cbc"; do + + if [ -z "$(echo -n "$openssl_sup" | grep -o "$i" | head -n1)" ]; then + + echo "Your openssl binary does not support ${i}" + dep_error=1 + + fi + + done + ;; + esac + fi + + done + + if [ -z "$(command -v python 2>&1)" ]; then + echo "WARNING: python is required for MegaCrypter password protected links and it's not installed." + fi + + if [[ "$(echo -n "$BASH_VERSION" | grep -o -E "[0-9]+" | head -n1)" < "$BASH_MIN_VER" ]]; then + echo "bash >= ${BASH_MIN_VER} is required" + dep_error=1 + fi + + if [ $dep_error -ne 0 ]; then + showError "ERROR: there are dependencies not present!" + fi +} + + +# 2:url +function urldecode { + + : "${*//+/ }"; echo -e "${_//%/\\x}"; +} + +# 1:b64_encoded_string +function urlb64_to_b64 { + local b64=$(echo -n "$1" | tr '\-_' '+/' | tr -d ',') + local pad=$(((4-${#1}%4)%4)) + + for i in $(seq 1 $pad); do + b64="${b64}=" + done + + echo -n "$b64" +} + +# 1:mega://enc link +function decrypt_md_link { + + local data=$(regex_imatch "^.*?mega:\/\/enc[0-9]*?\?([a-z0-9_,-]+).*?$" "$1" 1) + + local iv="79F10A01844A0B27FF5B2D4E0ED3163E" + + if [ $(echo -n "$1" | grep 'mega://enc?') ]; then + + key="6B316F36416C2D316B7A3F217A30357958585858585858585858585858585858" + + elif [ $(echo -n "$1" | grep 'mega://enc2?') ];then + + key="ED1F4C200B35139806B260563B3D3876F011B4750F3A1A4A5EFD0BBE67554B44" + fi + + echo -n "https://mega.nz/#"$(echo -n "$(urlb64_to_b64 "$data")" | $OPENSSL_AES_CBC_256_DEC -K "$key" -iv "$iv") +} + +# 1:hex_raw_key +function hrk2hk { + declare -A hk + hk[0]=$(( 0x${1:0:16} ^ 0x${1:32:16} )) + hk[1]=$(( 0x${1:16:16} ^ 0x${1:48:16} )) + printf "%016x%016x" ${hk[0]} ${hk[1]} +} + +# 1:link +function get_mc_link_info { + + local MC_API_URL=$(echo -n "$1" | grep -i -E -o 'https?://[^/]+')"/api" + + local download_exit_code=1 + + local info_link=$($DL_COM --header 'Content-Type: application/json' $DL_COM_PDATA "{\"m\":\"info\", \"link\":\"$1\"}" "$MC_API_URL") + + download_exit_code=$? + + if [ "$download_exit_code" -ne 0 ]; then + echo -e "ERROR: Oooops, something went bad. EXIT CODE (${download_exit_code})" + return 1 + fi + + if [ $(echo $info_link | grep '"error"') ]; then + local error_code=$(echo "$info_link" | jq -r .error) + echo -e "MEGACRYPTER ERROR $error_code" + return 1 + fi + + local expire=$(echo "$info_link" | jq -r .expire) + + if [ "$expire" != "false" ]; then + + IFS='#' read -a array <<< "$expire" + + local no_exp_token=${array[1]} + else + local no_exp_token="$expire" + fi + + local file_name=$(echo "$info_link" | jq -r .name | base64 -w 0 -i 2>/dev/null) + + local path=$(echo "$info_link" | jq -r .path) + + if [ "$path" != "false" ]; then + path=$(echo -n "$path" | base64 -w 0 -i 2>/dev/null) + fi + + local mc_pass=$(echo "$info_link" | jq -r .pass) + + local file_size=$(echo "$info_link" | jq -r .size) + + local key=$(echo "$info_link" | jq -r .key) + + echo -n "${file_name}@${path}@${file_size}@${mc_pass}@${key}@${no_exp_token}" +} + +# 1:file_name 2:file_size 3:formatted_file_size [4:md5_mclink] +function check_file_exists { + + if [ -f "$1" ]; then + + local actual_size=$(stat -c %s "$1") + + if [ "$actual_size" == "$2" ]; then + + if [ -n "$4" ] && [ -f ".megadown/${4}" ]; then + rm ".megadown/${4}" + fi + + showError "WARNING: File $1 exists. Download aborted!" + fi + + DL_MSG="\nFile $1 exists but with different size (${2} vs ${actual_size} bytes). Downloading [${3}] ...\n" + + else + + DL_MSG="\nDownloading $1 [${3}] ...\n" + + fi +} + +# 1:file_size +function format_file_size { + + if [ "$1" -ge 1073741824 ]; then + local file_size_f=$(awk "BEGIN { rounded = sprintf(\"%.1f\", ${1}/1073741824); print rounded }")" GB" + elif [ "$1" -ge 1048576 ];then + local file_size_f=$(awk "BEGIN { rounded = sprintf(\"%.1f\", ${1}/1048576); print rounded }")" MB" + else + local file_size_f="${1} bytes" + fi + + echo -ne "$file_size_f" +} + +# 1:password 2:salt 3:iterations +function mc_pbkdf2 { + + echo -e "import sys,hashlib,base64\nprint(base64.b64encode(hashlib.pbkdf2_hmac('sha256', b'${1}', base64.b64decode(b'${2}'), ${3})).decode())" | python +} + +# 1:mc_pass_info 2:pass_to_check +function mc_pass_check { + + IFS='#' read -a array <<< "$1" + + local iter_log2=${array[0]} + + local key_check=${array[1]} + + local salt=${array[2]} + + local iv=${array[3]} + + local mc_pass_hash=$(mc_pbkdf2 "$password" "$salt" $((2**$iter_log2))) + + mc_pass_hash=$(echo -n "$mc_pass_hash" | base64 -d -i 2>/dev/null | od -v -An -t x1 | tr -d '\n ') + + iv=$(echo -n "$iv" | base64 -d -i 2>/dev/null | od -v -An -t x1 | tr -d '\n ') + + if [ "$(echo -n "$key_check" | $OPENSSL_AES_CBC_256_DEC -K "$mc_pass_hash" -iv "$iv" 2>/dev/null | od -v -An -t x1 | tr -d '\n ')" != "$mc_pass_hash" ]; then + echo -n "0" + else + echo -n "${mc_pass_hash}#${iv}" + fi +} + +#1:string +function trim { + + if [[ "$1" =~ \ *([^ ]|[^ ].*[^ ])\ * ]]; then + echo -n "${BASH_REMATCH[1]}" + fi +} + +#1:pattern 2:subject 3:group +function regex_match { + + if [[ "$2" =~ $1 ]]; then + echo -n "${BASH_REMATCH[$3]}" + fi +} + +#1:pattern 2:subject 3:group +function regex_imatch { + + shopt -s nocasematch + + if [[ "$2" =~ $1 ]]; then + echo -n "${BASH_REMATCH[$3]}" + fi + + shopt -u nocasematch +} + +#MAIN STARTS HERE: +check_deps + +if [ -z "$1" ]; then + showHelp + exit 1 +fi + +eval set -- "$(getopt -o "l:p:k:o:s:qm" -l "list:,password:,key:,output:,speed:,quiet,metadata" -n ${0} -- "$@")" + +while true; do + case "$1" in + -l|--list) list="$2"; shift 2;; + -p|--password) password="$2"; shift 2;; + -o|--output) output="$2"; shift 2;; + -s|--speed) speed="$2"; shift 2;; + -q|--quiet) quiet=true; shift 1;; + -m|--metadata) metadata=true; shift 1;; + + --) shift; break;; + + *) + showHelp + exit 1;; + esac +done + +p1=$(trim $(urldecode "$1")) + +if [[ "$p1" =~ ^http ]] || [[ "$p1" =~ ^mega:// ]]; then + link="$p1" +fi + +if [ -z "$link" ]; then + + if [ -z "$list" ]; then + + showHelp + + showError "ERROR: MEGA/MC link or --list parameter is required" + + elif [ ! -f "$list" ]; then + + showHelp + + showError "ERROR: list file ${list} not found" + fi + + if [ ! $quiet ]; then + echo -ne "\n(Pre)reading mc links info..." + fi + + link_count=0 + + while IFS='' read -r line || [ -n "$line" ]; do + + if [ -n "$line" ] && ! [ $(echo -n "$line" | grep -E -o 'mega://enc') ];then + + link=$(regex_imatch "^.*?(https?\:\/\/[^\/]+\/[#!0-9a-z_-]+).*$" "$line" 1) + + if [ $(echo -n "$link" | grep -E -o 'https?://[^/]+/!') ]; then + + md5=$(echo -n "$link" | $OPENSSL_MD5 | grep -E -o '[0-9a-f]{32}') + + if [ ! -f ".megadown/${md5}" ];then + + mc_link_info=$(get_mc_link_info "$link") + + if ! [ "$?" -eq 1 ];then + echo -n "$mc_link_info" >> ".megadown/${md5}" + fi + fi + + link_count=$((link_count + 1)) + fi + fi + + done < "$list" + + echo -ne " OK(${link_count} MC links found)\n" + + while IFS='' read -r line || [ -n "$line" ]; do + + if [ -n "$line" ];then + + if [ $(echo -n "$line" | grep -E -o 'mega://enc') ]; then + + link=$(regex_imatch "^.*?(mega:\/\/enc\d*?\?[a-z0-9_-]+).*$" "$line" 1) + + output=$(regex_imatch "^.*?mega:\/\/enc\d*?\?[a-z0-9_-]+(.*)$" "$line" 1 1) + + + elif [ $(echo -n "$line" | grep -E -o 'https?://') ]; then + + link=$(regex_imatch ".*?(https?\:\/\/[^\/]+\/[#!0-9a-z_-]+).*$" "$line" 1) + + output=$(regex_imatch "^.*?https?\:\/\/[^\/]+\/[#!0-9a-z_-]+(.*)$" "$line" 1 1) + + else + continue + fi + + $SCRIPT "$link" --output="$output" --password="$password" --speed="$speed" + + fi + + done < "$list" + + exit 0 +fi + +if [ $(echo -n "$link" | grep -E -o 'mega://enc') ]; then + link=$(decrypt_md_link "$link") +fi + +if [ ! $quiet ]; then + echo -e "\nReading link metadata..." +fi + +if [ $(echo -n "$link" | grep -E -o 'mega(\.co)?\.nz') ]; then + + #MEGA.CO.NZ LINK + + file_id=$(regex_match "^.*\/#.*?!(.+)!.*$" "$link" 1) + + file_key=$(regex_match "^.*\/#.*?!.+!(.+)$" "$link" 1) + + hex_raw_key=$(echo -n $(urlb64_to_b64 "$file_key") | base64 -d -i 2>/dev/null | od -v -An -t x1 | tr -d '\n ') + + if [ $(echo -n "$link" | grep -E -o 'mega(\.co)?\.nz/#!') ]; then + + mega_req_json="[{\"a\":\"g\", \"p\":\"${file_id}\"}]" + + mega_req_url="${MEGA_API_URL}/cs?id=&ak=" + + elif [ $(echo -n "$link" | grep -E -o -i 'mega(\.co)?\.nz/#N!') ]; then + + mega_req_json="[{\"a\":\"g\", \"n\":\"${file_id}\"}]" + + folder_id=$(regex_match "###n\=(.+)$" "$link" 1) + + mega_req_url="${MEGA_API_URL}/cs?id=&ak=&n=${folder_id}" + fi + + mega_res_json=$($DL_COM --header 'Content-Type: application/json' $DL_COM_PDATA "$mega_req_json" "$mega_req_url") + + download_exit_code=$? + + if [ "$download_exit_code" -ne 0 ]; then + showError "Oooops, something went bad. EXIT CODE (${download_exit_code})" + fi + + if [ $(echo -n "$mega_res_json" | grep -E -o '\[ *\-[0-9]+ *\]') ]; then + showError "MEGA ERROR $(echo -n "$mega_res_json" | grep -E -o '\-[0-9]+')" + fi + + file_size=$(echo "$mega_res_json" | jq -r .[0].s) + + at=$(echo "$mega_res_json" | jq -r .[0].at) + + hex_key=$(hrk2hk "$hex_raw_key") + + at_dec_json=$(echo -n $(urlb64_to_b64 "$at") | $OPENSSL_AES_CBC_128_DEC -K "$hex_key" -iv "00000000000000000000000000000000" -nopad | tr -d '\0') + + if [ ! $(echo -n "$at_dec_json" | grep -E -o 'MEGA') ]; then + showError "MEGA bad link" + fi + + if [ -z "$output" ]; then + file_name=$(echo -n "$at_dec_json" | grep -E -o '\{.+\}' | jq -r .n) + else + file_name="$output" + fi + + if [ $metadata ]; then + echo "{\"file_name\" : \"${file_name}\", \"file_size\" : ${file_size}}" + exit 0 + fi + + check_file_exists "$file_name" "$file_size" "$(format_file_size "$file_size")" + + if [ $(echo -n "$link" | grep -E -o 'mega(\.co)?\.nz/#!') ]; then + mega_req_json="[{\"a\":\"g\", \"g\":\"1\", \"p\":\"$file_id\"}]" + elif [ $(echo -n "$link" | grep -E -o -i 'mega(\.co)?\.nz/#N!') ]; then + mega_req_json="[{\"a\":\"g\", \"g\":\"1\", \"n\":\"$file_id\"}]" + fi + + mega_res_json=$($DL_COM --header 'Content-Type: application/json' $DL_COM_PDATA "$mega_req_json" "$mega_req_url") + + download_exit_code=$? + + if [ "$download_exit_code" -ne 0 ]; then + showError "Oooops, something went bad. EXIT CODE (${download_exit_code})" + fi + + dl_temp_url=$(echo "$mega_res_json" | jq -r .[0].g) +else + + #MEGACRYPTER LINK + + MC_API_URL=$(echo -n "$link" | grep -i -E -o 'https?://[^/]+')"/api" + + md5=$(echo -n "$link" | $OPENSSL_MD5 | grep -E -o '[0-9a-f]{32}') + + if [ -f ".megadown/${md5}" ];then + mc_link_info=$(cat ".megadown/${md5}") + else + mc_link_info=$(get_mc_link_info "$link") + + if [ "$?" -eq 1 ];then + echo -e "$mc_link_info" + exit 1 + fi + + echo -n "$mc_link_info" >> ".megadown/${md5}" + fi + + IFS='@' read -a array <<< "$mc_link_info" + + if [ -z "$output" ];then + file_name=$(echo -n "${array[0]}" | base64 -d -i 2>/dev/null) + else + file_name="$output" + fi + + path=${array[1]} + + if [ "$path" != "false" ]; then + path=$(echo -n "$path" | base64 -d -i 2>/dev/null) + fi + + file_size=${array[2]} + + mc_pass=${array[3]} + + key=${array[4]} + + no_exp_token=${array[5]} + + if [ "$mc_pass" != "false" ]; then + + if [ -z "$(command -v python 2>&1)" ]; then + + echo "ERROR: python is required for MegaCrypter password protected links and it's not installed." + exit 1 + + fi + + echo -ne "\nLink is password protected. " + + if [ -n "$password" ]; then + + pass_hash=$(mc_pass_check "$mc_pass" "$password") + + fi + + if [ -z "$pass_hash" ] || [ "$pass_hash" == "0" ]; then + + echo -ne "\n\n" + + read -e -p "Enter password: " pass + + pass_hash=$(mc_pass_check "$mc_pass" "$pass") + + until [ "$pass_hash" != "false" ]; do + read -e -p "Wrong password! Try again: " pass + pass_hash=$(mc_pass_check "$mc_pass" "$pass") + done + fi + + echo -ne "\nPassword is OK. Decrypting metadata...\n" + + IFS='#' read -a array <<< "$pass_hash" + + pass_hash=${array[0]} + + iv=${array[1]} + + hex_raw_key=$(echo -n "$key" | $OPENSSL_AES_CBC_256_DEC -K "$pass_hash" -iv "$iv" | od -v -An -t x1 | tr -d '\n ') + + if [ -z "$output" ]; then + file_name=$(echo -n "$file_name" | $OPENSSL_AES_CBC_256_DEC -K "$pass_hash" -iv "$iv") + fi + else + hex_raw_key=$(echo -n $(urlb64_to_b64 "$key") | base64 -d -i 2>/dev/null | od -v -An -t x1 | tr -d '\n ') + fi + + if [ $metadata ]; then + echo "{\"file_name\" : \"${file_name}\", \"file_size\" : ${file_size}}" + exit 0 + fi + + if [ "$path" != "false" ] && [ "$path" != "" ]; then + + if [ ! -d "$path" ]; then + + mkdir -p "$path" + fi + + file_name="${path}${file_name}" + fi + + check_file_exists "$file_name" "$file_size" "$(format_file_size "$file_size")" "$md5" + + hex_key=$(hrk2hk "$hex_raw_key") + + dl_link=$($DL_COM --header 'Content-Type: application/json' $DL_COM_PDATA "{\"m\":\"dl\", \"link\":\"$link\", \"noexpire\":\"$no_exp_token\"}" "$MC_API_URL") + + download_exit_code=$? + + if [ "$download_exit_code" -ne 0 ]; then + showError "Oooops, something went bad. EXIT CODE (${download_exit_code})" + fi + + if [ $(echo $dl_link | grep '"error"') ]; then + + error_code=$(echo "$dl_link" | jq -r .error) + + showError "MEGACRYPTER ERROR $error_code" + fi + + dl_temp_url=$(echo "$dl_link" | jq -r .url) + + if [ "$mc_pass" != "false" ]; then + + iv=$(echo "$dl_link" | jq -r .pass | base64 -d -i 2>/dev/null | od -v -An -t x1 | tr -d '\n ') + + dl_temp_url=$(echo -n "$dl_temp_url" | $OPENSSL_AES_CBC_256_DEC -K "$pass_hash" -iv "$iv") + fi +fi + +if [ -z "$speed" ]; then + DL_COMMAND="$DL_COM" +else + DL_COMMAND="$DL_COM --limit-rate $speed" +fi + +if [ "$output" == "-" ]; then + + hex_iv="${hex_raw_key:32:16}0000000000000000" + + $DL_COMMAND "$dl_temp_url" | $OPENSSL_AES_CTR_128_DEC -K "$hex_key" -iv "$hex_iv" + + exit 0 +fi + +if [ ! $quiet ]; then + echo -e "$DL_MSG" +fi + +if [ ! $quiet ]; then + PV_CMD="pv" +else + PV_CMD="pv -q" +fi + +download_exit_code=1 + +until [ "$download_exit_code" -eq 0 ]; do + + if [ -f "${file_name}.temp" ]; then + + echo -e "(Resuming previous download ...)\n" + + temp_size=$(stat -c %s "${file_name}.temp") + + offset=$(($temp_size-$(($temp_size%16)))) + + iv_forward=$(printf "%016x" $(($offset/16))) + + hex_iv="${hex_raw_key:32:16}$iv_forward" + + truncate -s $offset "${file_name}.temp" + + $DL_COMMAND "$dl_temp_url/$offset" | $PV_CMD -s $(($file_size-$offset)) | $OPENSSL_AES_CTR_128_DEC -K "$hex_key" -iv "$hex_iv" >> "${file_name}.temp" + else + hex_iv="${hex_raw_key:32:16}0000000000000000" + + $DL_COMMAND "$dl_temp_url" | $PV_CMD -s $file_size | $OPENSSL_AES_CTR_128_DEC -K "$hex_key" -iv "$hex_iv" > "${file_name}.temp" + fi + + download_exit_code=${PIPESTATUS[0]} + + if [ "$download_exit_code" -ne 0 ]; then + showError "Oooops, download failed! EXIT CODE (${download_exit_code})" + fi +done + +if [ ! -f "${file_name}.temp" ]; then + showError "ERROR: FILE COULD NOT BE DOWNLOADED :(!" +fi + +mv "${file_name}.temp" "${file_name}" + +if [ -f ".megadown/${md5}" ];then + rm ".megadown/${md5}" +fi + +if [ ! $quiet ]; then + echo -e "\nFILE DOWNLOADED!\n" +fi + +exit 0 \ No newline at end of file diff --git a/datasets/ave/database2/download_dataset.sh b/datasets/ave/database2/download_dataset.sh new file mode 100644 index 0000000..71658e4 --- /dev/null +++ b/datasets/ave/database2/download_dataset.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +./megadown "https://mega.nz/#!58dARTDR!AII7nCEktkeMqbZ2XqXBFyDsAeMKqBf9reXR-MHPuKk" +wget -O coutrot_database2.mat "http://antoinecoutrot.magix.net/public/assets/coutrot_database2.mat" +7za x "ERB4_Stimuli.zip" +rm "ERB4_Stimuli.zip" \ No newline at end of file diff --git a/datasets/ave/database2/megadown b/datasets/ave/database2/megadown new file mode 100644 index 0000000..731a828 --- /dev/null +++ b/datasets/ave/database2/megadown @@ -0,0 +1,704 @@ +#!/bin/bash + +VERSION="1.9.46" + +# code from: https://github.com/tonikelope/megadown +# need to install: sudo apt-get install pv + +HERE=$(dirname "$0") +SCRIPT=$(readlink -f "$0") +BASH_MIN_VER="3" + +MEGA_API_URL="https://g.api.mega.co.nz" +OPENSSL_AES_CTR_128_DEC="openssl enc -d -aes-128-ctr" +OPENSSL_AES_CBC_128_DEC="openssl enc -a -A -d -aes-128-cbc" +OPENSSL_AES_CBC_256_DEC="openssl enc -a -A -d -aes-256-cbc" +OPENSSL_MD5="openssl md5" + +if [ ! -d ".megadown" ]; then + mkdir ".megadown" +fi + +# 1:message_error +function showError { + echo -e "\n$1\n" 1>&2 + exit 1 +} + +function showHelp { + echo -e "\nmegadown $VERSION - https://github.com/tonikelope/megadown" + echo -e "\ncli downloader for mega.nz and megacrypter" + echo -e "\nSingle url mode: megadown [OPTION]... 'URL'\n" + echo -e "\tOptions:" + echo -e "\t-o,\t--output FILE_NAME Store file with this name." + echo -e "\t-s,\t--speed SPEED Download speed limit (integer values: 500B, K, 2M)." + echo -e "\t-p,\t--password PASSWORD Password for MegaCrypter links." + echo -e "\t-q,\t--quiet Quiet mode." + echo -e "\t-m,\t--metadata Prints file metadata in JSON format and exits." + echo -e "\n\nMulti url mode: megadown [OPTION]... -l|--list FILE\n" + echo -e "\tOptions:" + echo -e "\t-s,\t--speed SPEED Download speed limit (integer values: 500B, 500K, 2M)." + echo -e "\t-p,\t--password PASSWORD Password for MegaCrypter links (same for every link in a list)." + echo -e "\t-q,\t--quiet Quiet mode." + echo -e "\t-m,\t--metadata Prints file metadata in JSON format and exits." + echo -e "\tFile line format: URL [optional_file_name]\n" +} + +function check_deps { + + local dep_error=0 + + if [ -n "$(command -v curl 2>&1)" ]; then + DL_COM="curl --fail -s" + DL_COM_PDATA="--data" + elif [ -n "$(command -v wget 2>&1)" ]; then + DL_COM="wget -q -O -" + DL_COM_PDATA="--post-data" + else + echo "wget OR curl is required and it's not installed" + dep_error=1 + fi + + for i in openssl pv jq; do + + if [ -z "$(command -v "$i" 2>&1)" ]; then + + echo "[$i] is required and it's not installed" + dep_error=1 + + else + + case "$i" in + + openssl) + + openssl_sup=$(openssl enc -ciphers 2>&1) + + for i in "aes-128-ctr" "aes-128-cbc" "aes-256-cbc"; do + + if [ -z "$(echo -n "$openssl_sup" | grep -o "$i" | head -n1)" ]; then + + echo "Your openssl binary does not support ${i}" + dep_error=1 + + fi + + done + ;; + esac + fi + + done + + if [ -z "$(command -v python 2>&1)" ]; then + echo "WARNING: python is required for MegaCrypter password protected links and it's not installed." + fi + + if [[ "$(echo -n "$BASH_VERSION" | grep -o -E "[0-9]+" | head -n1)" < "$BASH_MIN_VER" ]]; then + echo "bash >= ${BASH_MIN_VER} is required" + dep_error=1 + fi + + if [ $dep_error -ne 0 ]; then + showError "ERROR: there are dependencies not present!" + fi +} + + +# 2:url +function urldecode { + + : "${*//+/ }"; echo -e "${_//%/\\x}"; +} + +# 1:b64_encoded_string +function urlb64_to_b64 { + local b64=$(echo -n "$1" | tr '\-_' '+/' | tr -d ',') + local pad=$(((4-${#1}%4)%4)) + + for i in $(seq 1 $pad); do + b64="${b64}=" + done + + echo -n "$b64" +} + +# 1:mega://enc link +function decrypt_md_link { + + local data=$(regex_imatch "^.*?mega:\/\/enc[0-9]*?\?([a-z0-9_,-]+).*?$" "$1" 1) + + local iv="79F10A01844A0B27FF5B2D4E0ED3163E" + + if [ $(echo -n "$1" | grep 'mega://enc?') ]; then + + key="6B316F36416C2D316B7A3F217A30357958585858585858585858585858585858" + + elif [ $(echo -n "$1" | grep 'mega://enc2?') ];then + + key="ED1F4C200B35139806B260563B3D3876F011B4750F3A1A4A5EFD0BBE67554B44" + fi + + echo -n "https://mega.nz/#"$(echo -n "$(urlb64_to_b64 "$data")" | $OPENSSL_AES_CBC_256_DEC -K "$key" -iv "$iv") +} + +# 1:hex_raw_key +function hrk2hk { + declare -A hk + hk[0]=$(( 0x${1:0:16} ^ 0x${1:32:16} )) + hk[1]=$(( 0x${1:16:16} ^ 0x${1:48:16} )) + printf "%016x%016x" ${hk[0]} ${hk[1]} +} + +# 1:link +function get_mc_link_info { + + local MC_API_URL=$(echo -n "$1" | grep -i -E -o 'https?://[^/]+')"/api" + + local download_exit_code=1 + + local info_link=$($DL_COM --header 'Content-Type: application/json' $DL_COM_PDATA "{\"m\":\"info\", \"link\":\"$1\"}" "$MC_API_URL") + + download_exit_code=$? + + if [ "$download_exit_code" -ne 0 ]; then + echo -e "ERROR: Oooops, something went bad. EXIT CODE (${download_exit_code})" + return 1 + fi + + if [ $(echo $info_link | grep '"error"') ]; then + local error_code=$(echo "$info_link" | jq -r .error) + echo -e "MEGACRYPTER ERROR $error_code" + return 1 + fi + + local expire=$(echo "$info_link" | jq -r .expire) + + if [ "$expire" != "false" ]; then + + IFS='#' read -a array <<< "$expire" + + local no_exp_token=${array[1]} + else + local no_exp_token="$expire" + fi + + local file_name=$(echo "$info_link" | jq -r .name | base64 -w 0 -i 2>/dev/null) + + local path=$(echo "$info_link" | jq -r .path) + + if [ "$path" != "false" ]; then + path=$(echo -n "$path" | base64 -w 0 -i 2>/dev/null) + fi + + local mc_pass=$(echo "$info_link" | jq -r .pass) + + local file_size=$(echo "$info_link" | jq -r .size) + + local key=$(echo "$info_link" | jq -r .key) + + echo -n "${file_name}@${path}@${file_size}@${mc_pass}@${key}@${no_exp_token}" +} + +# 1:file_name 2:file_size 3:formatted_file_size [4:md5_mclink] +function check_file_exists { + + if [ -f "$1" ]; then + + local actual_size=$(stat -c %s "$1") + + if [ "$actual_size" == "$2" ]; then + + if [ -n "$4" ] && [ -f ".megadown/${4}" ]; then + rm ".megadown/${4}" + fi + + showError "WARNING: File $1 exists. Download aborted!" + fi + + DL_MSG="\nFile $1 exists but with different size (${2} vs ${actual_size} bytes). Downloading [${3}] ...\n" + + else + + DL_MSG="\nDownloading $1 [${3}] ...\n" + + fi +} + +# 1:file_size +function format_file_size { + + if [ "$1" -ge 1073741824 ]; then + local file_size_f=$(awk "BEGIN { rounded = sprintf(\"%.1f\", ${1}/1073741824); print rounded }")" GB" + elif [ "$1" -ge 1048576 ];then + local file_size_f=$(awk "BEGIN { rounded = sprintf(\"%.1f\", ${1}/1048576); print rounded }")" MB" + else + local file_size_f="${1} bytes" + fi + + echo -ne "$file_size_f" +} + +# 1:password 2:salt 3:iterations +function mc_pbkdf2 { + + echo -e "import sys,hashlib,base64\nprint(base64.b64encode(hashlib.pbkdf2_hmac('sha256', b'${1}', base64.b64decode(b'${2}'), ${3})).decode())" | python +} + +# 1:mc_pass_info 2:pass_to_check +function mc_pass_check { + + IFS='#' read -a array <<< "$1" + + local iter_log2=${array[0]} + + local key_check=${array[1]} + + local salt=${array[2]} + + local iv=${array[3]} + + local mc_pass_hash=$(mc_pbkdf2 "$password" "$salt" $((2**$iter_log2))) + + mc_pass_hash=$(echo -n "$mc_pass_hash" | base64 -d -i 2>/dev/null | od -v -An -t x1 | tr -d '\n ') + + iv=$(echo -n "$iv" | base64 -d -i 2>/dev/null | od -v -An -t x1 | tr -d '\n ') + + if [ "$(echo -n "$key_check" | $OPENSSL_AES_CBC_256_DEC -K "$mc_pass_hash" -iv "$iv" 2>/dev/null | od -v -An -t x1 | tr -d '\n ')" != "$mc_pass_hash" ]; then + echo -n "0" + else + echo -n "${mc_pass_hash}#${iv}" + fi +} + +#1:string +function trim { + + if [[ "$1" =~ \ *([^ ]|[^ ].*[^ ])\ * ]]; then + echo -n "${BASH_REMATCH[1]}" + fi +} + +#1:pattern 2:subject 3:group +function regex_match { + + if [[ "$2" =~ $1 ]]; then + echo -n "${BASH_REMATCH[$3]}" + fi +} + +#1:pattern 2:subject 3:group +function regex_imatch { + + shopt -s nocasematch + + if [[ "$2" =~ $1 ]]; then + echo -n "${BASH_REMATCH[$3]}" + fi + + shopt -u nocasematch +} + +#MAIN STARTS HERE: +check_deps + +if [ -z "$1" ]; then + showHelp + exit 1 +fi + +eval set -- "$(getopt -o "l:p:k:o:s:qm" -l "list:,password:,key:,output:,speed:,quiet,metadata" -n ${0} -- "$@")" + +while true; do + case "$1" in + -l|--list) list="$2"; shift 2;; + -p|--password) password="$2"; shift 2;; + -o|--output) output="$2"; shift 2;; + -s|--speed) speed="$2"; shift 2;; + -q|--quiet) quiet=true; shift 1;; + -m|--metadata) metadata=true; shift 1;; + + --) shift; break;; + + *) + showHelp + exit 1;; + esac +done + +p1=$(trim $(urldecode "$1")) + +if [[ "$p1" =~ ^http ]] || [[ "$p1" =~ ^mega:// ]]; then + link="$p1" +fi + +if [ -z "$link" ]; then + + if [ -z "$list" ]; then + + showHelp + + showError "ERROR: MEGA/MC link or --list parameter is required" + + elif [ ! -f "$list" ]; then + + showHelp + + showError "ERROR: list file ${list} not found" + fi + + if [ ! $quiet ]; then + echo -ne "\n(Pre)reading mc links info..." + fi + + link_count=0 + + while IFS='' read -r line || [ -n "$line" ]; do + + if [ -n "$line" ] && ! [ $(echo -n "$line" | grep -E -o 'mega://enc') ];then + + link=$(regex_imatch "^.*?(https?\:\/\/[^\/]+\/[#!0-9a-z_-]+).*$" "$line" 1) + + if [ $(echo -n "$link" | grep -E -o 'https?://[^/]+/!') ]; then + + md5=$(echo -n "$link" | $OPENSSL_MD5 | grep -E -o '[0-9a-f]{32}') + + if [ ! -f ".megadown/${md5}" ];then + + mc_link_info=$(get_mc_link_info "$link") + + if ! [ "$?" -eq 1 ];then + echo -n "$mc_link_info" >> ".megadown/${md5}" + fi + fi + + link_count=$((link_count + 1)) + fi + fi + + done < "$list" + + echo -ne " OK(${link_count} MC links found)\n" + + while IFS='' read -r line || [ -n "$line" ]; do + + if [ -n "$line" ];then + + if [ $(echo -n "$line" | grep -E -o 'mega://enc') ]; then + + link=$(regex_imatch "^.*?(mega:\/\/enc\d*?\?[a-z0-9_-]+).*$" "$line" 1) + + output=$(regex_imatch "^.*?mega:\/\/enc\d*?\?[a-z0-9_-]+(.*)$" "$line" 1 1) + + + elif [ $(echo -n "$line" | grep -E -o 'https?://') ]; then + + link=$(regex_imatch ".*?(https?\:\/\/[^\/]+\/[#!0-9a-z_-]+).*$" "$line" 1) + + output=$(regex_imatch "^.*?https?\:\/\/[^\/]+\/[#!0-9a-z_-]+(.*)$" "$line" 1 1) + + else + continue + fi + + $SCRIPT "$link" --output="$output" --password="$password" --speed="$speed" + + fi + + done < "$list" + + exit 0 +fi + +if [ $(echo -n "$link" | grep -E -o 'mega://enc') ]; then + link=$(decrypt_md_link "$link") +fi + +if [ ! $quiet ]; then + echo -e "\nReading link metadata..." +fi + +if [ $(echo -n "$link" | grep -E -o 'mega(\.co)?\.nz') ]; then + + #MEGA.CO.NZ LINK + + file_id=$(regex_match "^.*\/#.*?!(.+)!.*$" "$link" 1) + + file_key=$(regex_match "^.*\/#.*?!.+!(.+)$" "$link" 1) + + hex_raw_key=$(echo -n $(urlb64_to_b64 "$file_key") | base64 -d -i 2>/dev/null | od -v -An -t x1 | tr -d '\n ') + + if [ $(echo -n "$link" | grep -E -o 'mega(\.co)?\.nz/#!') ]; then + + mega_req_json="[{\"a\":\"g\", \"p\":\"${file_id}\"}]" + + mega_req_url="${MEGA_API_URL}/cs?id=&ak=" + + elif [ $(echo -n "$link" | grep -E -o -i 'mega(\.co)?\.nz/#N!') ]; then + + mega_req_json="[{\"a\":\"g\", \"n\":\"${file_id}\"}]" + + folder_id=$(regex_match "###n\=(.+)$" "$link" 1) + + mega_req_url="${MEGA_API_URL}/cs?id=&ak=&n=${folder_id}" + fi + + mega_res_json=$($DL_COM --header 'Content-Type: application/json' $DL_COM_PDATA "$mega_req_json" "$mega_req_url") + + download_exit_code=$? + + if [ "$download_exit_code" -ne 0 ]; then + showError "Oooops, something went bad. EXIT CODE (${download_exit_code})" + fi + + if [ $(echo -n "$mega_res_json" | grep -E -o '\[ *\-[0-9]+ *\]') ]; then + showError "MEGA ERROR $(echo -n "$mega_res_json" | grep -E -o '\-[0-9]+')" + fi + + file_size=$(echo "$mega_res_json" | jq -r .[0].s) + + at=$(echo "$mega_res_json" | jq -r .[0].at) + + hex_key=$(hrk2hk "$hex_raw_key") + + at_dec_json=$(echo -n $(urlb64_to_b64 "$at") | $OPENSSL_AES_CBC_128_DEC -K "$hex_key" -iv "00000000000000000000000000000000" -nopad | tr -d '\0') + + if [ ! $(echo -n "$at_dec_json" | grep -E -o 'MEGA') ]; then + showError "MEGA bad link" + fi + + if [ -z "$output" ]; then + file_name=$(echo -n "$at_dec_json" | grep -E -o '\{.+\}' | jq -r .n) + else + file_name="$output" + fi + + if [ $metadata ]; then + echo "{\"file_name\" : \"${file_name}\", \"file_size\" : ${file_size}}" + exit 0 + fi + + check_file_exists "$file_name" "$file_size" "$(format_file_size "$file_size")" + + if [ $(echo -n "$link" | grep -E -o 'mega(\.co)?\.nz/#!') ]; then + mega_req_json="[{\"a\":\"g\", \"g\":\"1\", \"p\":\"$file_id\"}]" + elif [ $(echo -n "$link" | grep -E -o -i 'mega(\.co)?\.nz/#N!') ]; then + mega_req_json="[{\"a\":\"g\", \"g\":\"1\", \"n\":\"$file_id\"}]" + fi + + mega_res_json=$($DL_COM --header 'Content-Type: application/json' $DL_COM_PDATA "$mega_req_json" "$mega_req_url") + + download_exit_code=$? + + if [ "$download_exit_code" -ne 0 ]; then + showError "Oooops, something went bad. EXIT CODE (${download_exit_code})" + fi + + dl_temp_url=$(echo "$mega_res_json" | jq -r .[0].g) +else + + #MEGACRYPTER LINK + + MC_API_URL=$(echo -n "$link" | grep -i -E -o 'https?://[^/]+')"/api" + + md5=$(echo -n "$link" | $OPENSSL_MD5 | grep -E -o '[0-9a-f]{32}') + + if [ -f ".megadown/${md5}" ];then + mc_link_info=$(cat ".megadown/${md5}") + else + mc_link_info=$(get_mc_link_info "$link") + + if [ "$?" -eq 1 ];then + echo -e "$mc_link_info" + exit 1 + fi + + echo -n "$mc_link_info" >> ".megadown/${md5}" + fi + + IFS='@' read -a array <<< "$mc_link_info" + + if [ -z "$output" ];then + file_name=$(echo -n "${array[0]}" | base64 -d -i 2>/dev/null) + else + file_name="$output" + fi + + path=${array[1]} + + if [ "$path" != "false" ]; then + path=$(echo -n "$path" | base64 -d -i 2>/dev/null) + fi + + file_size=${array[2]} + + mc_pass=${array[3]} + + key=${array[4]} + + no_exp_token=${array[5]} + + if [ "$mc_pass" != "false" ]; then + + if [ -z "$(command -v python 2>&1)" ]; then + + echo "ERROR: python is required for MegaCrypter password protected links and it's not installed." + exit 1 + + fi + + echo -ne "\nLink is password protected. " + + if [ -n "$password" ]; then + + pass_hash=$(mc_pass_check "$mc_pass" "$password") + + fi + + if [ -z "$pass_hash" ] || [ "$pass_hash" == "0" ]; then + + echo -ne "\n\n" + + read -e -p "Enter password: " pass + + pass_hash=$(mc_pass_check "$mc_pass" "$pass") + + until [ "$pass_hash" != "false" ]; do + read -e -p "Wrong password! Try again: " pass + pass_hash=$(mc_pass_check "$mc_pass" "$pass") + done + fi + + echo -ne "\nPassword is OK. Decrypting metadata...\n" + + IFS='#' read -a array <<< "$pass_hash" + + pass_hash=${array[0]} + + iv=${array[1]} + + hex_raw_key=$(echo -n "$key" | $OPENSSL_AES_CBC_256_DEC -K "$pass_hash" -iv "$iv" | od -v -An -t x1 | tr -d '\n ') + + if [ -z "$output" ]; then + file_name=$(echo -n "$file_name" | $OPENSSL_AES_CBC_256_DEC -K "$pass_hash" -iv "$iv") + fi + else + hex_raw_key=$(echo -n $(urlb64_to_b64 "$key") | base64 -d -i 2>/dev/null | od -v -An -t x1 | tr -d '\n ') + fi + + if [ $metadata ]; then + echo "{\"file_name\" : \"${file_name}\", \"file_size\" : ${file_size}}" + exit 0 + fi + + if [ "$path" != "false" ] && [ "$path" != "" ]; then + + if [ ! -d "$path" ]; then + + mkdir -p "$path" + fi + + file_name="${path}${file_name}" + fi + + check_file_exists "$file_name" "$file_size" "$(format_file_size "$file_size")" "$md5" + + hex_key=$(hrk2hk "$hex_raw_key") + + dl_link=$($DL_COM --header 'Content-Type: application/json' $DL_COM_PDATA "{\"m\":\"dl\", \"link\":\"$link\", \"noexpire\":\"$no_exp_token\"}" "$MC_API_URL") + + download_exit_code=$? + + if [ "$download_exit_code" -ne 0 ]; then + showError "Oooops, something went bad. EXIT CODE (${download_exit_code})" + fi + + if [ $(echo $dl_link | grep '"error"') ]; then + + error_code=$(echo "$dl_link" | jq -r .error) + + showError "MEGACRYPTER ERROR $error_code" + fi + + dl_temp_url=$(echo "$dl_link" | jq -r .url) + + if [ "$mc_pass" != "false" ]; then + + iv=$(echo "$dl_link" | jq -r .pass | base64 -d -i 2>/dev/null | od -v -An -t x1 | tr -d '\n ') + + dl_temp_url=$(echo -n "$dl_temp_url" | $OPENSSL_AES_CBC_256_DEC -K "$pass_hash" -iv "$iv") + fi +fi + +if [ -z "$speed" ]; then + DL_COMMAND="$DL_COM" +else + DL_COMMAND="$DL_COM --limit-rate $speed" +fi + +if [ "$output" == "-" ]; then + + hex_iv="${hex_raw_key:32:16}0000000000000000" + + $DL_COMMAND "$dl_temp_url" | $OPENSSL_AES_CTR_128_DEC -K "$hex_key" -iv "$hex_iv" + + exit 0 +fi + +if [ ! $quiet ]; then + echo -e "$DL_MSG" +fi + +if [ ! $quiet ]; then + PV_CMD="pv" +else + PV_CMD="pv -q" +fi + +download_exit_code=1 + +until [ "$download_exit_code" -eq 0 ]; do + + if [ -f "${file_name}.temp" ]; then + + echo -e "(Resuming previous download ...)\n" + + temp_size=$(stat -c %s "${file_name}.temp") + + offset=$(($temp_size-$(($temp_size%16)))) + + iv_forward=$(printf "%016x" $(($offset/16))) + + hex_iv="${hex_raw_key:32:16}$iv_forward" + + truncate -s $offset "${file_name}.temp" + + $DL_COMMAND "$dl_temp_url/$offset" | $PV_CMD -s $(($file_size-$offset)) | $OPENSSL_AES_CTR_128_DEC -K "$hex_key" -iv "$hex_iv" >> "${file_name}.temp" + else + hex_iv="${hex_raw_key:32:16}0000000000000000" + + $DL_COMMAND "$dl_temp_url" | $PV_CMD -s $file_size | $OPENSSL_AES_CTR_128_DEC -K "$hex_key" -iv "$hex_iv" > "${file_name}.temp" + fi + + download_exit_code=${PIPESTATUS[0]} + + if [ "$download_exit_code" -ne 0 ]; then + showError "Oooops, download failed! EXIT CODE (${download_exit_code})" + fi +done + +if [ ! -f "${file_name}.temp" ]; then + showError "ERROR: FILE COULD NOT BE DOWNLOADED :(!" +fi + +mv "${file_name}.temp" "${file_name}" + +if [ -f ".megadown/${md5}" ];then + rm ".megadown/${md5}" +fi + +if [ ! $quiet ]; then + echo -e "\nFILE DOWNLOADED!\n" +fi + +exit 0 \ No newline at end of file diff --git a/datasets/ave/diem/download_dataset.sh b/datasets/ave/diem/download_dataset.sh new file mode 100644 index 0000000..d0b6fbe --- /dev/null +++ b/datasets/ave/diem/download_dataset.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# need to manually download the files from http://www.mediafire.com/?mpu3ot0m2o384 +wget -O diem.tar.gz https://www2.informatik.uni-hamburg.de/WTM/corpora/GASP/datasets/ave/diem.tar.gz +tar -xvf diem.tar.gz +mv informatik3/wtm/datasets/External\ Datasets/DIEM/* ./ +rm -rf informatik3 +rm diem.tar.gz +7za x "*.7z" -o* +rm *.7z diff --git a/datasets/aveyetracking/download_dataset.sh b/datasets/aveyetracking/download_dataset.sh new file mode 100644 index 0000000..b04a507 --- /dev/null +++ b/datasets/aveyetracking/download_dataset.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +wget -O SumMe_ETMD.zip "http://cvsp.cs.ntua.gr/research/aveyetracking/SumMe_ETMD.zip" +unzip SumMe_ETMD.zip +rm SumMe_ETMD.zip \ No newline at end of file diff --git a/datasets/findwho/download_dataset.sh b/datasets/findwho/download_dataset.sh new file mode 100644 index 0000000..655fab5 --- /dev/null +++ b/datasets/findwho/download_dataset.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +wget --no-check-certificate --content-disposition -O - http://github.com/yufanLIU/find/archive/master.tar.gz | tar xz --strip=2 "find-master/Our_database" \ No newline at end of file diff --git a/datasets/processed/Grouped_frames/coutrot1/download_dataset.sh b/datasets/processed/Grouped_frames/coutrot1/download_dataset.sh new file mode 100644 index 0000000..cde62b5 --- /dev/null +++ b/datasets/processed/Grouped_frames/coutrot1/download_dataset.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +wget -O coutrot1.tar.gz https://www2.informatik.uni-hamburg.de/WTM/corpora/GASP/datasets/processed/Grouped_frames/coutrot1.tar.gz + +mv coutrot1/* ./ +rm -rf coutrot1/ +tar -xf coutrot1.tar.gz +rm coutrot1.tar.gz diff --git a/datasets/processed/Grouped_frames/coutrot2/download_dataset.sh b/datasets/processed/Grouped_frames/coutrot2/download_dataset.sh new file mode 100644 index 0000000..bfd21fc --- /dev/null +++ b/datasets/processed/Grouped_frames/coutrot2/download_dataset.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +wget -O coutrot2.tar.gz https://www2.informatik.uni-hamburg.de/WTM/corpora/GASP/datasets/processed/Grouped_frames/coutrot2.tar.gz + +mv coutrot2/* ./ +rm -rf coutrot2/ +tar -xf coutrot2.tar.gz +rm coutrot2.tar.gz diff --git a/datasets/processed/Grouped_frames/diem/download_dataset.sh b/datasets/processed/Grouped_frames/diem/download_dataset.sh new file mode 100644 index 0000000..14517c7 --- /dev/null +++ b/datasets/processed/Grouped_frames/diem/download_dataset.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +wget -O diem.tar.gz https://www2.informatik.uni-hamburg.de/WTM/corpora/GASP/datasets/processed/Grouped_frames/diem.tar.gz + +mv diem/* ./ +rm -rf diem/ +tar -xf diem.tar.gz +rm diem.tar.gz diff --git a/datasets/processed/VideosPerCategory.xlsx b/datasets/processed/VideosPerCategory.xlsx new file mode 100644 index 0000000..1f7c120 Binary files /dev/null and b/datasets/processed/VideosPerCategory.xlsx differ diff --git a/datasets/processed/center_bias.jpg b/datasets/processed/center_bias.jpg new file mode 100644 index 0000000..5872b16 Binary files /dev/null and b/datasets/processed/center_bias.jpg differ diff --git a/datasets/processed/center_bias_bw.jpg b/datasets/processed/center_bias_bw.jpg new file mode 100644 index 0000000..87fa27a Binary files /dev/null and b/datasets/processed/center_bias_bw.jpg differ diff --git a/datasets/processed/test.csv b/datasets/processed/test.csv new file mode 100644 index 0000000..ea45c44 --- /dev/null +++ b/datasets/processed/test.csv @@ -0,0 +1,28 @@ +video_id,fps,scene_type,dataset +advert_bbc4_library_1024x576,30,Other,diem +arctic_bears_1066x710,30,Nature,diem +BBC_wildlife_serpent_1280x704,30,Nature,diem +clip_12,25,Nature,coutrot1 +clip_2,25,Other,coutrot1 +clip_23,25,Other,coutrot1 +clip_30,25,Other,coutrot1 +clip_35,25,Nature,coutrot1 +clip_45,25,Nature,coutrot1 +clip_58,25,Social,coutrot1 +clip_60,25,Social,coutrot1 +clip_9,25,Other,coutrot1 +documentary_coral_reef_adventure_1280x720,30,Nature,diem +documentary_planet_earth_1280x704,30,Nature,diem +game_trailer_wrath_lich_king_shortened_subtitles_1280x548,30,Other,diem +movie_trailer_alice_in_wonderland_1280x682,30,Other,diem +music_red_hot_chili_peppers_shortened_1024x576,30,Other,diem +news_bee_parasites_768x576,30,Nature,diem +news_us_election_debate_1080x600,30,Social,diem +nightlife_in_mozambique_1280x580,30,Nature,diem +pingpong_long_shot_960x720,30,Social,diem +planet_earth_jungles_monkeys_1280x704,30,Nature,diem +sport_cricket_ashes_2007_1252x720,30,Social,diem +sport_poker_1280x640,30,Social,diem +sport_surfing_in_thurso_900x720,30,Social,diem +sport_wimbledon_murray_1280x704,30,Social,diem +tv_ketch2_672x544,30,Social,diem diff --git a/datasets/processed/test_ave.csv b/datasets/processed/test_ave.csv new file mode 100644 index 0000000..dcdbef9 --- /dev/null +++ b/datasets/processed/test_ave.csv @@ -0,0 +1,30 @@ +video_id,fps,scene_type,dataset +advert_bbc4_library_1024x576,30,Other,diem +arctic_bears_1066x710,30,Nature,diem +BBC_wildlife_serpent_1280x704,30,Nature,diem +clip_12,25,Nature,coutrot1 +clip_2,25,Other,coutrot1 +clip_23,25,Other,coutrot1 +clip_30,25,Other,coutrot1 +clip_35,25,Nature,coutrot1 +clip_45,25,Nature,coutrot1 +clip_58,25,Social,coutrot1 +clip_60,25,Social,coutrot1 +clip_9,25,Other,coutrot1 +clip_2,25,Social,coutrot2 +documentary_coral_reef_adventure_1280x720,30,Nature,diem +documentary_planet_earth_1280x704,30,Nature,diem +game_trailer_wrath_lich_king_shortened_subtitles_1280x548,30,Other,diem +home_movie_Charlie_bit_my_finger_again_960x720,30,Social,diem +movie_trailer_alice_in_wonderland_1280x682,30,Other,diem +music_red_hot_chili_peppers_shortened_1024x576,30,Other,diem +news_bee_parasites_768x576,30,Nature,diem +news_us_election_debate_1080x600,30,Social,diem +nightlife_in_mozambique_1280x580,30,Nature,diem +pingpong_long_shot_960x720,30,Social,diem +planet_earth_jungles_monkeys_1280x704,30,Nature,diem +sport_cricket_ashes_2007_1252x720,30,Social,diem +sport_poker_1280x640,30,Social,diem +sport_surfing_in_thurso_900x720,30,Social,diem +sport_wimbledon_murray_1280x704,30,Social,diem +tv_ketch2_672x544,30,Social,diem diff --git a/datasets/processed/test_stavis_1.csv b/datasets/processed/test_stavis_1.csv new file mode 100644 index 0000000..2a9bc99 --- /dev/null +++ b/datasets/processed/test_stavis_1.csv @@ -0,0 +1,43 @@ +video_id,fps,scene_type,dataset +BBC_life_in_cold_blood_1278x710,30,Nature,diem +BBC_wildlife_serpent_1280x704,30,Nature,diem +DIY_SOS_1280x712,30,Other,diem +advert_bbc4_bees_1024x576,30,Nature,diem +advert_bbc4_library_1024x576,30,Other,diem +advert_iphone_1272x720,30,Other,diem +harry_potter_6_trailer_1280x544,30,Social,diem +music_gummybear_880x720,30,Other,diem +music_trailer_nine_inch_nails_1280x720,30,Social,diem +nightlife_in_mozambique_1280x580,30,Nature,diem +one_show_1280x712,30,Social,diem +pingpong_angle_shot_960x720,30,Social,diem +pingpong_no_bodies_960x720,30,Other,diem +sport_scramblers_1280x720,30,Other,diem +sport_wimbledon_federer_final_1280x704,30,Social,diem +tv_uni_challenge_final_1280x712,30,Social,diem +university_forum_construction_ionic_1280x720,30,Other,diem +clip_4,25,Social,coutrot2 +clip_15,25,Social,coutrot2 +clip_11,25,Social,coutrot2 +clip_14,25,Social,coutrot2 +clip_2,25,Social,coutrot2 +clip_4,25,Other,coutrot1 +clip_5,25,Other,coutrot1 +clip_10,25,Other,coutrot1 +clip_14,25,Other,coutrot1 +clip_15,25,Other,coutrot1 +clip_16,25,Other,coutrot1 +clip_17,25,Other,coutrot1 +clip_22,25,Other,coutrot1 +clip_24,25,Other,coutrot1 +clip_26,25,Other,coutrot1 +clip_32,25,Nature,coutrot1 +clip_37,25,Nature,coutrot1 +clip_38,25,Nature,coutrot1 +clip_39,25,Nature,coutrot1 +clip_44,25,Nature,coutrot1 +clip_47,25,Social,coutrot1 +clip_48,25,Social,coutrot1 +clip_50,25,Social,coutrot1 +clip_56,25,Social,coutrot1 +clip_58,25,Social,coutrot1 diff --git a/datasets/processed/test_stavis_2.csv b/datasets/processed/test_stavis_2.csv new file mode 100644 index 0000000..50db22b --- /dev/null +++ b/datasets/processed/test_stavis_2.csv @@ -0,0 +1,43 @@ +video_id,fps,scene_type,dataset +BBC_life_in_cold_blood_1278x710,30,Nature,diem +BBC_wildlife_serpent_1280x704,30,Nature,diem +DIY_SOS_1280x712,30,Other,diem +advert_bbc4_bees_1024x576,30,Nature,diem +advert_bbc4_library_1024x576,30,Other,diem +advert_iphone_1272x720,30,Other,diem +harry_potter_6_trailer_1280x544,30,Social,diem +music_gummybear_880x720,30,Other,diem +music_trailer_nine_inch_nails_1280x720,30,Social,diem +nightlife_in_mozambique_1280x580,30,Nature,diem +one_show_1280x712,30,Social,diem +pingpong_angle_shot_960x720,30,Social,diem +pingpong_no_bodies_960x720,30,Other,diem +sport_scramblers_1280x720,30,Other,diem +sport_wimbledon_federer_final_1280x704,30,Social,diem +tv_uni_challenge_final_1280x712,30,Social,diem +university_forum_construction_ionic_1280x720,30,Other,diem +clip_12,25,Social,coutrot2 +clip_13,25,Social,coutrot2 +clip_10,25,Social,coutrot2 +clip_7,25,Social,coutrot2 +clip_5,25,Social,coutrot2 +clip_1,25,Other,coutrot1 +clip_3,25,Other,coutrot1 +clip_7,25,Other,coutrot1 +clip_8,25,Other,coutrot1 +clip_11,25,Other,coutrot1 +clip_25,25,Other,coutrot1 +clip_27,25,Other,coutrot1 +clip_28,25,Other,coutrot1 +clip_29,25,Nature,coutrot1 +clip_30,25,Other,coutrot1 +clip_31,25,Nature,coutrot1 +clip_33,25,Nature,coutrot1 +clip_36,25,Nature,coutrot1 +clip_42,25,Nature,coutrot1 +clip_43,25,Nature,coutrot1 +clip_49,25,Social,coutrot1 +clip_52,25,Social,coutrot1 +clip_54,25,Social,coutrot1 +clip_55,25,Social,coutrot1 +clip_57,25,Social,coutrot1 diff --git a/datasets/processed/test_stavis_3.csv b/datasets/processed/test_stavis_3.csv new file mode 100644 index 0000000..305ba0b --- /dev/null +++ b/datasets/processed/test_stavis_3.csv @@ -0,0 +1,43 @@ +video_id,fps,scene_type,dataset +BBC_life_in_cold_blood_1278x710,30,Nature,diem +BBC_wildlife_serpent_1280x704,30,Nature,diem +DIY_SOS_1280x712,30,Other,diem +advert_bbc4_bees_1024x576,30,Nature,diem +advert_bbc4_library_1024x576,30,Other,diem +advert_iphone_1272x720,30,Other,diem +harry_potter_6_trailer_1280x544,30,Social,diem +music_gummybear_880x720,30,Other,diem +music_trailer_nine_inch_nails_1280x720,30,Social,diem +nightlife_in_mozambique_1280x580,30,Nature,diem +one_show_1280x712,30,Social,diem +pingpong_angle_shot_960x720,30,Social,diem +pingpong_no_bodies_960x720,30,Other,diem +sport_scramblers_1280x720,30,Other,diem +sport_wimbledon_federer_final_1280x704,30,Social,diem +tv_uni_challenge_final_1280x712,30,Social,diem +university_forum_construction_ionic_1280x720,30,Other,diem +clip_3,25,Social,coutrot2 +clip_1,25,Social,coutrot2 +clip_9,25,Social,coutrot2 +clip_6,25,Social,coutrot2 +clip_8,25,Social,coutrot2 +clip_2,25,Other,coutrot1 +clip_6,25,Other,coutrot1 +clip_9,25,Other,coutrot1 +clip_12,25,Nature,coutrot1 +clip_13,25,Other,coutrot1 +clip_18,25,Nature,coutrot1 +clip_19,25,Other,coutrot1 +clip_20,25,Other,coutrot1 +clip_21,25,Other,coutrot1 +clip_23,25,Other,coutrot1 +clip_34,25,Nature,coutrot1 +clip_35,25,Nature,coutrot1 +clip_40,25,Nature,coutrot1 +clip_41,25,Nature,coutrot1 +clip_45,25,Nature,coutrot1 +clip_46,25,Social,coutrot1 +clip_51,25,Social,coutrot1 +clip_53,25,Social,coutrot1 +clip_59,25,Social,coutrot1 +clip_60,25,Social,coutrot1 diff --git a/datasets/processed/train.csv b/datasets/processed/train.csv new file mode 100644 index 0000000..60acefa --- /dev/null +++ b/datasets/processed/train.csv @@ -0,0 +1,89 @@ +video_id,fps,scene_type,dataset +50_people_brooklyn_1280x720,30,Social,diem +advert_bbc4_bees_1024x576,30,Nature,diem +advert_bravia_paint_1280x720,30,Other,diem +Antarctica_landscape_1246x720,30,Nature,diem +basketball_of_sorts_960x720,30,Social,diem +BBC_wildlife_eagle_930x720,30,Nature,diem +BBC_wildlife_special_tiger_1276x720,30,Nature,diem +clip_10,25,Other,coutrot1 +clip_11,25,Other,coutrot1 +clip_1,25,Other,coutrot1 +clip_13,25,Other,coutrot1 +clip_14,25,Other,coutrot1 +clip_15,25,Other,coutrot1 +clip_16,25,Other,coutrot1 +clip_17,25,Other,coutrot1 +clip_19,25,Other,coutrot1 +clip_20,25,Other,coutrot1 +clip_21,25,Other,coutrot1 +clip_22,25,Other,coutrot1 +clip_24,25,Other,coutrot1 +clip_25,25,Other,coutrot1 +clip_27,25,Other,coutrot1 +clip_28,25,Other,coutrot1 +clip_29,25,Nature,coutrot1 +clip_31,25,Nature,coutrot1 +clip_32,25,Nature,coutrot1 +clip_3,25,Other,coutrot1 +clip_34,25,Nature,coutrot1 +clip_36,25,Nature,coutrot1 +clip_37,25,Nature,coutrot1 +clip_38,25,Nature,coutrot1 +clip_39,25,Nature,coutrot1 +clip_41,25,Nature,coutrot1 +clip_4,25,Other,coutrot1 +clip_43,25,Nature,coutrot1 +clip_44,25,Nature,coutrot1 +clip_46,25,Social,coutrot1 +clip_47,25,Social,coutrot1 +clip_49,25,Social,coutrot1 +clip_50,25,Social,coutrot1 +clip_51,25,Social,coutrot1 +clip_5,25,Other,coutrot1 +clip_53,25,Social,coutrot1 +clip_54,25,Social,coutrot1 +clip_55,25,Social,coutrot1 +clip_56,25,Social,coutrot1 +clip_57,25,Social,coutrot1 +clip_59,25,Social,coutrot1 +clip_6,25,Other,coutrot1 +clip_10,25,Social,coutrot2 +clip_12,25,Social,coutrot2 +clip_13,25,Social,coutrot2 +clip_15,25,Social,coutrot2 +clip_3,25,Social,coutrot2 +clip_5,25,Social,coutrot2 +clip_7,25,Social,coutrot2 +clip_8,25,Social,coutrot2 +DIY_SOS_1280x712,30,Other,diem +documentary_adrenaline_rush_1280x720,30,Other,diem +documentary_discoverers_1280x720,30,Other,diem +documentary_mystery_nile_1280x720,30,Nature,diem +game_trailer_bullet_witch_1280x720,30,Other,diem +game_trailer_lego_indiana_jones_1280x720,30,Other,diem +growing_sweetcorn_1280x712,30,Social,diem +harry_potter_6_trailer_1280x544,30,Social,diem +hummingbirds_closeups_960x720,30,Nature,diem +hydraulics_1280x712,30,Other,diem +movie_trailer_ice_age_3_1280x690,30,Other,diem +music_gummybear_880x720,30,Other,diem +music_trailer_nine_inch_nails_1280x720,30,Social,diem +news_video_republic_960x720,30,Other,diem +nigella_chocolate_pears_1280x712,30,Social,diem +one_show_1280x712,30,Social,diem +pingpong_closeup_rallys_960x720,30,Social,diem +pingpong_miscues_1080x720,30,Social,diem +planet_earth_jungles_frogs_1280x704,30,Nature,diem +scottish_parliament_1152x864,30,Social,diem +sport_barcelona_extreme_1280x720,30,Other,diem +sport_F1_slick_tyres_1280x720,30,Other,diem +sport_golf_fade_a_driver_1280x720,30,Other,diem +sport_scramblers_1280x720,30,Other,diem +sport_slam_dunk_1280x704,30,Social,diem +sport_wimbledon_baltacha_1280x704,30,Social,diem +sport_wimbledon_magic_wand_1280x704,30,Social,diem +spotty_trifle_1280x712,30,Other,diem +tv_graduates_1280x720,30,Social,diem +tv_the_simpsons_860x528,30,Other,diem +university_forum_construction_ionic_1280x720,30,Other,diem diff --git a/datasets/processed/train_ave.csv b/datasets/processed/train_ave.csv new file mode 100644 index 0000000..5cd449a --- /dev/null +++ b/datasets/processed/train_ave.csv @@ -0,0 +1,94 @@ +video_id,fps,scene_type,dataset +50_people_brooklyn_1280x720,30,Social,diem +advert_bbc4_bees_1024x576,30,Nature,diem +advert_bravia_paint_1280x720,30,Other,diem +Antarctica_landscape_1246x720,30,Nature,diem +basketball_of_sorts_960x720,30,Social,diem +BBC_wildlife_eagle_930x720,30,Nature,diem +BBC_wildlife_special_tiger_1276x720,30,Nature,diem +clip_10,25,Other,coutrot1 +clip_11,25,Other,coutrot1 +clip_1,25,Other,coutrot1 +clip_13,25,Other,coutrot1 +clip_14,25,Other,coutrot1 +clip_15,25,Other,coutrot1 +clip_16,25,Other,coutrot1 +clip_17,25,Other,coutrot1 +clip_19,25,Other,coutrot1 +clip_20,25,Other,coutrot1 +clip_21,25,Other,coutrot1 +clip_22,25,Other,coutrot1 +clip_24,25,Other,coutrot1 +clip_25,25,Other,coutrot1 +clip_27,25,Other,coutrot1 +clip_28,25,Other,coutrot1 +clip_29,25,Nature,coutrot1 +clip_31,25,Nature,coutrot1 +clip_32,25,Nature,coutrot1 +clip_3,25,Other,coutrot1 +clip_34,25,Nature,coutrot1 +clip_36,25,Nature,coutrot1 +clip_37,25,Nature,coutrot1 +clip_38,25,Nature,coutrot1 +clip_39,25,Nature,coutrot1 +clip_41,25,Nature,coutrot1 +clip_4,25,Other,coutrot1 +clip_42,25,Nature,coutrot1 +clip_43,25,Nature,coutrot1 +clip_44,25,Nature,coutrot1 +clip_46,25,Social,coutrot1 +clip_47,25,Social,coutrot1 +clip_49,25,Social,coutrot1 +clip_50,25,Social,coutrot1 +clip_51,25,Social,coutrot1 +clip_5,25,Other,coutrot1 +clip_53,25,Social,coutrot1 +clip_54,25,Social,coutrot1 +clip_55,25,Social,coutrot1 +clip_56,25,Social,coutrot1 +clip_57,25,Social,coutrot1 +clip_59,25,Social,coutrot1 +clip_6,25,Other,coutrot1 +clip_10,25,Social,coutrot2 +clip_11,25,Social,coutrot2 +clip_12,25,Social,coutrot2 +clip_1,25,Social,coutrot2 +clip_13,25,Social,coutrot2 +clip_14,25,Social,coutrot2 +clip_15,25,Social,coutrot2 +clip_3,25,Social,coutrot2 +clip_5,25,Social,coutrot2 +clip_7,25,Social,coutrot2 +clip_8,25,Social,coutrot2 +clip_9,25,Social,coutrot2 +DIY_SOS_1280x712,30,Other,diem +documentary_adrenaline_rush_1280x720,30,Other,diem +documentary_discoverers_1280x720,30,Other,diem +documentary_mystery_nile_1280x720,30,Nature,diem +game_trailer_bullet_witch_1280x720,30,Other,diem +game_trailer_lego_indiana_jones_1280x720,30,Other,diem +growing_sweetcorn_1280x712,30,Social,diem +harry_potter_6_trailer_1280x544,30,Social,diem +hummingbirds_closeups_960x720,30,Nature,diem +hydraulics_1280x712,30,Other,diem +movie_trailer_ice_age_3_1280x690,30,Other,diem +music_gummybear_880x720,30,Other,diem +music_trailer_nine_inch_nails_1280x720,30,Social,diem +news_video_republic_960x720,30,Other,diem +nigella_chocolate_pears_1280x712,30,Social,diem +one_show_1280x712,30,Social,diem +pingpong_closeup_rallys_960x720,30,Social,diem +pingpong_miscues_1080x720,30,Social,diem +planet_earth_jungles_frogs_1280x704,30,Nature,diem +scottish_parliament_1152x864,30,Social,diem +sport_barcelona_extreme_1280x720,30,Other,diem +sport_F1_slick_tyres_1280x720,30,Other,diem +sport_golf_fade_a_driver_1280x720,30,Other,diem +sport_scramblers_1280x720,30,Other,diem +sport_slam_dunk_1280x704,30,Social,diem +sport_wimbledon_baltacha_1280x704,30,Social,diem +sport_wimbledon_magic_wand_1280x704,30,Social,diem +spotty_trifle_1280x712,30,Other,diem +tv_graduates_1280x720,30,Social,diem +tv_the_simpsons_860x528,30,Other,diem +university_forum_construction_ionic_1280x720,30,Other,diem diff --git a/datasets/processed/train_stavis_1.csv b/datasets/processed/train_stavis_1.csv new file mode 100644 index 0000000..a780631 --- /dev/null +++ b/datasets/processed/train_stavis_1.csv @@ -0,0 +1,115 @@ +video_id,fps,scene_type,dataset +50_people_brooklyn_1280x720,30,Social,diem +50_people_brooklyn_no_voices_1280x720,30,SocialEXCLUDE,diem +50_people_london_1280x720,30,Social,diem +50_people_london_no_voices_1280x720,30,SocialEXCLUDE,diem +Antarctica_landscape_1246x720,30,Nature,diem +BBC_wildlife_eagle_930x720,30,Nature,diem +BBC_wildlife_special_tiger_1276x720,30,Nature,diem +advert_bravia_paint_1280x720,30,Other,diem +ami_is1000a_closeup_720x576,30,Other,diem +ami_is1000a_left_720x576,30,Other,diem +arctic_bears_1066x710,30,Nature,diem +basketball_of_sorts_960x720,30,Social,diem +chilli_plasters_1280x712,30,Social,diem +documentary_adrenaline_rush_1280x720,30,Other,diem +documentary_coral_reef_adventure_1280x720,30,Nature,diem +documentary_discoverers_1280x720,30,Other,diem +documentary_dolphins_1280x720,30,Nature,diem +documentary_mystery_nile_1280x720,30,Nature,diem +documentary_planet_earth_1280x704,30,Nature,diem +game_trailer_bullet_witch_1280x720,30,Other,diem +game_trailer_ghostbusters_1280x720,30,Other,diem +game_trailer_lego_indiana_jones_1280x720,30,Other,diem +game_trailer_wrath_lich_king_shortened_subtitles_1280x548,30,Other,diem +growing_sweetcorn_1280x712,30,Social,diem +hairy_bikers_cabbage_1280x712,30,Social,diem +home_movie_Charlie_bit_my_finger_again_960x720,30,Social,diem +hummingbirds_closeups_960x720,30,Nature,diem +hummingbirds_narrator_960x720,30,Nature,diem +hydraulics_1280x712,30,Other,diem +movie_trailer_alice_in_wonderland_1280x682,30,Other,diem +movie_trailer_ice_age_3_1280x690,30,Other,diem +movie_trailer_quantum_of_solace_1280x688,30,Social,diem +music_red_hot_chili_peppers_shortened_1024x576,30,Other,diem +news_bee_parasites_768x576,30,Nature,diem +news_newsnight_othello_720x416,30,Other,diem +news_sherry_drinking_mice_768x576,30,Other,diem +news_us_election_debate_1080x600,30,Social,diem +news_video_republic_960x720,30,Other,diem +news_wimbledon_macenroe_shortened_768x576,30,Social,diem +nigella_chocolate_pears_1280x712,30,Social,diem +pingpong_closeup_rallys_960x720,30,Social,diem +pingpong_long_shot_960x720,30,Social,diem +pingpong_miscues_1080x720,30,Social,diem +planet_earth_jungles_frogs_1280x704,30,Nature,diem +planet_earth_jungles_monkeys_1280x70,30,Nature,diem +scottish_parliament_1152x864,30,Social,diem +scottish_starters_1280x712,30,Social,diem +sport_F1_slick_tyres_1280x720,30,Other,diem +sport_barcelona_extreme_1280x720,30,Other,diem +sport_cricket_ashes_2007_1252x720,30,Social,diem +sport_football_best_goals_976x720,30,Social,diem +sport_golf_fade_a_driver_1280x720,30,Other,diem +sport_poker_1280x640,30,Social,diem +sport_slam_dunk_1280x704,30,Social,diem +sport_surfing_in_thurso_900x720,30,Social,diem +sport_wimbledon_baltacha_1280x704,30,Social,diem +sport_wimbledon_magic_wand_1280x704,30,Social,diem +sport_wimbledon_murray_1280x704,30,Social,diem +sports_kendo_1280x710,30,Social,diem +spotty_trifle_1280x712,30,Other,diem +stewart_lee_1280x712,30,Social,diem +tv_graduates_1280x720,30,Social,diem +tv_ketch2_672x544,30,Social,diem +tv_the_simpsons_860x528,30,Other,diem +clip_12,25,Social,coutrot2 +clip_13,25,Social,coutrot2 +clip_10,25,Social,coutrot2 +clip_3,25,Social,coutrot2 +clip_1,25,Social,coutrot2 +clip_7,25,Social,coutrot2 +clip_9,25,Social,coutrot2 +clip_6,25,Social,coutrot2 +clip_5,25,Social,coutrot2 +clip_8,25,Social,coutrot2 +clip_1,25,Other,coutrot1 +clip_2,25,Other,coutrot1 +clip_3,25,Other,coutrot1 +clip_6,25,Other,coutrot1 +clip_7,25,Other,coutrot1 +clip_8,25,Other,coutrot1 +clip_9,25,Other,coutrot1 +clip_11,25,Other,coutrot1 +clip_12,25,Nature,coutrot1 +clip_13,25,Other,coutrot1 +clip_18,25,Nature,coutrot1 +clip_19,25,Other,coutrot1 +clip_20,25,Other,coutrot1 +clip_21,25,Other,coutrot1 +clip_23,25,Other,coutrot1 +clip_25,25,Other,coutrot1 +clip_27,25,Other,coutrot1 +clip_28,25,Other,coutrot1 +clip_29,25,Nature,coutrot1 +clip_30,25,Other,coutrot1 +clip_31,25,Nature,coutrot1 +clip_33,25,Nature,coutrot1 +clip_34,25,Nature,coutrot1 +clip_35,25,Nature,coutrot1 +clip_36,25,Nature,coutrot1 +clip_40,25,Nature,coutrot1 +clip_41,25,Nature,coutrot1 +clip_42,25,Nature,coutrot1 +clip_43,25,Nature,coutrot1 +clip_45,25,Nature,coutrot1 +clip_46,25,Social,coutrot1 +clip_49,25,Social,coutrot1 +clip_51,25,Social,coutrot1 +clip_52,25,Social,coutrot1 +clip_53,25,Social,coutrot1 +clip_54,25,Social,coutrot1 +clip_55,25,Social,coutrot1 +clip_57,25,Social,coutrot1 +clip_59,25,Social,coutrot1 +clip_60,25,Social,coutrot1 diff --git a/datasets/processed/train_stavis_2.csv b/datasets/processed/train_stavis_2.csv new file mode 100644 index 0000000..41919a0 --- /dev/null +++ b/datasets/processed/train_stavis_2.csv @@ -0,0 +1,115 @@ +video_id,fps,scene_type,dataset +50_people_brooklyn_1280x720,30,Social,diem +50_people_brooklyn_no_voices_1280x720,30,SocialEXCLUDE,diem +50_people_london_1280x720,30,Social,diem +50_people_london_no_voices_1280x720,30,SocialEXCLUDE,diem +Antarctica_landscape_1246x720,30,Nature,diem +BBC_wildlife_eagle_930x720,30,Nature,diem +BBC_wildlife_special_tiger_1276x720,30,Nature,diem +advert_bravia_paint_1280x720,30,Other,diem +ami_is1000a_closeup_720x576,30,Other,diem +ami_is1000a_left_720x576,30,Other,diem +arctic_bears_1066x710,30,Nature,diem +basketball_of_sorts_960x720,30,Social,diem +chilli_plasters_1280x712,30,Social,diem +documentary_adrenaline_rush_1280x720,30,Other,diem +documentary_coral_reef_adventure_1280x720,30,Nature,diem +documentary_discoverers_1280x720,30,Other,diem +documentary_dolphins_1280x720,30,Nature,diem +documentary_mystery_nile_1280x720,30,Nature,diem +documentary_planet_earth_1280x704,30,Nature,diem +game_trailer_bullet_witch_1280x720,30,Other,diem +game_trailer_ghostbusters_1280x720,30,Other,diem +game_trailer_lego_indiana_jones_1280x720,30,Other,diem +game_trailer_wrath_lich_king_shortened_subtitles_1280x548,30,Other,diem +growing_sweetcorn_1280x712,30,Social,diem +hairy_bikers_cabbage_1280x712,30,Social,diem +home_movie_Charlie_bit_my_finger_again_960x720,30,Social,diem +hummingbirds_closeups_960x720,30,Nature,diem +hummingbirds_narrator_960x720,30,Nature,diem +hydraulics_1280x712,30,Other,diem +movie_trailer_alice_in_wonderland_1280x682,30,Other,diem +movie_trailer_ice_age_3_1280x690,30,Other,diem +movie_trailer_quantum_of_solace_1280x688,30,Social,diem +music_red_hot_chili_peppers_shortened_1024x576,30,Other,diem +news_bee_parasites_768x576,30,Nature,diem +news_newsnight_othello_720x416,30,Other,diem +news_sherry_drinking_mice_768x576,30,Other,diem +news_us_election_debate_1080x600,30,Social,diem +news_video_republic_960x720,30,Other,diem +news_wimbledon_macenroe_shortened_768x576,30,Social,diem +nigella_chocolate_pears_1280x712,30,Social,diem +pingpong_closeup_rallys_960x720,30,Social,diem +pingpong_long_shot_960x720,30,Social,diem +pingpong_miscues_1080x720,30,Social,diem +planet_earth_jungles_frogs_1280x704,30,Nature,diem +planet_earth_jungles_monkeys_1280x70,30,Nature,diem +scottish_parliament_1152x864,30,Social,diem +scottish_starters_1280x712,30,Social,diem +sport_F1_slick_tyres_1280x720,30,Other,diem +sport_barcelona_extreme_1280x720,30,Other,diem +sport_cricket_ashes_2007_1252x720,30,Social,diem +sport_football_best_goals_976x720,30,Social,diem +sport_golf_fade_a_driver_1280x720,30,Other,diem +sport_poker_1280x640,30,Social,diem +sport_slam_dunk_1280x704,30,Social,diem +sport_surfing_in_thurso_900x720,30,Social,diem +sport_wimbledon_baltacha_1280x704,30,Social,diem +sport_wimbledon_magic_wand_1280x704,30,Social,diem +sport_wimbledon_murray_1280x704,30,Social,diem +sports_kendo_1280x710,30,Social,diem +spotty_trifle_1280x712,30,Other,diem +stewart_lee_1280x712,30,Social,diem +tv_graduates_1280x720,30,Social,diem +tv_ketch2_672x544,30,Social,diem +tv_the_simpsons_860x528,30,Other,diem +clip_3,25,Social,coutrot2 +clip_4,25,Social,coutrot2 +clip_15,25,Social,coutrot2 +clip_1,25,Social,coutrot2 +clip_11,25,Social,coutrot2 +clip_14,25,Social,coutrot2 +clip_9,25,Social,coutrot2 +clip_2,25,Social,coutrot2 +clip_6,25,Social,coutrot2 +clip_8,25,Social,coutrot2 +clip_2,25,Other,coutrot1 +clip_4,25,Other,coutrot1 +clip_5,25,Other,coutrot1 +clip_6,25,Other,coutrot1 +clip_9,25,Other,coutrot1 +clip_10,25,Other,coutrot1 +clip_12,25,Nature,coutrot1 +clip_13,25,Other,coutrot1 +clip_14,25,Other,coutrot1 +clip_15,25,Other,coutrot1 +clip_16,25,Other,coutrot1 +clip_17,25,Other,coutrot1 +clip_18,25,Nature,coutrot1 +clip_19,25,Other,coutrot1 +clip_20,25,Other,coutrot1 +clip_21,25,Other,coutrot1 +clip_22,25,Other,coutrot1 +clip_23,25,Other,coutrot1 +clip_24,25,Other,coutrot1 +clip_26,25,Other,coutrot1 +clip_32,25,Nature,coutrot1 +clip_34,25,Nature,coutrot1 +clip_35,25,Nature,coutrot1 +clip_37,25,Nature,coutrot1 +clip_38,25,Nature,coutrot1 +clip_39,25,Nature,coutrot1 +clip_40,25,Nature,coutrot1 +clip_41,25,Nature,coutrot1 +clip_44,25,Nature,coutrot1 +clip_45,25,Nature,coutrot1 +clip_46,25,Social,coutrot1 +clip_47,25,Social,coutrot1 +clip_48,25,Social,coutrot1 +clip_50,25,Social,coutrot1 +clip_51,25,Social,coutrot1 +clip_53,25,Social,coutrot1 +clip_56,25,Social,coutrot1 +clip_58,25,Social,coutrot1 +clip_59,25,Social,coutrot1 +clip_60,25,Social,coutrot1 diff --git a/datasets/processed/train_stavis_3.csv b/datasets/processed/train_stavis_3.csv new file mode 100644 index 0000000..83fc585 --- /dev/null +++ b/datasets/processed/train_stavis_3.csv @@ -0,0 +1,115 @@ +video_id,fps,scene_type,dataset +50_people_brooklyn_1280x720,30,Social,diem +50_people_brooklyn_no_voices_1280x720,30,SocialEXCLUDE,diem +50_people_london_1280x720,30,Social,diem +50_people_london_no_voices_1280x720,30,SocialEXCLUDE,diem +Antarctica_landscape_1246x720,30,Nature,diem +BBC_wildlife_eagle_930x720,30,Nature,diem +BBC_wildlife_special_tiger_1276x720,30,Nature,diem +advert_bravia_paint_1280x720,30,Other,diem +ami_is1000a_closeup_720x576,30,Other,diem +ami_is1000a_left_720x576,30,Other,diem +arctic_bears_1066x710,30,Nature,diem +basketball_of_sorts_960x720,30,Social,diem +chilli_plasters_1280x712,30,Social,diem +documentary_adrenaline_rush_1280x720,30,Other,diem +documentary_coral_reef_adventure_1280x720,30,Nature,diem +documentary_discoverers_1280x720,30,Other,diem +documentary_dolphins_1280x720,30,Nature,diem +documentary_mystery_nile_1280x720,30,Nature,diem +documentary_planet_earth_1280x704,30,Nature,diem +game_trailer_bullet_witch_1280x720,30,Other,diem +game_trailer_ghostbusters_1280x720,30,Other,diem +game_trailer_lego_indiana_jones_1280x720,30,Other,diem +game_trailer_wrath_lich_king_shortened_subtitles_1280x548,30,Other,diem +growing_sweetcorn_1280x712,30,Social,diem +hairy_bikers_cabbage_1280x712,30,Social,diem +home_movie_Charlie_bit_my_finger_again_960x720,30,Social,diem +hummingbirds_closeups_960x720,30,Nature,diem +hummingbirds_narrator_960x720,30,Nature,diem +hydraulics_1280x712,30,Other,diem +movie_trailer_alice_in_wonderland_1280x682,30,Other,diem +movie_trailer_ice_age_3_1280x690,30,Other,diem +movie_trailer_quantum_of_solace_1280x688,30,Social,diem +music_red_hot_chili_peppers_shortened_1024x576,30,Other,diem +news_bee_parasites_768x576,30,Nature,diem +news_newsnight_othello_720x416,30,Other,diem +news_sherry_drinking_mice_768x576,30,Other,diem +news_us_election_debate_1080x600,30,Social,diem +news_video_republic_960x720,30,Other,diem +news_wimbledon_macenroe_shortened_768x576,30,Social,diem +nigella_chocolate_pears_1280x712,30,Social,diem +pingpong_closeup_rallys_960x720,30,Social,diem +pingpong_long_shot_960x720,30,Social,diem +pingpong_miscues_1080x720,30,Social,diem +planet_earth_jungles_frogs_1280x704,30,Nature,diem +planet_earth_jungles_monkeys_1280x70,30,Nature,diem +scottish_parliament_1152x864,30,Social,diem +scottish_starters_1280x712,30,Social,diem +sport_F1_slick_tyres_1280x720,30,Other,diem +sport_barcelona_extreme_1280x720,30,Other,diem +sport_cricket_ashes_2007_1252x720,30,Social,diem +sport_football_best_goals_976x720,30,Social,diem +sport_golf_fade_a_driver_1280x720,30,Other,diem +sport_poker_1280x640,30,Social,diem +sport_slam_dunk_1280x704,30,Social,diem +sport_surfing_in_thurso_900x720,30,Social,diem +sport_wimbledon_baltacha_1280x704,30,Social,diem +sport_wimbledon_magic_wand_1280x704,30,Social,diem +sport_wimbledon_murray_1280x704,30,Social,diem +sports_kendo_1280x710,30,Social,diem +spotty_trifle_1280x712,30,Other,diem +stewart_lee_1280x712,30,Social,diem +tv_graduates_1280x720,30,Social,diem +tv_ketch2_672x544,30,Social,diem +tv_the_simpsons_860x528,30,Other,diem +clip_12,25,Social,coutrot2 +clip_13,25,Social,coutrot2 +clip_10,25,Social,coutrot2 +clip_4,25,Social,coutrot2 +clip_15,25,Social,coutrot2 +clip_11,25,Social,coutrot2 +clip_7,25,Social,coutrot2 +clip_14,25,Social,coutrot2 +clip_2,25,Social,coutrot2 +clip_5,25,Social,coutrot2 +clip_1,25,Other,coutrot1 +clip_3,25,Other,coutrot1 +clip_4,25,Other,coutrot1 +clip_5,25,Other,coutrot1 +clip_7,25,Other,coutrot1 +clip_8,25,Other,coutrot1 +clip_10,25,Other,coutrot1 +clip_11,25,Other,coutrot1 +clip_14,25,Other,coutrot1 +clip_15,25,Other,coutrot1 +clip_16,25,Other,coutrot1 +clip_17,25,Other,coutrot1 +clip_22,25,Other,coutrot1 +clip_24,25,Other,coutrot1 +clip_25,25,Other,coutrot1 +clip_26,25,Other,coutrot1 +clip_27,25,Other,coutrot1 +clip_28,25,Other,coutrot1 +clip_29,25,Nature,coutrot1 +clip_30,25,Other,coutrot1 +clip_31,25,Nature,coutrot1 +clip_32,25,Nature,coutrot1 +clip_33,25,Nature,coutrot1 +clip_36,25,Nature,coutrot1 +clip_37,25,Nature,coutrot1 +clip_38,25,Nature,coutrot1 +clip_39,25,Nature,coutrot1 +clip_42,25,Nature,coutrot1 +clip_43,25,Nature,coutrot1 +clip_44,25,Nature,coutrot1 +clip_47,25,Social,coutrot1 +clip_48,25,Social,coutrot1 +clip_49,25,Social,coutrot1 +clip_50,25,Social,coutrot1 +clip_52,25,Social,coutrot1 +clip_54,25,Social,coutrot1 +clip_55,25,Social,coutrot1 +clip_56,25,Social,coutrot1 +clip_57,25,Social,coutrot1 +clip_58,25,Social,coutrot1 diff --git a/datasets/processed/validation.csv b/datasets/processed/validation.csv new file mode 100644 index 0000000..c77df36 --- /dev/null +++ b/datasets/processed/validation.csv @@ -0,0 +1,29 @@ +video_id,fps,scene_type,dataset +50_people_london_1280x720,30,Social,diem +advert_iphone_1272x720,30,Other,diem +BBC_life_in_cold_blood_1278x710,30,Nature,diem +chilli_plasters_1280x712,30,Social,diem +clip_18,25,Nature,coutrot1 +clip_26,25,Other,coutrot1 +clip_33,25,Nature,coutrot1 +clip_40,25,Nature,coutrot1 +clip_48,25,Social,coutrot1 +clip_52,25,Social,coutrot1 +clip_7,25,Other,coutrot1 +clip_8,25,Other,coutrot1 +clip_4,25,Social,coutrot2 +clip_6,25,Social,coutrot2 +documentary_dolphins_1280x720,30,Nature,diem +game_trailer_ghostbusters_1280x720,30,Other,diem +hairy_bikers_cabbage_1280x712,30,Social,diem +hummingbirds_narrator_960x720,30,Nature,diem +movie_trailer_quantum_of_solace_1280x688,30,Social,diem +news_sherry_drinking_mice_768x576,30,Other,diem +news_wimbledon_macenroe_shortened_768x576,30,Social,diem +pingpong_angle_shot_960x720,30,Social,diem +pingpong_no_bodies_960x720,30,Other,diem +scottish_starters_1280x712,30,Social,diem +sport_football_best_goals_976x720,30,Social,diem +sports_kendo_1280x710,30,Social,diem +sport_wimbledon_federer_final_1280x704,30,Social,diem +stewart_lee_1280x712,30,Social,diem diff --git a/datasets/processed/validation_ave.csv b/datasets/processed/validation_ave.csv new file mode 100644 index 0000000..fd85233 --- /dev/null +++ b/datasets/processed/validation_ave.csv @@ -0,0 +1,30 @@ +video_id,fps,scene_type,dataset +50_people_london_1280x720,30,Social,diem +advert_iphone_1272x720,30,Other,diem +BBC_life_in_cold_blood_1278x710,30,Nature,diem +chilli_plasters_1280x712,30,Social,diem +clip_18,25,Nature,coutrot1 +clip_26,25,Other,coutrot1 +clip_33,25,Nature,coutrot1 +clip_40,25,Nature,coutrot1 +clip_48,25,Social,coutrot1 +clip_52,25,Social,coutrot1 +clip_7,25,Other,coutrot1 +clip_8,25,Other,coutrot1 +clip_4,25,Social,coutrot2 +clip_6,25,Social,coutrot2 +documentary_dolphins_1280x720,30,Nature,diem +game_trailer_ghostbusters_1280x720,30,Other,diem +hairy_bikers_cabbage_1280x712,30,Social,diem +hummingbirds_narrator_960x720,30,Nature,diem +movie_trailer_quantum_of_solace_1280x688,30,Social,diem +news_sherry_drinking_mice_768x576,30,Other,diem +news_wimbledon_macenroe_shortened_768x576,30,Social,diem +pingpong_angle_shot_960x720,30,Social,diem +pingpong_no_bodies_960x720,30,Other,diem +scottish_starters_1280x712,30,Social,diem +sport_football_best_goals_976x720,30,Social,diem +sports_kendo_1280x710,30,Social,diem +sport_wimbledon_federer_final_1280x704,30,Social,diem +stewart_lee_1280x712,30,Social,diem +tv_uni_challenge_final_1280x712,30,Social,diem diff --git a/datasets/stavis_preprocessed/download_dataset.sh b/datasets/stavis_preprocessed/download_dataset.sh new file mode 100644 index 0000000..f26c281 --- /dev/null +++ b/datasets/stavis_preprocessed/download_dataset.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +wget http://cvsp.cs.ntua.gr/research/stavis/data/annotations/DIEM.tar.gz +wget http://cvsp.cs.ntua.gr/research/stavis/data/annotations/Coutrot_db1.tar.gz +wget http://cvsp.cs.ntua.gr/research/stavis/data/annotations/Coutrot_db2.tar.gz + +tar -xf DIEM.tar.gz +rm DIEM.tar.gz +mv DIEM diem +tar -xf Coutrot_db1.tar.gz +rm Coutrot_db1.tar.gz +mv Coutrot_db1 coutrot1 +cd coutrot1 +for d in ./*/; do mv -v "$d" "${d/clip/clip_}"; done; +cd ../ +tar -xf Coutrot_db2.tar.gz +rm Coutrot_db2.tar.gz +mv Coutrot_db2 coutrot2 +cd coutrot2 +for d in ./*/; do mv -v "$d" "${d/clip/clip_}"; done; +cd ../ + + + diff --git a/gazenet/__init__.py b/gazenet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/bin/__init__.py b/gazenet/bin/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/bin/download_manager.py b/gazenet/bin/download_manager.py new file mode 100644 index 0000000..0f7c38c --- /dev/null +++ b/gazenet/bin/download_manager.py @@ -0,0 +1,110 @@ +import argparse +import errno +import os +import glob +import shutil +import subprocess + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--working_dir", default="./", type=str, + help="The working directory in which the directory structure is " + "built and downloaded files are stored.") + + parser.add_argument("--datasets", type=str, nargs='+', required=False, + help="The list of dataset names to be downloaded") + + parser.add_argument("--models", type=str, nargs='+', required=False, + help="The list of model names to be downloaded") + + return parser.parse_args() + + +def copy_dir(src_dir, dst_dir): + try: + shutil.copytree(src_dir, dst_dir, False, None) + except (OSError, FileExistsError) as exc: # python >2.5 + if exc.errno == errno.ENOTDIR: + shutil.copy(src_dir, dst_dir) + else: + pass + + +def create_dir(parent_dir, dst_dir): + dst_path = os.path.join(parent_dir, dst_dir) + try: + os.makedirs(dst_path, exist_ok=True) + print("Directory '%s' creation succeeded" % dst_dir) + except OSError as e: + print("Directory '%s' creation failed") + + +def main(): + args = parse_args() + + # some boiler-plate dirs + struct_dirs = ["temp", "logs", os.path.join("logs", "metrics")] + for struct_dir in struct_dirs: + create_dir(args.working_dir, struct_dir) + + # copy datasets directory if not already there (to restore structure from repo delete datasets/ in working_dir) + if not os.path.isdir(os.path.join(args.working_dir, "datasets")): + copy_dir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "datasets"), + os.path.join(args.working_dir, "datasets")) + + # copy config directory if not already there (to restore structure from repo delete datasets/ in working_dir) + if not os.path.isdir(os.path.join(args.working_dir, "gazenet", "configs", "infer_configs")): + copy_dir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "configs", "infer_configs"), + os.path.join(args.working_dir, "gazenet", "configs", "infer_configs")) + if not os.path.isdir(os.path.join(args.working_dir, "gazenet", "configs", "train_configs")): + copy_dir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "configs", "train_configs"), + os.path.join(args.working_dir, "gazenet", "configs", "train_configs")) + + # download datasets + if args.datasets: + for dataset in args.datasets: + dataset_script_path = os.path.join(args.working_dir, "datasets", dataset) + bashCommand = "./download_dataset.sh" + process = subprocess.Popen(bashCommand.split(), + stdout=subprocess.PIPE, cwd=dataset_script_path, universal_newlines=True) + for stdout_line in iter(process.stdout.readline, ""): + print(stdout_line) + process.stdout.close() + return_code = process.wait() + if return_code: + raise subprocess.CalledProcessError(return_code, bashCommand.split()) + # output, error = process.communicate() + + # download models + if args.models: + for models in args.models: + if "<...>" in models: + base_models_script_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "models") + models = glob.glob(os.path.join(base_models_script_path, + models.replace("<...>", "/**/download_model.sh")), recursive=True) + models = map(os.path.dirname, models) + models = [model.replace(base_models_script_path + os.sep, "") for model in models] + else: + models = [models] + + for model in models: + base_model_script_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "models", model) + model_script_path = os.path.join(args.working_dir, "gazenet", "models", model) + + copy_dir(base_model_script_path, model_script_path) + + bashCommand = "./download_model.sh" + process = subprocess.Popen(bashCommand.split(), + stdout=subprocess.PIPE, cwd=model_script_path, universal_newlines=True) + for stdout_line in iter(process.stdout.readline, ""): + print(stdout_line) + process.stdout.close() + return_code = process.wait() + if return_code: + raise subprocess.CalledProcessError(return_code, bashCommand.split()) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/gazenet/bin/infer.py b/gazenet/bin/infer.py new file mode 100644 index 0000000..8d59a17 --- /dev/null +++ b/gazenet/bin/infer.py @@ -0,0 +1,370 @@ +#!/bin/bash/python3 + +import argparse +from collections import deque +import re +import os +import json +import copy +import threading +import sys + +from tqdm import tqdm +import torch +import numpy as np +import pandas +import cv2 +from joblib import Parallel, delayed +import sounddevice as sd + +from gazenet.utils.registrar import * +import gazenet.utils.sample_processors as sp +from gazenet.utils.dataset_processors import DataWriter, DataSplitter +from gazenet.utils.annotation_plotter import OpenCV +from gazenet.utils.helpers import stack_images + +pandas.set_option("display.max_columns", 15) +read_lock = threading.Lock() + + +def exec_video_extraction(video, face_detector, preprocessed_data, max_w_size=1, video_properties={}, plotter=None, realtime_capture=False): + + if preprocessed_data is None: + preprocessed_data = {} + preprocessed_data["grabbed_video_list"] = deque(maxlen=max_w_size) + preprocessed_data["grouped_video_frames_list"] = deque(maxlen=max_w_size) + preprocessed_data["info_list"] = deque(maxlen=max_w_size) + preprocessed_data["properties_list"] = deque(maxlen=max_w_size) + preprocessed_data["video_frames_list"] = deque(maxlen=max_w_size) + preprocessed_data["faces_locations"] = deque(maxlen=max_w_size) + preprocessed_data = preprocessed_data.copy() + extracted_data_list = video.extract_frames(extract_audio=False, realtime_indexing=True if realtime_capture else False) + grabbed_video_list, grouped_video_frames_list, _, _, info_list, properties_list = \ + video.annotate_frames(extracted_data_list, plotter, **video_properties) + with read_lock: + preprocessed_data["grabbed_video_list"].extend(grabbed_video_list) + preprocessed_data["grouped_video_frames_list"].extend(grouped_video_frames_list) + preprocessed_data["info_list"].extend(info_list) + preprocessed_data["properties_list"].extend(properties_list) + + video_frames_list = stack_images(grouped_video_frames_list, grabbed_video_list, plot_override=[["captured"]]) + preprocessed_data["video_frames_list"].extend(video_frames_list) + preprocessed_data["faces_locations"].extend(face_detector.detect_frames(video_frames_list)) + return preprocessed_data + + + +def exec_audio_extraction(video, audio_feat_extractors, preprocessed_data, max_w_size=25, plotter=None): + + if preprocessed_data is None: + preprocessed_data = {} + preprocessed_data["grabbed_audio_list"] = deque(maxlen=max_w_size) + preprocessed_data["audio_frames_list"] = deque(maxlen=max_w_size) + preprocessed_data.update(**{audio_feat_name: deque(maxlen=max_w_size) for audio_feat_name in audio_feat_extractors.keys()}) + # preprocessed_data["audio_features"] = deque(maxlen=max_w_size) + # preprocessed_data["hann_audio_frames"] = deque(maxlen=max_w_size) + + extracted_data_list = video.extract_frames(extract_video=False) + _, _, grabbed_audio_list, audio_frames_list, _, _ = video.annotate_frames(extracted_data_list, plotter) + if any(grabbed_audio_list): + audio_idx = list(filter(lambda x: grabbed_audio_list[x], range(len(grabbed_audio_list)))) + with read_lock: + try: + audio_frames_list = audio_frames_list[audio_idx[0]] + preprocessed_data["audio_frames_list"].extend(audio_frames_list) + for audio_feat_name, audio_feat_extractor in audio_feat_extractors.items(): + audio_feat = audio_feat_extractor.waveform_to_feature(audio_frames_list, + rate=video.audio_cap.get(sp.AUCAP_PROP_SAMPLE_RATE)) + preprocessed_data[audio_feat_name].extend(audio_feat) + except: + audio_frames_list = np.zeros((1, video.audio_cap.get(sp.AUCAP_PROP_SAMPLE_RATE)), + dtype=np.float32) + preprocessed_data["audio_frames_list"].extend(audio_frames_list) + for audio_feat_name, audio_feat_extractor in audio_feat_extractors.items(): + audio_feat = audio_feat_extractor.waveform_to_feature(audio_frames_list, + rate=video.audio_cap.get(sp.AUCAP_PROP_SAMPLE_RATE)) + preprocessed_data[audio_feat_name].extend(audio_feat) + preprocessed_data["grabbed_audio_list"].extend(grabbed_audio_list * len(preprocessed_data["audio_frames_list"])) + return preprocessed_data + + +def exec_inference(inferer, plotter, preprocessed_data, previous_data, source_frames_idxs=None, + inference_properties={}, preproc_properties={}, postproc_properties={}): + with read_lock: + postprocessed_data = inferer.preprocess_frames(**preprocessed_data, previous_data=previous_data, **preproc_properties) + extracted_data_list = inferer.extract_frames(**postprocessed_data, source_frames_idxs=source_frames_idxs) + grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, info_list, properties_list = \ + inferer.annotate_frames(extracted_data_list, plotter, **inference_properties) + grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, info_list, properties_list = \ + inferer.postprocess_frames(grabbed_video_list, grouped_video_frames_list, + grabbed_audio_list, audio_frames_list, info_list, properties_list, **postproc_properties) + return grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, info_list, properties_list + + +def infer(args, config): + + accumulated_metrics = None + + # create the plotting helper + plotter = OpenCV() + + # create the video object from camera + if config.realtime_capture: + config.sampler_properties.update(enable_audio=True, width=config.width, height=config.height) + video = sp.SampleProcessor(**config.sampler_properties) + video.load({"video_name": 0, "audio_name": "0"}) + else: + # or create the video object from a database + config.reader_properties.update(mode="d") + video_source = ReaderRegistrar.registry[config.reader](**config.reader_properties) + config.sampler_properties.update(w_size=config.stride, enable_audio=config.enable_audio) + video = SampleRegistrar.registry[config.sampler](video_source, **config.sampler_properties) + # traverse dataset videos only + if config.process_dataset_videos_only: + config.datasplitter_properties.update(mode="r") + dataset_splitter = DataSplitter(**config.datasplitter_properties) + if video.short_name != "processed": + dataset_samples = dataset_splitter.samples[(dataset_splitter.samples["scene_type"] == "Social") & + (dataset_splitter.samples["dataset"] == video.short_name)] + else: + dataset_samples = dataset_splitter.samples[(dataset_splitter.samples["scene_type"] == "Social")] + dataset_iter = dataset_samples.iterrows() + dataset = next(dataset_iter) + if video.short_name != "processed": + video.goto(dataset[1]["video_id"], by_index=False) + else: + video.goto(os.path.join(dataset[1]["dataset"], dataset[1]["video_id"]), by_index=False) + else: + # video.goto("clip_50",by_index=False) # choose the first video instead: next(video) + video.goto(0) + # get the fps + w_fps = video.frames_per_sec() + + # create detectors + audio_feature_extraction = {audio_feat_name: AudioFeatureRegistrar.registry[audio_feat](hop_len_sec=1 / (w_fps)) + for audio_feat_name, audio_feat in config.audio_features.items()} + face_detection = FaceDetectorRegistrar.registry[config.face_detector](device=config.device) + + if config.write_images or config.write_annotations or config.write_videos: + # create the data writer + writer = DataWriter(video.short_name, video_name=video.reader.samples[video.index]["id"], + output_video_size=(video.reader.samples[video.index]["video_width"], + video.reader.samples[video.index]["video_height"]), + frames_per_sec=w_fps, + # output_video_size=(1232, 504), frames_per_sec=25, + write_images=config.write_images, + write_annotations=config.write_annotations, + write_videos=config.write_videos) + + # create the nice bar so we can look at something while processing happens + bar_writer = tqdm(desc="Write -> Video Nr: " + str(video.index), total=video.len_frames()) + + if config.compute_metrics: + # create the metrics + metrics_logger = MetricsRegistrar.registry[config.metrics](save_file=config.metrics_save_file, + dataset_name=video.short_name, + video_name=video.reader.samples[video.index]["id"], + metrics_list=config.metrics_list) + + bar_metrics = tqdm(desc="Metrics -> Video Nr: " + str(video.index), total=video.len_frames()) + + # create models + inferers = [[InferenceRegistrar.registry[model_data[0]] + (w_size=model_data[1], width=config.width, height=config.height, device=config.device, **model_data[3]) + for model_data in model_group] + for model_group in config.model_groups] + + # create returns placeholder + returns = [[]] * len(config.model_groups) + + # initially run preprocessors + preprocessed_vid_data_list = exec_video_extraction(video, face_detection, None, max_w_size=config.max_w_size, + video_properties=config.sampling_properties, plotter=plotter) + + preprocessed_aud_data_list = exec_audio_extraction(video, audio_feature_extraction, None, max_w_size=w_fps) + preprocessed_data_list = {**preprocessed_vid_data_list, **preprocessed_aud_data_list} + while True: + try: + for idx_model_group, model_group in enumerate(config.model_groups): + if idx_model_group == 0: + # create execution recipe for extracting data + execution_recipe = [ + delayed(exec_video_extraction)(video, face_detection, preprocessed_data_list, + video_properties=config.sampling_properties, plotter=plotter, + realtime_capture=config.realtime_capture), + delayed(exec_audio_extraction)(video, audio_feature_extraction, preprocessed_data_list), + ] + else: + execution_recipe = [] + for idx_model, model_data in enumerate(model_group): + execution_recipe.extend([delayed(exec_inference)(inferers[idx_model_group][idx_model], plotter, + preprocessed_data_list, + previous_data=returns[ + idx_model_group - 1] if idx_model_group != 0 else None, + source_frames_idxs=model_data[2], + inference_properties=config.inference_properties, + preproc_properties=model_data[4], + postproc_properties=model_data[5])]) + + returns[idx_model_group] = Parallel(n_jobs=config.n_jobs[idx_model_group], prefer="threads")(execution_recipe) + + if idx_model_group == 0: + # update preprocessed data in the first model_group iteration + preprocessed_data_list = {**returns[0][0], **returns[0][1]} + + # write images and annotations + if config.write_images or config.write_annotations or config.write_videos: + bar_writer.update(1) + if not any(returns[idx_model_group][idx_model_group]["grabbed_video_list"]): + raise IndexError + writer.add_detections(returns[idx_model_group], model_group) + + # compute metrics + if config.compute_metrics: + bar_metrics.update(1) + if not any(returns[idx_model_group][idx_model_group]["grabbed_video_list"]): + raise IndexError + frame_metrics = metrics_logger.add_metrics(returns[idx_model_group], model_group, config.metrics_mappings) + + # customize this plotter as you like + if config.visualize_images: + try: + new_plot = {"PLOT": [[]]} + for idx_model, model_data in enumerate(model_group): + grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, \ + info_list, properties_list = returns[idx_model_group][ + 2 + idx_model if idx_model_group == 0 else idx_model] + # if play_audio: + # sd.play(np.array(audio_frames_list).flatten(), video.audio_cap.get(sp.AUCAP_PROP_SAMPLE_RATE)) + # + if grouped_video_frames_list: + transformed_frames = stack_images(grouped_video_frames_list, grabbed_video_list) + for idx_frame, transformed_frame in enumerate(transformed_frames): + if transformed_frame is not None: + new_plot["transformed_" + model_data[0] + str(idx_frame)] = transformed_frame + # new_plot["transformed_" + model_data[0] + str(idx_frame)] = transformed_frame + new_plot["PLOT"][-1].extend(["transformed_" + model_data[0] + str(idx_frame)]) + else: + pass + else: + pass + cv2.imshow("target_" + str(idx_model_group), stack_images(new_plot, grabbed_video_list=True)) + # cv2.waitKey(int(1/w_fps * 100)) + cv2.waitKey(1) + except (cv2.error, AttributeError): + pass + + except (cv2.error, IndexError): + # iterate videos for offline datasets + if not config.realtime_capture: + if config.process_dataset_videos_only: + try: + dataset = next(dataset_iter) + if video.short_name != "processed": + video.goto(dataset[1]["video_id"], by_index=False) + else: + video.goto(os.path.join(dataset[1]["dataset"], dataset[1]["video_id"]), by_index=False) + except StopIteration: + break + else: + next(video) + + # get the fps + w_fps = video.frames_per_sec() + + # start new images and annotations writer + if config.write_images or config.write_annotations or config.write_videos: + if config.write_annotations: + writer.dump_annotations() + if config.write_videos: + writer.dump_videos() + bar_writer.close() + writer.set_new_name(video.reader.samples[video.index]["id"], + output_vid_size=(video.reader.samples[video.index]["video_width"], + video.reader.samples[video.index]["video_height"]), + fps=w_fps) + # writer.set_new_name(video.reader.samples[video.index]["id"], output_vid_size=(1232, 504), fps=25) + + bar_writer = tqdm(desc="Write -> Video Nr: " + str(video.index), total=video.len_frames()) + + if config.compute_metrics: + bar_metrics.close() + print(metrics_logger.set_new_name(video.reader.samples[video.index]["id"])) + bar_metrics = tqdm(desc="Metrics -> Video Nr: " + str(video.index), total=video.len_frames()) + + + # create detectors + audio_feature_extraction = { + audio_feat_name: AudioFeatureRegistrar.registry[audio_feat](hop_len_sec=1 / (w_fps)) + for audio_feat_name, audio_feat in config.audio_features.items()} + + # re-run preprocessors + preprocessed_vid_data_list = exec_video_extraction(video, face_detection, None, max_w_size=config.max_w_size, + video_properties=config.sampling_properties, + realtime_capture=config.realtime_capture) + + preprocessed_aud_data_list = exec_audio_extraction(video, audio_feature_extraction, None, max_w_size=w_fps) + preprocessed_data_list = {**preprocessed_vid_data_list, **preprocessed_aud_data_list} + + if config.compute_metrics: + metrics_logger.accumulate_metrics() + accumulated_metrics = metrics_logger.scores + + return accumulated_metrics + +def parse_args(): + inferer_summaries = "configuration summaries:" + for config_name in InferenceConfigRegistrar.registry.keys(): + config_summary = InferenceConfigRegistrar.registry[config_name].config_info()["summary"] + inferer_summaries += ("\n " + config_name + "\n " + config_summary) + + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=inferer_summaries) + + parser.add_argument("--infer_config", type=str, default="InferPlayground001", required=False, + choices=InferenceConfigRegistrar.registry.keys(), + help="The inference configuration. Select config from ../configs/infer_config.py") + parser.add_argument("--infer_config_file", type=str, required=False, + help="The json inference configuration file (overrides infer_config).") + parser.add_argument("--gpu", type=int, required=False, + help="The GPU index with CUDA support for running the inferer.") + return parser.parse_args() + + +def main(): + InferenceConfigRegistrar.scan() + args = parse_args() + if args.infer_config_file: + with open(args.infer_config_file) as fp: + data = json.load(fp) + config = InferenceConfigRegistrar.registry["InferGeneratorAllModelsBase"] + config.__name__ = os.path.splitext(os.path.basename(args.infer_config_file))[0] + for data_key, data_val in data.items(): + setattr(config, data_key, data_val) + else: + config = InferenceConfigRegistrar.registry[args.infer_config] + + # update config with args + if args.gpu is not None: + setattr(config, "device", "cuda:"+str(args.gpu)) + + # scan the registrars + InferenceRegistrar.scan() + ReaderRegistrar.scan() + SampleRegistrar.scan() + FaceDetectorRegistrar.scan() + AudioFeatureRegistrar.scan() + + # create metrics if enabled + if config.compute_metrics: + # scan the metrics registrar + MetricsRegistrar.scan() + + metrics = infer(args, config) + if config.compute_metrics: + print(metrics) + + +if __name__ == "__main__": + main() diff --git a/gazenet/bin/scripts.py b/gazenet/bin/scripts.py new file mode 100644 index 0000000..8854819 --- /dev/null +++ b/gazenet/bin/scripts.py @@ -0,0 +1,112 @@ +import argparse +import subprocess +import inspect +import json +from shutil import copyfile, rmtree +import os + +from PIL import Image +import scipy.io as sio +import numpy as np + +from gazenet.utils.registrar import * + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--working_dir", default="./", type=str, + help="The working directory in which the directory structure is " + "built and downloaded files are stored.") + + parser.add_argument("--scripts", type=str, nargs='+', required=True, + help="The list of script names to be executed") + + return parser.parse_args() + + +def postprocess_get_from_stavis(dst_dir): + # after preprocessing, read preprocessed groundtruth from stavis and copy + # content to the preprocessed directory. + + stavis_img_dir = os.path.join(dst_dir, "datasets", "stavis_preprocessed") + data_dir = os.path.join(dst_dir, "datasets", "processed", "Grouped_frames") + + for root, subdirs, files in os.walk(data_dir): + print("scanning", root) + for subdir in subdirs: + tgt_dir = root.replace(data_dir, "")[1:] + stavis_img_file = os.path.join(stavis_img_dir, tgt_dir, "maps", "eyeMap_" + f"{subdir:0>5}" + ".jpg") + if os.path.isfile(stavis_img_file): + copyfile(stavis_img_file, os.path.join(root, subdir, "transformed_salmap_1.jpg")) + stavis_mat_file = os.path.join(stavis_img_dir, tgt_dir, "fixMap_" + f"{subdir:0>5}" + ".mat") + if os.path.isfile(stavis_mat_file): + tmp_mat = sio.loadmat(stavis_mat_file) + binmap_np = np.array( + Image.fromarray(tmp_mat['eyeMap'].astype(float)).resize((120, 120), resample=Image.BILINEAR)) > 0 + fixmap = Image.fromarray((255 * binmap_np).astype('uint8')) + fixmap.save(os.path.join(root, subdir, "transformed_fixmap_1.jpg")) + + +def generate_config_files(dst_dir): + # generate the json files from the classes in infer_config.py and train_config.py + + # training configurations + TrainingConfigRegistrar.scan() + + for config_name, config in TrainingConfigRegistrar.registry.items(): + if inspect.isclass(config): + config_dict = {key: value for key, value in zip(dir(config), [getattr(config, k) for k in dir(config)]) if + not key.startswith('__') and not isinstance(value, classmethod) and not inspect.ismethod( + value)} + config_dict.update( + {key: getattr(config, key)() for key, value in + zip(dir(config), [getattr(config, k) for k in dir(config)]) + if + not key.startswith('__') and (isinstance(value, classmethod) or inspect.ismethod(value))}) + elif isinstance(config, dict): + config_dict = config + + with open(os.path.join(dst_dir, "gazenet", "configs", "train_configs", config_name + ".json"), 'w') as fp: + json.dump(config_dict, fp, indent=4) + + print("Generated train configs to %", os.path.join(dst_dir, "gazenet", "configs", "train_configs")) + + # inference configurations + InferenceConfigRegistrar.scan() + + for config_name, config in InferenceConfigRegistrar.registry.items(): + if inspect.isclass(config): + config_dict = {key: value for key, value in zip(dir(config), [getattr(config, k) for k in dir(config)]) if + not key.startswith('__') and not isinstance(value, classmethod) and not inspect.ismethod( + value)} + config_dict.update({key: getattr(config, key)() for key, value in + zip(dir(config), [getattr(config, k) for k in dir(config)]) if + not key.startswith('__') and ( + isinstance(value, classmethod) or inspect.ismethod(value))}) + elif isinstance(config, dict): + config_dict = config + + with open(os.path.join(dst_dir, "gazenet", "configs", "infer_configs", config_name + ".json"), 'w') as fp: + json.dump(config_dict, fp, indent=4) + + print("Generated infer configs to %", os.path.join(dst_dir, "gazenet", "configs", "infer_configs")) + +def clean_temp(dst_dir): + rmtree(os.path.join(dst_dir, "temp")) + os.mkdir(os.path.join(dst_dir, "temp")) + + +def main(): + args = parse_args() + + for script in args.scripts: + if script == "postprocess_get_from_stavis": + postprocess_get_from_stavis(args.working_dir) + if script == "generate_config_files": + generate_config_files(args.working_dir) + if script == "clean_temp": + clean_temp(args.working_dir) + + +if __name__ == "__main__": + main() diff --git a/gazenet/bin/train.py b/gazenet/bin/train.py new file mode 100644 index 0000000..7089304 --- /dev/null +++ b/gazenet/bin/train.py @@ -0,0 +1,231 @@ +import inspect +import argparse +import os +import json +import copy + +from comet_ml import Experiment +import torch +import pytorch_lightning as pl +from pytorch_lightning.loggers import * +from pytorch_lightning.callbacks import * + +from gazenet.utils.helpers import flatten_dict +from gazenet.utils.registrar import * +from gazenet.bin.infer import infer + + +def store_config_log(config, logger, prefix="", filename="train_config.cfg"): + config_dict = None + + if inspect.isclass(config): + config_dict = {key: value for key, value in config.__dict__.items() if + not key.startswith('__') and not isinstance(value, classmethod) and not inspect.ismethod( + value)} + config_dict.update({key: getattr(config, key)() for key, value in config.__dict__.items() if + not key.startswith('__') and (isinstance(value, classmethod) or inspect.ismethod(value))}) + elif isinstance(config, dict): + config_dict = config + + config_dict = flatten_dict(config_dict, "", {}) + log_path = config.log_dir + + if config_dict is not None: + if isinstance(logger, CometLogger): + logger.experiment.log_parameters(config_dict, prefix=prefix) + log_path = os.path.join(logger.save_dir, logger.experiment.project_name, + config.comet_experiment_key) + os.makedirs(log_path, exist_ok=True) + with open(os.path.join(log_path, filename), "w") as fp: + json.dump(config_dict, fp) + + if isinstance(logger, TensorBoardLogger): + logger.log_hyperparams(config_dict) + log_path = logger.log_dir + os.makedirs(log_path, exist_ok=True) + with open(os.path.join(log_path, filename), "w") as fp: + json.dump(config_dict, fp) + return log_path + + +def train(args, config, infer_configs=None): + metrics = None + log_path = config.log_dir + experiment_name = config.experiment_name + experiment_key = "_" + + # arguments made to CometLogger are passed on to the comet_ml.Experiment class + if config.logger == "comet": + logger = CometLogger( + api_key=os.environ["COMET_KEY"], + workspace=os.environ["COMET_WORKSPACE"], # Optional + project_name=config.project_name, # Optional + save_dir=config.log_dir, + experiment_name=config.experiment_name # Optional + ) + setattr(config, "comet_experiment_key", logger.experiment.id) + experiment_key = logger.experiment.id + log_path = store_config_log(config, logger=logger, prefix="train_config.") # " + args.train_config + "." + + elif config.logger == "tensorboard": + logger = TensorBoardLogger(save_dir=config.log_dir, name=config.project_name) + setattr(config, "tensorboard_experiment_key", logger.log_dir) + experiment_key = logger.log_dir.split("/")[-1] + log_path = store_config_log(config, logger=logger, prefix="train_config.") + else: + logger = False + + # saves the checkpoint + checkpoint_path = os.path.join(config.checkpoint_model_dir, experiment_name, experiment_key) + os.makedirs(checkpoint_path, exist_ok=True) + checkpoint_callback = ModelCheckpoint( + monitor='val_loss', + dirpath=checkpoint_path, + filename="model-{epoch:02d}-{val_loss:.2f}", + save_top_k=args.checkpoint_save_n_top, + period=args.checkpoint_save_every_n_epoch, + mode='min') + + # train + trainer = pl.Trainer.from_argparse_args(args, checkpoint_callback=checkpoint_callback, logger=logger) + config.model_properties.update(train_dataset_properties=config.train_dataset_properties, + val_dataset_properties=config.val_dataset_properties, + test_dataset_properties=config.test_dataset_properties) + + if args.auto_lr_find: + model_data = None + model = ModelRegistrar.registry[config.model_name](**config.model_properties, + val_store_image_samples=args.val_store_image_samples) + trainer.tune(model) + else: + if hasattr(config, "model_data_name"): + model_data = ModelDataRegistrar.registry[config.model_data_name](**config.model_properties) + model = ModelRegistrar.registry[config.model_name](**config.model_properties, + val_store_image_samples=args.val_store_image_samples, + **model_data.get_attributes()) + trainer.fit(model, model_data) + else: + model_data = None + model = ModelRegistrar.registry[config.model_name](**config.model_properties, + val_store_image_samples=args.val_store_image_samples) + trainer.fit(model) + + last_checkpoint_file = os.path.join(checkpoint_path, "last_model.pt") + torch.save(model.state_dict(), last_checkpoint_file) + + # run the models in inference + for infer_config in infer_configs: + if model_data is None: + updated_infer_config = model.update_infer_config(log_path, last_checkpoint_file, config, + copy.deepcopy(infer_config), device=args.gpus) + else: + updated_infer_config = model_data.update_infer_config(log_path, last_checkpoint_file, config, + copy.deepcopy(infer_config), device=args.gpus) + updated_infer_config.compute_metrics = args.compute_metrics + inferer_metrics = infer(args, updated_infer_config) + if metrics is None: + metrics = inferer_metrics + else: + metrics = metrics.append(inferer_metrics, ignore_index=True) + return metrics + + +def parse_args(): + trainer_summaries = "training configuration summaries:" + for config_name in TrainingConfigRegistrar.registry.keys(): + config_summary = TrainingConfigRegistrar.registry[config_name].config_info()["summary"] + config_example = TrainingConfigRegistrar.registry[config_name].config_info()["example"] + trainer_summaries += ("\n " + config_name + "\n " + config_summary + "\n example: " + config_example) + + inferer_summaries = "inference configuration summaries:" + for config_name in InferenceConfigRegistrar.registry.keys(): + config_summary = InferenceConfigRegistrar.registry[config_name].config_info()["summary"] + inferer_summaries += ("\n " + config_name + "\n " + config_summary) + + summaries = trainer_summaries + "\n\n\n\n" + inferer_summaries + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=summaries) + + parser.add_argument("--train_config", type=str, default="TrainPlayground001", required=False, + choices=TrainingConfigRegistrar.registry.keys(), + help="The training configuration. Select config from ../configs/train_config.py") + + parser.add_argument("--infer_configs", type=str, default=[], nargs='+', required=False, # "InferMetricsGASPTrain" + choices=InferenceConfigRegistrar.registry.keys(), + help="The list of inference configurations. " + "This is needed for the metrics computation. Select config from ../configs/train_config.py") + + parser.add_argument("--train_config_file", type=str, required=False, + help="The json training configuration file (overrides train_config).") + + parser.add_argument("--infer_config_files", type=str, nargs='+', required=False, # "InferMetricsGASPTrain" + help="The list of json inference configuration files (overrides infer_configs). " + "This is needed for the metrics computation.") + + parser.add_argument("--logger_name", type=str, required=False, + choices=["comet", "tensorboard", ""], + help="The logging framework name") + + parser.add_argument('--checkpoint_save_every_n_epoch', type=int, default=1000, help='Save model every n epochs') + + parser.add_argument('--checkpoint_save_n_top', type=int, default=3, help='Save top n model checkpoints') + + parser.add_argument('--val_store_image_samples', help='Store sampled validation images to logger', + action='store_true') + + parser.add_argument('--compute_metrics', help='Compute the metrics', + action='store_true') + + parser = pl.Trainer.add_argparse_args(parser) + return parser.parse_args() + + +def main(): + TrainingConfigRegistrar.scan() + InferenceConfigRegistrar.scan() + args = parse_args() + + if args.train_config_file: + with open(args.train_config_file) as fp: + data = json.load(fp) + config = TrainingConfigRegistrar.registry["TrainerBase"] + config.__name__ = os.path.splitext(os.path.basename(args.train_config_file))[0] + for data_key, data_val in data.items(): + setattr(config, data_key, data_val) + else: + config = TrainingConfigRegistrar.registry[args.train_config] + + # update config with args + setattr(config, "compute_metrics", args.compute_metrics) + if args.logger_name is not None: + setattr(config, "logger", args.logger_name) + + # inference configs + infer_configs = [] + for infer_config_name in args.infer_configs: + infer_configs.append(InferenceConfigRegistrar.registry[infer_config_name]) + + # scan the registrars + InferenceRegistrar.scan() + ReaderRegistrar.scan() + SampleRegistrar.scan() + FaceDetectorRegistrar.scan() + AudioFeatureRegistrar.scan() + + ModelRegistrar.scan() + ModelDataRegistrar.scan() + + # create metrics if enabled + if args.compute_metrics: + # scan the metrics registrar + MetricsRegistrar.scan() + + # train + metrics = train(args, config, infer_configs=infer_configs) + if config.compute_metrics: + print(metrics) + + +if __name__ == "__main__": + main() diff --git a/gazenet/configs/__init__.py b/gazenet/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/configs/infer_config.py b/gazenet/configs/infer_config.py new file mode 100644 index 0000000..37b2bee --- /dev/null +++ b/gazenet/configs/infer_config.py @@ -0,0 +1,633 @@ +# TODO (fabawi): variables marked as AUTO should update automatically, but this won't happen. Place all vars in init + +from gazenet.utils.registrar import * + + +@InferenceConfigRegistrar.register +class InferGeneratorAllModelsBase(object): + # define the reader + reader = "" + sampler = "" + + reader_properties = {} + sampler_properties = {} + sampling_properties = {"show_fixation_locations": True, "show_saliency_map": True, + "enable_transform_overlays": False, "color_map": "bone"} + # sampler_properties = {"show_saliency_map": True} + + # define the face detector + face_detector = "SFDFaceDetection" # "MTCNNFaceDetection", "DlibFaceDetection" + # define audio features needed by the models + audio_features = {"audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures"} + + # define the models + postproc_properties = {"keep_properties": False, "keep_audio": False, + "keep_plot_frames_only": True, "resize_frames": True} + # model_name, window_size, source_frames_idxs, model_properties, preproc_properties, postproc_properties -> + # each in its own model_group: model groups are executed in order + model_groups = [ + [["DAVEInference", 16, [15], {}, {}, + dict(postproc_properties, **{"plot_override": [["captured", "transformed_salmap", "transformed_fixmap", "det_transformed_dave"]]})], + ["ESR9Inference", 16, [15], {}, {}, + dict(postproc_properties, **{"plot_override": [["det_source_esr9", "det_transformed_esr9"]]})], + ["Gaze360Inference", 7, [3], {}, {}, + dict(postproc_properties, **{"plot_override": [["det_transformed_gaze360"]]})], + ["VideoGazeInference", 7, [3], {}, {}, + dict(postproc_properties, **{"plot_override": [["det_transformed_vidgaze"]]})]] + ] + + inference_properties = {"show_det_saliency_map": True, "enable_transform_overlays": False, "color_map": "jet"} + + # define the metrics calculator properties (only needed when compute_metrics=True) + metrics = "SaliencyPredictionMetrics" + metrics_list = ["aucj", "aucs", "cc", "nss", "sim"] + metrics_mappings = {"gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", # "gt_baseline": "transformed_fixmap" + "scores_info": ["gate_scores"]} + metrics_save_file = "logs/metrics/default.csv" + + # define the datasplitter properties (only needed when process_dataset_videos_only=True) + datasplitter_properties = {"train_csv_file": "datasets/processed/train_ave.csv", + "val_csv_file": "datasets/processed/validation_ave.csv", + "test_csv_file": "datasets/processed/test_ave.csv"} + + # constants + width, height = 500, 500 # the frame's width and height + stride = 1 # the number of frames to capture per inference iteration. Should be lte than max_w_size + max_w_size = 16 # AUTO: the largest window needed by any model + enable_audio = True # if only one of the models needs audio, then this should be set to True + play_audio = False # if any of the models employing audio has no source_frames_idxs. Check keep_audio in postproc_properties. DOES NOT WORK AT THE MOMENT + realtime_capture = False # capture audio/video in realtime (cam/mic) + visualize_images = False # visualize the plotters + write_images = False # if only realtime capture is False + write_videos = False # if only realtime capture is False + write_annotations = False # always set to False, since annotations not needed for training the models + process_dataset_videos_only = True # process videos only if they exist in the train,val,test sets if only realtime capture is False + compute_metrics = False # enable the metrics computation + device = "cpu" # the pytorch device to use for all models + n_jobs = [len(model_groups[0]) + 2] + [len(model_group) for model_group in model_groups[1:]] # AUTO: number of jobs to run in parallel per model group. Extraction in group[0] + + @classmethod + def config_info(cls): + return {"summary": "This generates the datasets images needed for a majority of the experiments. " + "Only dataset samples (Social) are generated. " + "The 'reader' and 'sampler' need to be set and does not write automatically. "} + + +@InferenceConfigRegistrar.register +class InferGeneratorAllModelsCoutrot1(InferGeneratorAllModelsBase): + # define the reader + reader = "Coutrot1SampleReader" + sampler = "CoutrotSample" + + write_images = True # if only realtime capture is False + + @classmethod + def config_info(cls): + return {"summary": "This generates the datasets images needed for a majority of the experiments. " + "Only dataset samples (Social) are generated. " + "It runs the 4 social cue modalities for Coutrot1. "} + + +@InferenceConfigRegistrar.register +class InferGeneratorAllModelsCoutrot2(InferGeneratorAllModelsBase): + # define the reader + reader = "Coutrot2SampleReader" + sampler = "CoutrotSample" + + write_images = True # if only realtime capture is False + + @classmethod + def config_info(cls): + return {"summary": "This generates the datasets images needed for a majority of the experiments. " + "Only dataset samples (Social) are generated. " + "It runs the 4 social cue modalities for Coutrot2. "} + + +@InferenceConfigRegistrar.register +class InferGeneratorAllModelsDIEM(InferGeneratorAllModelsBase): + # define the reader + reader = "DIEMSampleReader" + sampler = "DIEMSample" + + write_images = True # if only realtime capture is False + + @classmethod + def config_info(cls): + return {"summary": "This generates the datasets images needed for a majority of the experiments. " + "Only dataset samples (Social) are generated. " + "It runs the 4 social cue modalities for DIEM. "} + + +@InferenceConfigRegistrar.register +class InferGeneratorFindWho(InferGeneratorAllModelsBase): + # define the reader + reader = "FindWhoSampleReader" + sampler = "FindWhoSample" + + # define the models + postproc_properties = {"keep_properties": False, "keep_audio": False, + "keep_plot_frames_only": True, "resize_frames": True} + model_groups = [ + [ + ["DAVEInference", 16, [15], {}, {}, + dict(postproc_properties, **{ + "plot_override": [["captured", "transformed_salmap", "transformed_fixmap", "det_transformed_dave"]]})], + ["ESR9Inference", 16, [15], {}, {}, + dict(postproc_properties, **{"plot_override": [["det_source_esr9", "det_transformed_esr9"]]})], + ["Gaze360Inference", 7, [3], {}, {}, + dict(postproc_properties, **{"plot_override": [["det_transformed_gaze360"]]})], + # ["VideoGazeInference", 7, [3], {}, {}, + # dict(postproc_properties, **{"plot_override": [["det_transformed_vidgaze"]]})] + ], + ] + width, height = 512, 320 # the frame's width and height + visualize_images = True # visualize the plotters + write_images = True # if only realtime capture is False + write_videos = False # if only realtime capture is False + write_annotations = True # always set to False, since annotations not needed for training the models + process_dataset_videos_only = False # process videos only if they exist in the train,val,test sets if only realtime capture is False + + @classmethod + def config_info(cls): + return {"summary": "This generates the datasets annotation needed for gaze prediction experiments. " + "Only dataset samples (Social) are generated. " + "It runs the DAVE for FindWhos. "} + + +@InferenceConfigRegistrar.register +class InferMetricsGASP(InferGeneratorAllModelsBase): + # define the reader + reader = "DataSampleReader" + sampler = "DataSample" + + sampling_properties = {"show_fixation_locations": True, "show_saliency_map": True, + "enable_transform_overlays": False, "color_map": "bone", + "img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + postproc_properties = {"keep_properties": False, "keep_audio": False, + "keep_plot_frames_only": True, "resize_frames": True} + + model_groups = [ + [["GASPInference", 16, [15], {}, + {"inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"]}, + dict(postproc_properties, **{ + "plot_override": [["captured", "transformed_salmap", "transformed_fixmap", "det_transformed_gazenet"]]})]] + ] + + inference_properties = {"show_det_saliency_map": True, "enable_transform_overlays": False, "color_map": "bone"} + + compute_metrics = True + metrics_mappings = {"gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_gazenet", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", # "gt_baseline": "transformed_fixmap" + "scores_info": ["gate_scores"]} + metrics_save_file = "logs/metrics/defaultgazenet.csv" + + datasplitter_properties = {"train_csv_file": None, + "val_csv_file": None, + "test_csv_file": "datasets/processed/test_ave.csv"} + + process_dataset_videos_only = True + n_jobs = [len(model_groups[0]) + 2] + [len(model_group) for model_group in model_groups[1:]] # AUTO: number of jobs to run in parallel per model group. Extraction in group[0] + + @classmethod + def config_info(cls): + return {"summary": "This measures the saliency metrics on GASP. " + "Only dataset samples (Social) are generated. "} + + +@InferenceConfigRegistrar.register +class InferMetricsGASPTrain(InferMetricsGASP): + + postproc_properties = {"keep_properties": False, "keep_audio": False, + "keep_plot_frames_only": True, "resize_frames": True} + model_groups = [ + [["GASPInference", -1, -1, {}, {"inp_img_names_list": None}, + dict(postproc_properties, **{ + "plot_override": [["captured", "transformed_salmap", "transformed_fixmap", "det_transformed_gazenet"]]})]] + ] + + +@InferenceConfigRegistrar.register +class InferMetricsSTAViS(InferGeneratorAllModelsBase): + + postproc_properties = {"keep_properties": False, "keep_audio": False, + "keep_plot_frames_only": True, "resize_frames": True} + + model_groups = [ + [["STAViSInference", 16, [15], {"audiovisual": True}, {}, + dict(postproc_properties, **{"plot_override": [["captured", "transformed_salmap", "transformed_fixmap", "det_transformed_stavis"]]})], + ["ESR9Inference", 16, [15], {}, {}, + dict(postproc_properties, **{"plot_override": [["det_source_esr9", "det_transformed_esr9"]]})], + ["Gaze360Inference", 7, [3], {}, {}, + dict(postproc_properties, **{"plot_override": [["det_transformed_gaze360"]]})], + ["VideoGazeInference", 7, [3], {}, {}, + dict(postproc_properties, **{"plot_override": [["det_transformed_vidgaze"]]})]] + ] + + inference_properties = {"show_det_saliency_map": True, "enable_transform_overlays": False, "color_map": "bone"} + + n_jobs = [len(model_groups[0]) + 2] + [len(model_group) for model_group in model_groups[1:]] # AUTO: number of jobs to run in parallel per model group. Extraction in group[0] + + compute_metrics = True + metrics_mappings = {"gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_stavis", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", # "gt_baseline": "transformed_fixmap" + "scores_info": ["gate_scores"]} + metrics_save_file = "logs/metrics/defaultstavis.csv" + + @classmethod + def config_info(cls): + return {"summary": "This measures the saliency metrics on STAViS. " + "Only dataset samples (Social) are generated. "} + + +@InferenceConfigRegistrar.register +class InferMetricsSTAViS_VisOnly(InferGeneratorAllModelsBase): + # define the reader + reader = "DataSampleReader" + sampler = "DataSample" + + sampling_properties = {"show_fixation_locations": True, "show_saliency_map": True, + "enable_transform_overlays": False, "color_map": "bone", + "img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + postproc_properties = {"keep_properties": False, "keep_audio": False, + "keep_plot_frames_only": True, "resize_frames": True} + + model_groups = [ + [("STAViSInference", 16, [15], {"audiovisual": False}, {}, + dict(postproc_properties, **{ + "plot_override": [["captured", "transformed_salmap", "transformed_fixmap", "det_transformed_stavis"]]}))] + ] + + inference_properties = {"show_det_saliency_map": True, "enable_transform_overlays": False, "color_map": "bone"} + + n_jobs = [len(model_groups[0]) + 2] + [len(model_group) for model_group in model_groups[1:]] # AUTO: number of jobs to run in parallel per model group. Extraction in group[0] + + compute_metrics = True + metrics_mappings = {"gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_stavis", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", # "gt_baseline": "transformed_fixmap" + "scores_info": ["gate_scores"]} + metrics_save_file = "logs/metrics/defaultstavis_vis.csv" + + @classmethod + def config_info(cls): + return {"summary": "This measures the saliency metrics on STAViS. " + "Only dataset samples (Social) are generated. "} + + +@InferenceConfigRegistrar.register +class InferMetricsDAVE(InferGeneratorAllModelsBase): + + postproc_properties = {"keep_properties": False, "keep_audio": False, + "keep_plot_frames_only": True, "resize_frames": True} + + model_groups = [ + [["DAVEInference", 16, [15], {}, {}, + dict(postproc_properties, **{ + "plot_override": [["captured", "transformed_salmap", "transformed_fixmap", "det_transformed_dave"]]})], + ["ESR9Inference", 16, [15], {}, {}, + dict(postproc_properties, **{"plot_override": [["det_source_esr9", "det_transformed_esr9"]]})], + ["Gaze360Inference", 7, [3], {}, {}, + dict(postproc_properties, **{"plot_override": [["det_transformed_gaze360"]]})], + ["VideoGazeInference", 7, [3], {}, {}, + dict(postproc_properties, **{"plot_override": [["det_transformed_vidgaze"]]})]] + ] + + inference_properties = {"show_det_saliency_map": True, "enable_transform_overlays": False, "color_map": "bone"} + + n_jobs = [len(model_groups[0]) + 2] + [len(model_group) for model_group in model_groups[1:]] # AUTO: number of jobs to run in parallel per model group. Extraction in group[0] + + compute_metrics = True + metrics_mappings = {"gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_dave", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", # "gt_baseline": "transformed_fixmap" + "scores_info": ["gate_scores"]} + metrics_save_file = "logs/metrics/defaultdave.csv" + + @classmethod + def config_info(cls): + return {"summary": "This measures the saliency metrics on DAVE. " + "Only dataset samples (Social) are generated. "} + + +@InferenceConfigRegistrar.register +class InferMetricsDAVE_VisOnly(InferGeneratorAllModelsBase): + # define the reader + reader = "DataSampleReader" + sampler = "DataSample" + + sampling_properties = {"show_fixation_locations": True, "show_saliency_map": True, + "enable_transform_overlays": False, "color_map": "bone", + "img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + postproc_properties = {"keep_properties": False, "keep_audio": False, + "keep_plot_frames_only": True, "resize_frames": True} + + model_groups = [ + [["DAVEInference", 16, [15], {}, {}, + dict(postproc_properties, **{ + "plot_override": [["captured", + "transformed_salmap", + "transformed_fixmap", + "det_transformed_dave"]]})]] + ] + + inference_properties = {"show_det_saliency_map": True, "enable_transform_overlays": False, "color_map": "bone"} + + n_jobs = [len(model_groups[0]) + 2] + [len(model_group) for model_group in model_groups[1:]] # AUTO: number of jobs to run in parallel per model group. Extraction in group[0] + + compute_metrics = True + metrics_mappings = {"gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_dave", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", # "gt_baseline": "transformed_fixmap" + "scores_info": ["gate_scores"]} + metrics_save_file = "logs/metrics/defaultdave_vis.csv" + + @classmethod + def config_info(cls): + return {"summary": "This measures the saliency metrics on DAVE. " + "Only dataset samples (Social) are generated. "} + + +@InferenceConfigRegistrar.register +class InferMetricsTASED_VisOnly(InferGeneratorAllModelsBase): + # define the reader + reader = "DataSampleReader" + sampler = "DataSample" + + sampling_properties = {"show_fixation_locations": True, "show_saliency_map": True, + "enable_transform_overlays": False, "color_map": "bone", + "img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + postproc_properties = {"keep_properties": False, "keep_audio": False, + "keep_plot_frames_only": True, "resize_frames": True} + + model_groups = [ + [["TASEDInference", 32, [31], {}, {}, + dict(postproc_properties, **{"plot_override": [["captured", + "transformed_salmap", + "transformed_fixmap", + "det_transformed_tased"]]})]] + ] + + inference_properties = {"show_det_saliency_map": True, "enable_transform_overlays": False, "color_map": "bone"} + + n_jobs = [len(model_groups[0]) + 2] + [len(model_group) for model_group in model_groups[1:]] # AUTO: number of jobs to run in parallel per model group. Extraction in group[0] + + compute_metrics = True + metrics_mappings = {"gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_tased", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", # "gt_baseline": "transformed_fixmap" + "scores_info": ["gate_scores"]} + metrics_save_file = "logs/metrics/defaulttased_vis.csv" + + @classmethod + def config_info(cls): + return {"summary": "This measures the saliency metrics on TASED. " + "Only dataset samples (Social) are generated. "} + + +@InferenceConfigRegistrar.register +class InferMetricsUNISAL_VisOnly(InferGeneratorAllModelsBase): + # define the reader + reader = "DataSampleReader" + sampler = "DataSample" + + sampling_properties = {"show_fixation_locations": True, "show_saliency_map": True, + "enable_transform_overlays": False, "color_map": "bone", + "img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + postproc_properties = {"keep_properties": False, "keep_audio": False, + "keep_plot_frames_only": True, "resize_frames": True} + + model_groups = [ + [["UNISALInference", 12, [11], {}, {}, + dict(postproc_properties, **{"plot_override": [["captured", + "transformed_salmap", + "transformed_fixmap", + "det_transformed_unisal"]]})]] + ] + + inference_properties = {"show_det_saliency_map": True, "enable_transform_overlays": False, "color_map": "bone"} + + n_jobs = [len(model_groups[0]) + 2] + [len(model_group) for model_group in model_groups[ + 1:]] # AUTO: number of jobs to run in parallel per model group. Extraction in group[0] + + compute_metrics = True + metrics_mappings = {"gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_unisal", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", # "gt_baseline": "transformed_fixmap" + "scores_info": ["gate_scores"]} + metrics_save_file = "logs/metrics/defaultunisals_vis.csv" + + @classmethod + def config_info(cls): + return {"summary": "This measures the saliency metrics on UNISAL. " + "Only dataset samples (Social) are generated. "} + + + +@InferenceConfigRegistrar.register +class InferVisualizeGASPSeqDAMALSTMGMU1x1Conv_10Norm(object): + + # define the reader + reader = "DataSampleReader" + sampler = "DataSample" + + reader_properties = {} + sampler_properties = {} + sampling_properties = {"show_fixation_locations": True, "show_saliency_map": True, + "enable_transform_overlays": True, "color_map": "jet", + "img_names_list": ["transformed_salmap", "transformed_fixmap", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360"]} + + # define the face detector + face_detector = "SFDFaceDetection" # "MTCNNFaceDetection", "DlibFaceDetection" + # define audio features needed by the models + audio_features = {"audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures"} + + # define the models + postproc_properties = {"keep_properties": False, "keep_audio": False, + "keep_plot_frames_only": True, "resize_frames": True} + + # model_name, window_size, source_frames_idxs, enable_audio, preproc_properties, postproc_properties -> + model_groups = [ + [ + # ["GASPInference", 10, [9], {"weights_file": "seqdamalstmgmu_110nofer", "modalities": 4, "batch_size": 1, "sequence_len": 10, "sequence_norm": True}, + ["GASPInference", 10, [9], {"weights_file": "seqdamalstmgmu", "modalities": 5, "batch_size": 1, "sequence_len": 10, "sequence_norm": True}, + # ["GASPInference", 1, [0], {"modalities": 5, "batch_size": 1, "model_name": "GASPDAMEncGMUConv", "frames_len": 1, "weights_file": "damgmu"}, + {"inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"]}, + dict(postproc_properties, **{"plot_override": [["transformed_fixmap", + "det_transformed_esr9", + "det_transformed_dave", + "det_transformed_vidgaze", + "det_transformed_gaze360", + "det_transformed_gasp"]]})] + ], + ] + + inference_properties = {"show_det_saliency_map": True, "enable_transform_overlays":True, "color_map": "jet"} + # inference_properties = {"show_saliency_map": True} + + # define the metrics calculator + metrics = "SaliencyPredictionMetrics" + metrics_list = ["aucj", "aucs", "cc", "nss", "sim"] + metrics_mappings = {"gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_dave", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", # "gt_baseline": "transformed_fixmap" + "scores_info": ["gate_scores"]} + metrics_save_file = "logs/metrics/default.csv" + + datasplitter_properties = {"train_csv_file": "datasets/processed/test_ave.csv", + "val_csv_file": None, + "test_csv_file": None} + # constants + width, height = 500, 500 # the frame's width and height + stride = 1 # the number of frames to capture per inference iteration. Should be lte than max_w_size + max_w_size = 10 # AUTO: the largest window needed by any model + enable_audio = True # AUTO: if only one of the models needs audio, then this will automatically be True + play_audio = False # if any of the models employing audio has no source_frames_idxs. Check keep_audio in postproc_properties + realtime_capture = False # capture audio/video in realtime (cam/mic) + visualize_images = True # visualize the plotters + write_images = False # if only realtime capture is False + write_videos = True # if only realtime capture is False + write_annotations = False # always set to False, since annotations not needed for training the models + process_dataset_videos_only = True # process videos only if they exist in the train,val,test sets if only realtime capture is False + compute_metrics = False # enable the metrics computation + device = "cpu" # the pytorch device to use for all models + n_jobs = [len(model_groups[0]) + 2] + [len(model_group) for model_group in model_groups[1:]] # AUTO: number of jobs to run in parallel per model group. Extraction in group[0] + + @classmethod + def config_info(cls): + return {"summary": "This visualizes the sequential GASP model (DAM + LARGMU; Context Size = 10)"} + + +@InferenceConfigRegistrar.register +class InferVisualizeGASPSeqDAMALSTMGMU1x1Conv_10Norm_110(InferVisualizeGASPSeqDAMALSTMGMU1x1Conv_10Norm): + sampling_properties = {"show_fixation_locations": True, "show_saliency_map": True, + "enable_transform_overlays": True, "color_map": "jet", + "img_names_list": ["transformed_salmap", "transformed_fixmap", + "det_transformed_dave", + "det_transformed_vidgaze", + "det_transformed_gaze360"]} + + # define the models + postproc_properties = {"keep_properties": False, "keep_audio": False, + "keep_plot_frames_only": True, "resize_frames": True} + + # model_name, window_size, source_frames_idxs, enable_audio, preproc_properties, postproc_properties -> + model_groups = [ + [ + ["GASPInference", 10, [9], {"weights_file": "seqdamalstmgmu_110nofer", "modalities": 4, "batch_size": 1, "sequence_len": 10, "sequence_norm": True}, + {"inp_img_names_list": ["captured", "det_transformed_dave", + "det_transformed_vidgaze", "det_transformed_gaze360"]}, + dict(postproc_properties, **{"plot_override": [["transformed_fixmap", + "det_transformed_dave", + "det_transformed_vidgaze", + "det_transformed_gaze360", + "det_transformed_gasp"]]})] + ], + ] + + # constants + max_w_size = 10 # AUTO: the largest window needed by any model + + @classmethod + def config_info(cls): + return {"summary": "This visualizes the sequential GASP model (DAM + LARGMU; Context Size = 10) " + "excluding the FER modality"} + + +@InferenceConfigRegistrar.register +class InferVisualizeGASPDAMGMU1x1Conv(InferVisualizeGASPSeqDAMALSTMGMU1x1Conv_10Norm): + sampling_properties = {"show_fixation_locations": True, "show_saliency_map": True, + "enable_transform_overlays": True, "color_map": "jet", + "img_names_list": ["transformed_salmap", "transformed_fixmap", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360"]} + + # define the models + postproc_properties = {"keep_properties": False, "keep_audio": False, + "keep_plot_frames_only": True, "resize_frames": True} + + # model_name, window_size, source_frames_idxs, enable_audio, preproc_properties, postproc_properties -> + model_groups = [ + [ + ["GASPInference", 1, [0], {"weights_file": "damgmu", "modalities": 5, "batch_size": 1, "model_name": "GASPDAMEncGMUConv", "frames_len": 1}, + {"inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"]}, + dict(postproc_properties, **{"plot_override": [["transformed_fixmap", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360", + "det_transformed_gasp"]]})] + ], + ] + + # constants + max_w_size = 1 # AUTO: the largest window needed by any model + + @classmethod + def config_info(cls): + return {"summary": "This visualizes the static GASP model (DAM + GMU)"} + + +@InferenceConfigRegistrar.register +class InferVisualizeGASPSeqDAMGMUALSTM1x1Conv_10Norm(InferVisualizeGASPSeqDAMALSTMGMU1x1Conv_10Norm): + sampling_properties = {"show_fixation_locations": True, "show_saliency_map": True, + "enable_transform_overlays": True, "color_map": "jet", + "img_names_list": ["transformed_salmap", "transformed_fixmap", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360"]} + + # define the models + postproc_properties = {"keep_properties": False, "keep_audio": False, + "keep_plot_frames_only": True, "resize_frames": True} + + # model_name, window_size, source_frames_idxs, enable_audio, preproc_properties, postproc_properties -> + model_groups = [ + [ + ["GASPInference", 10, [9], {"weights_file": "seqdamgmualstm", "modalities": 5, "batch_size": 1, "model_name": "SequenceGASPDAMEncGMUALSTMConv", "sequence_len": 10, "sequence_norm": True}, + {"inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"]}, + dict(postproc_properties, **{"plot_override": [["transformed_fixmap", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360", + "det_transformed_gasp"]]})] + ], + ] + + # constants + max_w_size = 10 # AUTO: the largest window needed by any model + + @classmethod + def config_info(cls): + return {"summary": "This visualizes the static GASP model (DAM + GMU)"} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferGeneratorAllModelsBase.json b/gazenet/configs/infer_configs/InferGeneratorAllModelsBase.json new file mode 100644 index 0000000..d3acfd3 --- /dev/null +++ b/gazenet/configs/infer_configs/InferGeneratorAllModelsBase.json @@ -0,0 +1,159 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": false, + "datasplitter_properties": { + "train_csv_file": "datasets/processed/train_ave.csv", + "val_csv_file": "datasets/processed/validation_ave.csv", + "test_csv_file": "datasets/processed/test_ave.csv" + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 500, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "jet" + }, + "max_w_size": 16, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/default.csv", + "model_groups": [ + [ + [ + "DAVEInference", + 16, + [ + 15 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "captured", + "transformed_salmap", + "transformed_fixmap", + "det_transformed_dave" + ] + ] + } + ], + [ + "ESR9Inference", + 16, + [ + 15 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_source_esr9", + "det_transformed_esr9" + ] + ] + } + ], + [ + "Gaze360Inference", + 7, + [ + 3 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_transformed_gaze360" + ] + ] + } + ], + [ + "VideoGazeInference", + 7, + [ + 3 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_transformed_vidgaze" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 6 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": true, + "reader": "", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone" + }, + "stride": 1, + "visualize_images": false, + "width": 500, + "write_annotations": false, + "write_images": false, + "write_videos": false, + "config_info": { + "summary": "This generates the datasets images needed for a majority of the experiments. Only dataset samples (Social) are generated. The 'reader' and 'sampler' need to be set and does not write automatically. " + } +} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferGeneratorAllModelsCoutrot1.json b/gazenet/configs/infer_configs/InferGeneratorAllModelsCoutrot1.json new file mode 100644 index 0000000..865d040 --- /dev/null +++ b/gazenet/configs/infer_configs/InferGeneratorAllModelsCoutrot1.json @@ -0,0 +1,159 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": false, + "datasplitter_properties": { + "train_csv_file": "datasets/processed/train_ave.csv", + "val_csv_file": "datasets/processed/validation_ave.csv", + "test_csv_file": "datasets/processed/test_ave.csv" + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 500, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "jet" + }, + "max_w_size": 16, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/default.csv", + "model_groups": [ + [ + [ + "DAVEInference", + 16, + [ + 15 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "captured", + "transformed_salmap", + "transformed_fixmap", + "det_transformed_dave" + ] + ] + } + ], + [ + "ESR9Inference", + 16, + [ + 15 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_source_esr9", + "det_transformed_esr9" + ] + ] + } + ], + [ + "Gaze360Inference", + 7, + [ + 3 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_transformed_gaze360" + ] + ] + } + ], + [ + "VideoGazeInference", + 7, + [ + 3 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_transformed_vidgaze" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 6 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": true, + "reader": "Coutrot1SampleReader", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "CoutrotSample", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone" + }, + "stride": 1, + "visualize_images": false, + "width": 500, + "write_annotations": false, + "write_images": true, + "write_videos": false, + "config_info": { + "summary": "This generates the datasets images needed for a majority of the experiments. Only dataset samples (Social) are generated. It runs the 4 social cue modalities for Coutrot1. " + } +} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferGeneratorAllModelsCoutrot2.json b/gazenet/configs/infer_configs/InferGeneratorAllModelsCoutrot2.json new file mode 100644 index 0000000..b8dfa00 --- /dev/null +++ b/gazenet/configs/infer_configs/InferGeneratorAllModelsCoutrot2.json @@ -0,0 +1,159 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": false, + "datasplitter_properties": { + "train_csv_file": "datasets/processed/train_ave.csv", + "val_csv_file": "datasets/processed/validation_ave.csv", + "test_csv_file": "datasets/processed/test_ave.csv" + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 500, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "jet" + }, + "max_w_size": 16, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/default.csv", + "model_groups": [ + [ + [ + "DAVEInference", + 16, + [ + 15 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "captured", + "transformed_salmap", + "transformed_fixmap", + "det_transformed_dave" + ] + ] + } + ], + [ + "ESR9Inference", + 16, + [ + 15 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_source_esr9", + "det_transformed_esr9" + ] + ] + } + ], + [ + "Gaze360Inference", + 7, + [ + 3 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_transformed_gaze360" + ] + ] + } + ], + [ + "VideoGazeInference", + 7, + [ + 3 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_transformed_vidgaze" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 6 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": true, + "reader": "Coutrot2SampleReader", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "CoutrotSample", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone" + }, + "stride": 1, + "visualize_images": false, + "width": 500, + "write_annotations": false, + "write_images": true, + "write_videos": false, + "config_info": { + "summary": "This generates the datasets images needed for a majority of the experiments. Only dataset samples (Social) are generated. It runs the 4 social cue modalities for Coutrot2. " + } +} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferGeneratorAllModelsDIEM.json b/gazenet/configs/infer_configs/InferGeneratorAllModelsDIEM.json new file mode 100644 index 0000000..5e2428a --- /dev/null +++ b/gazenet/configs/infer_configs/InferGeneratorAllModelsDIEM.json @@ -0,0 +1,159 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": false, + "datasplitter_properties": { + "train_csv_file": "datasets/processed/train_ave.csv", + "val_csv_file": "datasets/processed/validation_ave.csv", + "test_csv_file": "datasets/processed/test_ave.csv" + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 500, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "jet" + }, + "max_w_size": 16, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/default.csv", + "model_groups": [ + [ + [ + "DAVEInference", + 16, + [ + 15 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "captured", + "transformed_salmap", + "transformed_fixmap", + "det_transformed_dave" + ] + ] + } + ], + [ + "ESR9Inference", + 16, + [ + 15 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_source_esr9", + "det_transformed_esr9" + ] + ] + } + ], + [ + "Gaze360Inference", + 7, + [ + 3 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_transformed_gaze360" + ] + ] + } + ], + [ + "VideoGazeInference", + 7, + [ + 3 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_transformed_vidgaze" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 6 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": true, + "reader": "DIEMSampleReader", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "DIEMSample", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone" + }, + "stride": 1, + "visualize_images": false, + "width": 500, + "write_annotations": false, + "write_images": true, + "write_videos": false, + "config_info": { + "summary": "This generates the datasets images needed for a majority of the experiments. Only dataset samples (Social) are generated. It runs the 4 social cue modalities for DIEM. " + } +} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferGeneratorFindWho.json b/gazenet/configs/infer_configs/InferGeneratorFindWho.json new file mode 100644 index 0000000..a2fe692 --- /dev/null +++ b/gazenet/configs/infer_configs/InferGeneratorFindWho.json @@ -0,0 +1,139 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": false, + "datasplitter_properties": { + "train_csv_file": "datasets/processed/train_ave.csv", + "val_csv_file": "datasets/processed/validation_ave.csv", + "test_csv_file": "datasets/processed/test_ave.csv" + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 320, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "jet" + }, + "max_w_size": 16, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/default.csv", + "model_groups": [ + [ + [ + "DAVEInference", + 16, + [ + 15 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "captured", + "transformed_salmap", + "transformed_fixmap", + "det_transformed_dave" + ] + ] + } + ], + [ + "ESR9Inference", + 16, + [ + 15 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_source_esr9", + "det_transformed_esr9" + ] + ] + } + ], + [ + "Gaze360Inference", + 7, + [ + 3 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_transformed_gaze360" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 6 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": false, + "reader": "FindWhoSampleReader", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "FindWhoSample", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone" + }, + "stride": 1, + "visualize_images": true, + "width": 512, + "write_annotations": true, + "write_images": true, + "write_videos": false, + "config_info": { + "summary": "This generates the datasets annotation needed for gaze prediction experiments. Only dataset samples (Social) are generated. It runs the DAVE for FindWhos. " + } +} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferMetricsDAVE.json b/gazenet/configs/infer_configs/InferMetricsDAVE.json new file mode 100644 index 0000000..458d119 --- /dev/null +++ b/gazenet/configs/infer_configs/InferMetricsDAVE.json @@ -0,0 +1,159 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": true, + "datasplitter_properties": { + "train_csv_file": "datasets/processed/train_ave.csv", + "val_csv_file": "datasets/processed/validation_ave.csv", + "test_csv_file": "datasets/processed/test_ave.csv" + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 500, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone" + }, + "max_w_size": 16, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_dave", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/defaultdave.csv", + "model_groups": [ + [ + [ + "DAVEInference", + 16, + [ + 15 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "captured", + "transformed_salmap", + "transformed_fixmap", + "det_transformed_dave" + ] + ] + } + ], + [ + "ESR9Inference", + 16, + [ + 15 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_source_esr9", + "det_transformed_esr9" + ] + ] + } + ], + [ + "Gaze360Inference", + 7, + [ + 3 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_transformed_gaze360" + ] + ] + } + ], + [ + "VideoGazeInference", + 7, + [ + 3 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_transformed_vidgaze" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 6 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": true, + "reader": "", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone" + }, + "stride": 1, + "visualize_images": false, + "width": 500, + "write_annotations": false, + "write_images": false, + "write_videos": false, + "config_info": { + "summary": "This measures the saliency metrics on DAVE. Only dataset samples (Social) are generated. " + } +} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferMetricsDAVE_VisOnly.json b/gazenet/configs/infer_configs/InferMetricsDAVE_VisOnly.json new file mode 100644 index 0000000..6d6ed60 --- /dev/null +++ b/gazenet/configs/infer_configs/InferMetricsDAVE_VisOnly.json @@ -0,0 +1,102 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": true, + "datasplitter_properties": { + "train_csv_file": "datasets/processed/train_ave.csv", + "val_csv_file": "datasets/processed/validation_ave.csv", + "test_csv_file": "datasets/processed/test_ave.csv" + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 500, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone" + }, + "max_w_size": 16, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_dave", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/defaultdave_vis.csv", + "model_groups": [ + [ + [ + "DAVEInference", + 16, + [ + 15 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "captured", + "transformed_salmap", + "transformed_fixmap", + "det_transformed_dave" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 3 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": true, + "reader": "DataSampleReader", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "DataSample", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone", + "img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "stride": 1, + "visualize_images": false, + "width": 500, + "write_annotations": false, + "write_images": false, + "write_videos": false, + "config_info": { + "summary": "This measures the saliency metrics on DAVE. Only dataset samples (Social) are generated. " + } +} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferMetricsGASP.json b/gazenet/configs/infer_configs/InferMetricsGASP.json new file mode 100644 index 0000000..6a5b6d5 --- /dev/null +++ b/gazenet/configs/infer_configs/InferMetricsGASP.json @@ -0,0 +1,110 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": true, + "datasplitter_properties": { + "train_csv_file": null, + "val_csv_file": null, + "test_csv_file": "datasets/processed/test_ave.csv" + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 500, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone" + }, + "max_w_size": 16, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_gazenet", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/defaultgazenet.csv", + "model_groups": [ + [ + [ + "GASPInference", + 16, + [ + 15 + ], + {}, + { + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ] + }, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "captured", + "transformed_salmap", + "transformed_fixmap", + "det_transformed_gazenet" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 3 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": true, + "reader": "DataSampleReader", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "DataSample", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone", + "img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "stride": 1, + "visualize_images": false, + "width": 500, + "write_annotations": false, + "write_images": false, + "write_videos": false, + "config_info": { + "summary": "This measures the saliency metrics on GASP. Only dataset samples (Social) are generated. " + } +} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferMetricsGASPTrain.json b/gazenet/configs/infer_configs/InferMetricsGASPTrain.json new file mode 100644 index 0000000..681aaf3 --- /dev/null +++ b/gazenet/configs/infer_configs/InferMetricsGASPTrain.json @@ -0,0 +1,102 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": true, + "datasplitter_properties": { + "train_csv_file": null, + "val_csv_file": null, + "test_csv_file": "datasets/processed/test_ave.csv" + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 500, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone" + }, + "max_w_size": 16, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_gazenet", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/defaultgazenet.csv", + "model_groups": [ + [ + [ + "GASPInference", + -1, + -1, + {}, + { + "inp_img_names_list": null + }, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "captured", + "transformed_salmap", + "transformed_fixmap", + "det_transformed_gazenet" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 3 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": true, + "reader": "DataSampleReader", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "DataSample", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone", + "img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "stride": 1, + "visualize_images": false, + "width": 500, + "write_annotations": false, + "write_images": false, + "write_videos": false, + "config_info": { + "summary": "This measures the saliency metrics on GASP. Only dataset samples (Social) are generated. " + } +} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferMetricsSTAViS.json b/gazenet/configs/infer_configs/InferMetricsSTAViS.json new file mode 100644 index 0000000..3e39829 --- /dev/null +++ b/gazenet/configs/infer_configs/InferMetricsSTAViS.json @@ -0,0 +1,161 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": true, + "datasplitter_properties": { + "train_csv_file": "datasets/processed/train_ave.csv", + "val_csv_file": "datasets/processed/validation_ave.csv", + "test_csv_file": "datasets/processed/test_ave.csv" + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 500, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone" + }, + "max_w_size": 16, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_stavis", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/defaultstavis.csv", + "model_groups": [ + [ + [ + "STAViSInference", + 16, + [ + 15 + ], + { + "audiovisual": true + }, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "captured", + "transformed_salmap", + "transformed_fixmap", + "det_transformed_stavis" + ] + ] + } + ], + [ + "ESR9Inference", + 16, + [ + 15 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_source_esr9", + "det_transformed_esr9" + ] + ] + } + ], + [ + "Gaze360Inference", + 7, + [ + 3 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_transformed_gaze360" + ] + ] + } + ], + [ + "VideoGazeInference", + 7, + [ + 3 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "det_transformed_vidgaze" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 6 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": true, + "reader": "", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone" + }, + "stride": 1, + "visualize_images": false, + "width": 500, + "write_annotations": false, + "write_images": false, + "write_videos": false, + "config_info": { + "summary": "This measures the saliency metrics on STAViS. Only dataset samples (Social) are generated. " + } +} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferMetricsSTAViS_VisOnly.json b/gazenet/configs/infer_configs/InferMetricsSTAViS_VisOnly.json new file mode 100644 index 0000000..6c83573 --- /dev/null +++ b/gazenet/configs/infer_configs/InferMetricsSTAViS_VisOnly.json @@ -0,0 +1,104 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": true, + "datasplitter_properties": { + "train_csv_file": "datasets/processed/train_ave.csv", + "val_csv_file": "datasets/processed/validation_ave.csv", + "test_csv_file": "datasets/processed/test_ave.csv" + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 500, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone" + }, + "max_w_size": 16, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_stavis", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/defaultstavis_vis.csv", + "model_groups": [ + [ + [ + "STAViSInference", + 16, + [ + 15 + ], + { + "audiovisual": false + }, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "captured", + "transformed_salmap", + "transformed_fixmap", + "det_transformed_stavis" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 3 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": true, + "reader": "DataSampleReader", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "DataSample", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone", + "img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "stride": 1, + "visualize_images": false, + "width": 500, + "write_annotations": false, + "write_images": false, + "write_videos": false, + "config_info": { + "summary": "This measures the saliency metrics on STAViS. Only dataset samples (Social) are generated. " + } +} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferMetricsTASED_VisOnly.json b/gazenet/configs/infer_configs/InferMetricsTASED_VisOnly.json new file mode 100644 index 0000000..6cf91e4 --- /dev/null +++ b/gazenet/configs/infer_configs/InferMetricsTASED_VisOnly.json @@ -0,0 +1,102 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": true, + "datasplitter_properties": { + "train_csv_file": "datasets/processed/train_ave.csv", + "val_csv_file": "datasets/processed/validation_ave.csv", + "test_csv_file": "datasets/processed/test_ave.csv" + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 500, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone" + }, + "max_w_size": 16, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_tased", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/defaulttased_vis.csv", + "model_groups": [ + [ + [ + "TASEDInference", + 32, + [ + 31 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "captured", + "transformed_salmap", + "transformed_fixmap", + "det_transformed_tased" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 3 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": true, + "reader": "DataSampleReader", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "DataSample", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone", + "img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "stride": 1, + "visualize_images": false, + "width": 500, + "write_annotations": false, + "write_images": false, + "write_videos": false, + "config_info": { + "summary": "This measures the saliency metrics on TASED. Only dataset samples (Social) are generated. " + } +} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferMetricsUNISAL_VisOnly.json b/gazenet/configs/infer_configs/InferMetricsUNISAL_VisOnly.json new file mode 100644 index 0000000..cf79551 --- /dev/null +++ b/gazenet/configs/infer_configs/InferMetricsUNISAL_VisOnly.json @@ -0,0 +1,102 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": true, + "datasplitter_properties": { + "train_csv_file": "datasets/processed/train_ave.csv", + "val_csv_file": "datasets/processed/validation_ave.csv", + "test_csv_file": "datasets/processed/test_ave.csv" + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 500, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone" + }, + "max_w_size": 16, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_unisal", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/defaultunisals_vis.csv", + "model_groups": [ + [ + [ + "UNISALInference", + 12, + [ + 11 + ], + {}, + {}, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "captured", + "transformed_salmap", + "transformed_fixmap", + "det_transformed_unisal" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 3 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": true, + "reader": "DataSampleReader", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "DataSample", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": false, + "color_map": "bone", + "img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "stride": 1, + "visualize_images": false, + "width": 500, + "write_annotations": false, + "write_images": false, + "write_videos": false, + "config_info": { + "summary": "This measures the saliency metrics on UNISAL. Only dataset samples (Social) are generated. " + } +} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferVisualizeGASPDAMGMU1x1Conv.json b/gazenet/configs/infer_configs/InferVisualizeGASPDAMGMU1x1Conv.json new file mode 100644 index 0000000..1229478 --- /dev/null +++ b/gazenet/configs/infer_configs/InferVisualizeGASPDAMGMU1x1Conv.json @@ -0,0 +1,122 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": false, + "datasplitter_properties": { + "train_csv_file": "datasets/processed/test_ave.csv", + "val_csv_file": null, + "test_csv_file": null + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 500, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": true, + "color_map": "jet" + }, + "max_w_size": 1, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_dave", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/default.csv", + "model_groups": [ + [ + [ + "GASPInference", + 1, + [ + 0 + ], + { + "weights_file": "damgmu", + "modalities": 5, + "batch_size": 1, + "model_name": "GASPDAMEncGMUConv", + "frames_len": 1 + }, + { + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ] + }, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "transformed_fixmap", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360", + "det_transformed_gasp" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 3 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": true, + "reader": "DataSampleReader", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "DataSample", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": true, + "color_map": "jet", + "img_names_list": [ + "transformed_salmap", + "transformed_fixmap", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ] + }, + "stride": 1, + "visualize_images": true, + "width": 500, + "write_annotations": false, + "write_images": false, + "write_videos": true, + "config_info": { + "summary": "This visualizes the static GASP model (DAM + GMU)" + } +} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferVisualizeGASPSeqDAMALSTMGMU1x1Conv_10Norm.json b/gazenet/configs/infer_configs/InferVisualizeGASPSeqDAMALSTMGMU1x1Conv_10Norm.json new file mode 100644 index 0000000..89d3c30 --- /dev/null +++ b/gazenet/configs/infer_configs/InferVisualizeGASPSeqDAMALSTMGMU1x1Conv_10Norm.json @@ -0,0 +1,122 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": false, + "datasplitter_properties": { + "train_csv_file": "datasets/processed/test_ave.csv", + "val_csv_file": null, + "test_csv_file": null + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 500, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": true, + "color_map": "jet" + }, + "max_w_size": 10, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_dave", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/default.csv", + "model_groups": [ + [ + [ + "GASPInference", + 10, + [ + 9 + ], + { + "weights_file": "seqdamalstmgmu", + "modalities": 5, + "batch_size": 1, + "sequence_len": 10, + "sequence_norm": true + }, + { + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ] + }, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "transformed_fixmap", + "det_transformed_esr9", + "det_transformed_dave", + "det_transformed_vidgaze", + "det_transformed_gaze360", + "det_transformed_gasp" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 3 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": true, + "reader": "DataSampleReader", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "DataSample", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": true, + "color_map": "jet", + "img_names_list": [ + "transformed_salmap", + "transformed_fixmap", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ] + }, + "stride": 1, + "visualize_images": true, + "width": 500, + "write_annotations": false, + "write_images": false, + "write_videos": true, + "config_info": { + "summary": "This visualizes the sequential GASP model (DAM + LARGMU; Context Size = 10)" + } +} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferVisualizeGASPSeqDAMALSTMGMU1x1Conv_10Norm_110.json b/gazenet/configs/infer_configs/InferVisualizeGASPSeqDAMALSTMGMU1x1Conv_10Norm_110.json new file mode 100644 index 0000000..16ba68f --- /dev/null +++ b/gazenet/configs/infer_configs/InferVisualizeGASPSeqDAMALSTMGMU1x1Conv_10Norm_110.json @@ -0,0 +1,119 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": false, + "datasplitter_properties": { + "train_csv_file": "datasets/processed/test_ave.csv", + "val_csv_file": null, + "test_csv_file": null + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 500, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": true, + "color_map": "jet" + }, + "max_w_size": 10, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_dave", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/default.csv", + "model_groups": [ + [ + [ + "GASPInference", + 10, + [ + 9 + ], + { + "weights_file": "seqdamalstmgmu_110nofer", + "modalities": 4, + "batch_size": 1, + "sequence_len": 10, + "sequence_norm": true + }, + { + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ] + }, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "transformed_fixmap", + "det_transformed_dave", + "det_transformed_vidgaze", + "det_transformed_gaze360", + "det_transformed_gasp" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 3 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": true, + "reader": "DataSampleReader", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "DataSample", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": true, + "color_map": "jet", + "img_names_list": [ + "transformed_salmap", + "transformed_fixmap", + "det_transformed_dave", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ] + }, + "stride": 1, + "visualize_images": true, + "width": 500, + "write_annotations": false, + "write_images": false, + "write_videos": true, + "config_info": { + "summary": "This visualizes the sequential GASP model (DAM + LARGMU; Context Size = 10) excluding the FER modality" + } +} \ No newline at end of file diff --git a/gazenet/configs/infer_configs/InferVisualizeGASPSeqDAMGMUALSTM1x1Conv_10Norm.json b/gazenet/configs/infer_configs/InferVisualizeGASPSeqDAMGMUALSTM1x1Conv_10Norm.json new file mode 100644 index 0000000..402a759 --- /dev/null +++ b/gazenet/configs/infer_configs/InferVisualizeGASPSeqDAMGMUALSTM1x1Conv_10Norm.json @@ -0,0 +1,123 @@ +{ + "audio_features": { + "audio_features": "MFCCAudioFeatures", + "hann_audio_frames": "WindowedAudioFeatures" + }, + "compute_metrics": false, + "datasplitter_properties": { + "train_csv_file": "datasets/processed/test_ave.csv", + "val_csv_file": null, + "test_csv_file": null + }, + "device": "cpu", + "enable_audio": true, + "face_detector": "SFDFaceDetection", + "height": 500, + "inference_properties": { + "show_det_saliency_map": true, + "enable_transform_overlays": true, + "color_map": "jet" + }, + "max_w_size": 10, + "metrics": "SaliencyPredictionMetrics", + "metrics_list": [ + "aucj", + "aucs", + "cc", + "nss", + "sim" + ], + "metrics_mappings": { + "gt_salmap": "transformed_salmap", + "gt_fixmap": "transformed_fixmap", + "pred_salmap": "det_transformed_dave", + "gt_baseline": "datasets/processed/center_bias_bw.jpg", + "scores_info": [ + "gate_scores" + ] + }, + "metrics_save_file": "logs/metrics/default.csv", + "model_groups": [ + [ + [ + "GASPInference", + 10, + [ + 9 + ], + { + "weights_file": "seqdamgmualstm", + "modalities": 5, + "batch_size": 1, + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "sequence_len": 10, + "sequence_norm": true + }, + { + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ] + }, + { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true, + "plot_override": [ + [ + "transformed_fixmap", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360", + "det_transformed_gasp" + ] + ] + } + ] + ] + ], + "n_jobs": [ + 3 + ], + "play_audio": false, + "postproc_properties": { + "keep_properties": false, + "keep_audio": false, + "keep_plot_frames_only": true, + "resize_frames": true + }, + "process_dataset_videos_only": true, + "reader": "DataSampleReader", + "reader_properties": {}, + "realtime_capture": false, + "sampler": "DataSample", + "sampler_properties": {}, + "sampling_properties": { + "show_fixation_locations": true, + "show_saliency_map": true, + "enable_transform_overlays": true, + "color_map": "jet", + "img_names_list": [ + "transformed_salmap", + "transformed_fixmap", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ] + }, + "stride": 1, + "visualize_images": true, + "width": 500, + "write_annotations": false, + "write_images": false, + "write_videos": true, + "config_info": { + "summary": "This visualizes the static GASP model (DAM + GMU)" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_config.py b/gazenet/configs/train_config.py new file mode 100644 index 0000000..5f8af40 --- /dev/null +++ b/gazenet/configs/train_config.py @@ -0,0 +1,2490 @@ +from gazenet.utils.registrar import * + + +@TrainingConfigRegistrar.register +class TrainerBase(object): + inferer_name = "" + model_name = "" + model_properties = {} + log_dir = "" + logger = "" # comet, tensorboard + project_name = "" + experiment_name = model_name + checkpoint_model_dir = "" + train_dataset_properties = {} + val_dataset_properties = {} + test_dataset_properties = {} + + @classmethod + def config_info(cls): + return {"summary": "This is base class and cannot be used directly.", + "example": "None"} + +@TrainingConfigRegistrar.register +class GASPExp001_1x1Conv(object): + + inferer_name = "GASPInference" + model_name = "GASPEncConv" + model_properties = {"modalities": 5, "batch_size": 4} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + @classmethod + def config_info(cls): + return {"summary": "Static GASP: 1x1 convolutional variant. Not included in the paper.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp001_ALSTM1x1Conv(object): + + inferer_name = "GASPInference" + model_name = "GASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Static GASP: ALSTM variant.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp001_Add1x1Conv(object): + + inferer_name = "GASPInference" + model_name = "GASPEncAddConv" + model_properties = {"modalities": 5, "batch_size": 4} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Static GASP: Additive variant.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp001_SE1x1Conv(object): + + inferer_name = "GASPInference" + model_name = "GASPSEEncConv" + model_properties = {"modalities": 5, "batch_size": 4} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Static GASP: SE variant.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + + +@TrainingConfigRegistrar.register +class GASPExp001_GMU1x1Conv(object): + + inferer_name = "GASPInference" + model_name = "GASPEncGMUConv" + model_properties = {"modalities": 5, "batch_size": 4} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Static GASP: GMU variant.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp001_GMUALSTM1x1Conv(object): + + inferer_name = "GASPInference" + model_name = "GASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Static GASP: AGMU variant.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp001_ALSTMGMU1x1Conv(object): + + inferer_name = "GASPInference" + model_name = "GASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Static GASP: LAGMU variant.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp001_DAMALSTMGMU1x1Conv(object): + + inferer_name = "GASPInference" + model_name = "GASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Static GASP: DAM + LAGMU variant.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp001_DAMGMU1x1Conv(object): + + inferer_name = "GASPInference" + model_name = "GASPDAMEncGMUConv" + model_properties = {"modalities": 5, "batch_size": 4} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Static GASP: DAM + GMU variant.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp001_DAMGMUALSTM1x1Conv(object): + + inferer_name = "GASPInference" + model_name = "GASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Static GASP: DAM + AGMU variant.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTM1x1Conv_2(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 2, "sequence_norm": False} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->" + "N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTM1x1Conv_4(GASPExp002_SeqALSTM1x1Conv_2): + + model_name = "SequenceGASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 4, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTM1x1Conv_6(GASPExp002_SeqALSTM1x1Conv_2): + + model_name = "SequenceGASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 6, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTM1x1Conv_8(GASPExp002_SeqALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 8, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTM1x1Conv_10(GASPExp002_SeqALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 10, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTM1x1Conv_12(GASPExp002_SeqALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 12, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTM1x1Conv_14(GASPExp002_SeqALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 14, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTM1x1Conv_16(GASPExp002_SeqALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 16, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTM1x1Conv_2Norm(GASPExp002_SeqALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 2, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTM1x1Conv_4Norm(GASPExp002_SeqALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 4, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTM1x1Conv_6Norm(GASPExp002_SeqALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 6, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTM1x1Conv_8Norm(GASPExp002_SeqALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 8, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTM1x1Conv_10Norm(GASPExp002_SeqALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTM1x1Conv_12Norm(GASPExp002_SeqALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 12, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTM1x1Conv_14Norm(GASPExp002_SeqALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 14, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTM1x1Conv_16Norm(GASPExp002_SeqALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 16, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTMGMU1x1Conv_2(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 2, "sequence_norm": False} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->" + "N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTMGMU1x1Conv_4(GASPExp002_SeqALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 4, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTMGMU1x1Conv_6(GASPExp002_SeqALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 6, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTMGMU1x1Conv_8(GASPExp002_SeqALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 8, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTMGMU1x1Conv_10(GASPExp002_SeqALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 10, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTMGMU1x1Conv_12(GASPExp002_SeqALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 12, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTMGMU1x1Conv_14(GASPExp002_SeqALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 14, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTMGMU1x1Conv_16(GASPExp002_SeqALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 16, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTMGMU1x1Conv_2Norm(GASPExp002_SeqALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 2, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTMGMU1x1Conv_4Norm(GASPExp002_SeqALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 4, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTMGMU1x1Conv_6Norm(GASPExp002_SeqALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 6, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTMGMU1x1Conv_8Norm(GASPExp002_SeqALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 8, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTMGMU1x1Conv_10Norm(GASPExp002_SeqALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTMGMU1x1Conv_12Norm(GASPExp002_SeqALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 12, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTMGMU1x1Conv_14Norm(GASPExp002_SeqALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 14, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqALSTMGMU1x1Conv_16Norm(GASPExp002_SeqALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 16, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqGMUALSTM1x1Conv_2(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 2, "sequence_norm": False} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->" + "N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqGMUALSTM1x1Conv_4(GASPExp002_SeqGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 4, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqGMUALSTM1x1Conv_6(GASPExp002_SeqGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 6, "sequence_norm": False} + + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqGMUALSTM1x1Conv_8(GASPExp002_SeqGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 8, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqGMUALSTM1x1Conv_10(GASPExp002_SeqGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 10, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqGMUALSTM1x1Conv_12(GASPExp002_SeqGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 12, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqGMUALSTM1x1Conv_14(GASPExp002_SeqGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 14, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqGMUALSTM1x1Conv_16(GASPExp002_SeqGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 16, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqGMUALSTM1x1Conv_2Norm(GASPExp002_SeqGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 2, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqGMUALSTM1x1Conv_4Norm(GASPExp002_SeqGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 4, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqGMUALSTM1x1Conv_6Norm(GASPExp002_SeqGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 6, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqGMUALSTM1x1Conv_8Norm(GASPExp002_SeqGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 8, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqGMUALSTM1x1Conv_10Norm(GASPExp002_SeqGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqGMUALSTM1x1Conv_12Norm(GASPExp002_SeqGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 12, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqGMUALSTM1x1Conv_14Norm(GASPExp002_SeqGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 14, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqGMUALSTM1x1Conv_16Norm(GASPExp002_SeqGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 16, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMGMUALSTM1x1Conv_2(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 2, "sequence_norm": False} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->" + "N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMGMUALSTM1x1Conv_4(GASPExp002_SeqDAMGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 4, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMGMUALSTM1x1Conv_6(GASPExp002_SeqDAMGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 6, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMGMUALSTM1x1Conv_8(GASPExp002_SeqDAMGMUALSTM1x1Conv_2): + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 8, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMGMUALSTM1x1Conv_10(GASPExp002_SeqDAMGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 10, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMGMUALSTM1x1Conv_12(GASPExp002_SeqDAMGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 12, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMGMUALSTM1x1Conv_14(GASPExp002_SeqDAMGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 14, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMGMUALSTM1x1Conv_16(GASPExp002_SeqDAMGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 16, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMGMUALSTM1x1Conv_2Norm(GASPExp002_SeqDAMGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 2, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMGMUALSTM1x1Conv_4Norm(GASPExp002_SeqDAMGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 4, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMGMUALSTM1x1Conv_6Norm(GASPExp002_SeqDAMGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 6, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMGMUALSTM1x1Conv_8Norm(GASPExp002_SeqDAMGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 8, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMGMUALSTM1x1Conv_10Norm(GASPExp002_SeqDAMGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMGMUALSTM1x1Conv_12Norm(GASPExp002_SeqDAMGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 12, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMGMUALSTM1x1Conv_14Norm(GASPExp002_SeqDAMGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 14, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMGMUALSTM1x1Conv_16Norm(GASPExp002_SeqDAMGMUALSTM1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 16, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMALSTMGMU1x1Conv_2(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 2, "sequence_norm": False} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->" + "N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMALSTMGMU1x1Conv_4(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 4, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMALSTMGMU1x1Conv_6(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 6, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMALSTMGMU1x1Conv_8(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 8, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMALSTMGMU1x1Conv_10(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 10, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMALSTMGMU1x1Conv_12(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 12, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMALSTMGMU1x1Conv_14(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 14, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMALSTMGMU1x1Conv_16(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 16, "sequence_norm": False} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMALSTMGMU1x1Conv_2Norm(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 2, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMALSTMGMU1x1Conv_4Norm(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 4, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMALSTMGMU1x1Conv_6Norm(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 6, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMALSTMGMU1x1Conv_8Norm(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 8, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMALSTMGMU1x1Conv_10Norm(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMALSTMGMU1x1Conv_12Norm(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 12, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMALSTMGMU1x1Conv_14Norm(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 14, "sequence_norm": True} + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMALSTMGMU1x1Conv_16Norm(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 16, "sequence_norm": True} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqRGMU1x1Conv_2(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncRGMUConv" + model_properties = {"modalities": 5, "in_channels": 3, "batch_size": 4, "sequence_len": 2} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: RGMU variant. Exp_Variant_[N] ->" + "N: Context size;", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqRGMU1x1Conv_4(GASPExp002_SeqRGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncRGMUConv" + model_properties = {"modalities": 5, "in_channels": 3, "batch_size": 4, "sequence_len": 4} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqRGMU1x1Conv_6(GASPExp002_SeqRGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncRGMUConv" + model_properties = {"modalities": 5, "in_channels": 3, "batch_size": 4, "sequence_len": 6} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqRGMU1x1Conv_8(GASPExp002_SeqRGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncRGMUConv" + model_properties = {"modalities": 5, "in_channels": 3, "batch_size": 4, "sequence_len": 8} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqRGMU1x1Conv_10(GASPExp002_SeqRGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncRGMUConv" + model_properties = {"modalities": 5, "in_channels": 3, "batch_size": 4, "sequence_len": 10} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqRGMU1x1Conv_12(GASPExp002_SeqRGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncRGMUConv" + model_properties = {"modalities": 5, "in_channels": 3, "batch_size": 4, "sequence_len": 12} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqRGMU1x1Conv_14(GASPExp002_SeqRGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncRGMUConv" + model_properties = {"modalities": 5, "in_channels": 3, "batch_size": 4, "sequence_len": 14} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqRGMU1x1Conv_16(GASPExp002_SeqRGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncRGMUConv" + model_properties = {"modalities": 5, "in_channels": 3, "batch_size": 4, "sequence_len": 16} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMRGMU1x1Conv_2(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncRGMUConv" + model_properties = {"modalities": 5, "in_channels": 3, "batch_size": 4, "sequence_len": 2} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + RGMU variant. Exp_Variant_[N]{Norm} ->" + "N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMRGMU1x1Conv_4(GASPExp002_SeqDAMRGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncRGMUConv" + model_properties = {"modalities": 5, "in_channels": 3, "batch_size": 4, "sequence_len": 4} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMRGMU1x1Conv_6(GASPExp002_SeqDAMRGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncRGMUConv" + model_properties = {"modalities": 5, "in_channels": 3, "batch_size": 4, "sequence_len": 6} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMRGMU1x1Conv_8(GASPExp002_SeqDAMRGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncRGMUConv" + model_properties = {"modalities": 5, "in_channels": 3, "batch_size": 4, "sequence_len": 8} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMRGMU1x1Conv_10(GASPExp002_SeqDAMRGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncRGMUConv" + model_properties = {"modalities": 5, "in_channels": 3, "batch_size": 4, "sequence_len": 10} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMRGMU1x1Conv_12(GASPExp002_SeqDAMRGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncRGMUConv" + model_properties = {"modalities": 5, "in_channels": 3, "batch_size": 4, "sequence_len": 12} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMRGMU1x1Conv_14(GASPExp002_SeqDAMRGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncRGMUConv" + model_properties = {"modalities": 5, "in_channels": 3, "batch_size": 4, "sequence_len": 14} + + +@TrainingConfigRegistrar.register +class GASPExp002_SeqDAMRGMU1x1Conv_16(GASPExp002_SeqDAMRGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncRGMUConv" + model_properties = {"modalities": 5, "in_channels": 3, "batch_size": 4, "sequence_len": 16} + + +# different Stage-1 SP +@TrainingConfigRegistrar.register +class GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_TASED(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_tased", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant inferring on TASED. ", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_UNISAL(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_unisal", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant inferring on UNISAL.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_STAViS_1(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + train_dataset_properties = {"csv_file": "datasets/processed/train_stavis_1.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_stavis", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/test_stavis_1.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_stavis", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_stavis_1.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_stavis", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant running on STAViS (1st fold).", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_STAViS_2(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + train_dataset_properties = {"csv_file": "datasets/processed/train_stavis_2.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_stavis", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/test_stavis_2.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_stavis", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_stavis_2.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_stavis", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant running on STAViS (2nd fold).", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_STAViS_3(GASPExp002_SeqDAMALSTMGMU1x1Conv_2): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 5, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + train_dataset_properties = {"csv_file": "datasets/processed/train_stavis_3.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_stavis", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/test_stavis_3.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_stavis", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_stavis_3.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_stavis", "det_transformed_esr9", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant running on STAViS (3rd fold).", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +# ablation experiments {GE}{GF}{FER} +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_000(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 2, "batch_size": 4, "sequence_len": 8, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + ARGMU (Context Size = 8) variant. " + "Exp_variant_{GE: False}{GF: False}{FER: False}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_001(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 3, "batch_size": 4, "sequence_len": 8, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + ARGMU (Context Size = 8) variant. " + "Exp_variant_{GE: False}{GF: False}{FER: True}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_010(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 3, "batch_size": 4, "sequence_len": 8, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + ARGMU (Context Size = 8) variant. " + "Exp_variant_{GE: False}{GF: True}{FER: False}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_011(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 4, "batch_size": 4, "sequence_len": 8, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + ARGMU (Context Size = 8) variant. " + "Exp_variant_{GE: False}{GF: True}{FER: True}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_100(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 3, "batch_size": 4, "sequence_len": 8, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + ARGMU (Context Size = 8) variant. " + "Exp_variant_{GE: True}{GF: False}{FER: False}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_101(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 4, "batch_size": 4, "sequence_len": 8, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + ARGMU (Context Size = 8) variant. " + "Exp_variant_{GE: True}{GF: False}{FER: True}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_110(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 4, "batch_size": 4, "sequence_len": 8, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + ARGMU (Context Size = 8) variant. " + "Exp_variant_{GE: True}{GF: True}{FER: False}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_000(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 2, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + ARGMU (Context Size = 10) variant. " + "Exp_variant_{GE: False}{GF: False}{FER: False}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_001(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 3, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + ARGMU (Context Size = 10) variant. " + "Exp_variant_{GE: False}{GF: False}{FER: True}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_010(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 3, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + ARGMU (Context Size = 10) variant. " + "Exp_variant_{GE: False}{GF: True}{FER: False}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_011(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 4, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + ARGMU (Context Size = 10) variant. " + "Exp_variant_{GE: False}{GF: True}{FER: True}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_100(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 3, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + ARGMU (Context Size = 10) variant. " + "Exp_variant_{GE: True}{GF: False}{FER: False}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_101(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 4, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + ARGMU (Context Size = 10) variant. " + "Exp_variant_{GE: True}{GF: False}{FER: True}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_110(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncGMUALSTMConv" + model_properties = {"modalities": 4, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + ARGMU (Context Size = 10) variant. " + "Exp_variant_{GE: True}{GF: True}{FER: False}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_000(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 2, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant. " + "Exp_variant_{GE: False}{GF: False}{FER: False}.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_001(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 3, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant. " + "Exp_variant_{GE: False}{GF: False}{FER: True}.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_010(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 3, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant. " + "Exp_variant_{GE: False}{GF: True}{FER: False}.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_011(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 4, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant. " + "Exp_variant_{GE: False}{GF: True}{FER: True}.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_100(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 3, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant. " + "Exp_variant_{GE: True}{GF: False}{FER: False}.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_101(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 4, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", + "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant. " + "Exp_variant_{GE: True}{GF: False}{FER: True}.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_110(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPDAMEncALSTMGMUConv" + model_properties = {"modalities": 4, "batch_size": 4, "sequence_len": 10, "sequence_norm": True} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "gasp_runs" + experiment_name = model_name + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", + "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant. " + "Exp_variant_{GE: True}{GF: True}{FER: False}.", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 " + "--checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3"} + + +@TrainingConfigRegistrar.register +class TrainPlayground001(object): + + inferer_name = "GASPInference" + model_name = "SequenceGASPEncALSTMGMUConv" + model_properties = {"modalities": 5, "sequence_len": 5, "sequence_norm": False} + + log_dir = "logs" + logger = "comet" # comet, tensorboard + project_name = "testing_gasp" + experiment_name = model_name + "_TRAINONLY_BADVAL" + + checkpoint_model_dir = os.path.join("gazenet", "models", "saliency_prediction", + "gasp", "checkpoints", "pretrained_" + + str.lower(experiment_name)) + + train_dataset_properties = {"csv_file": "datasets/processed/train_ave.csv", "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured","det_transformed_dave", "det_transformed_esr9", "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + val_dataset_properties = {"csv_file": "datasets/processed/validation_ave.csv", "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + test_dataset_properties = {"csv_file": "datasets/processed/test_ave.csv", "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["captured", "det_transformed_dave", "det_transformed_esr9", "det_transformed_vidgaze", "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + @classmethod + def config_info(cls): + return {"summary": "This is a playground configuration which is unstable " + "but can be used for quickly testing and visualizing models", + "example": "python3 gazenet/bin/train.py --train_config " + cls.__name__ + + " --gpus \"0\" --check_val_every_n_epoch 500 --max_epochs 5000 " + "--checkpoint_save_every_n_epoch 1000 --checkpoint_save_n_top 3"} diff --git a/gazenet/configs/train_configs/GASPExp001_1x1Conv.json b/gazenet/configs/train_configs/GASPExp001_1x1Conv.json new file mode 100644 index 0000000..ebd9c48 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp001_1x1Conv.json @@ -0,0 +1,62 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_gaspencconv", + "experiment_name": "GASPEncConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "GASPEncConv", + "model_properties": { + "modalities": 5, + "batch_size": 4 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Static GASP: 1x1 convolutional variant. Not included in the paper.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp001_1x1Conv --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp001_ALSTM1x1Conv.json b/gazenet/configs/train_configs/GASPExp001_ALSTM1x1Conv.json new file mode 100644 index 0000000..c410a3a --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp001_ALSTM1x1Conv.json @@ -0,0 +1,62 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_gaspencalstmconv", + "experiment_name": "GASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "GASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Static GASP: ALSTM variant.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp001_ALSTM1x1Conv --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp001_ALSTMGMU1x1Conv.json b/gazenet/configs/train_configs/GASPExp001_ALSTMGMU1x1Conv.json new file mode 100644 index 0000000..0533e90 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp001_ALSTMGMU1x1Conv.json @@ -0,0 +1,62 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_gaspencalstmgmuconv", + "experiment_name": "GASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "GASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Static GASP: LAGMU variant.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp001_ALSTMGMU1x1Conv --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp001_Add1x1Conv.json b/gazenet/configs/train_configs/GASPExp001_Add1x1Conv.json new file mode 100644 index 0000000..a071546 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp001_Add1x1Conv.json @@ -0,0 +1,62 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_gaspencaddconv", + "experiment_name": "GASPEncAddConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "GASPEncAddConv", + "model_properties": { + "modalities": 5, + "batch_size": 4 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Static GASP: Additive variant.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp001_Add1x1Conv --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp001_DAMALSTMGMU1x1Conv.json b/gazenet/configs/train_configs/GASPExp001_DAMALSTMGMU1x1Conv.json new file mode 100644 index 0000000..9db68e3 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp001_DAMALSTMGMU1x1Conv.json @@ -0,0 +1,62 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_gaspdamencalstmgmuconv", + "experiment_name": "GASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "GASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Static GASP: DAM + LAGMU variant.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp001_DAMALSTMGMU1x1Conv --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp001_DAMGMU1x1Conv.json b/gazenet/configs/train_configs/GASPExp001_DAMGMU1x1Conv.json new file mode 100644 index 0000000..23b8ca7 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp001_DAMGMU1x1Conv.json @@ -0,0 +1,62 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_gaspdamencgmuconv", + "experiment_name": "GASPDAMEncGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "GASPDAMEncGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Static GASP: DAM + GMU variant.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp001_DAMGMU1x1Conv --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp001_DAMGMUALSTM1x1Conv.json b/gazenet/configs/train_configs/GASPExp001_DAMGMUALSTM1x1Conv.json new file mode 100644 index 0000000..b505fd1 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp001_DAMGMUALSTM1x1Conv.json @@ -0,0 +1,62 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_gaspdamencgmualstmconv", + "experiment_name": "GASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "GASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Static GASP: DAM + AGMU variant.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp001_DAMGMUALSTM1x1Conv --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp001_GMU1x1Conv.json b/gazenet/configs/train_configs/GASPExp001_GMU1x1Conv.json new file mode 100644 index 0000000..6bd6a0f --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp001_GMU1x1Conv.json @@ -0,0 +1,62 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_gaspencgmuconv", + "experiment_name": "GASPEncGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "GASPEncGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Static GASP: GMU variant.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp001_GMU1x1Conv --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp001_GMUALSTM1x1Conv.json b/gazenet/configs/train_configs/GASPExp001_GMUALSTM1x1Conv.json new file mode 100644 index 0000000..1f37fcf --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp001_GMUALSTM1x1Conv.json @@ -0,0 +1,62 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_gaspencgmualstmconv", + "experiment_name": "GASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "GASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Static GASP: AGMU variant.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp001_GMUALSTM1x1Conv --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp001_SE1x1Conv.json b/gazenet/configs/train_configs/GASPExp001_SE1x1Conv.json new file mode 100644 index 0000000..eb381e2 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp001_SE1x1Conv.json @@ -0,0 +1,62 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_gaspseencconv", + "experiment_name": "GASPSEEncConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "GASPSEEncConv", + "model_properties": { + "modalities": 5, + "batch_size": 4 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Static GASP: SE variant.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp001_SE1x1Conv --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_10.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_10.json new file mode 100644 index 0000000..ce2d57d --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_10.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmconv", + "experiment_name": "SequenceGASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTM1x1Conv_10 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_10Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_10Norm.json new file mode 100644 index 0000000..363a105 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_10Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmconv", + "experiment_name": "SequenceGASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTM1x1Conv_10Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_12.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_12.json new file mode 100644 index 0000000..a0dea0a --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_12.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmconv", + "experiment_name": "SequenceGASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 12, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTM1x1Conv_12 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_12Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_12Norm.json new file mode 100644 index 0000000..4f7c1f9 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_12Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmconv", + "experiment_name": "SequenceGASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 12, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTM1x1Conv_12Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_14.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_14.json new file mode 100644 index 0000000..8b7dc92 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_14.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmconv", + "experiment_name": "SequenceGASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 14, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTM1x1Conv_14 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_14Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_14Norm.json new file mode 100644 index 0000000..af4c600 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_14Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmconv", + "experiment_name": "SequenceGASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 14, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTM1x1Conv_14Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_16.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_16.json new file mode 100644 index 0000000..8df0909 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_16.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmconv", + "experiment_name": "SequenceGASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 16, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTM1x1Conv_16 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_16Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_16Norm.json new file mode 100644 index 0000000..e3e3341 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_16Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmconv", + "experiment_name": "SequenceGASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 16, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTM1x1Conv_16Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_2.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_2.json new file mode 100644 index 0000000..4d48533 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_2.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmconv", + "experiment_name": "SequenceGASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 2, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTM1x1Conv_2 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_2Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_2Norm.json new file mode 100644 index 0000000..27b42c0 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_2Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmconv", + "experiment_name": "SequenceGASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 2, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTM1x1Conv_2Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_4.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_4.json new file mode 100644 index 0000000..7afc8f0 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_4.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmconv", + "experiment_name": "SequenceGASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 4, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTM1x1Conv_4 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_4Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_4Norm.json new file mode 100644 index 0000000..e2da519 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_4Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmconv", + "experiment_name": "SequenceGASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 4, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTM1x1Conv_4Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_6.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_6.json new file mode 100644 index 0000000..d63e578 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_6.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmconv", + "experiment_name": "SequenceGASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 6, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTM1x1Conv_6 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_6Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_6Norm.json new file mode 100644 index 0000000..a6b7ac4 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_6Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmconv", + "experiment_name": "SequenceGASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 6, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTM1x1Conv_6Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_8.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_8.json new file mode 100644 index 0000000..988725e --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_8.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmconv", + "experiment_name": "SequenceGASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTM1x1Conv_8 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_8Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_8Norm.json new file mode 100644 index 0000000..a27e9cf --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTM1x1Conv_8Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmconv", + "experiment_name": "SequenceGASPEncALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: Sequential ALSTM variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTM1x1Conv_8Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_10.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_10.json new file mode 100644 index 0000000..822be83 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_10.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv", + "experiment_name": "SequenceGASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTMGMU1x1Conv_10 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_10Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_10Norm.json new file mode 100644 index 0000000..8a6c750 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_10Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv", + "experiment_name": "SequenceGASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTMGMU1x1Conv_10Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_12.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_12.json new file mode 100644 index 0000000..b353027 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_12.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv", + "experiment_name": "SequenceGASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 12, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTMGMU1x1Conv_12 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_12Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_12Norm.json new file mode 100644 index 0000000..7fde74f --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_12Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv", + "experiment_name": "SequenceGASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 12, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTMGMU1x1Conv_12Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_14.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_14.json new file mode 100644 index 0000000..c5c0438 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_14.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv", + "experiment_name": "SequenceGASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 14, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTMGMU1x1Conv_14 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_14Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_14Norm.json new file mode 100644 index 0000000..e375b58 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_14Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv", + "experiment_name": "SequenceGASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 14, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTMGMU1x1Conv_14Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_16.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_16.json new file mode 100644 index 0000000..ee55354 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_16.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv", + "experiment_name": "SequenceGASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 16, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTMGMU1x1Conv_16 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_16Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_16Norm.json new file mode 100644 index 0000000..b899ec1 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_16Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv", + "experiment_name": "SequenceGASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 16, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTMGMU1x1Conv_16Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_2.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_2.json new file mode 100644 index 0000000..43dbb68 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_2.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv", + "experiment_name": "SequenceGASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 2, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTMGMU1x1Conv_2 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_2Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_2Norm.json new file mode 100644 index 0000000..148a918 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_2Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv", + "experiment_name": "SequenceGASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 2, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTMGMU1x1Conv_2Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_4.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_4.json new file mode 100644 index 0000000..7fb004e --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_4.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv", + "experiment_name": "SequenceGASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 4, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTMGMU1x1Conv_4 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_4Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_4Norm.json new file mode 100644 index 0000000..aeac7b2 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_4Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv", + "experiment_name": "SequenceGASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 4, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTMGMU1x1Conv_4Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_6.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_6.json new file mode 100644 index 0000000..aef66f4 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_6.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv", + "experiment_name": "SequenceGASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 6, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTMGMU1x1Conv_6 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_6Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_6Norm.json new file mode 100644 index 0000000..df4356e --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_6Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv", + "experiment_name": "SequenceGASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 6, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTMGMU1x1Conv_6Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_8.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_8.json new file mode 100644 index 0000000..9529ed6 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_8.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv", + "experiment_name": "SequenceGASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTMGMU1x1Conv_8 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_8Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_8Norm.json new file mode 100644 index 0000000..b5520ea --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqALSTMGMU1x1Conv_8Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv", + "experiment_name": "SequenceGASPEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqALSTMGMU1x1Conv_8Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_10.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_10.json new file mode 100644 index 0000000..ea5a109 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_10.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_10 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_10Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_10Norm.json new file mode 100644 index 0000000..569b6f0 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_10Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_10Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_12.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_12.json new file mode 100644 index 0000000..e48de25 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_12.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 12, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_12 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_12Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_12Norm.json new file mode 100644 index 0000000..7678cfa --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_12Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 12, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_12Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_14.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_14.json new file mode 100644 index 0000000..36ffc4f --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_14.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 14, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_14 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_14Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_14Norm.json new file mode 100644 index 0000000..cf530be --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_14Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 14, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_14Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_16.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_16.json new file mode 100644 index 0000000..0120fc6 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_16.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 16, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_16 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_16Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_16Norm.json new file mode 100644 index 0000000..2f7e942 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_16Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 16, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_16Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_2.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_2.json new file mode 100644 index 0000000..de8c190 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_2.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 2, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_2 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_2Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_2Norm.json new file mode 100644 index 0000000..f901926 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_2Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 2, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_2Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_4.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_4.json new file mode 100644 index 0000000..fbb25b6 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_4.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 4, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_4 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_4Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_4Norm.json new file mode 100644 index 0000000..cb055b6 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_4Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 4, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_4Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_6.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_6.json new file mode 100644 index 0000000..fd7e21a --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_6.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 6, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_6 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_6Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_6Norm.json new file mode 100644 index 0000000..b1ddebb --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_6Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 6, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_6Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_8.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_8.json new file mode 100644 index 0000000..20d1ef0 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_8.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_8 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_8Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_8Norm.json new file mode 100644 index 0000000..4e0eaf4 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMALSTMGMU1x1Conv_8Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMALSTMGMU1x1Conv_8Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_10.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_10.json new file mode 100644 index 0000000..571b917 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_10.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMGMUALSTM1x1Conv_10 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_10Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_10Norm.json new file mode 100644 index 0000000..043f523 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_10Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMGMUALSTM1x1Conv_10Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_12.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_12.json new file mode 100644 index 0000000..d378cf6 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_12.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 12, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMGMUALSTM1x1Conv_12 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_12Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_12Norm.json new file mode 100644 index 0000000..2f62b7d --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_12Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 12, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMGMUALSTM1x1Conv_12Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_14.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_14.json new file mode 100644 index 0000000..e48b232 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_14.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 14, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMGMUALSTM1x1Conv_14 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_14Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_14Norm.json new file mode 100644 index 0000000..fd4c834 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_14Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 14, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMGMUALSTM1x1Conv_14Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_16.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_16.json new file mode 100644 index 0000000..ea14322 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_16.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 16, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMGMUALSTM1x1Conv_16 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_16Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_16Norm.json new file mode 100644 index 0000000..9e31bca --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_16Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 16, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMGMUALSTM1x1Conv_16Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_2.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_2.json new file mode 100644 index 0000000..d3cfab6 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_2.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 2, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMGMUALSTM1x1Conv_2 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_2Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_2Norm.json new file mode 100644 index 0000000..bf6930f --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_2Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 2, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMGMUALSTM1x1Conv_2Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_4.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_4.json new file mode 100644 index 0000000..0ca57fd --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_4.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 4, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMGMUALSTM1x1Conv_4 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_4Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_4Norm.json new file mode 100644 index 0000000..b30a279 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_4Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 4, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMGMUALSTM1x1Conv_4Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_6.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_6.json new file mode 100644 index 0000000..51c7520 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_6.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 6, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMGMUALSTM1x1Conv_6 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_6Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_6Norm.json new file mode 100644 index 0000000..f3a38d7 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_6Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 6, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMGMUALSTM1x1Conv_6Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_8.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_8.json new file mode 100644 index 0000000..5c6b40a --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_8.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMGMUALSTM1x1Conv_8 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_8Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_8Norm.json new file mode 100644 index 0000000..107b27e --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMGMUALSTM1x1Conv_8Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMGMUALSTM1x1Conv_8Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_10.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_10.json new file mode 100644 index 0000000..3452576 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_10.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencrgmuconv", + "experiment_name": "SequenceGASPDAMEncRGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncRGMUConv", + "model_properties": { + "modalities": 5, + "in_channels": 3, + "batch_size": 4, + "sequence_len": 10 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + RGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMRGMU1x1Conv_10 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_12.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_12.json new file mode 100644 index 0000000..3c657bf --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_12.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencrgmuconv", + "experiment_name": "SequenceGASPDAMEncRGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncRGMUConv", + "model_properties": { + "modalities": 5, + "in_channels": 3, + "batch_size": 4, + "sequence_len": 12 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + RGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMRGMU1x1Conv_12 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_14.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_14.json new file mode 100644 index 0000000..fddbdfa --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_14.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencrgmuconv", + "experiment_name": "SequenceGASPDAMEncRGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncRGMUConv", + "model_properties": { + "modalities": 5, + "in_channels": 3, + "batch_size": 4, + "sequence_len": 14 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + RGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMRGMU1x1Conv_14 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_16.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_16.json new file mode 100644 index 0000000..98293b4 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_16.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencrgmuconv", + "experiment_name": "SequenceGASPDAMEncRGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncRGMUConv", + "model_properties": { + "modalities": 5, + "in_channels": 3, + "batch_size": 4, + "sequence_len": 16 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + RGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMRGMU1x1Conv_16 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_2.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_2.json new file mode 100644 index 0000000..0e851ba --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_2.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencrgmuconv", + "experiment_name": "SequenceGASPDAMEncRGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncRGMUConv", + "model_properties": { + "modalities": 5, + "in_channels": 3, + "batch_size": 4, + "sequence_len": 2 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + RGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMRGMU1x1Conv_2 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_4.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_4.json new file mode 100644 index 0000000..981ff0a --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_4.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencrgmuconv", + "experiment_name": "SequenceGASPDAMEncRGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncRGMUConv", + "model_properties": { + "modalities": 5, + "in_channels": 3, + "batch_size": 4, + "sequence_len": 4 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + RGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMRGMU1x1Conv_4 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_6.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_6.json new file mode 100644 index 0000000..ab42393 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_6.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencrgmuconv", + "experiment_name": "SequenceGASPDAMEncRGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncRGMUConv", + "model_properties": { + "modalities": 5, + "in_channels": 3, + "batch_size": 4, + "sequence_len": 6 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + RGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMRGMU1x1Conv_6 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_8.json b/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_8.json new file mode 100644 index 0000000..32cd7e2 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqDAMRGMU1x1Conv_8.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencrgmuconv", + "experiment_name": "SequenceGASPDAMEncRGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncRGMUConv", + "model_properties": { + "modalities": 5, + "in_channels": 3, + "batch_size": 4, + "sequence_len": 8 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + RGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqDAMRGMU1x1Conv_8 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_10.json b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_10.json new file mode 100644 index 0000000..25e3a9a --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_10.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencgmualstmconv", + "experiment_name": "SequenceGASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqGMUALSTM1x1Conv_10 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_10Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_10Norm.json new file mode 100644 index 0000000..b7c8959 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_10Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencgmualstmconv", + "experiment_name": "SequenceGASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqGMUALSTM1x1Conv_10Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_12.json b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_12.json new file mode 100644 index 0000000..2e4a2aa --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_12.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencgmualstmconv", + "experiment_name": "SequenceGASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 12, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqGMUALSTM1x1Conv_12 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_12Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_12Norm.json new file mode 100644 index 0000000..7713ab1 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_12Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencgmualstmconv", + "experiment_name": "SequenceGASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 12, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqGMUALSTM1x1Conv_12Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_14.json b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_14.json new file mode 100644 index 0000000..d643f99 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_14.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencgmualstmconv", + "experiment_name": "SequenceGASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 14, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqGMUALSTM1x1Conv_14 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_14Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_14Norm.json new file mode 100644 index 0000000..98fca35 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_14Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencgmualstmconv", + "experiment_name": "SequenceGASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 14, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqGMUALSTM1x1Conv_14Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_16.json b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_16.json new file mode 100644 index 0000000..2d90bd9 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_16.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencgmualstmconv", + "experiment_name": "SequenceGASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 16, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqGMUALSTM1x1Conv_16 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_16Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_16Norm.json new file mode 100644 index 0000000..93777ba --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_16Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencgmualstmconv", + "experiment_name": "SequenceGASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 16, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqGMUALSTM1x1Conv_16Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_2.json b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_2.json new file mode 100644 index 0000000..7cbd2a3 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_2.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencgmualstmconv", + "experiment_name": "SequenceGASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 2, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqGMUALSTM1x1Conv_2 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_2Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_2Norm.json new file mode 100644 index 0000000..ea53fa2 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_2Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencgmualstmconv", + "experiment_name": "SequenceGASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 2, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqGMUALSTM1x1Conv_2Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_4.json b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_4.json new file mode 100644 index 0000000..f3b04b8 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_4.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencgmualstmconv", + "experiment_name": "SequenceGASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 4, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqGMUALSTM1x1Conv_4 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_4Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_4Norm.json new file mode 100644 index 0000000..6d6b00d --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_4Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencgmualstmconv", + "experiment_name": "SequenceGASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 4, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqGMUALSTM1x1Conv_4Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_6.json b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_6.json new file mode 100644 index 0000000..9d994fd --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_6.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencgmualstmconv", + "experiment_name": "SequenceGASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 6, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqGMUALSTM1x1Conv_6 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_6Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_6Norm.json new file mode 100644 index 0000000..ab87290 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_6Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencgmualstmconv", + "experiment_name": "SequenceGASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 6, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqGMUALSTM1x1Conv_6Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_8.json b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_8.json new file mode 100644 index 0000000..bf9d716 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_8.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencgmualstmconv", + "experiment_name": "SequenceGASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": false + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqGMUALSTM1x1Conv_8 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_8Norm.json b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_8Norm.json new file mode 100644 index 0000000..472829d --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqGMUALSTM1x1Conv_8Norm.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencgmualstmconv", + "experiment_name": "SequenceGASPEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncGMUALSTMConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: ARGMU variant. Exp_Variant_[N]{Norm} ->N: Context size; Norm: Enable temporal normalization;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqGMUALSTM1x1Conv_8Norm --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_10.json b/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_10.json new file mode 100644 index 0000000..75908e0 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_10.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencrgmuconv", + "experiment_name": "SequenceGASPEncRGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncRGMUConv", + "model_properties": { + "modalities": 5, + "in_channels": 3, + "batch_size": 4, + "sequence_len": 10 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: RGMU variant. Exp_Variant_[N] ->N: Context size;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqRGMU1x1Conv_10 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_12.json b/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_12.json new file mode 100644 index 0000000..d0b85c7 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_12.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencrgmuconv", + "experiment_name": "SequenceGASPEncRGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncRGMUConv", + "model_properties": { + "modalities": 5, + "in_channels": 3, + "batch_size": 4, + "sequence_len": 12 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: RGMU variant. Exp_Variant_[N] ->N: Context size;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqRGMU1x1Conv_12 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_14.json b/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_14.json new file mode 100644 index 0000000..a244b6e --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_14.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencrgmuconv", + "experiment_name": "SequenceGASPEncRGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncRGMUConv", + "model_properties": { + "modalities": 5, + "in_channels": 3, + "batch_size": 4, + "sequence_len": 14 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: RGMU variant. Exp_Variant_[N] ->N: Context size;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqRGMU1x1Conv_14 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_16.json b/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_16.json new file mode 100644 index 0000000..b23357c --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_16.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencrgmuconv", + "experiment_name": "SequenceGASPEncRGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncRGMUConv", + "model_properties": { + "modalities": 5, + "in_channels": 3, + "batch_size": 4, + "sequence_len": 16 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: RGMU variant. Exp_Variant_[N] ->N: Context size;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqRGMU1x1Conv_16 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_2.json b/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_2.json new file mode 100644 index 0000000..a746b94 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_2.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencrgmuconv", + "experiment_name": "SequenceGASPEncRGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncRGMUConv", + "model_properties": { + "modalities": 5, + "in_channels": 3, + "batch_size": 4, + "sequence_len": 2 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: RGMU variant. Exp_Variant_[N] ->N: Context size;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqRGMU1x1Conv_2 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_4.json b/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_4.json new file mode 100644 index 0000000..cc697a3 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_4.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencrgmuconv", + "experiment_name": "SequenceGASPEncRGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncRGMUConv", + "model_properties": { + "modalities": 5, + "in_channels": 3, + "batch_size": 4, + "sequence_len": 4 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: RGMU variant. Exp_Variant_[N] ->N: Context size;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqRGMU1x1Conv_4 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_6.json b/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_6.json new file mode 100644 index 0000000..7103e91 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_6.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencrgmuconv", + "experiment_name": "SequenceGASPEncRGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncRGMUConv", + "model_properties": { + "modalities": 5, + "in_channels": 3, + "batch_size": 4, + "sequence_len": 6 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: RGMU variant. Exp_Variant_[N] ->N: Context size;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqRGMU1x1Conv_6 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_8.json b/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_8.json new file mode 100644 index 0000000..b7e9406 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp002_SeqRGMU1x1Conv_8.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencrgmuconv", + "experiment_name": "SequenceGASPEncRGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncRGMUConv", + "model_properties": { + "modalities": 5, + "in_channels": 3, + "batch_size": 4, + "sequence_len": 8 + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: RGMU variant. Exp_Variant_[N] ->N: Context size;", + "example": "python3 gazenet/bin/train.py --train_config GASPExp002_SeqRGMU1x1Conv_8 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_STAViS_1.json b/gazenet/configs/train_configs/GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_STAViS_1.json new file mode 100644 index 0000000..e1d9681 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_STAViS_1.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_stavis_1.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_stavis", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_stavis_1.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_stavis", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/test_stavis_1.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_stavis", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant running on STAViS (1st fold).", + "example": "python3 gazenet/bin/train.py --train_config GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_STAViS_1 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_STAViS_2.json b/gazenet/configs/train_configs/GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_STAViS_2.json new file mode 100644 index 0000000..f121f1c --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_STAViS_2.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_stavis_2.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_stavis", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_stavis_2.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_stavis", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/test_stavis_2.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_stavis", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant running on STAViS (2nd fold).", + "example": "python3 gazenet/bin/train.py --train_config GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_STAViS_2 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_STAViS_3.json b/gazenet/configs/train_configs/GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_STAViS_3.json new file mode 100644 index 0000000..45fc1cd --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_STAViS_3.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_stavis_3.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_stavis", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_stavis_3.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_stavis", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/test_stavis_3.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_stavis", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant running on STAViS (3rd fold).", + "example": "python3 gazenet/bin/train.py --train_config GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_STAViS_3 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_TASED.json b/gazenet/configs/train_configs/GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_TASED.json new file mode 100644 index 0000000..605505a --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_TASED.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_tased", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant inferring on TASED. ", + "example": "python3 gazenet/bin/train.py --train_config GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_TASED --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_UNISAL.json b/gazenet/configs/train_configs/GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_UNISAL.json new file mode 100644 index 0000000..ecd6fec --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_UNISAL.json @@ -0,0 +1,64 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_unisal", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant inferring on UNISAL.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp003_SeqDAMALSTMGMU1x1Conv_10Norm_UNISAL --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_000.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_000.json new file mode 100644 index 0000000..cefcc12 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_000.json @@ -0,0 +1,55 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 2, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant. Exp_variant_{GE: False}{GF: False}{FER: False}.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_000 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_001.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_001.json new file mode 100644 index 0000000..13796ff --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_001.json @@ -0,0 +1,58 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 3, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant. Exp_variant_{GE: False}{GF: False}{FER: True}.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_001 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_010.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_010.json new file mode 100644 index 0000000..f87fe47 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_010.json @@ -0,0 +1,58 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 3, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant. Exp_variant_{GE: False}{GF: True}{FER: False}.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_010 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_011.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_011.json new file mode 100644 index 0000000..7b888f6 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_011.json @@ -0,0 +1,61 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 4, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant. Exp_variant_{GE: False}{GF: True}{FER: True}.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_011 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_100.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_100.json new file mode 100644 index 0000000..9aea43c --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_100.json @@ -0,0 +1,58 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 3, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant. Exp_variant_{GE: True}{GF: False}{FER: False}.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_100 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_101.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_101.json new file mode 100644 index 0000000..e1de9e0 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_101.json @@ -0,0 +1,61 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 4, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant. Exp_variant_{GE: True}{GF: False}{FER: True}.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_101 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_110.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_110.json new file mode 100644 index 0000000..eb8b111 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_110.json @@ -0,0 +1,61 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv", + "experiment_name": "SequenceGASPDAMEncALSTMGMUConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncALSTMGMUConv", + "model_properties": { + "modalities": 4, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + LARGMU (Context Size = 10) variant. Exp_variant_{GE: True}{GF: True}{FER: False}.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMALSTMGMU1x1Conv_10Norm_110 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_000.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_000.json new file mode 100644 index 0000000..3d44d46 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_000.json @@ -0,0 +1,55 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 2, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU (Context Size = 10) variant. Exp_variant_{GE: False}{GF: False}{FER: False}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_000 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_001.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_001.json new file mode 100644 index 0000000..f1efc50 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_001.json @@ -0,0 +1,58 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 3, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU (Context Size = 10) variant. Exp_variant_{GE: False}{GF: False}{FER: True}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_001 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_010.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_010.json new file mode 100644 index 0000000..9b1e29d --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_010.json @@ -0,0 +1,58 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 3, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU (Context Size = 10) variant. Exp_variant_{GE: False}{GF: True}{FER: False}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_010 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_011.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_011.json new file mode 100644 index 0000000..28b876a --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_011.json @@ -0,0 +1,61 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 4, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU (Context Size = 10) variant. Exp_variant_{GE: False}{GF: True}{FER: True}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_011 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_100.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_100.json new file mode 100644 index 0000000..84b256d --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_100.json @@ -0,0 +1,58 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 3, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU (Context Size = 10) variant. Exp_variant_{GE: True}{GF: False}{FER: False}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_100 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_101.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_101.json new file mode 100644 index 0000000..720a7b6 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_101.json @@ -0,0 +1,61 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 4, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU (Context Size = 10) variant. Exp_variant_{GE: True}{GF: False}{FER: True}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_101 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_110.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_110.json new file mode 100644 index 0000000..f5c37dc --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_110.json @@ -0,0 +1,61 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 4, + "batch_size": 4, + "sequence_len": 10, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU (Context Size = 10) variant. Exp_variant_{GE: True}{GF: True}{FER: False}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMGMUALSTM1x1Conv_10Norm_110 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_000.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_000.json new file mode 100644 index 0000000..e5eb6e5 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_000.json @@ -0,0 +1,55 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 2, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU (Context Size = 8) variant. Exp_variant_{GE: False}{GF: False}{FER: False}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_000 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_001.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_001.json new file mode 100644 index 0000000..58df926 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_001.json @@ -0,0 +1,58 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 3, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU (Context Size = 8) variant. Exp_variant_{GE: False}{GF: False}{FER: True}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_001 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_010.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_010.json new file mode 100644 index 0000000..60fb043 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_010.json @@ -0,0 +1,58 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 3, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU (Context Size = 8) variant. Exp_variant_{GE: False}{GF: True}{FER: False}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_010 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_011.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_011.json new file mode 100644 index 0000000..64f967a --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_011.json @@ -0,0 +1,61 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 4, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU (Context Size = 8) variant. Exp_variant_{GE: False}{GF: True}{FER: True}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_011 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_100.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_100.json new file mode 100644 index 0000000..047a394 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_100.json @@ -0,0 +1,58 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 3, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU (Context Size = 8) variant. Exp_variant_{GE: True}{GF: False}{FER: False}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_100 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_101.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_101.json new file mode 100644 index 0000000..303eb0e --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_101.json @@ -0,0 +1,61 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 4, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU (Context Size = 8) variant. Exp_variant_{GE: True}{GF: False}{FER: True}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_101 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_110.json b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_110.json new file mode 100644 index 0000000..099af76 --- /dev/null +++ b/gazenet/configs/train_configs/GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_110.json @@ -0,0 +1,61 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv", + "experiment_name": "SequenceGASPDAMEncGMUALSTMConv", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPDAMEncGMUALSTMConv", + "model_properties": { + "modalities": 4, + "batch_size": 4, + "sequence_len": 8, + "sequence_norm": true + }, + "project_name": "gasp_runs", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "Sequential GASP: DAM + ARGMU (Context Size = 8) variant. Exp_variant_{GE: True}{GF: True}{FER: False}. Not in paper.", + "example": "python3 gazenet/bin/train.py --train_config GASPExp004_SeqDAMGMUALSTM1x1Conv_8Norm_110 --gpus \"0\" --check_val_every_n_epoch 100 --max_epochs 10000 --checkpoint_save_every_n_epoch 100 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/TrainPlayground001.json b/gazenet/configs/train_configs/TrainPlayground001.json new file mode 100644 index 0000000..e5daaaa --- /dev/null +++ b/gazenet/configs/train_configs/TrainPlayground001.json @@ -0,0 +1,63 @@ +{ + "checkpoint_model_dir": "gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspencalstmgmuconv_trainonly_badval", + "experiment_name": "SequenceGASPEncALSTMGMUConv_TRAINONLY_BADVAL", + "inferer_name": "GASPInference", + "log_dir": "logs", + "logger": "comet", + "model_name": "SequenceGASPEncALSTMGMUConv", + "model_properties": { + "modalities": 5, + "sequence_len": 5, + "sequence_norm": false + }, + "project_name": "testing_gasp", + "test_dataset_properties": { + "csv_file": "datasets/processed/test_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "train_dataset_properties": { + "csv_file": "datasets/processed/train_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "val_dataset_properties": { + "csv_file": "datasets/processed/validation_ave.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": [ + "captured", + "det_transformed_dave", + "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360" + ], + "gt_img_names_list": [ + "transformed_salmap", + "transformed_fixmap" + ] + }, + "config_info": { + "summary": "This is a playground configuration which is unstable but can be used for quickly testing and visualizing models", + "example": "python3 gazenet/bin/train.py --train_config TrainPlayground001 --gpus \"0\" --check_val_every_n_epoch 500 --max_epochs 5000 --checkpoint_save_every_n_epoch 1000 --checkpoint_save_n_top 3" + } +} \ No newline at end of file diff --git a/gazenet/configs/train_configs/TrainerBase.json b/gazenet/configs/train_configs/TrainerBase.json new file mode 100644 index 0000000..17d01cb --- /dev/null +++ b/gazenet/configs/train_configs/TrainerBase.json @@ -0,0 +1,17 @@ +{ + "checkpoint_model_dir": "", + "experiment_name": "", + "inferer_name": "", + "log_dir": "", + "logger": "", + "model_name": "", + "model_properties": {}, + "project_name": "", + "test_dataset_properties": {}, + "train_dataset_properties": {}, + "val_dataset_properties": {}, + "config_info": { + "summary": "This is base class and cannot be used directly.", + "example": "None" + } +} \ No newline at end of file diff --git a/gazenet/models/__init__.py b/gazenet/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/emotion_recognition/__init__.py b/gazenet/models/emotion_recognition/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/emotion_recognition/esr9/__init__.py b/gazenet/models/emotion_recognition/esr9/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/emotion_recognition/esr9/checkpoints/pretrained_esr9_orig/download_model.sh b/gazenet/models/emotion_recognition/esr9/checkpoints/pretrained_esr9_orig/download_model.sh new file mode 100644 index 0000000..41ce2f2 --- /dev/null +++ b/gazenet/models/emotion_recognition/esr9/checkpoints/pretrained_esr9_orig/download_model.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# shared representation +wget -O Net-Base-Shared_Representations.pt https://github.com/siqueira-hc/Efficient-Facial-Feature-Learning-with-Wide-Ensemble-based-Convolutional-Neural-Networks/raw/a827cae59a41eed58c4b94a094758a91097ef312/model/ml/trained_models/esr_9/Net-Base-Shared_Representations.pt +# ensembles +wget -O Net-Branch_1.pt https://github.com/siqueira-hc/Efficient-Facial-Feature-Learning-with-Wide-Ensemble-based-Convolutional-Neural-Networks/raw/a827cae59a41eed58c4b94a094758a91097ef312/model/ml/trained_models/esr_9/Net-Branch_1.pt +wget -O Net-Branch_2.pt https://github.com/siqueira-hc/Efficient-Facial-Feature-Learning-with-Wide-Ensemble-based-Convolutional-Neural-Networks/raw/a827cae59a41eed58c4b94a094758a91097ef312/model/ml/trained_models/esr_9/Net-Branch_2.pt +wget -O Net-Branch_3.pt https://github.com/siqueira-hc/Efficient-Facial-Feature-Learning-with-Wide-Ensemble-based-Convolutional-Neural-Networks/raw/a827cae59a41eed58c4b94a094758a91097ef312/model/ml/trained_models/esr_9/Net-Branch_3.pt +wget -O Net-Branch_4.pt https://github.com/siqueira-hc/Efficient-Facial-Feature-Learning-with-Wide-Ensemble-based-Convolutional-Neural-Networks/raw/a827cae59a41eed58c4b94a094758a91097ef312/model/ml/trained_models/esr_9/Net-Branch_4.pt +wget -O Net-Branch_5.pt https://github.com/siqueira-hc/Efficient-Facial-Feature-Learning-with-Wide-Ensemble-based-Convolutional-Neural-Networks/raw/a827cae59a41eed58c4b94a094758a91097ef312/model/ml/trained_models/esr_9/Net-Branch_5.pt +wget -O Net-Branch_6.pt https://github.com/siqueira-hc/Efficient-Facial-Feature-Learning-with-Wide-Ensemble-based-Convolutional-Neural-Networks/raw/a827cae59a41eed58c4b94a094758a91097ef312/model/ml/trained_models/esr_9/Net-Branch_6.pt +wget -O Net-Branch_7.pt https://github.com/siqueira-hc/Efficient-Facial-Feature-Learning-with-Wide-Ensemble-based-Convolutional-Neural-Networks/raw/a827cae59a41eed58c4b94a094758a91097ef312/model/ml/trained_models/esr_9/Net-Branch_7.pt +wget -O Net-Branch_8.pt https://github.com/siqueira-hc/Efficient-Facial-Feature-Learning-with-Wide-Ensemble-based-Convolutional-Neural-Networks/raw/a827cae59a41eed58c4b94a094758a91097ef312/model/ml/trained_models/esr_9/Net-Branch_8.pt +wget -O Net-Branch_9.pt https://github.com/siqueira-hc/Efficient-Facial-Feature-Learning-with-Wide-Ensemble-based-Convolutional-Neural-Networks/raw/a827cae59a41eed58c4b94a094758a91097ef312/model/ml/trained_models/esr_9/Net-Branch_9.pt \ No newline at end of file diff --git a/gazenet/models/emotion_recognition/esr9/generator.py b/gazenet/models/emotion_recognition/esr9/generator.py new file mode 100644 index 0000000..6f9c468 --- /dev/null +++ b/gazenet/models/emotion_recognition/esr9/generator.py @@ -0,0 +1,18 @@ +import cv2 +import torchvision.transforms.functional as F + + +def pre_process_input_image(img, img_width, img_height, img_mean, img_std): + """ + Pre-processes an image for ESR-9. + :param img: (ndarray) + :return: (ndarray) image + """ + + img = cv2.resize(img, (img_width, img_height)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = F.to_tensor(img) + img = F.normalize(img, img_mean, img_std) + # img = transforms.Normalize(mean=ESR.INPUT_IMAGE_NORMALIZATION_MEAN, std=ESR.INPUT_IMAGE_NORMALIZATION_STD)(transforms.ToTensor()(img)).unsqueeze(0) + img = img.unsqueeze(0) + return img \ No newline at end of file diff --git a/gazenet/models/emotion_recognition/esr9/infer.py b/gazenet/models/emotion_recognition/esr9/infer.py new file mode 100644 index 0000000..b8767dc --- /dev/null +++ b/gazenet/models/emotion_recognition/esr9/infer.py @@ -0,0 +1,246 @@ +import os + +import numpy as np +from scipy.stats import gmean +from skimage.filters import window +import torch + +from gazenet.utils.registrar import * +from gazenet.models.emotion_recognition.esr9.generator import pre_process_input_image +from gazenet.models.emotion_recognition.esr9.model import ESR +from gazenet.models.shared_components.gradcam.model import GradCAM +from gazenet.utils.sample_processors import InferenceSampleProcessor + +MODEL_PATHS = { + "esr9": os.path.join("gazenet", "models", "emotion_recognition", "esr9","checkpoints", "pretrained_esr9_orig"), + "esr9_shared": "Net-Base-Shared_Representations.pt", + "esr9_ensembles": "Net-Branch_{}.pt"} +TRG_CLASSES = { + 0: 'Neutral', + 1: 'Happy', + 2: 'Sad', + 3: 'Surprise', + 4: 'Fear', + 5: 'Disgust', + 6: 'Anger', + 7: 'Contempt'} + +INP_IMG_WIDTH = 96 +INP_IMG_HEIGHT = 96 +INP_IMG_MEAN = (0.0, 0.0, 0.0) +INP_IMG_STD = (1.0, 1.0, 1.0) +# INP_IMG_MEAN = [149.35457 / 255., 117.06477 / 255., 102.67609 / 255.] +# INP_IMG_STD = [69.18084 / 255., 61.907074 / 255., 60.435623 / 255.] +ENSEMBLES_NUM = 9 + + +@InferenceRegistrar.register +class ESR9Inference(InferenceSampleProcessor): + def __init__(self, weights_path=MODEL_PATHS['esr9'], + shared_weights_basename=MODEL_PATHS["esr9_shared"], + ensembles_weights_baseformat=MODEL_PATHS["esr9_ensembles"], + enable_gradcam=True, ensembles_num=ENSEMBLES_NUM, trg_classes=TRG_CLASSES, + inp_img_width=INP_IMG_WIDTH, inp_img_height=INP_IMG_HEIGHT, inp_img_mean=INP_IMG_MEAN, inp_img_std=INP_IMG_STD, + device="cuda:0", width=None, height=None, **kwargs): + super().__init__(width=width, height=height, **kwargs) + self.short_name = "esr9" + self._device = device + + self.ensembles_num = ensembles_num + self.trg_classes = trg_classes + self.inp_img_width = inp_img_width + self.inp_img_height = inp_img_height + self.inp_img_mean = inp_img_mean + self.inp_img_std = inp_img_std + + # load the model + self.model = ESR(ensembles_num=ensembles_num) + self.model.base.load_state_dict(torch.load(os.path.join(weights_path, shared_weights_basename), map_location=device)) + self.model.base.to(device) + for en_idx, ensemble in enumerate(self.model.convolutional_branches, start=1): + ensemble.load_state_dict(torch.load(os.path.join(weights_path, ensembles_weights_baseformat.format(en_idx)), map_location=device)) + ensemble.to(device) + print("ESR-9 model loaded from", weights_path) + self.model.to(device) + self.model.eval() + + # load gradcam model + if enable_gradcam: + self.model_gradcam = GradCAM(self.model, device=device) + else: + self.model_gradcam = None + + + def infer_frame(self, grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, info_list, properties_list, + video_frames_list, faces_locations, source_frames_idxs, **kwargs): + frames_idxs = range(len(video_frames_list)) if source_frames_idxs is None else source_frames_idxs + for frame_id in frames_idxs: + info = {"frame_detections_" + self.short_name: { + "saliency_maps": [], # detected + "emotions": [], # detected + "affects": [], # detected + "h_bboxes": [] # processed + }} + img = video_frames_list[frame_id] + + for id, face_local in enumerate(faces_locations[frame_id]): + if not face_local: + continue + sample_emotions = [] + sample_emotions_idx = [] + # sample_saliency = [] + # sample_affect = None + + (top, right, bottom, left) = face_local + info["frame_detections_" + self.short_name]["h_bboxes"].append((left, top, right, bottom, id)) + + # crop face image + crop_img_face = img[top:bottom, left:right] + crop_img_face = pre_process_input_image(crop_img_face, self.inp_img_width, self.inp_img_height, + self.inp_img_mean, self.inp_img_std) + crop_img_face = crop_img_face.to(self._device, non_blocking=True) + # emotion, affect, emotion_idx = _predict(crop_img_face) + + # computes ensemble prediction for affect + emotion, affect = self.model(crop_img_face) + + # converts from Tensor to ndarray + affect = np.array([a[0].cpu().detach().numpy() for a in affect]) + + # normalizes arousal + affect[:, 1] = np.clip((affect[:, 1] + 1) / 2.0, 0, 1) + + # computes mean arousal and valence as the ensemble prediction + ensemble_affect = np.expand_dims(np.mean(affect, 0), axis=0) + + # concatenates the ensemble prediction to the list of affect predictions + sample_affect = np.concatenate((affect, ensemble_affect), axis=0) + info["frame_detections_" + self.short_name]["affects"].append((sample_affect, id)) + + # converts from Tensor to ndarray + emotion = np.array([e[0].cpu().detach().numpy() for e in emotion]) + + # gets number of classes + num_classes = emotion.shape[1] + + # computes votes and add label to the list of emotions + emotion_votes = np.zeros(num_classes) + for e in emotion: + e_idx = np.argmax(e) + sample_emotions_idx.append(e_idx) + sample_emotions.append(self.trg_classes[e_idx]) + emotion_votes[e_idx] += 1 + + # concatenates the ensemble prediction to the list of emotion predictions + sample_emotions.append(self.trg_classes[np.argmax(emotion_votes)]) + info["frame_detections_" + self.short_name]["emotions"].append((sample_emotions, id)) + + if self.model_gradcam is not None: + sample_saliency = self.model_gradcam.grad_cam(crop_img_face, sample_emotions_idx) + sample_saliency = np.array([s.cpu().detach().numpy() for s in sample_saliency]) + info["frame_detections_" + self.short_name]["saliency_maps"].append((sample_saliency, id)) + + info_list[frame_id].update(**info) + kept_data = self._keep_extracted_frames_data(source_frames_idxs, grabbed_video_list, grouped_video_frames_list, + grabbed_audio_list, audio_frames_list, info_list, properties_list) + return kept_data + + def preprocess_frames(self, video_frames_list=None, faces_locations=None, **kwargs): + features = super().preprocess_frames(**kwargs) + pad = features["preproc_pad_len"] + lim = features["preproc_lim_len"] + if video_frames_list is not None: + video_frames_list = list(video_frames_list) + features["video_frames_list"] = video_frames_list[:lim] + [video_frames_list[lim]] * pad + if faces_locations is not None: + faces_locations = list(faces_locations) + features["faces_locations"] = faces_locations[:lim] + [faces_locations[lim]] * pad + return features + + def annotate_frame(self, input_data, plotter, + show_det_saliency_map=True, + show_det_emotion_label=False, + show_det_valence_arousal_label=False, + show_det_head_bbox=True, + ensemble_aggregation="sum", + hanning_face=True, + keep_unmatching_ensembles=False, + enable_transform_overlays=True, + color_map=None, + **kwargs): + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties = input_data + + properties = {**properties, "show_det_saliency_map": (show_det_saliency_map, "toggle", (True, False)), + "show_det_emotion_label": (show_det_emotion_label, "toggle", (True, False)), + "show_det_valence_arousal_label": (show_det_valence_arousal_label, "toggle", (True, False)), + "show_det_head_bbox": (show_det_head_bbox, "toggle", (True, False))} + + grouped_video_frames = {**grouped_video_frames, + "PLOT": grouped_video_frames["PLOT"] + [["det_source_" + self.short_name, + "det_transformed_" + self.short_name]], + "det_source_" + self.short_name: grouped_video_frames["captured"], + "det_transformed_" + self.short_name: grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"])} + + if grabbed_video: + collated_saliency_maps = np.zeros_like(grouped_video_frames["captured"]) + if show_det_head_bbox or show_det_saliency_map or show_det_emotion_label: + for h_bbox, saliency_map, emotion, affect in zip(info["frame_detections_" + self.short_name]["h_bboxes"], + info["frame_detections_" + self.short_name]["saliency_maps"], + info["frame_detections_" + self.short_name]["emotions"], + info["frame_detections_" + self.short_name]["affects"],): + xmin_h_bbox, ymin_h_bbox, xmax_h_bbox, ymax_h_bbox, participant_id = h_bbox + if show_det_head_bbox: + frame_source = plotter.plot_bbox(grouped_video_frames["det_source_" + self.short_name], (xmin_h_bbox, ymin_h_bbox), + (xmax_h_bbox, ymax_h_bbox), color_id=participant_id) + grouped_video_frames["det_source_" + self.short_name] = frame_source + if show_det_saliency_map: + if keep_unmatching_ensembles: + filtered_saliency_map = saliency_map[0] + else: + # remove gradcams from wrong classifications + filter_idxs = [flt_idx for flt_idx, flt_cls in enumerate(emotion[0]) if + flt_cls == emotion[0][-1]] + filtered_saliency_map = np.take(saliency_map[0], filter_idxs[:-1], axis=0) + + # aggregate the ensembles into a single saliency plot + if ensemble_aggregation == "amean": + face_saliency = np.nanmean(filtered_saliency_map, axis=0) + elif ensemble_aggregation == "gmean": + face_saliency = gmean(filtered_saliency_map, axis=0) + elif ensemble_aggregation == "sum": + face_saliency = np.sum(filtered_saliency_map, axis=0) + else: + face_saliency = filtered_saliency_map[0] + # project saliency to full image + if hanning_face: + face_saliency = face_saliency * window("hann", face_saliency.shape) + face_saliency = plotter.plot_color_map(np.uint8(255 * face_saliency), color_map=color_map) + collated_saliency_maps += plotter.plot_alpha_overlay(collated_saliency_maps, face_saliency, + xy_min=(xmin_h_bbox, ymin_h_bbox), + xy_max=(xmax_h_bbox, ymax_h_bbox), alpha=1) + if show_det_emotion_label: + grouped_video_frames["det_transformed_" + self.short_name] = \ + plotter.plot_text(grouped_video_frames["det_transformed_" + self.short_name], + emotion[0][-1], + (xmin_h_bbox, ymax_h_bbox + 10), + color_id=emotion[1]) + if show_det_valence_arousal_label: + grouped_video_frames["det_transformed_" + self.short_name] = \ + plotter.plot_text(grouped_video_frames["det_transformed_" + self.short_name], + "valence: " + str(np.round(affect[0][-1][0],3)), + (xmin_h_bbox, ymax_h_bbox + 30), + color_id=emotion[1]) + grouped_video_frames["det_transformed_" + self.short_name] = \ + plotter.plot_text(grouped_video_frames["det_transformed_" + self.short_name], + "arousal: " + str(np.round(affect[0][-1][1], 3)), + (xmin_h_bbox, ymax_h_bbox + 50), + color_id=emotion[1]) + if show_det_saliency_map: + frame_transformed = \ + plotter.plot_alpha_overlay(grouped_video_frames["det_transformed_" + self.short_name], + collated_saliency_maps, + alpha=0.4 if enable_transform_overlays else 1.0) + grouped_video_frames["det_transformed_" + self.short_name] = frame_transformed + + return grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties diff --git a/gazenet/models/emotion_recognition/esr9/model.py b/gazenet/models/emotion_recognition/esr9/model.py new file mode 100644 index 0000000..39f30ce --- /dev/null +++ b/gazenet/models/emotion_recognition/esr9/model.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Implementation of ESR-9 (Siqueira et al., 2020) trained on AffectNet (Mollahosseini et al., 2017) for emotion +and affect perception. +Reference: + Siqueira, H., Magg, S. and Wermter, S., 2020. Efficient Facial Feature Learning with Wide Ensemble-based + Convolutional Neural Networks. Proceedings of the Thirty-Fourth AAAI Conference on Artificial Intelligence + (AAAI-20), pages 1–1, New York, USA. + Mollahosseini, A., Hasani, B. and Mahoor, M.H., 2017. AffectNet: A database for facial expression, valence, + and arousal computing in the wild. IEEE Transactions on Affective Computing, 10(1), pp.18-31. +""" + +__author__ = "Henrique Siqueira" +__email__ = "siqueira.hc@outlook.com" +__license__ = "MIT license" +__version__ = "1.0" + +# Standard libraries +from os import path + +# External libraries +import torch.nn.functional as F +import torch.nn as nn +import torch + + +class Base(nn.Module): + """ + The base of the network (Ensembles with Shared Representations, ESRs) is responsible for learning low- and + mid-level representations from the input data that are shared with an ensemble of convolutional branches + on top of the architecture. + In our paper (Siqueira et al., 2020), it is called shared layers or shared representations. + """ + + def __init__(self): + super(Base, self).__init__() + + # convolutional layers + self.conv1 = nn.Conv2d(3, 64, 5, 1) + self.conv2 = nn.Conv2d(64, 128, 3, 1) + self.conv3 = nn.Conv2d(128, 128, 3, 1) + self.conv4 = nn.Conv2d(128, 128, 3, 1) + + # batch-normalization layers + self.bn1 = nn.BatchNorm2d(64) + self.bn2 = nn.BatchNorm2d(128) + self.bn3 = nn.BatchNorm2d(128) + self.bn4 = nn.BatchNorm2d(128) + + # max-pooling layer + self.pool = nn.MaxPool2d(2, 2) + + def forward(self, x): + # convolutional, batch-normalization and pooling layers for representation learning + x_shared_representations = F.relu(self.bn1(self.conv1(x))) + x_shared_representations = self.pool(F.relu(self.bn2(self.conv2(x_shared_representations)))) + x_shared_representations = F.relu(self.bn3(self.conv3(x_shared_representations))) + x_shared_representations = self.pool(F.relu(self.bn4(self.conv4(x_shared_representations)))) + + return x_shared_representations + + +class ConvolutionalBranch(nn.Module): + """ + Convolutional branches that compose the ensemble in ESRs. Each branch was trained on a sub-training + set from the AffectNet dataset to learn complementary representations from the data (Siqueira et al., 2020). + Note that, the second last layer provides eight discrete emotion labels whereas the last layer provides + continuous values of arousal and valence levels. + """ + + def __init__(self): + super(ConvolutionalBranch, self).__init__() + + # convolutional layers + self.conv1 = nn.Conv2d(128, 128, 3, 1) + self.conv2 = nn.Conv2d(128, 256, 3, 1) + self.conv3 = nn.Conv2d(256, 256, 3, 1) + self.conv4 = nn.Conv2d(256, 512, 3, 1, 1) + + # batch-normalization layers + self.bn1 = nn.BatchNorm2d(128) + self.bn2 = nn.BatchNorm2d(256) + self.bn3 = nn.BatchNorm2d(256) + self.bn4 = nn.BatchNorm2d(512) + + # second last, fully-connected layer related to discrete emotion labels + self.fc = nn.Linear(512, 8) + + # last, fully-connected layer related to continuous affect levels (arousal and valence) + self.fc_dimensional = nn.Linear(8, 2) + + # pooling layers + # Max-pooling layer + self.pool = nn.MaxPool2d(2, 2) + + # global average pooling layer + self.global_pool = nn.AdaptiveAvgPool2d(1) + + def forward(self, x_shared_representations): + # convolutional, batch-normalization and pooling layers + x_conv_branch = F.relu(self.bn1(self.conv1(x_shared_representations))) + x_conv_branch = self.pool(F.relu(self.bn2(self.conv2(x_conv_branch)))) + x_conv_branch = F.relu(self.bn3(self.conv3(x_conv_branch))) + x_conv_branch = self.global_pool(F.relu(self.bn4(self.conv4(x_conv_branch)))) + x_conv_branch = x_conv_branch.view(-1, 512) + + # fully connected layer for emotion perception + discrete_emotion = self.fc(x_conv_branch) + + # application of the ReLU function to neurons related to discrete emotion labels + x_conv_branch = F.relu(discrete_emotion) + + # fully connected layer for affect perception + continuous_affect = self.fc_dimensional(x_conv_branch) + + # returns activations of the discrete emotion output layer and arousal and valence levels + return discrete_emotion, continuous_affect + + def forward_to_last_conv_layer(self, x_shared_representations): + """ + Propagates activations to the last convolutional layer of the architecture. + This method is used to generate saliency maps with the Grad-CAM algorithm (Selvaraju et al., 2017). + Reference: + Selvaraju, R.R., Cogswell, M., Das, A., Vedantam, R., Parikh, D. and Batra, D., 2017. + Grad-cam: Visual explanations from deep networks via gradient-based localization. + In Proceedings of the IEEE international conference on computer vision (pp. 618-626). + :param x_shared_representations: (ndarray) feature maps from shared layers + :return: feature maps of the last convolutional layer + """ + + # convolutional, batch-normalization and pooling layers + x_to_last_conv_layer = F.relu(self.bn1(self.conv1(x_shared_representations))) + x_to_last_conv_layer = self.pool(F.relu(self.bn2(self.conv2(x_to_last_conv_layer)))) + x_to_last_conv_layer = F.relu(self.bn3(self.conv3(x_to_last_conv_layer))) + x_to_last_conv_layer = F.relu(self.bn4(self.conv4(x_to_last_conv_layer))) + + # feature maps of the last convolutional layer + return x_to_last_conv_layer + + def forward_from_last_conv_layer_to_output_layer(self, x_from_last_conv_layer): + """ + Propagates activations to the second last, fully-connected layer (here referred as output layer). + This layer represents emotion labels. + :param x_from_last_conv_layer: (ndarray) feature maps from the last convolutional layer of this branch. + :return: (ndarray) activations of the last second, fully-connected layer of the network + """ + + # global average polling and reshape + x_to_output_layer = self.global_pool(x_from_last_conv_layer) + x_to_output_layer = x_to_output_layer.view(-1, 512) + + # output layer: emotion labels + x_to_output_layer = self.fc(x_to_output_layer) + + # returns activations of the discrete emotion output layer + return x_to_output_layer + + +class ESR(nn.Module): + """ + ESR is the unified ensemble architecture composed of two building blocks the Base and ConvolutionalBranch + classes as described below by Siqueira et al. (2020): + 'An ESR consists of two building blocks. (1) The base (class Base) of the network is an array of convolutional + layers for low- and middle-level feature learning. (2) These informative features are then shared with + independent convolutional branches (class ConvolutionalBranch) that constitute the ensemble.' + """ + + def __init__(self, ensembles_num=9): + super(ESR, self).__init__() + + # base of ESR-9 as described in the docstring (see mark 1) + self.base = Base() + + # convolutional branches that composes ESR-9 as described in the docstring (see mark 2) + self.convolutional_branches = [] + for i in range(1, ensembles_num + 1): + self.convolutional_branches.append(ConvolutionalBranch()) + + def forward(self, x): + """ + Forward method of ESR-9. + :param x: (ndarray) Input data. + :return: A list of emotions and affect values from each convolutional branch in the ensemble. + """ + + # list of emotions and affect values from the ensemble + emotions = [] + affect_values = [] + + # get shared representations + x_shared_representations = self.base(x) + + # add to the lists of predictions outputs from each convolutional branch in the ensemble + for branch in self.convolutional_branches: + output_emotion, output_affect = branch(x_shared_representations) + emotions.append(output_emotion) + affect_values.append(output_affect) + + return emotions, affect_values + + def __len__(self): + """ + ESR with nine branches trained on AffectNet (Siqueira et al., 2020). + :return: (int) Size of the ensemble + """ + return len(self.convolutional_branches) diff --git a/gazenet/models/gaze_estimation/__init__.py b/gazenet/models/gaze_estimation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/gaze_estimation/gaze360/__init__.py b/gazenet/models/gaze_estimation/gaze360/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/gaze_estimation/gaze360/checkpoints/pretrained_gaze360_orig/download_model.sh b/gazenet/models/gaze_estimation/gaze360/checkpoints/pretrained_gaze360_orig/download_model.sh new file mode 100644 index 0000000..0463e4f --- /dev/null +++ b/gazenet/models/gaze_estimation/gaze360/checkpoints/pretrained_gaze360_orig/download_model.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +wget -O model.pth.tar http://gaze360.csail.mit.edu/files/gaze360_model.pth.tar \ No newline at end of file diff --git a/gazenet/models/gaze_estimation/gaze360/generator.py b/gazenet/models/gaze_estimation/gaze360/generator.py new file mode 100644 index 0000000..c972af9 --- /dev/null +++ b/gazenet/models/gaze_estimation/gaze360/generator.py @@ -0,0 +1,34 @@ +import cv2 +import torch +import torchvision.transforms.functional as F +from PIL import Image + +def pre_process_input_image(img, img_width, img_height, img_mean, img_std): + """ + Pre-processes an image for gaze360. + :param img: (ndarray) + :return: (ndarray) image + """ + + img = cv2.resize(img, (img_width, img_height)) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = F.to_tensor(img) + img = F.normalize(img, img_mean, img_std) + # img = transforms.Normalize(mean=ESR.INPUT_IMAGE_NORMALIZATION_MEAN, std=ESR.INPUT_IMAGE_NORMALIZATION_STD)(transforms.ToTensor()(img)).unsqueeze(0) + return img + + +def spherical_to_compatible_form(tpr): + ptr = torch.zeros(tpr.size(0), 3) + ptr[:, 0] = tpr[:, 1] + ptr[:, 1] = tpr[:, 0] + ptr[:, 2] = 1 + return ptr + + +def spherical_to_cartesian(tpr): + xyz = torch.zeros(tpr.size(0),3) + xyz[:,2] = -torch.cos(tpr[:,1])*torch.cos(tpr[:,0]) + xyz[:,0] = torch.cos(tpr[:,1])*torch.sin(tpr[:,0]) + xyz[:,1] = torch.sin(tpr[:,1]) + return xyz \ No newline at end of file diff --git a/gazenet/models/gaze_estimation/gaze360/infer.py b/gazenet/models/gaze_estimation/gaze360/infer.py new file mode 100644 index 0000000..bb21680 --- /dev/null +++ b/gazenet/models/gaze_estimation/gaze360/infer.py @@ -0,0 +1,170 @@ +import os + +import numpy as np +import torch +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +import torchvision.transforms as transforms + +from gazenet.utils.registrar import * +from gazenet.models.gaze_estimation.gaze360.generator import spherical_to_cartesian, spherical_to_compatible_form, pre_process_input_image +from gazenet.models.gaze_estimation.gaze360.model import GazeLSTM +from gazenet.utils.sample_processors import InferenceSampleProcessor +from gazenet.utils.helpers import spherical_to_euler + +MODEL_PATHS = { + "gaze360": os.path.join("gazenet", "models", "gaze_estimation", "gaze360", "checkpoints", "pretrained_gaze360_orig", "model.pth.tar")} + +INP_IMG_WIDTH = 224 +INP_IMG_HEIGHT = 224 +INP_IMG_MEAN = (0.485, 0.456, 0.406) +INP_IMG_STD = (0.229, 0.224, 0.225) + + +@InferenceRegistrar.register +class Gaze360Inference(InferenceSampleProcessor): + def __init__(self, weights_file=MODEL_PATHS['gaze360'], w_fps=30, inp_img_width=INP_IMG_WIDTH, + inp_img_height=INP_IMG_HEIGHT, inp_img_mean=INP_IMG_MEAN, inp_img_std=INP_IMG_STD, + device="cuda:0", width=None, height=None, **kwargs): + super().__init__(width=width, height=height, **kwargs) + self.short_name = "gaze360" + # the original implementation skips an 8th of a frame when scanning surrounig frames + # self.w_fps_div = max(int(w_fps // 8), 1) + # we skip one frame at a time + self.w_fps_div = 1 + self._device = device + + self.inp_img_width = inp_img_width + self.inp_img_height = inp_img_height + self.inp_img_mean = inp_img_mean + self.inp_img_std = inp_img_std + # extra grouping list + self.faces_locations = [] + + # load the model + self.model = GazeLSTM() + model = torch.nn.DataParallel(self.model).to(device) + model.to(device) + checkpoint = torch.load(weights_file) + model.load_state_dict(checkpoint['state_dict']) + print("Gaze360 model loaded from", weights_file) + model.eval() + + + # adapted from: https://colab.research.google.com/drive/1AUvmhpHklM9BNt0Mn5DjSo3JRuqKkU4y#scrollTo=FKESCskkymbs + def infer_frame(self, grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, info_list, properties_list, + video_frames_list, faces_locations, source_frames_idxs=None, **kwargs): + frames_idxs = range(len(video_frames_list)) if source_frames_idxs is None else source_frames_idxs + for frame_id in frames_idxs: + info = {"frame_detections_" + self.short_name: { + "CART_gaze_poses": [], # detected + "SPHERE_gaze_poses": [], # detected + "h_bboxes": [] # processed + }} + # read image + input_image = torch.zeros(7, 3, self.inp_img_width, self.inp_img_height) + + for id, face_local in enumerate(faces_locations[frame_id]): + if not face_local: + continue + (top, right, bottom, left) = face_local + info["frame_detections_" + self.short_name]["h_bboxes"].append((left, top, right, bottom, id)) + + count = 0 + for j in range(frame_id - 3 * self.w_fps_div, frame_id + 4 * self.w_fps_div, self.w_fps_div): + if j < 0 or j >= len(faces_locations) or not faces_locations[j]: + face_img = video_frames_list[frame_id].copy() + else: + if id < len(faces_locations[j]): + face_img = video_frames_list[j].copy() + face_local = faces_locations[j][id] + else: + face_img = video_frames_list[frame_id].copy() + + (top, right, bottom, left) = face_local + # crop face image + crop_img_face = face_img[top:bottom, left:right] + + # fill the images + input_image[count, :, :, :] = pre_process_input_image(crop_img_face, + self.inp_img_width, self.inp_img_height, + self.inp_img_height, self.inp_img_std) + count = count + 1 + + # bbox, eyes = tracking_id[i][id_t] + # bbox = np.asarray(bbox).astype(int) + output_gaze, _ = self.model(input_image.view(1, 7, 3, + self.inp_img_width, self.inp_img_height).to(self._device)) + gaze = spherical_to_cartesian(output_gaze).detach().numpy() + gaze = gaze.reshape((-1)) + info["frame_detections_" + self.short_name]["CART_gaze_poses"].append((gaze[0], gaze[1], gaze[2], id)) + gaze = spherical_to_compatible_form(output_gaze).detach().numpy() + gaze = gaze.reshape((-1)) + info["frame_detections_" + self.short_name]["SPHERE_gaze_poses"].append((gaze[0], gaze[1], gaze[2], id)) + info_list[frame_id].update(**info) + + kept_data = self._keep_extracted_frames_data(source_frames_idxs, grabbed_video_list, grouped_video_frames_list, + grabbed_audio_list, audio_frames_list, info_list, properties_list) + return kept_data + + def preprocess_frames(self, video_frames_list=None, faces_locations=None, **kwargs): + features = super().preprocess_frames(**kwargs) + pad = features["preproc_pad_len"] + lim = features["preproc_lim_len"] + if video_frames_list is not None: + video_frames_list = list(video_frames_list) + features["video_frames_list"] = video_frames_list[:lim] + [video_frames_list[lim]] * pad + if faces_locations is not None: + faces_locations = list(faces_locations) + features["faces_locations"] = faces_locations[:lim] + [faces_locations[lim]] * pad + return features + + def annotate_frame(self, input_data, plotter, + show_det_gaze_axis=False, + show_det_gaze_direction_field=True, + show_det_head_bbox=True, + enable_transform_overlays=True, + color_map=None, + **kwargs): + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties = input_data + + properties = {**properties, "show_det_gaze_axis": (show_det_gaze_axis, "toggle", (True, False)), + "show_det_gaze_direction_field": (show_det_gaze_direction_field, "toggle", (True, False)), + "show_det_head_bbox": (show_det_head_bbox, "toggle", (True, False))} + + grouped_video_frames = {**grouped_video_frames, + "PLOT": grouped_video_frames["PLOT"] + [["det_source_" + self.short_name, + "det_transformed_" + self.short_name]], + "det_source_" + self.short_name: grouped_video_frames["captured"], + "det_transformed_" + self.short_name: grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"])} + + if grabbed_video: + if show_det_head_bbox or show_det_gaze_axis or show_det_gaze_direction_field: + for h_bbox, gaze_pose_cart, gaze_pose_sphere in zip(info["frame_detections_" + self.short_name]["h_bboxes"], + info["frame_detections_" + self.short_name]["CART_gaze_poses"], + info["frame_detections_" + self.short_name]["SPHERE_gaze_poses"]): + xmin_h_bbox, ymin_h_bbox, xmax_h_bbox, ymax_h_bbox, participant_id = h_bbox + if show_det_head_bbox: + frame_source = plotter.plot_bbox(grouped_video_frames["det_source_" + self.short_name], + (xmin_h_bbox, ymin_h_bbox), + (xmax_h_bbox, ymax_h_bbox), color_id=participant_id) + grouped_video_frames["det_source_" + self.short_name] = frame_source + if show_det_gaze_axis: + frame_transformed = plotter.plot_axis(grouped_video_frames["det_transformed_" + self.short_name], + (xmin_h_bbox, ymin_h_bbox), + (xmax_h_bbox, ymax_h_bbox), spherical_to_euler(gaze_pose_sphere)) + grouped_video_frames["det_transformed_" + self.short_name] = frame_transformed + if show_det_gaze_direction_field: + gaze_orig = (xmin_h_bbox, ymin_h_bbox, 0) + frame_transformed = \ + plotter.plot_conic_field(grouped_video_frames["det_transformed_" + self.short_name], gaze_orig, + gaze_pose_cart, radius_orig=1, radius_tgt=60, color_map=color_map) + grouped_video_frames["det_transformed_" + self.short_name] = \ + plotter.plot_alpha_overlay(grouped_video_frames["det_transformed_" + self.short_name], + frame_transformed, + alpha=0.4 if enable_transform_overlays else 0.5) + + return grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties diff --git a/gazenet/models/gaze_estimation/gaze360/model.py b/gazenet/models/gaze_estimation/gaze360/model.py new file mode 100644 index 0000000..7807118 --- /dev/null +++ b/gazenet/models/gaze_estimation/gaze360/model.py @@ -0,0 +1,71 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models +from torch.nn.init import normal, constant +import numpy as np + +from gazenet.models.shared_components.resnet.model import resnet18 + + +class GazeLSTM(nn.Module): + def __init__(self): + super(GazeLSTM, self).__init__() + self.img_feature_dim = 256 # the dimension of the CNN feature to represent each frame + + self.base_model = resnet18(pretrained=True) + + self.base_model.fc2 = nn.Linear(1000, self.img_feature_dim) + + self.lstm = nn.LSTM(self.img_feature_dim, self.img_feature_dim,bidirectional=True,num_layers=2,batch_first=True) + + # The linear layer that maps the LSTM with the 3 outputs + self.last_layer = nn.Linear(2*self.img_feature_dim, 3) + + + def forward(self, input): + + base_out = self.base_model(input.view((-1, 3) + input.size()[-2:])) + + base_out = base_out.view(input.size(0),7,self.img_feature_dim) + + lstm_out, _ = self.lstm(base_out) + lstm_out = lstm_out[:,3,:] + output = self.last_layer(lstm_out).view(-1,3) + + + angular_output = output[:,:2] + angular_output[:,0:1] = math.pi*nn.Tanh()(angular_output[:,0:1]) + angular_output[:,1:2] = (math.pi/2)*nn.Tanh()(angular_output[:,1:2]) + + var = math.pi*nn.Sigmoid()(output[:,2:3]) + var = var.view(-1,1).expand(var.size(0),2) + + return angular_output,var + + +class PinBallLoss(nn.Module): + def __init__(self): + super(PinBallLoss, self).__init__() + self.q1 = 0.1 + self.q9 = 1-self.q1 + + def forward(self, output_o,target_o,var_o): + q_10 = target_o-(output_o-var_o) + q_90 = target_o-(output_o+var_o) + + loss_10 = torch.max(self.q1*q_10, (self.q1-1)*q_10) + loss_90 = torch.max(self.q9*q_90, (self.q9-1)*q_90) + + + loss_10 = torch.mean(loss_10) + loss_90 = torch.mean(loss_90) + + return loss_10+loss_90 diff --git a/gazenet/models/gaze_following/__init__.py b/gazenet/models/gaze_following/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/gaze_following/videogaze/__init__.py b/gazenet/models/gaze_following/videogaze/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/gaze_following/videogaze/checkpoints/pretrained_videogaze_orig/download_model.sh b/gazenet/models/gaze_following/videogaze/checkpoints/pretrained_videogaze_orig/download_model.sh new file mode 100644 index 0000000..b728402 --- /dev/null +++ b/gazenet/models/gaze_following/videogaze/checkpoints/pretrained_videogaze_orig/download_model.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +#wget http://videogazefollow.csail.mit.edu/downloads/model.pth.tar +wget -O model.pth.tar https://www2.informatik.uni-hamburg.de/WTM/corpora/GASP/gazenet/models/gaze_following/videogaze/checkpoints/pretrained_videogaze_orig/model.pth.tar diff --git a/gazenet/models/gaze_following/videogaze/infer.py b/gazenet/models/gaze_following/videogaze/infer.py new file mode 100644 index 0000000..1b4ddde --- /dev/null +++ b/gazenet/models/gaze_following/videogaze/infer.py @@ -0,0 +1,223 @@ +import os +from itertools import zip_longest + +import cv2 +import numpy as np +import torch +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +import torchvision.transforms as transforms + +from gazenet.utils.registrar import * +from gazenet.models.gaze_following.videogaze.model import VideoGaze +from gazenet.utils.sample_processors import InferenceSampleProcessor + +MODEL_PATHS = { + "videogaze": os.path.join("gazenet", "models", "gaze_following", "videogaze", "checkpoints", "pretrained_videogaze_orig", "model.pth.tar")} + +INP_IMG_WIDTH = 227 +INP_IMG_HEIGHT = 227 +TRG_IMG_SIDE = 20 + + +@InferenceRegistrar.register +class VideoGazeInference(InferenceSampleProcessor): + def __init__(self, weights_file=MODEL_PATHS['videogaze'], + batch_size=1, w_fps=30, w_size=2, + trg_img_side=TRG_IMG_SIDE, inp_img_width=INP_IMG_WIDTH, inp_img_height=INP_IMG_HEIGHT, + device="cuda:0", width=None, height=None, **kwargs): + super().__init__(width=width, height=height, w_size=w_size, **kwargs) + self.short_name = "vidgaze" + # the original implementation skips frames + # self.w_fps = w_fps + # we skip one frame at a time + self.w_fps = 1 + self._device = device + + self.inp_img_width = inp_img_width + self.inp_img_height = inp_img_height + self.trg_img_side = trg_img_side + + # load the model + self.model = VideoGaze(batch_size=batch_size, side=trg_img_side) + checkpoint = torch.load(weights_file) + self.model.load_state_dict(checkpoint['state_dict']) + print("VideoGaze model loaded from", weights_file) + self.model.to(device) + cudnn.benchmark = True + + def infer_frame(self, grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, info_list, properties_list, + video_frames_list, faces_locations, source_frames_idxs=None, **kwargs): + trans = transforms.ToTensor() + + target_frame = torch.FloatTensor(self.w_size, 3, self.inp_img_width, self.inp_img_height) + target_frame = target_frame.to(self._device) + + eyes = torch.zeros(self.w_size, 3) + eyes = eyes.to(self._device) + + # info_list = [] + # initialize the info_list with the info structure since target sal. maps can appear in frames other than curr. + for f_idx in range(len(info_list)): + info_list[f_idx].update(**{"frame_detections_" + self.short_name: + { + "saliency_maps": [], # detected + "h_bboxes": [] # processed + }}) + + frames_idxs = range(len(video_frames_list)) if source_frames_idxs is None else source_frames_idxs + for frame_id in frames_idxs: + # info = {"frame_detections_" + self.short_name: { + # "saliency_maps": [], # detected + # "h_bboxes": [] # processed + # }} + + # print('Processing of frame %d out of %d' % (i, len(video_frames_list))) + + # avoid the problems with the video limit + # if self.w_fps * (self.w_size - 1) // 2 < frame_id < (len(video_frames_list) - self.w_fps * (self.w_size - 1) // 2) \ + # and len(faces_locations) > 0: + if len(faces_locations) > 0: + # read the image + img = video_frames_list[frame_id] + if img is None: + continue + h, w, c = img.shape + + # resize image + img_resized = cv2.resize(img, (self.inp_img_width, self.inp_img_height)) + for id, face_local in enumerate(faces_locations[frame_id]): + if not face_local: + continue + (top, right, bottom, left) = face_local + info_list[frame_id]["frame_detections_" + self.short_name]["h_bboxes"].append((left, top, right, bottom, id)) + # crop face image + crop_img_face = img[top:bottom, left:right] + crop_img_face = cv2.resize(crop_img_face, (self.inp_img_width, self.inp_img_height)) + + # compute the center of the head and estimate the eyes location + eyes[:, 0] = (right + left) / (2 * w) + eyes[:, 1] = (top + bottom) / (2 * h) + + # fill the tensors for the exploring window. Face and source frame are the same + source_frame = trans(img_resized).view(1, 3, self.inp_img_width, self.inp_img_height) + face_frame = trans(crop_img_face).view(1, 3, self.inp_img_width, self.inp_img_height) + for j in range(self.w_size - 1): + trans_im = trans(img_resized).view(1, 3, self.inp_img_width, self.inp_img_height) + source_frame = torch.cat((source_frame, trans_im), 0) + crop_img = trans(crop_img_face).view(1, 3, self.inp_img_width, self.inp_img_height) + face_frame = torch.cat((face_frame, crop_img), 0) + + # fill the targets for the exploring window. + for j in range(self.w_size): + # target_im = video_frames_list[frame_id + self.w_fps * (j - ((self.w_size - 1) // 2))] + target_im = video_frames_list[j] + target_im = cv2.resize(target_im, (self.inp_img_width, self.inp_img_height)) + target_im = trans(target_im) + target_frame[j, :, :, :] = target_im + + # run the model + source_frame = source_frame.to(self._device, non_blocking=True) + target_frame = target_frame.to(self._device, non_blocking=True) + face_frame = face_frame.to(self._device, non_blocking=True) + eyes = eyes.to(self._device, non_blocking=True) + source_frame_var = torch.autograd.Variable(source_frame) + target_frame_var = torch.autograd.Variable(target_frame) + face_frame_var = torch.autograd.Variable(face_frame) + eyes_var = torch.autograd.Variable(eyes) + output, sigmoid = self.model(source_frame_var, target_frame_var, face_frame_var, eyes_var) + + # recover the data from the variables + sigmoid = sigmoid.data + output = output.data + + # pick the maximum value for the frame selection + v, ids = torch.sort(sigmoid, dim=0, descending=True) + index_target = ids[0, 0] + + # pick the video_frames_list corresponding to the maximum value + # target_im = frame_list[i + w_fps * (index_target - ((N - 1) // 2))].copy() + output_target = output[index_target, :, :, :].view(self.trg_img_side, self.trg_img_side).cpu().numpy() + + # compute the gaze location + # heatmaps += output_target + # info["frame_detections_" + self.short_name]["saliency_maps"].append((output_target, id)) + # info_list[i + w_fps * (index_target.cpu().numpy() - ((self.w_size - 1) // 2))][ + info_list[index_target.cpu().numpy()][ + "frame_detections_" + self.short_name]["saliency_maps"].append((output_target, id)) + # info_list.append(info) + # kept_data = self.keep_extracted_frames_data(None, grabbed_video_list, grouped_video_frames_list, + # grabbed_audio_list, audio_frames_list, info_list, properties_list) + return grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, info_list, properties_list + + def preprocess_frames(self, video_frames_list=None, faces_locations=None, **kwargs): + features = super().preprocess_frames(**kwargs) + pad = features["preproc_pad_len"] + lim = features["preproc_lim_len"] + if video_frames_list is not None: + video_frames_list = list(video_frames_list) + features["video_frames_list"] = video_frames_list[:lim] + [video_frames_list[lim]] * pad + if faces_locations is not None: + faces_locations = list(faces_locations) + features["faces_locations"] = faces_locations[:lim] + [faces_locations[lim]] * pad + return features + + def annotate_frame(self, input_data, plotter, + show_det_saliency_map=True, + show_det_gaze_target=True, + show_det_head_bbox=True, + enable_transform_overlays=True, + color_map=None, + **kwargs): + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties = input_data + + properties = {"show_det_saliency_map": (show_det_saliency_map, "toggle", (True, False)), + "show_det_gaze_target": (show_det_gaze_target, "toggle", (True, False)), + "show_det_head_bbox": (show_det_head_bbox, "toggle", (True, False))} + + grouped_video_frames = {**grouped_video_frames, + "PLOT": grouped_video_frames["PLOT"] + [["det_source_" + self.short_name, + "det_transformed_" + self.short_name]], + "det_source_" + self.short_name: grouped_video_frames["captured"], + "det_transformed_" + self.short_name: grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"])} + if grabbed_video: + collated_saliency_maps = np.zeros((self.trg_img_side, self.trg_img_side)) + if show_det_head_bbox or show_det_saliency_map or show_det_gaze_target: + for h_bbox, saliency_map in zip_longest(info["frame_detections_" + self.short_name]["h_bboxes"], + info["frame_detections_" + self.short_name]["saliency_maps"]): + if h_bbox is not None: + xmin_h_bbox, ymin_h_bbox, xmax_h_bbox, ymax_h_bbox, participant_id = h_bbox + if show_det_head_bbox: + frame_source = plotter.plot_bbox(grouped_video_frames["det_source_" + self.short_name], + (xmin_h_bbox, ymin_h_bbox), + (xmax_h_bbox, ymax_h_bbox), color_id=participant_id) + grouped_video_frames["det_source_" + self.short_name] = frame_source + if saliency_map is not None: + if show_det_saliency_map: + collated_saliency_maps += saliency_map[0] + if show_det_gaze_target: + map = np.reshape(saliency_map[0], (self.trg_img_side * self.trg_img_side)) + int_class = np.argmax(map) + x_class = int_class % self.trg_img_side + y_class = (int_class - x_class) // self.trg_img_side + y_float = y_class / self.trg_img_side + x_float = x_class / self.trg_img_side + x_point = np.floor(x_float * grouped_video_frames["det_transformed_" + self.short_name].shape[1]).astype(np.int32) + y_point = np.floor(y_float * grouped_video_frames["det_transformed_" + self.short_name].shape[0]).astype(np.int32) + frame_transformed = plotter.plot_point(grouped_video_frames["det_transformed_" + self.short_name], + (x_point, y_point), color_id=saliency_map[1], radius=10) + grouped_video_frames["det_transformed_" + self.short_name] = frame_transformed + if show_det_saliency_map: + frame_transformed = plotter.plot_color_map(np.uint8(255 * collated_saliency_maps), + color_map=color_map) + if enable_transform_overlays: + frame_transformed = plotter.plot_alpha_overlay(grouped_video_frames["det_transformed_" + self.short_name], + frame_transformed, alpha=0.4) + grouped_video_frames["det_transformed_" + self.short_name] = frame_transformed + + return grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties + + diff --git a/gazenet/models/gaze_following/videogaze/model.py b/gazenet/models/gaze_following/videogaze/model.py new file mode 100644 index 0000000..77fad83 --- /dev/null +++ b/gazenet/models/gaze_following/videogaze/model.py @@ -0,0 +1,296 @@ +# TODO (fabawi): move Alexnet to the shared components package +import argparse +import os +import shutil +import time +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models +import numpy as np +import torch.utils.model_zoo as model_zoo +from torch.autograd.variable import Variable + + +MODEL_URLS = { + 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', +} + + +class AlexNet(nn.Module): + + def __init__(self, num_classes=1000): + super(AlexNet, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + #nn.MaxPool2d(kernel_size=3, stride=2), + ) + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), 256,13,13) + return x + + +class HeadPoseAlexnet(nn.Module): + + def __init__(self, num_classes=1000): + + super(HeadPoseAlexnet, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + ) + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + x = self.features(x) + return x + + +class HeadPose(nn.Module): + + def __init__(self, num_classes=1000): + super(HeadPose, self).__init__() + self.alexnet = HeadPoseAlexnet(num_classes=num_classes) + self.alexnet.load_state_dict(model_zoo.load_url(MODEL_URLS['alexnet'])) + self.linear1 = nn.Linear(256*6*6,500) + self.threshold1 = nn.Threshold(0, 1e-6) + self.linear2 = nn.Linear(500,200) + self.threshold2 = nn.Threshold(0, 1e-6) + self.linear3 = nn.Linear(200,4) + + def forward(self, x): + x = self.alexnet(x) + x = x.view(x.size(0), 256*6*6) + x = self.linear1(x) + x = self.threshold1(x) + x = self.linear2(x) + x = self.threshold2(x) + x = self.linear3(x) + x = x.view(x.size(0), 4) + return x + + +class TransformationPathway(nn.Module): + + def __init__(self, num_classes=1000): + super(TransformationPathway, self).__init__() + self.alexnet = HeadPoseAlexnet(num_classes=num_classes) + self.alexnet.load_state_dict(model_zoo.load_url(MODEL_URLS['alexnet'])) + + self.conv_features = nn.Sequential( + nn.Conv2d(512,100,kernel_size=3,padding=1,stride=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2)) + self.linear_features = nn.Sequential( + nn.Linear(400, 200), + nn.ReLU(inplace=True), + nn.Linear(200, 100), + nn.ReLU(inplace=True) + ) + self.final_linear = nn.Linear(100,7) + self.final_linear.bias.data[0] = 0 + self.final_linear.bias.data[1] = 0 + self.final_linear.bias.data[2] = 0.3 + self.final_linear.bias.data[3] = 0 + self.final_linear.bias.data[4] = 0 + self.final_linear.bias.data[5] = 0 + self.final_linear.bias.data[6] = 0.4 + + + def forward(self, source, target): + source_conv = self.alexnet(source) + target_conv = self.alexnet(target) + all_conv = torch.cat((source_conv,target_conv),1) + conv_output = self.conv_features(all_conv).view(-1,400) + fc_output = self.linear_features(conv_output) + fc_output = self.final_linear(fc_output) + angles = torch.mul(nn.Tanh()(fc_output[:,3:6]),np.pi) + R = self.rotation_tensor(angles[:,0],angles[:,1] , angles[:,2], angles.size(0)) + x_t = fc_output[:,0:3] + sigmoid = nn.Hardtanh(0,1)(fc_output[:,6]) + return R,x_t,sigmoid + + def rotation_tensor(self,theta, phi, psi, n_comps): + rot_x = Variable(torch.zeros(n_comps, 3, 3).cuda()) + rot_y = Variable(torch.zeros(n_comps, 3, 3).cuda()) + rot_z = Variable(torch.zeros(n_comps, 3, 3).cuda()) + + rot_x[:, 0, 0] = 1 + rot_x[:, 1, 1] = theta.cos() + rot_x[:, 1, 2] = theta.sin() + rot_x[:, 2, 1] = -theta.sin() + rot_x[:, 2, 2] = theta.cos() + + rot_y[:, 0, 0] = phi.cos() + rot_y[:, 0, 2] = -phi.sin() + rot_y[:, 1, 1] = 1 + rot_y[:, 2, 0] = phi.sin() + rot_y[:, 2, 2] = phi.cos() + + rot_z[:, 0, 0] = psi.cos() + rot_z[:, 0, 1] = -psi.sin() + rot_z[:, 1, 0] = psi.sin() + rot_z[:, 1, 1] = psi.cos() + rot_z[:, 2, 2] = 1 + rot_2 = torch.bmm(rot_y, rot_x) + return torch.bmm(rot_z, rot_2) + + +class ConeProjection(nn.Module): + def __init__(self, batch_size=100): + super(ConeProjection, self).__init__() + self.batch_size = batch_size + + + def forward(self,eyes,v,R,t,alpha): + + P = Variable(torch.zeros(eyes.size(0), 169, 3).cuda()) + for b in range(eyes.size(0)): + for i in range(13): + for j in range(13): + k = 13*i + j + P[b,k,0] = (i-6)/6 + P[b,k,1] = (j-6)/6 + P[b,k,2] = 1 + + + id_matrix = Variable(torch.zeros(eyes.size(0), 3, 3).cuda()) + id_matrix[:,0,0] = alpha + id_matrix[:,1,1] = alpha + id_matrix[:,2,2] = alpha + + #Normalize vector! + v = v / v.norm(2, 1).clamp(min=0.00000000000001).view(-1,1).expand_as(v) + + + v_matrix = torch.bmm(v.view(-1,3,1),v.view(-1,1,3)) + + + M = v_matrix-id_matrix + + sigma_matrix = Variable(torch.zeros(eyes.size(0), 3, 3).cuda()) + + v1 = R[:,:,0].contiguous().view(-1,3) + v2 = R[:,:,1].contiguous().view(-1,3) + + u_e = eyes + + v11 = v1.contiguous().view(-1,1,3) + v21 = v2.contiguous().view(-1,1,3) + v12 = v1.contiguous().view(-1,3,1) + v22 = v2.contiguous().view(-1,3,1) + u_e1 = u_e.contiguous().view(-1,1,3) + u_e2 = u_e.contiguous().view(-1,3,1) + t1 = t.contiguous().view(-1,1,3) + t2 = t.contiguous().view(-1,3,1) + + + sigma_matrix[:,0:1,0:1] = torch.bmm(v11,torch.bmm(M,v12)) + sigma_matrix[:,0:1,1:2] = torch.bmm(v11,torch.bmm(M,v22)) + sigma_matrix[:,0:1,2:3] = torch.bmm(v11,torch.bmm(M,(t2-u_e2))) + sigma_matrix[:,1:2,0:1] = torch.bmm(v21,torch.bmm(M,v12)) + sigma_matrix[:,1:2,1:2] = torch.bmm(v21,torch.bmm(M,v22)) + sigma_matrix[:,1:2,2:3] = torch.bmm(v21,torch.bmm(M,t2-u_e2)) + sigma_matrix[:,2:3,0:1] = torch.bmm(t1-u_e1,torch.bmm(M,v12)) + sigma_matrix[:,2:3,1:2] = torch.bmm(t1-u_e1,torch.bmm(M,v22)) + sigma_matrix[:,2:3,2:3] = torch.bmm(t1-u_e1,torch.bmm(M,t2-u_e2)) + + sigma_matrix_all = sigma_matrix.view(-1,1,3,3).expand(eyes.size(0),169,3,3).contiguous().view(-1,3,3) + P1 = P.contiguous().view(-1,1,3) + P2 = P.contiguous().view(-1,3,1) + sum_all = torch.bmm(P1,torch.bmm(sigma_matrix_all,P2)).contiguous().view(-1,169) + + return sum_all + + +class VideoGaze(nn.Module): + def __init__(self, batch_size=200, side=20): + super(VideoGaze, self).__init__() + self.saliency_pathway = AlexNet() + self.saliency_pathway.load_state_dict(model_zoo.load_url(MODEL_URLS['alexnet'])) + self.last_conv = nn.Conv2d(256, 1, kernel_size=1, stride=1) + self.relu_saliency = nn.ReLU(inplace=True) + self.cone_pathway = HeadPose() + self.projection = ConeProjection(batch_size) + self.transformation_path = TransformationPathway() + self.linear_final = nn.Linear(169,side*side) + self.sigmoid1 = nn.Linear(169*2,200) + self.sigmoid2 = nn.Linear(200,1) + self.last_convolution = nn.Conv2d(1, 1, kernel_size=1, stride=1) + self.side = side + + def forward(self, source,target,face,eyes): + saliency_256 = self.saliency_pathway(target) + saliency_output = self.last_conv(saliency_256) + saliency_output = self.relu_saliency(saliency_output) + saliency_output = saliency_output.view(-1,169) + cone_parameters = self.cone_pathway(face) + head_v = cone_parameters[:,0:3] + variance = nn.Hardtanh(0.5, 0.99)(cone_parameters[:,3]) + R,t,sigmoid = self.transformation_path(source,target) + projection = self.projection(eyes,head_v,R,t,variance) + projection_simoid = torch.mul(projection,sigmoid.view(-1,1).expand_as(projection)) + + input_sigmoid = torch.cat((saliency_output,projection_simoid),1) + output_sigmoid_l1 = nn.ReLU()(self.sigmoid1(input_sigmoid)) + output_sigmoid_l2 = self.sigmoid2(output_sigmoid_l1) + output_sigmoid_l2 = nn.Sigmoid()(output_sigmoid_l2) + + output = torch.mul(projection_simoid,saliency_output) + output = self.linear_final(output) + # softmax on output + output = nn.Softmax(1)(output) + # or min max normalize the output + # output -= output.min(1, keepdim=True)[0] + # output /= output.max(1, keepdim=True)[0] + + output = output.view(-1,1,self.side,self.side) + + return output,output_sigmoid_l2 + diff --git a/gazenet/models/saliency_prediction/__init__.py b/gazenet/models/saliency_prediction/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/saliency_prediction/avinet/__init__.py b/gazenet/models/saliency_prediction/avinet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/saliency_prediction/avinet/checkpoints/pretrained_avinet_orig/download_model.sh b/gazenet/models/saliency_prediction/avinet/checkpoints/pretrained_avinet_orig/download_model.sh new file mode 100644 index 0000000..e9eccc0 --- /dev/null +++ b/gazenet/models/saliency_prediction/avinet/checkpoints/pretrained_avinet_orig/download_model.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +# Here , the downloaded DAVE weights refer to the avinet model trained on the AVE dataset and **NOT** the DAVE model +# https://iiitaphyd-my.sharepoint.com/personal/samyak_j_research_iiit_ac_in/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fsamyak%5Fj%5Fresearch%5Fiiit%5Fac%5Fin%2FDocuments%2FVideo%20Saliency%2FAViNet%20Pretrained%2Ezip&parent=%2Fpersonal%2Fsamyak%5Fj%5Fresearch%5Fiiit%5Fac%5Fin%2FDocuments%2FVideo%20Saliency&originalPath=aHR0cHM6Ly9paWl0YXBoeWQtbXkuc2hhcmVwb2ludC5jb20vOnU6L2cvcGVyc29uYWwvc2FteWFrX2pfcmVzZWFyY2hfaWlpdF9hY19pbi9FWFlxNVdpU2JoOUtxOVJfbi1HcjN5QUJSeUtQU2t4TTdST0xnLXpQRFhWX3FBP3J0aW1lPU5XN3BjZzdHMkVn + +wget -O AViNet_Dave.pt https://www2.informatik.uni-hamburg.de/WTM/corpora/GASP/gazenet/models/saliency_prediction/avinet/checkpoints/pretrained_avinet_orig/AViNet_Dave.pt +wget -O ViNet_Dave.pt https://www2.informatik.uni-hamburg.de/WTM/corpora/GASP/gazenet/models/saliency_prediction/avinet/checkpoints/pretrained_avinet_orig/ViNet_Dave.pt diff --git a/gazenet/models/saliency_prediction/avinet/generator.py b/gazenet/models/saliency_prediction/avinet/generator.py new file mode 100644 index 0000000..dd1eeea --- /dev/null +++ b/gazenet/models/saliency_prediction/avinet/generator.py @@ -0,0 +1,60 @@ +import torch +import torchvision.transforms.functional as F +import numpy as np +import librosa as sf +from PIL import Image +import cv2 + +def normalize_data(data): + data_min = np.min(data) + data_max = np.max(data) + data_norm = np.clip((data - data_min) * + (255.0 / (data_max - data_min)), + 0, 255).astype(np.uint8) + return data_norm + +def create_data_packet(in_data, frame_number, frames_len=16): + in_data = np.array(in_data) + n_frame = in_data.shape[0] + # if the frame number is larger, we just use the last sound one heard + frame_number = min(frame_number, n_frame) + starting_frame = frame_number - frames_len + 1 + # ensure we do not have any negative video_frames_list + starting_frame = max(0, starting_frame) + data_pack = in_data + # data_pack = in_data[starting_frame:frame_number+1, :] + return data_pack, frames_len # frame_number + + +def get_wav_features(features, frame_number, frames_len=16): + + audio_data, valid_frame_number = create_data_packet(features, frame_number, frames_len=frames_len) + return torch.from_numpy(audio_data).float().view(1,-1,1), valid_frame_number + + +def load_video_frames(frames_list, last_frame_idx, valid_frame_idx, img_mean, img_std, img_width, img_height, frames_len=16): + # load video video_frames_list, process them and return a suitable tensor + frame_number = min(last_frame_idx, valid_frame_idx) + start_frame_number = frame_number - frames_len + 1 + start_frame_number = max(0, start_frame_number) + frames_list_idx = [f for f in range(start_frame_number, frame_number)] + if len(frames_list_idx) < frames_len: + nsh = frames_len - len(frames_list_idx) + frames_list_idx = np.concatenate((np.tile(frames_list_idx[0], (nsh)), frames_list_idx), axis=0) + frames = [] + for i in range(len(frames_list_idx)): + # TODO (fabawi): loading the first frame on failure is not ideal. Find a better way + try: + img = cv2.resize(frames_list[frames_list_idx[i]].copy(), (img_width, img_height)) + except: + try: + img = cv2.resize(frames_list[frames_list_idx[0]].copy(), (img_width, img_height)) + except: + img = np.zeros((img_height, img_width, 3), dtype=np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + # img = img.convert('RGB') + img = F.to_tensor(img) + img = F.normalize(img, img_mean, img_std) + frames.append(img) + frames = torch.stack(frames, dim=0) + return frames.permute(1, 0, 2, 3) # clip = clip.permute((0,2,1,3,4)) diff --git a/gazenet/models/saliency_prediction/avinet/infer.py b/gazenet/models/saliency_prediction/avinet/infer.py new file mode 100644 index 0000000..51c6d94 --- /dev/null +++ b/gazenet/models/saliency_prediction/avinet/infer.py @@ -0,0 +1,174 @@ + +import re +import os + +import torch +import numpy as np +import torch.backends.cudnn as cudnn + +from gazenet.utils.registrar import * +from gazenet.utils.sample_processors import InferenceSampleProcessor + +import gazenet.models.saliency_prediction.avinet.model as avinet_model +from gazenet.models.saliency_prediction.avinet.generator import load_video_frames, get_wav_features, normalize_data + +MODEL_PATHS = { + "avinet": os.path.join("gazenet", "models", "saliency_prediction", "avinet", "checkpoints", "pretrained_avinet_orig", "AViNet_Dave.pt"), + "vinet": os.path.join("gazenet", "models", "saliency_prediction", "avinet", "checkpoints", "pretrained_avinet_orig", "ViNet_Dave.pt")} + +INP_IMG_WIDTH = 384 +INP_IMG_HEIGHT = 224 +INP_IMG_MEAN = (0.485, 0.456, 0.406) +INP_IMG_STD = (0.229, 0.224, 0.225) + +# IMG_MEAN = [0,0,0] +# IMG_STD = [1,1,1] +# AUD_MEAN = [114.7748 / 255.0, 107.7354 / 255.0, 99.4750 / 255.0] +FRAMES_LEN = 32 + + +@InferenceRegistrar.register +class AViNetInference(InferenceSampleProcessor): + + def __init__(self, weights_file=None, w_size=32, audiovisual=False, + frames_len=FRAMES_LEN, + inp_img_width=INP_IMG_WIDTH, inp_img_height=INP_IMG_HEIGHT, inp_img_mean=INP_IMG_MEAN, inp_img_std=INP_IMG_STD, + device="cuda:0", width=None, height=None, **kwargs): + self.short_name = "avinet" + self._device = device + + self.frames_len = frames_len + self.inp_img_width = inp_img_width + self.inp_img_height = inp_img_height + self.inp_img_mean = inp_img_mean + self.inp_img_std = inp_img_std + + if weights_file is None: + if audiovisual: + weights_file = MODEL_PATHS['avinet'] + else: + weights_file = MODEL_PATHS['vinet'] + + super().__init__(width=width, height=height, w_size=w_size, **kwargs) + + # load the model + if audiovisual: + self.model = avinet_model.VideoAudioSaliencyModel( + transformer_in_channel=32, + nhead=4, + use_transformer=False, # False + num_encoder_layers=3, + use_upsample=True, + num_hier=3, + num_clips=frames_len + ) + + else: + self.model = avinet_model.VideoSaliencyModel( + use_upsample=True, + num_hier=3, + num_clips=frames_len + ) + + # self.model.load_state_dict(self.load_state_dict(weights_file, device), strict=True) + self.model.load_state_dict(torch.load(weights_file)) + print("AViNet model loaded from", weights_file) + self.model = self.model.to(device) + cudnn.benchmarks = False + self.model.eval() + + @staticmethod + def _load_state_dict_(filepath, device): + if os.path.isfile(filepath): + # print("=> loading checkpoint '{}'".format(filepath)) + checkpoint = torch.load(filepath, map_location=torch.device(device)) + + pattern = re.compile(r'module+\.*') + state_dict = checkpoint['state_dict'] + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = re.sub('module.', '', key) + state_dict[new_key] = state_dict[key] + del state_dict[key] + return state_dict + + def infer_frame(self, grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, info_list, properties_list, + video_frames_list, hann_audio_frames, valid_audio_frames_len=None, source_frames_idxs=None, **kwargs): + if valid_audio_frames_len is None: + valid_audio_frames_len = self.frames_len + audio_data = hann_audio_frames.to(self._device) + audio_data = torch.unsqueeze(audio_data, 0) + frames_idxs = range(len(grouped_video_frames_list)) if source_frames_idxs is None else source_frames_idxs + for f_idx, frame_id in enumerate(frames_idxs): + info = {"frame_detections_" + self.short_name: { + "saliency_maps": [], "video_saliency_maps": [], "audio_saliency_maps": [], # detected + }} + video_frames_tensor = load_video_frames(video_frames_list[:frame_id+1], + frame_id+1, + valid_audio_frames_len, + img_width=self.inp_img_width, img_height=self.inp_img_height, + img_mean=self.inp_img_mean, img_std=self.inp_img_std, + frames_len=self.frames_len) + with torch.no_grad(): + video_frames = video_frames_tensor.to(self._device) + video_frames = torch.unsqueeze(video_frames, 0) + prediction = self.model(video_frames, audio_data) + + prediction_l = prediction + prediction_l = torch.sigmoid(prediction_l) + saliency = prediction_l.cpu().data.numpy() + saliency = np.squeeze(saliency) + saliency = normalize_data(saliency) + info["frame_detections_" + self.short_name]["saliency_maps"].append((saliency, -1)) + info_list[frame_id].update(**info) + + kept_data = self._keep_extracted_frames_data(source_frames_idxs, grabbed_video_list, grouped_video_frames_list, + grabbed_audio_list, audio_frames_list, info_list, properties_list) + return kept_data + + def preprocess_frames(self, video_frames_list=None, hann_audio_frames=None, **kwargs): + features = super().preprocess_frames(**kwargs) + pad = features["preproc_pad_len"] + lim = features["preproc_lim_len"] + if video_frames_list is not None: + video_frames_list = list(video_frames_list) + features["video_frames_list"] = video_frames_list[:lim] + [video_frames_list[lim]] * pad + if hann_audio_frames is not None: + features["hann_audio_frames"], features["valid_audio_frames_len"] = \ + get_wav_features(list(hann_audio_frames), self.frames_len, frames_len=self.frames_len) + return features + + def annotate_frame(self, input_data, plotter, + show_det_saliency_map=True, + enable_transform_overlays=True, + color_map=None, + **kwargs): + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties = input_data + + properties = {**properties, "show_det_saliency_map": (show_det_saliency_map, "toggle", (True, False))} + + grouped_video_frames = {**grouped_video_frames, + "PLOT": grouped_video_frames["PLOT"] + [["det_source_" + self.short_name, + "det_transformed_" + self.short_name]], + "det_source_" + self.short_name: grouped_video_frames["captured"], + "det_transformed_" + self.short_name: grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"])} + + for saliency_map_name, frame_name in zip(["saliency_maps"],[""]): + if grabbed_video: + if show_det_saliency_map: + saliency_map = info["frame_detections_" + self.short_name][saliency_map_name][0][0] + frame_transformed = plotter.plot_color_map(saliency_map, color_map=color_map) + if enable_transform_overlays: + frame_transformed = plotter.plot_alpha_overlay(grouped_video_frames["det_transformed_" + + frame_name + self.short_name], + frame_transformed, alpha=0.4) + else: + frame_transformed = plotter.resize(frame_transformed, + height=grouped_video_frames["det_transformed_" + frame_name + + self.short_name].shape[0], + width=grouped_video_frames["det_transformed_" + frame_name + + self.short_name].shape[1]) + grouped_video_frames["det_transformed_" + frame_name + self.short_name] = frame_transformed + return grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties diff --git a/gazenet/models/saliency_prediction/avinet/model.py b/gazenet/models/saliency_prediction/avinet/model.py new file mode 100644 index 0000000..dbff570 --- /dev/null +++ b/gazenet/models/saliency_prediction/avinet/model.py @@ -0,0 +1,694 @@ +import os +from collections import OrderedDict + +import torch +from torch import nn + +from gazenet.models.shared_components.conv3d import model as conv3d +from gazenet.models.shared_components.transformer import model as transformer +from gazenet.models.shared_components.soundnet8 import model as soundnet + + +class VideoSaliencyModel(nn.Module): + def __init__(self, + transformer_in_channel=32, + nhead=4, + use_upsample=True, + num_hier=3, + num_clips=32 + ): + super(VideoSaliencyModel, self).__init__() + + self.backbone = BackBoneS3D() + self.num_hier = num_hier + if use_upsample: + if num_hier == 0: + self.decoder = DecoderConvUpNoHier() + elif num_hier == 1: + self.decoder = DecoderConvUp1Hier() + elif num_hier == 2: + self.decoder = DecoderConvUp2Hier() + elif num_hier == 3: + if num_clips == 8: + self.decoder = DecoderConvUp8() + elif num_clips == 16: + self.decoder = DecoderConvUp16() + elif num_clips == 32: + self.decoder = DecoderConvUp() + elif num_clips == 48: + self.decoder = DecoderConvUp48() + else: + # TODO (fabawi): this decoder does not exist but it's not used anyways + # self.decoder = DecoderConvT() + pass + + def forward(self, x): + [y0, y1, y2, y3] = self.backbone(x) + if self.num_hier == 0: + return self.decoder(y0) + if self.num_hier == 1: + return self.decoder(y0, y1) + if self.num_hier == 2: + return self.decoder(y0, y1, y2) + if self.num_hier == 3: + return self.decoder(y0, y1, y2, y3) + + +class VideoAudioSaliencyFusionModel(nn.Module): + def __init__(self, + use_transformer=True, + transformer_in_channel=512, + num_encoder_layers=3, + nhead=4, + use_upsample=True, + num_hier=3, + num_clips=32 + ): + super(VideoAudioSaliencyFusionModel, self).__init__() + self.use_transformer = use_transformer + self.visual_model = VideoSaliencyModel( + transformer_in_channel=transformer_in_channel, + nhead=nhead, + use_upsample=use_upsample, + num_hier=num_hier, + num_clips=num_clips + ) + + self.conv_in_1x1 = nn.Conv3d(in_channels=1024, out_channels=transformer_in_channel, kernel_size=1, stride=1, + bias=True) + self.transformer = transformer.Transformer( + transformer_in_channel, + hidden_size=transformer_in_channel, + nhead=nhead, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=-1, + max_len=4 * 7 * 12 + 3, + ) + + self.audionet = soundnet.SoundNet() + self.audio_conv_1x1 = nn.Conv2d(in_channels=1024, out_channels=transformer_in_channel, kernel_size=1, stride=1, + bias=True) + self.audionet.load_state_dict(torch.load(os.path.join("gazenet", "models", "saliency_prediction", "avinet", "checkpoints", 'soundnet8_final.pth'))) + print("Loaded SoundNet Weights") + for param in self.audionet.parameters(): + param.requires_grad = True + + self.maxpool = nn.MaxPool3d((4, 1, 1), stride=(2, 1, 2), padding=(0, 0, 0)) + self.bilinear = nn.Bilinear(42, 3, 4 * 7 * 12) + + def forward(self, x, audio): + audio = self.audionet(audio) + # print(audio.size()) + audio = self.audio_conv_1x1(audio) + audio = audio.flatten(2) + # print("audio", audio.shape) + + [y0, y1, y2, y3] = self.visual_model.backbone(x) + y0 = self.conv_in_1x1(y0) + y0 = y0.flatten(2) + # print("video", y0.shape) + + fused_out = torch.cat((y0, audio), 2) + # print("fused_out", fused_out.size()) + fused_out = fused_out.permute((2, 0, 1)) + fused_out = self.transformer(fused_out, -1) + + fused_out = fused_out.permute((1, 2, 0)) + + video_features = fused_out[..., :4 * 7 * 12] + audio_features = fused_out[..., 4 * 7 * 12:] + + # print("separate", video_features.shape, audio_features.shape) + + video_features = video_features.view(video_features.size(0), video_features.size(1), 4, 7, 12) + audio_features = torch.mean(audio_features, dim=2) + + audio_features = audio_features.view(audio_features.size(0), audio_features.size(1), 1, 1, 1).repeat(1, 1, 4, 7, + 12) + + final_out = torch.cat((video_features, audio_features), 1) + + # print(final_out.size()) + + return self.visual_model.decoder(final_out, y1, y2, y3) + + +class VideoAudioSaliencyModel(nn.Module): + def __init__(self, + use_transformer=False, + transformer_in_channel=32, + num_encoder_layers=3, + nhead=4, + use_upsample=True, + num_hier=3, + num_clips=32 + ): + super(VideoAudioSaliencyModel, self).__init__() + self.use_transformer = use_transformer + self.visual_model = VideoSaliencyModel( + transformer_in_channel=transformer_in_channel, + nhead=nhead, + use_upsample=use_upsample, + num_hier=num_hier, + num_clips=num_clips + ) + + if self.use_transformer: + self.conv_in_1x1 = nn.Conv3d(in_channels=1024, out_channels=transformer_in_channel, kernel_size=1, stride=1, + bias=True) + self.conv_out_1x1 = nn.Conv3d(in_channels=32, out_channels=1024, kernel_size=1, stride=1, bias=True) + self.transformer = transformer.Transformer( + 4 * 7 * 12, + hidden_size=4 * 7 * 12, + nhead=nhead, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=-1, + max_len=transformer_in_channel, + ) + + self.audionet = soundnet.SoundNet() + self.audionet.load_state_dict(torch.load(os.path.join("gazenet", "models", "saliency_prediction", "avinet", "checkpoints", 'soundnet8_final.pth'))) + print("Loaded SoundNet Weights") + for param in self.audionet.parameters(): + param.requires_grad = True + + self.maxpool = nn.MaxPool3d((4, 1, 1), stride=(2, 1, 2), padding=(0, 0, 0)) + self.bilinear = nn.Bilinear(42, 3, 4 * 7 * 12) + + def forward(self, x, audio): + audio = self.audionet(audio) + [y0, y1, y2, y3] = self.visual_model.backbone(x) + y0 = self.maxpool(y0) + fused_out = self.bilinear(y0.flatten(2), audio.flatten(2)) + fused_out = fused_out.view(fused_out.size(0), fused_out.size(1), 4, 7, 12) + + if self.use_transformer: + fused_out = self.conv_in_1x1(fused_out) + fused_out = fused_out.flatten(2) + fused_out = fused_out.permute((1, 0, 2)) + # print("fused_out", fused_out.shape) + fused_out = self.transformer(fused_out, -1) + fused_out = fused_out.permute((1, 0, 2)) + fused_out = fused_out.view(fused_out.size(0), fused_out.size(1), 4, 7, 12) + fused_out = self.conv_out_1x1(fused_out) + + return self.visual_model.decoder(fused_out, y1, y2, y3) + + +class DecoderConvUp(nn.Module): + def __init__(self): + super(DecoderConvUp, self).__init__() + self.upsampling = nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear') + self.convtsp1 = nn.Sequential( + nn.Conv3d(1024, 832, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp2 = nn.Sequential( + nn.Conv3d(832, 480, kernel_size=(3, 3, 3), stride=(3, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp3 = nn.Sequential( + nn.Conv3d(480, 192, kernel_size=(5, 3, 3), stride=(5, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp4 = nn.Sequential( + nn.Conv3d(192, 64, kernel_size=(5, 3, 3), stride=(5, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling, # 112 x 192 + + nn.Conv3d(64, 32, kernel_size=(2, 3, 3), stride=(2, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling, # 224 x 384 + + # 4 time dimension + nn.Conv3d(32, 32, kernel_size=(2, 1, 1), stride=(2, 1, 1), bias=False), + nn.ReLU(), + nn.Conv3d(32, 1, kernel_size=1, stride=1, bias=True), + nn.Sigmoid(), + ) + + def forward(self, y0, y1, y2, y3): + z = self.convtsp1(y0) + # print('convtsp1', z.shape) + + z = torch.cat((z, y1), 2) + # print('cat_convtsp1', z.shape) + + z = self.convtsp2(z) + # print('convtsp2', z.shape) + + z = torch.cat((z, y2), 2) + # print('cat_convtsp2', z.shape) + + z = self.convtsp3(z) + # print('convtsp3', z.shape) + + z = torch.cat((z, y3), 2) + # print("cat_convtsp3", z.shape) + + z = self.convtsp4(z) + # print('convtsp4', z.shape) + + z = z.view(z.size(0), z.size(3), z.size(4)) + # print('output', z.shape) + + return z + + +class DecoderConvUp16(nn.Module): + def __init__(self): + super(DecoderConvUp16, self).__init__() + self.upsampling = nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear') + self.convtsp1 = nn.Sequential( + nn.Conv3d(1024, 832, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp2 = nn.Sequential( + nn.Conv3d(832, 480, kernel_size=(3, 3, 3), stride=(3, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp3 = nn.Sequential( + nn.Conv3d(480, 192, kernel_size=(5, 3, 3), stride=(5, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp4 = nn.Sequential( + nn.Conv3d(192, 64, kernel_size=(5, 3, 3), stride=(5, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling, # 112 x 192 + + nn.Conv3d(64, 32, kernel_size=(2, 3, 3), stride=(2, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling, # 224 x 384 + + # 4 time dimension + nn.Conv3d(32, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=True), + # nn.ReLU(), + # nn.Conv3d(32, 1, kernel_size=1, stride=1, bias=True), + nn.Sigmoid(), + ) + + def forward(self, y0, y1, y2, y3): + z = self.convtsp1(y0) + # print('convtsp1', z.shape) + + z = torch.cat((z, y1), 2) + # print('cat_convtsp1', z.shape) + + z = self.convtsp2(z) + # print('convtsp2', z.shape) + + z = torch.cat((z, y2), 2) + # print('cat_convtsp2', z.shape) + + z = self.convtsp3(z) + # print('convtsp3', z.shape) + + z = torch.cat((z, y3), 2) + # print("cat_convtsp3", z.shape) + + z = self.convtsp4(z) + # print('convtsp4', z.shape) + + z = z.view(z.size(0), z.size(3), z.size(4)) + # print('output', z.shape) + + return z + + +class DecoderConvUp8(nn.Module): + def __init__(self): + super(DecoderConvUp8, self).__init__() + self.upsampling = nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear') + self.convtsp1 = nn.Sequential( + nn.Conv3d(1024, 832, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp2 = nn.Sequential( + nn.Conv3d(832, 480, kernel_size=(3, 3, 3), stride=(3, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp3 = nn.Sequential( + nn.Conv3d(480, 192, kernel_size=(5, 3, 3), stride=(5, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp4 = nn.Sequential( + nn.Conv3d(192, 64, kernel_size=(5, 3, 3), stride=(5, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling, # 112 x 192 + + nn.Conv3d(64, 32, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling, # 224 x 384 + + # 4 time dimension + nn.Conv3d(32, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=True), + # nn.ReLU(), + # nn.Conv3d(32, 1, kernel_size=1, stride=1, bias=True), + nn.Sigmoid(), + ) + + def forward(self, y0, y1, y2, y3): + z = self.convtsp1(y0) + # print('convtsp1', z.shape) + + z = torch.cat((z, y1), 2) + # print('cat_convtsp1', z.shape) + + z = self.convtsp2(z) + # print('convtsp2', z.shape) + + z = torch.cat((z, y2), 2) + # print('cat_convtsp2', z.shape) + + z = self.convtsp3(z) + # print('convtsp3', z.shape) + + z = torch.cat((z, y3), 2) + # print("cat_convtsp3", z.shape) + + z = self.convtsp4(z) + # print('convtsp4', z.shape) + + z = z.view(z.size(0), z.size(3), z.size(4)) + # print('output', z.shape) + + return z + + +class DecoderConvUp48(nn.Module): + def __init__(self): + super(DecoderConvUp48, self).__init__() + self.upsampling = nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear') + self.convtsp1 = nn.Sequential( + nn.Conv3d(1024, 832, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp2 = nn.Sequential( + nn.Conv3d(832, 480, kernel_size=(3, 3, 3), stride=(3, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp3 = nn.Sequential( + nn.Conv3d(480, 192, kernel_size=(5, 3, 3), stride=(5, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp4 = nn.Sequential( + nn.Conv3d(192, 64, kernel_size=(5, 3, 3), stride=(5, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling, # 112 x 192 + + nn.Conv3d(64, 32, kernel_size=(2, 3, 3), stride=(2, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling, # 224 x 384 + + # 4 time dimension + nn.Conv3d(32, 32, kernel_size=(3, 1, 1), stride=(3, 1, 1), bias=True), + nn.ReLU(), + nn.Conv3d(32, 1, kernel_size=1, stride=1, bias=True), + nn.Sigmoid(), + ) + + def forward(self, y0, y1, y2, y3): + # print(y0.shape) + z = self.convtsp1(y0) + # print('convtsp1', z.shape) + + z = torch.cat((z, y1), 2) + # print('cat_convtsp1', z.shape) + + z = self.convtsp2(z) + # print('convtsp2', z.shape) + + z = torch.cat((z, y2), 2) + # print('cat_convtsp2', z.shape) + + z = self.convtsp3(z) + # print('convtsp3', z.shape) + + z = torch.cat((z, y3), 2) + # print("cat_convtsp3", z.shape) + + z = self.convtsp4(z) + # print('convtsp4', z.shape) + + z = z.view(z.size(0), z.size(3), z.size(4)) + # print('output', z.shape) + + return z + + +class DecoderConvUpNoHier(nn.Module): + def __init__(self): + super(DecoderConvUpNoHier, self).__init__() + self.upsampling = nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear') + self.convtsp1 = nn.Sequential( + nn.Conv3d(1024, 832, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp2 = nn.Sequential( + nn.Conv3d(832, 480, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp3 = nn.Sequential( + nn.Conv3d(480, 192, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp4 = nn.Sequential( + nn.Conv3d(192, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling, # 112 x 192 + + nn.Conv3d(64, 32, kernel_size=(2, 3, 3), stride=(2, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling, # 224 x 384 + + # 4 time dimension + nn.Conv3d(32, 32, kernel_size=(2, 1, 1), stride=(2, 1, 1), bias=False), + nn.ReLU(), + nn.Conv3d(32, 1, kernel_size=1, stride=1, bias=True), + nn.Sigmoid(), + ) + + def forward(self, y0): + z = self.convtsp1(y0) + # print('convtsp1', z.shape) + + # z = torch.cat((z,y1), 2) + # print('cat_convtsp1', z.shape) + + z = self.convtsp2(z) + # print('convtsp2', z.shape) + + # z = torch.cat((z,y2), 2) + # print('cat_convtsp2', z.shape) + + z = self.convtsp3(z) + # print('convtsp3', z.shape) + + # z = torch.cat((z,y3), 2) + # print("cat_convtsp3", z.shape) + + z = self.convtsp4(z) + # print('convtsp4', z.shape) + + z = z.view(z.size(0), z.size(3), z.size(4)) + # print('output', z.shape) + + return z + + +class DecoderConvUp1Hier(nn.Module): + def __init__(self): + super(DecoderConvUp1Hier, self).__init__() + self.upsampling = nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear') + self.convtsp1 = nn.Sequential( + nn.Conv3d(1024, 832, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp2 = nn.Sequential( + nn.Conv3d(832, 480, kernel_size=(3, 3, 3), stride=(3, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp3 = nn.Sequential( + nn.Conv3d(480, 192, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp4 = nn.Sequential( + nn.Conv3d(192, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling, # 112 x 192 + + nn.Conv3d(64, 32, kernel_size=(2, 3, 3), stride=(2, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling, # 224 x 384 + + # 4 time dimension + nn.Conv3d(32, 32, kernel_size=(2, 1, 1), stride=(2, 1, 1), bias=False), + nn.ReLU(), + nn.Conv3d(32, 1, kernel_size=1, stride=1, bias=True), + nn.Sigmoid(), + ) + + def forward(self, y0, y1): + z = self.convtsp1(y0) + # print('convtsp1', z.shape, y1.shape) + + z = torch.cat((z, y1), 2) + # print('cat_convtsp1', z.shape) + + z = self.convtsp2(z) + # print('convtsp2', z.shape) + + # z = torch.cat((z,y2), 2) + # print('cat_convtsp2', z.shape) + + z = self.convtsp3(z) + # print('convtsp3', z.shape) + + # z = torch.cat((z,y3), 2) + # print("cat_convtsp3", z.shape) + + z = self.convtsp4(z) + # print('convtsp4', z.shape) + + z = z.view(z.size(0), z.size(3), z.size(4)) + # print('output', z.shape) + + return z + + +class DecoderConvUp2Hier(nn.Module): + def __init__(self): + super(DecoderConvUp2Hier, self).__init__() + self.upsampling = nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear') + self.convtsp1 = nn.Sequential( + nn.Conv3d(1024, 832, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp2 = nn.Sequential( + nn.Conv3d(832, 480, kernel_size=(3, 3, 3), stride=(3, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp3 = nn.Sequential( + nn.Conv3d(480, 192, kernel_size=(5, 3, 3), stride=(5, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling + ) + self.convtsp4 = nn.Sequential( + nn.Conv3d(192, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling, # 112 x 192 + + nn.Conv3d(64, 32, kernel_size=(2, 3, 3), stride=(2, 1, 1), padding=(0, 1, 1), bias=False), + nn.ReLU(), + self.upsampling, # 224 x 384 + + # 4 time dimension + nn.Conv3d(32, 32, kernel_size=(2, 1, 1), stride=(2, 1, 1), bias=False), + nn.ReLU(), + nn.Conv3d(32, 1, kernel_size=1, stride=1, bias=True), + nn.Sigmoid(), + ) + + def forward(self, y0, y1, y2): + z = self.convtsp1(y0) + # print('convtsp1', z.shape) + + z = torch.cat((z, y1), 2) + # print('cat_convtsp1', z.shape) + + z = self.convtsp2(z) + # print('convtsp2', z.shape) + + z = torch.cat((z, y2), 2) + # print('cat_convtsp2', z.shape) + + z = self.convtsp3(z) + # print('convtsp3', z.shape) + + # z = torch.cat((z,y3), 2) + # print("cat_convtsp3", z.shape) + + z = self.convtsp4(z) + # print('convtsp4', z.shape) + + z = z.view(z.size(0), z.size(3), z.size(4)) + # print('output', z.shape) + + return z + + +class BackBoneS3D(nn.Module): + def __init__(self): + super(BackBoneS3D, self).__init__() + + self.base1 = nn.Sequential( + conv3d.SepConv3d(3, 64, kernel_size=7, stride=2, padding=3), + nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), + conv3d.BasicConv3d(64, 64, kernel_size=1, stride=1), + conv3d.SepConv3d(64, 192, kernel_size=3, stride=1, padding=1), + ) + self.maxp2 = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) + self.base2 = nn.Sequential( + conv3d.Mixed_3b(), + conv3d.Mixed_3c(), + ) + self.maxp3 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)) + self.base3 = nn.Sequential( + conv3d.Mixed_4b(), + conv3d.Mixed_4c(), + conv3d.Mixed_4d(), + conv3d.Mixed_4e(), + conv3d.Mixed_4f(), + ) + self.maxt4 = nn.MaxPool3d(kernel_size=(2, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + self.maxp4 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=(0, 0, 0)) + self.base4 = nn.Sequential( + conv3d.Mixed_5b(), + conv3d.Mixed_5c(), + ) + + def forward(self, x): + # print('input', x.shape) + y3 = self.base1(x) + # print('base1', y3.shape) + + y = self.maxp2(y3) + # print('maxp2', y.shape) + + y2 = self.base2(y) + # print('base2', y2.shape) + + y = self.maxp3(y2) + # print('maxp3', y.shape) + + y1 = self.base3(y) + # print('base3', y1.shape) + + y = self.maxt4(y1) + y = self.maxp4(y) + # print('maxt4p4', y.shape) + + y0 = self.base4(y) + + return [y0, y1, y2, y3] + diff --git a/gazenet/models/saliency_prediction/dave/__init__.py b/gazenet/models/saliency_prediction/dave/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/saliency_prediction/dave/checkpoints/pretrained_dave_orig/download_model.sh b/gazenet/models/saliency_prediction/dave/checkpoints/pretrained_dave_orig/download_model.sh new file mode 100644 index 0000000..e904992 --- /dev/null +++ b/gazenet/models/saliency_prediction/dave/checkpoints/pretrained_dave_orig/download_model.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +wget --load-cookies /tmp/cookies.txt "https://drive.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://drive.google.com/uc?export=download&id=1hVf2PKp9UQNYMeG-tyT__0qJDHkj79sV' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1hVf2PKp9UQNYMeG-tyT__0qJDHkj79sV" -O model.pth.tar && rm -rf /tmp/cookies.txt diff --git a/gazenet/models/saliency_prediction/dave/generator.py b/gazenet/models/saliency_prediction/dave/generator.py new file mode 100644 index 0000000..594a562 --- /dev/null +++ b/gazenet/models/saliency_prediction/dave/generator.py @@ -0,0 +1,62 @@ +import torch +import torchvision.transforms.functional as F +import numpy as np +import librosa as sf +from PIL import Image +import cv2 + + +def create_data_packet(in_data, frame_number, frames_len=16): + in_data = np.array(in_data) + n_frame = in_data.shape[0] + # if the frame number is larger, we just use the last sound one heard + frame_number = min(frame_number, n_frame) + starting_frame = frame_number - frames_len + 1 + # ensure we do not have any negative video_frames_list + starting_frame = max(0, starting_frame) + data_pack = in_data[starting_frame:frame_number+1, :, :] + n_pack = data_pack.shape[0] + + if n_pack < frames_len: + nsh = frames_len - n_pack + data_pack = np.concatenate((np.tile(data_pack[0,:,:], (nsh, 1, 1)), data_pack), axis=0) + + assert data_pack.shape[0] == frames_len + + data_pack = np.tile(data_pack, (3, 1, 1, 1)) + + return data_pack, frame_number + + +def get_wav_features(features, frame_number, frames_len=16): + + audio_data, valid_frame_number = create_data_packet(features, frame_number, frames_len=frames_len) + return torch.from_numpy(audio_data).float(), valid_frame_number + + +def load_video_frames(frames_list, last_frame_idx, valid_frame_idx, img_mean, img_std, img_width, img_height, frames_len=16): + # load video video_frames_list, process them and return a suitable tensor + frame_number = min(last_frame_idx, valid_frame_idx) + start_frame_number = frame_number - frames_len + 1 + start_frame_number = max(0, start_frame_number) + frames_list_idx = [f for f in range(start_frame_number, frame_number)] + if len(frames_list_idx) < frames_len: + nsh = frames_len - len(frames_list_idx) + frames_list_idx = np.concatenate((np.tile(frames_list_idx[0], (nsh)), frames_list_idx), axis=0) + frames = [] + for i in range(len(frames_list_idx)): + # TODO (fabawi): loading the first frame on failure is not ideal. Find a better way + try: + img = cv2.resize(frames_list[frames_list_idx[i]].copy(), (img_width, img_height)) + except: + try: + img = cv2.resize(frames_list[frames_list_idx[0]].copy(), (img_width, img_height)) + except: + img = np.zeros((img_height, img_width, 3), dtype=np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + # img = img.convert('RGB') + img = F.to_tensor(img) + img = F.normalize(img, img_mean, img_std) + frames.append(img) + frames = torch.stack(frames, dim=0) + return frames.permute([1, 0, 2, 3]) diff --git a/gazenet/models/saliency_prediction/dave/infer.py b/gazenet/models/saliency_prediction/dave/infer.py new file mode 100644 index 0000000..87c3d01 --- /dev/null +++ b/gazenet/models/saliency_prediction/dave/infer.py @@ -0,0 +1,161 @@ +# +# DAVE: A Deep Audio-Visual Embedding for Dynamic Saliency Prediction +# https://arxiv.org/abs/1905.10693 +# https://hrtavakoli.github.io/DAVE/ +# +# Copyright by Hamed Rezazadegan Tavakoli +# + +import re +import os + +import torch +import numpy as np + +from gazenet.utils.registrar import * +from gazenet.models.saliency_prediction.dave.generator import get_wav_features, load_video_frames +from gazenet.models.saliency_prediction.dave.model import DAVE +from gazenet.utils.sample_processors import InferenceSampleProcessor + + +MODEL_PATHS = { + "dave": os.path.join("gazenet", "models", "saliency_prediction", "dave", "checkpoints", "pretrained_dave_orig", "model.pth.tar")} + +INP_IMG_WIDTH = 320 +INP_IMG_HEIGHT = 256 +# TRG_IMG_WIDTH = 40 +# TRG_IMG_HEIGHT = 32 +INP_IMG_MEAN = (110.63666788 / 255.0, 103.16065604 / 255.0, 96.29023126 / 255.0) +INP_IMG_STD = (38.7568578 / 255.0, 37.88248729 / 255.0, 40.02898126 / 255.0) +FRAMES_LEN = 16 + + +@InferenceRegistrar.register +class DAVEInference(InferenceSampleProcessor): + + def __init__(self, weights_file=MODEL_PATHS['dave'], w_size=16, + frames_len=FRAMES_LEN, + inp_img_width=INP_IMG_WIDTH, inp_img_height=INP_IMG_HEIGHT, inp_img_mean=INP_IMG_MEAN, inp_img_std=INP_IMG_STD, + device="cuda:0", width=None, height=None, **kwargs): + super().__init__(width=width, height=height, w_size=w_size, **kwargs) + self.short_name = "dave" + self._device = device + + self.frames_len = frames_len + self.inp_img_width = inp_img_width + self.inp_img_height = inp_img_height + self.inp_img_mean = inp_img_mean + self.inp_img_std = inp_img_std + + # load the model + self.model = DAVE(frames_len=frames_len) + self.model.load_state_dict(self._load_state_dict_(weights_file, device), strict=True) + print("DAVE model loaded from", weights_file) + self.model = self.model.to(device) + self.model.eval() + + @staticmethod + def _load_state_dict_(filepath, device): + if os.path.isfile(filepath): + # print("=> loading checkpoint '{}'".format(filepath)) + checkpoint = torch.load(filepath, map_location=torch.device(device)) + + pattern = re.compile(r'module+\.*') + state_dict = checkpoint['state_dict'] + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = re.sub('module.', '', key) + state_dict[new_key] = state_dict[key] + del state_dict[key] + return state_dict + + def infer_frame(self, grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, info_list, properties_list, + video_frames_list, audio_features=None, valid_audio_frames_len=None, source_frames_idxs=None, **kwargs): + if valid_audio_frames_len is None: + valid_audio_frames_len = self.frames_len + + if audio_features is not None: + audio_data = audio_features.to(self._device) + audio_data = torch.unsqueeze(audio_data, 0) + else: + audio_data = None + frames_idxs = range(len(grouped_video_frames_list)) if source_frames_idxs is None else source_frames_idxs + for f_idx, frame_id in enumerate(frames_idxs): + info = {"frame_detections_" + self.short_name: { + "saliency_maps": [], "video_saliency_maps": [], "audio_saliency_maps": [], # detected + # "audio_features": [] # processed (this is inefficient since the features are the same for all frames) + }} + video_frames_tensor = load_video_frames(video_frames_list[:frame_id+1], + frame_id+1, + valid_audio_frames_len, + img_width=self.inp_img_width, img_height=self.inp_img_height, + img_mean=self.inp_img_mean, img_std=self.inp_img_std, + frames_len=self.frames_len) + video_frames = video_frames_tensor.to(self._device) + video_frames = torch.unsqueeze(video_frames, 0) + final_prediction, video_prediction, audio_prediction = self.model(video_frames, audio_data, + return_latent_streams=True) + # get the visual and auditory feature maps + # for prediction, prediction_name in zip([final_prediction, video_prediction, audio_prediction], + # ["saliency_maps", "video_saliency_maps", "audio_saliency_maps"]): + for prediction, prediction_name in zip([final_prediction], ["saliency_maps"]): + saliency = prediction.cpu().data.numpy() + saliency = np.squeeze(saliency) + saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min()) + info["frame_detections_" + self.short_name][prediction_name].append((saliency, -1)) + info_list[frame_id].update(**info) + + kept_data = self._keep_extracted_frames_data(source_frames_idxs, grabbed_video_list, grouped_video_frames_list, + grabbed_audio_list, audio_frames_list, info_list, properties_list) + return kept_data + + def preprocess_frames(self, video_frames_list=None, audio_features=None, **kwargs): + features = super().preprocess_frames(**kwargs) + pad = features["preproc_pad_len"] + lim = features["preproc_lim_len"] + if video_frames_list is not None: + video_frames_list = list(video_frames_list) + features["video_frames_list"] = video_frames_list[:lim] + [video_frames_list[lim]] * pad + if audio_features and audio_features is not None: + features["audio_features"], features["valid_audio_frames_len"] = \ + get_wav_features(list(audio_features), self.frames_len, frames_len=self.frames_len) + return features + + def annotate_frame(self, input_data, plotter, + show_det_saliency_map=True, + enable_transform_overlays=True, + color_map=None, + **kwargs): + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties = input_data + + properties = {**properties, "show_det_saliency_map": (show_det_saliency_map, "toggle", (True, False))} + + grouped_video_frames = {**grouped_video_frames, + "PLOT": grouped_video_frames["PLOT"] + [["det_source_" + self.short_name, + "det_transformed_" + self.short_name]], + "det_source_" + self.short_name: grouped_video_frames["captured"], + "det_transformed_" + self.short_name: grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"])} + # grouped_video_frames["det_transformed_video_" + self.short_name] = grouped_video_frames["det_transformed_" + self.short_name].copy() + # grouped_video_frames["det_transformed_audio_" + self.short_name] = grouped_video_frames["det_transformed_" + self.short_name].copy() + + # for saliency_map_name, frame_name in zip(["saliency_maps", "video_saliency_maps", "audio_saliency_maps"], + # ["", "video_", "audio_"]): + for saliency_map_name, frame_name in zip(["saliency_maps"],[""]): + if grabbed_video: + if show_det_saliency_map: + saliency_map = info["frame_detections_" + self.short_name][saliency_map_name][0][0] + frame_transformed = plotter.plot_color_map(np.uint8(255 * saliency_map), color_map=color_map) + if enable_transform_overlays: + frame_transformed = plotter.plot_alpha_overlay(grouped_video_frames["det_transformed_" + + frame_name + self.short_name], + frame_transformed, alpha=0.4) + else: + frame_transformed = plotter.resize(frame_transformed, + height=grouped_video_frames["det_transformed_" + frame_name + + self.short_name].shape[0], + width=grouped_video_frames["det_transformed_" + frame_name + + self.short_name].shape[1]) + grouped_video_frames["det_transformed_" + frame_name + self.short_name] = frame_transformed + return grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties diff --git a/gazenet/models/saliency_prediction/dave/model.py b/gazenet/models/saliency_prediction/dave/model.py new file mode 100644 index 0000000..2a81f17 --- /dev/null +++ b/gazenet/models/saliency_prediction/dave/model.py @@ -0,0 +1,86 @@ +# +# DAVE: A Deep Audio-Visual Embedding for Dynamic Saliency Prediction +# https://arxiv.org/abs/1905.10693 +# https://hrtavakoli.github.io/DAVE/ +# +# Copyright by Hamed Rezazadegan Tavakoli +# + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from gazenet.models.shared_components.resnet3d.model import resnet18 + + +class ScaleUp(nn.Module): + + def __init__(self, in_size, out_size): + super(ScaleUp, self).__init__() + + self.combine = nn.Conv2d(in_size, out_size, kernel_size=3, stride=1, padding=1) + self.bn = nn.BatchNorm2d(out_size) + + self._weights_init() + + def _weights_init(self): + + nn.init.kaiming_normal_(self.combine.weight) + nn.init.constant_(self.combine.bias, 0.0) + + def forward(self, inputs): + output = F.interpolate(inputs, scale_factor=2, mode='bilinear', align_corners=True) + output = self.combine(output) + output = F.relu(output, inplace=True) + return output + + +class DAVE(nn.Module): + + def __init__(self, frames_len=16, num_classes_video=400, num_classes_audio=12): + super(DAVE, self).__init__() + + self.audio_branch = resnet18(shortcut_type='A', sample_size=64, sample_duration=frames_len, num_classes=num_classes_audio, last_fc=False, last_pool=True) + self.video_branch = resnet18(shortcut_type='A', sample_size=112, sample_duration=frames_len, num_classes=num_classes_video, last_fc=False, last_pool=False) + self.upscale1 = ScaleUp(512, 512) + self.upscale2 = ScaleUp(512, 128) + self.combinedEmbedding = nn.Conv2d(1024, 512, kernel_size=1) + self.saliency = nn.Conv2d(128, 1, kernel_size=1) + self._weights_init() + + def _weights_init(self): + + nn.init.kaiming_normal_(self.saliency.weight) + nn.init.constant_(self.saliency.bias, 0.0) + + nn.init.kaiming_normal_(self.combinedEmbedding.weight) + nn.init.constant_(self.combinedEmbedding.bias, 0.0) + + def forward(self, v, a=None, return_latent_streams=False): + # V video video_frames_list of 3x16x256x320 + # A audio video_frames_list of 3x16x64x64 + # return a map of 32x40 + + xV1 = self.video_branch(v) + if a is not None: + xA1 = self.audio_branch(a) + xA1 = xA1.expand_as(xV1) + else: + # replace audio branch with zeros + xA1 = torch.zeros_like(xV1) + xC = torch.cat((xV1, xA1), dim=1) + xC = torch.squeeze(xC, dim=2) + x = self.combinedEmbedding(xC) + x = F.relu(x, inplace=True) + + x = torch.squeeze(x, dim=2) + x = self.upscale1(x) + x = self.upscale2(x) + sal = self.saliency(x) + sal = F.relu(sal, inplace=True) + if return_latent_streams: + return sal, xV1, xA1 + else: + return sal + diff --git a/gazenet/models/saliency_prediction/gasp/__init__.py b/gazenet/models/saliency_prediction/gasp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_gaspdamencgmuconv/GASPDAMEncGMUConv/c9dd4df04e87469ea9914ea67af764b6/download_model.sh b/gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_gaspdamencgmuconv/GASPDAMEncGMUConv/c9dd4df04e87469ea9914ea67af764b6/download_model.sh new file mode 100644 index 0000000..3f776c6 --- /dev/null +++ b/gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_gaspdamencgmuconv/GASPDAMEncGMUConv/c9dd4df04e87469ea9914ea67af764b6/download_model.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +wget -O last_model.pt https://www2.informatik.uni-hamburg.de/WTM/corpora/GASP/gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_gaspdamencgmuconv/GASPDAMEncGMUConv/c9dd4df04e87469ea9914ea67af764b6/last_model.pt diff --git a/gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_seqeuncegaspdamencgmualstmconv/SequenceGASPDAMEncGMUALSTMConv/53ea3d5639d647fc86e3974d6e1d1719/download_model.sh b/gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_seqeuncegaspdamencgmualstmconv/SequenceGASPDAMEncGMUALSTMConv/53ea3d5639d647fc86e3974d6e1d1719/download_model.sh new file mode 100644 index 0000000..ab9b688 --- /dev/null +++ b/gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_seqeuncegaspdamencgmualstmconv/SequenceGASPDAMEncGMUALSTMConv/53ea3d5639d647fc86e3974d6e1d1719/download_model.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +wget -O last_model.pt https://www2.informatik.uni-hamburg.de/WTM/corpora/GASP/gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencgmualstmconv/SequenceGASPDAMEncGMUALSTMConv/53ea3d5639d647fc86e3974d6e1d1719/last_model.pt diff --git a/gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv/SequenceGASPDAMEncALSTMGMUConv/1e49c1fe600c43f09291c4e6710c769c/download_model.sh b/gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv/SequenceGASPDAMEncALSTMGMUConv/1e49c1fe600c43f09291c4e6710c769c/download_model.sh new file mode 100644 index 0000000..227633e --- /dev/null +++ b/gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv/SequenceGASPDAMEncALSTMGMUConv/1e49c1fe600c43f09291c4e6710c769c/download_model.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +wget -O last_model.pt https://www2.informatik.uni-hamburg.de/WTM/corpora/GASP/gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv/SequenceGASPDAMEncALSTMGMUConv/1e49c1fe600c43f09291c4e6710c769c/last_model.pt diff --git a/gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv/SequenceGASPDAMEncALSTMGMUConv/ba69921766274fe69363b9cd9b5d4b76/download_model.sh b/gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv/SequenceGASPDAMEncALSTMGMUConv/ba69921766274fe69363b9cd9b5d4b76/download_model.sh new file mode 100644 index 0000000..249f978 --- /dev/null +++ b/gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv/SequenceGASPDAMEncALSTMGMUConv/ba69921766274fe69363b9cd9b5d4b76/download_model.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +wget -O last_model.pt https://www2.informatik.uni-hamburg.de/WTM/corpora/GASP/gazenet/models/saliency_prediction/gasp/checkpoints/pretrained_sequencegaspdamencalstmgmuconv/SequenceGASPDAMEncALSTMGMUConv/ba69921766274fe69363b9cd9b5d4b76/last_model.pt diff --git a/gazenet/models/saliency_prediction/gasp/generator.py b/gazenet/models/saliency_prediction/gasp/generator.py new file mode 100644 index 0000000..41413b8 --- /dev/null +++ b/gazenet/models/saliency_prediction/gasp/generator.py @@ -0,0 +1,219 @@ +import os +import random + +import cv2 +import numpy as np +import pandas as pd +import torch +import torchvision.transforms.functional as F +from torch.utils.data import Dataset, DataLoader +import torchvision + +try: # leverages intel IPP for accelerated image loading + from accimage import Image + torchvision.set_image_backend('accimage') +except: + from PIL import Image + +INP_IMG_MODE = "RGB" +TRG_IMG_MODE = "L" +DEFAULT_IMG_SIDE = 10 # just a place-holder + + +def load_video_frames(grouped_frames_list, last_frame_idx, inp_img_names_list, img_mean, img_std, img_width, img_height, frames_len=16): + # load video video_frames_list, process them and return a suitable tensor + frame_number = last_frame_idx + start_frame_number = frame_number - frames_len # + 1 + start_frame_number = max(0, start_frame_number) + frames_list_idx = [f for f in range(start_frame_number, frame_number)] + if len(frames_list_idx) < frames_len: + nsh = frames_len - len(frames_list_idx) + frames_list_idx = np.concatenate((np.tile(frames_list_idx[0], (nsh)), frames_list_idx), axis=0) + frames = [[] for _ in inp_img_names_list] + + for i in range(len(frames_list_idx)): + for frame_name, frame in grouped_frames_list[i].items(): + if not frame_name in inp_img_names_list: + continue + else: + group_idx = inp_img_names_list.index(frame_name) + # TODO (fabawi): loading the first frame on failure is not ideal. Find a better way + try: + img = cv2.resize(frame.copy(), (img_width, img_height)) + except: + try: + img = cv2.resize(frame.copy(), (img_width, img_height)) + except: + img = np.zeros((img_height, img_width, 3), dtype=np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = F.to_tensor(img) + img = F.normalize(img, img_mean, img_std) + frames[group_idx].append(img) + frames = [torch.stack(frames[group_idx], dim=0) for group_idx in range(len(inp_img_names_list))] + frames = torch.cat(frames, 1) # .permute(1, 0, 2, 3) + if len(frames_list_idx) > 1: + return frames + else: + return frames[0, ...] + + +class UnNormalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensor): + """ + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + Returns: + Tensor: Normalized image. + """ + for t, m, s in zip(tensor, self.mean, self.std): + t.mul_(s).add_(m) + # The normalize code -> t.sub_(m).div_(s) + return tensor + + +class GASPDataset(Dataset): + """ + Reads the DataWriter-generated dataset as pytorch compatible objects. When sequence_len is not 1, + any non consecutive images will be zero padded. Loads images in the form + + """ + def __init__(self, csv_file, video_dir, + inp_img_names_list, gt_img_names_list, + inp_img_transform=None, gt_img_transform=None, + sequence_len=1, exhaustive=False, cache_sequence_pointers=False + ): + """ + Args: + csv_file (string): Path to the csv file with the sample descriptions + video_dir (string): Directory with all the frames separated into folders + inp_img_names_list (list): List of the input image prefixes + gt_img_names_list (list): List of the ground-truth images prefixes + inp_img_transform (callable, optional): Optional transform applied to the input image + gt_img_transform (callable, optional): Optional transform applied to the ground-truth image + sequence_len (int): Length of an image sequence. If None, the sequence dimension is removed + exhaustive (boolean): Return gt images for the entire sequence. Otherwise, only the last gt image + cache_sequence_pointers (boolean): Remember the last sample frame and resume. Otherwise, samples randomly + """ + self.samples = pd.read_csv(csv_file) + self.samples = self.samples[(self.samples["scene_type"] == "Social")].reset_index() + self.video_dir = video_dir + self.input_img_names_list = inp_img_names_list + self.gt_img_names_list = gt_img_names_list + self.input_img_transform = inp_img_transform + self.gt_img_transform = gt_img_transform + self.sequence_len = sequence_len + self.exhaustive = exhaustive + self.cache_sequence_pointers = cache_sequence_pointers + self.sequence_pointer_cache = {dataset: {} for dataset in self.samples.dataset.unique()} # {dataset_name: video_name: (curr_idx, min_idx, max_idx)} + + def len_frames(self): + pass + + def __len__(self): + return len(self.samples) + + def __getfilenames__(self, idx): + ds_name = self.samples.loc[idx, "dataset"] + vid_name = self.samples.loc[idx, "video_id"] + vid_path = os.path.join(self.video_dir, ds_name, vid_name) + + curr_idx, min_idx, max_idx = self.sequence_pointer_cache[ds_name].get(vid_name, (1, 1, len(os.listdir(vid_path)))) + if self.cache_sequence_pointers: + next_idx = curr_idx + self.sequence_len + self.sequence_pointer_cache[ds_name][vid_name] = (next_idx if next_idx < max_idx else 1, min_idx, max_idx) + else: + curr_idx = random.randint(min_idx, max_idx) + imgs_paths = [os.path.join(vid_path, str(frame_idx)) for frame_idx in range(curr_idx, curr_idx + self.sequence_len)] + + return imgs_paths + + def __getimgs__(self, idx): + all_input_imgs = [] + all_gt_imgs_dict = {gt_img_name: [] for gt_img_name in self.gt_img_names_list} + if self.exhaustive: + all_gt_imgs_dict.update(**{"seq_"+gt_img_name: [] for gt_img_name in self.gt_img_names_list}) + + imgs_paths = self.__getfilenames__(idx) + + prev_input_imgs = [None] * len(self.input_img_names_list) + for seq_img_idx, seq_img_path in enumerate(imgs_paths): + + input_imgs = None + for input_img_idx, input_img_name in enumerate(self.input_img_names_list): + try: + input_img = Image.open(os.path.join(seq_img_path, input_img_name + "_1.jpg")).convert(INP_IMG_MODE) + prev_input_imgs[input_img_idx] = input_img.copy() + except: + if prev_input_imgs[input_img_idx] is None: + input_img = Image.new(INP_IMG_MODE, (DEFAULT_IMG_SIDE, DEFAULT_IMG_SIDE)) + input_img.putpixel((0, 0), tuple([1] * len(input_img.getpixel((0, 0))))) # to avoid nan + else: + input_img = prev_input_imgs[input_img_idx].copy() + if self.input_img_transform: + input_img = self.input_img_transform(input_img) + input_img = input_img if torch.is_tensor(input_img) else torchvision.transforms.ToTensor()(input_img) + input_imgs = torch.cat([input_imgs, input_img]) if input_imgs is not None else input_img + all_input_imgs.append(input_imgs) + + for gt_img_name in self.gt_img_names_list: + + try: + gt_img = Image.open(os.path.join(seq_img_path, gt_img_name + "_1.jpg")).convert(TRG_IMG_MODE) + if self.exhaustive: + gt_img_seq = gt_img.copy() + else: + gt_img_seq = None + except: + if self.exhaustive: + gt_img_seq = Image.new(TRG_IMG_MODE, (DEFAULT_IMG_SIDE, DEFAULT_IMG_SIDE)) + gt_img_seq.putpixel((0, 0), 1) + if not all_gt_imgs_dict[gt_img_name] and (seq_img_idx == len(imgs_paths) - 1): + gt_img = Image.new(TRG_IMG_MODE, (DEFAULT_IMG_SIDE, DEFAULT_IMG_SIDE)) + gt_img.putpixel((0, 0), 1) # to avoid nan + else: + gt_img = None + elif not all_gt_imgs_dict[gt_img_name] and (seq_img_idx == len(imgs_paths) - 1): + gt_img_seq = None + gt_img = Image.new(TRG_IMG_MODE, (DEFAULT_IMG_SIDE, DEFAULT_IMG_SIDE)) + gt_img.putpixel((0, 0), 1) # to avoid nan + else: + gt_img_seq = None + gt_img = None + if gt_img is not None: + if self.gt_img_transform: + gt_img = self.gt_img_transform(gt_img) + gt_img = gt_img if torch.is_tensor(gt_img) else torchvision.transforms.ToTensor()(gt_img) + if gt_img_seq is not None: + if self.gt_img_transform: + gt_img_seq = self.gt_img_transform(gt_img_seq) + gt_img_seq = gt_img_seq if torch.is_tensor(gt_img_seq) else torchvision.transforms.ToTensor()(gt_img_seq) + + if self.exhaustive: + all_gt_imgs_dict["seq_" + gt_img_name].append(gt_img_seq) + if gt_img is not None: + all_gt_imgs_dict[gt_img_name] = [gt_img] + elif gt_img is not None: + all_gt_imgs_dict[gt_img_name] = [gt_img] + + all_input_imgs = torch.cat(all_input_imgs) if self.sequence_len == 1 else torch.stack(all_input_imgs) + for k_gt_imgs, v_gt_imgs in all_gt_imgs_dict.items(): + if k_gt_imgs.startswith("seq_"): + all_gt_imgs_dict[k_gt_imgs] = torch.stack(v_gt_imgs) + # all_gt_imgs_dict[k_gt_imgs] = torch.cat(v_gt_imgs) if self.sequence_len == 1 or not self.exhaustive else torch.stack(v_gt_imgs) + else: + all_gt_imgs_dict[k_gt_imgs] = torch.cat(v_gt_imgs) + + return all_input_imgs, all_gt_imgs_dict + + def __getanno__(self, idx): + raise NotImplementedError("Annotations not needed since the images are already transformed for all models. " + "Might be necessary for other tasks w/o multitask learning or other modality inputs.") + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + return self.__getimgs__(idx) diff --git a/gazenet/models/saliency_prediction/gasp/infer.py b/gazenet/models/saliency_prediction/gasp/infer.py new file mode 100644 index 0000000..71b7877 --- /dev/null +++ b/gazenet/models/saliency_prediction/gasp/infer.py @@ -0,0 +1,159 @@ +import re +import os + +import torch +import numpy as np + +from gazenet.utils.registrar import * +from gazenet.models.saliency_prediction.gasp.generator import load_video_frames + +from gazenet.utils.sample_processors import InferenceSampleProcessor + + +MODEL_PATHS = { + "seqdamgmualstm": os.path.join("gazenet", "models", "saliency_prediction", "gasp", "checkpoints", + "pretrained_seqeuncegaspdamencgmualstmconv", + "SequenceGASPDAMEncGMUALSTMConv", "53ea3d5639d647fc86e3974d6e1d1719", "last_model.pt"), + "seqdamalstmgmu": os.path.join("gazenet", "models", "saliency_prediction", "gasp", "checkpoints", + "pretrained_sequencegaspdamencalstmgmuconv", + "SequenceGASPDAMEncALSTMGMUConv", "1e49c1fe600c43f09291c4e6710c769c", "last_model.pt"), #10 + "seqdamalstmgmu_110nofer": os.path.join("gazenet", "models", "saliency_prediction", "gasp", "checkpoints", + "pretrained_sequencegaspdamencalstmgmuconv", + "SequenceGASPDAMEncALSTMGMUConv", "ba69921766274fe69363b9cd9b5d4b76", "last_model.pt"), #10 + "damgmu": os.path.join("gazenet", "models", "saliency_prediction", "gasp", "checkpoints", + "pretrained_gaspdamencgmuconv", + "GASPDAMEncGMUConv", "c9dd4df04e87469ea9914ea67af764b6", "last_model.pt") +} + +INP_IMG_WIDTH = 120 +INP_IMG_HEIGHT = 120 +TRG_IMG_WIDTH = INP_IMG_WIDTH//2 +TRG_IMG_HEIGHT = INP_IMG_HEIGHT//2 +INP_IMG_MEAN = (110.63666788 / 255.0, 103.16065604 / 255.0, 96.29023126 / 255.0) +INP_IMG_STD = (38.7568578 / 255.0, 37.88248729 / 255.0, 40.02898126 / 255.0) +FRAMES_LEN = 10 + + +@InferenceRegistrar.register +class GASPInference(InferenceSampleProcessor): + + def __init__(self, weights_file=MODEL_PATHS['seqdamalstmgmu'], model_name="SequenceGASPDAMEncALSTMGMUConv", w_size=10, + frames_len=FRAMES_LEN, trg_img_width=TRG_IMG_WIDTH, trg_img_height=TRG_IMG_HEIGHT, + inp_img_width=INP_IMG_WIDTH, inp_img_height=INP_IMG_HEIGHT, inp_img_mean=INP_IMG_MEAN, inp_img_std=INP_IMG_STD, + device="cuda:0", width=None, height=None, **kwargs): + super().__init__(width=width, height=height, w_size=w_size, **kwargs) + self.short_name = "gasp" + self._device = device + + self.frames_len = frames_len + self.trg_img_width = trg_img_width + self.trg_img_height = trg_img_height + self.inp_img_width = inp_img_width + self.inp_img_height = inp_img_height + self.inp_img_mean = inp_img_mean + self.inp_img_std = inp_img_std + # scan model registry + ModelRegistrar.scan() + + # load the model + kwargs.update(batch_size=1) + self.model = ModelRegistrar.registry[model_name](**kwargs) + if weights_file in MODEL_PATHS.keys(): + weights_file = MODEL_PATHS[weights_file] + self.model.load_state_dict(torch.load(weights_file)) + print("GASP model loaded from", weights_file) + self.model = self.model.to(device) + self.model.eval() + + def infer_frame(self, grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, info_list, properties_list, + compute_gate_scores=True, inp_img_names_list=None, source_frames_idxs=None, **kwargs): + + frames_idxs = range(len(grouped_video_frames_list)) if source_frames_idxs is None else source_frames_idxs + for f_idx, frame_id in enumerate(frames_idxs): + info = {"frame_detections_" + self.short_name: { + "saliency_maps": [], "gate_scores": [] + }} + + video_frames_tensor = load_video_frames(grouped_video_frames_list[:frame_id+1], + frame_id+1, + inp_img_names_list=inp_img_names_list, + img_width=self.inp_img_width, img_height=self.inp_img_height, + img_mean=self.inp_img_mean, img_std=self.inp_img_std, + frames_len=self.frames_len) + + video_frames = video_frames_tensor.to(self._device) + video_frames = torch.unsqueeze(video_frames, 0) + final_prediction = self.model(video_frames) + prediction = torch.sigmoid(final_prediction[0]) + saliency = prediction.cpu().data.numpy() + saliency = np.squeeze(saliency) + saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min()) + info["frame_detections_" + self.short_name]["saliency_maps"].append((saliency, -1)) + if final_prediction[2] is not None and compute_gate_scores: + gate_scores = final_prediction[2][1].cpu().data.numpy() + gate_scores = gate_scores.mean(axis=(2,3,4) if len(gate_scores.shape) == 5 else (2,3)) + gate_scores = np.squeeze(gate_scores) + info["frame_detections_" + self.short_name]["gate_scores"].append((gate_scores, -1)) + info_list[frame_id].update(**info) + + kept_data = self._keep_extracted_frames_data(source_frames_idxs, grabbed_video_list, grouped_video_frames_list, + grabbed_audio_list, audio_frames_list, info_list, properties_list) + return kept_data + + def preprocess_frames(self, previous_data, inp_img_names_list=None, **kwargs): + features = super().preprocess_frames(**kwargs) + features["inp_img_names_list"] = inp_img_names_list + # previous_maps = {} + # if previous_data is None: + # previous_data = ((None, features["grouped_video_frames_list"], None, None, features["info_list"]),) + # + # for result in previous_data: + # if isinstance(result, tuple): + # # keep the frames object and extract the id from the info + # # if keep_frames is None or result[keep_frames + # for f_idx, frame_data in enumerate(result[4]): + # for plot_name, plot in result[1][f_idx].items(): + # if plot_name != "PLOT" and (keep_plot_names is None or plot_name in keep_plot_names): + # if frame_data["frame_info"]["frame_id"] in previous_maps: + # if plot_name in previous_maps[frame_data["frame_info"]["frame_id"]]: + # previous_maps[frame_data["frame_info"]["frame_id"]][plot_name].append(plot) + # else: + # previous_maps[frame_data["frame_info"]["frame_id"]].update(**{plot_name: [plot]}) + # else: + # previous_maps[frame_data["frame_info"]["frame_id"]] = {plot_name: [plot]} + # features["previous_maps"] = previous_maps + return features + + def annotate_frame(self, input_data, plotter, + show_det_saliency_map=True, + enable_transform_overlays=True, + color_map=None, + **kwargs): + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties = input_data + + properties = {**properties, "show_det_saliency_map": (show_det_saliency_map, "toggle", (True, False))} + + grouped_video_frames = {**grouped_video_frames, + "PLOT": grouped_video_frames["PLOT"] + [["det_source_" + self.short_name, + "det_transformed_" + self.short_name]], + "det_source_" + self.short_name: grouped_video_frames["captured"], + "det_transformed_" + self.short_name: grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"])} + + for saliency_map_name, frame_name in zip(["saliency_maps"],[""]): + if grabbed_video: + if show_det_saliency_map: + saliency_map = info["frame_detections_" + self.short_name][saliency_map_name][0][0] + frame_transformed = plotter.plot_color_map(np.uint8(255 * saliency_map), color_map=color_map) + if enable_transform_overlays: + frame_transformed = plotter.plot_alpha_overlay(grouped_video_frames["det_transformed_" + + frame_name + self.short_name], + frame_transformed, alpha=0.4) + else: + frame_transformed = plotter.resize(frame_transformed, + height=grouped_video_frames["det_transformed_" + frame_name + + self.short_name].shape[0], + width=grouped_video_frames["det_transformed_" + frame_name + + self.short_name].shape[1]) + grouped_video_frames["det_transformed_" + frame_name + self.short_name] = frame_transformed + return grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties diff --git a/gazenet/models/saliency_prediction/gasp/model.py b/gazenet/models/saliency_prediction/gasp/model.py new file mode 100644 index 0000000..7b050b7 --- /dev/null +++ b/gazenet/models/saliency_prediction/gasp/model.py @@ -0,0 +1,1137 @@ +import os + +import torch +from torch import nn +import pytorch_lightning as pl +from pytorch_lightning.loggers import * +from torch.utils.data import DataLoader +from torchvision import transforms + +from gazenet.utils.registrar import * +from gazenet.models.shared_components.attentive_convlstm.model import AttentiveLSTM, SequenceAttentiveLSTM +from gazenet.models.shared_components.squeezeexcitation.model import SELayer +from gazenet.models.shared_components.gmu.model import GMUConv2d, RGMUConv2d +from gazenet.models.saliency_prediction.gasp.generator import * +from gazenet.models.saliency_prediction.losses import cross_entropy_loss, nss_score, cc_score + +LATENT_CONV_C = 32 + + +class DAMLayer(nn.Module): + def __init__(self, channel, reduction=2): + super(DAMLayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Linear(channel, channel // reduction, bias=False) + self.fc1_relu = nn.ReLU(inplace=True) + self.fc2 = nn.Linear(channel // reduction, channel, bias=False) + self.fc2_sig = nn.Sigmoid() + + self.conv = nn.Sequential(nn.Conv2d(channel, 256, kernel_size=1), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(256, 1, kernel_size=1)) + + def forward(self, x, detached=False): + if detached: + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = torch.nn.functional.linear(y, self.fc1.weight) + y = self.fc1_relu(y) + y = torch.nn.functional.linear(y, self.fc2.weight) + y = self.fc2_sig(y).view(b, c, 1, 1) + sal = x * y.expand_as(x) + sal = sal + else: + x = torch.log(1 / torch.nn.functional.softmax(x, dim=1)) + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc1(y) + y = self.fc1_relu(y) + y = self.fc2(y) + y = self.fc2_sig(y).view(b, c, 1, 1) + sal = x * y.expand_as(x) + sal = self.conv(sal) + return sal + + +# NOTE (fabawi): for all non-sequential models, the modality encoder is shared amongst modules; for sequential models, +# each modality has its own encoder, but shared across timesteps +class ModalityEncoder(nn.Module): + def __init__(self, in_channels, out_channels=LATENT_CONV_C, encoder="Conv"): + super().__init__() + self.encoder = encoder + + if encoder == "Deep": + self.enc = nn.Sequential( + + nn.Conv2d(in_channels, LATENT_CONV_C, kernel_size=3, stride=1, padding=1), + nn.Conv2d(LATENT_CONV_C, LATENT_CONV_C, kernel_size=3, stride=1, padding=1), + + # nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(LATENT_CONV_C, 64, kernel_size=3, stride=1, padding=1), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + + # nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), + nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + + # nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), + + # the decoder + nn.ConvTranspose2d(128, 128, kernel_size=3, stride=1, padding=1), + nn.ConvTranspose2d(128, 128, kernel_size=3, stride=1, padding=1), + nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1), + nn.ConvTranspose2d(64, LATENT_CONV_C, kernel_size=3, stride=1, padding=1), + nn.ConvTranspose2d(LATENT_CONV_C, out_channels, kernel_size=3, stride=1, padding=1), + ) + elif encoder == "Conv": + self.enc = nn.Sequential( + nn.Conv2d(in_channels, LATENT_CONV_C, kernel_size=3, padding=1), + nn.Conv2d(LATENT_CONV_C, 64, kernel_size=3, padding=1), + + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(64, 128, kernel_size=3, padding=1), + nn.Conv2d(128, 128, kernel_size=3, padding=1), + + nn.ConvTranspose2d(128, 128, kernel_size=3, stride=1, padding=1), + nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1), + nn.ConvTranspose2d(64, LATENT_CONV_C, kernel_size=3, stride=1, padding=1), + nn.ConvTranspose2d(LATENT_CONV_C, out_channels, kernel_size=3, stride=1, padding=1), + ) + + elif encoder == "MobileNet": + raise NotImplementedError("MobileNet not yet integrated") + + def forward(self, x): + x = self.enc(x) + return x + + +class GASPBase(pl.LightningModule): + + def __init__(self, learning_rate=0.00014, batch_size=8, num_workers=16, loss_weights=(0.5, 2, 1), dam_loss_weight=0.5, + in_channels=3, modalities=4, out_channels=1, sequence_len=1, + trg_img_width=60, trg_img_height=60, + inp_img_width=120, inp_img_height=120, + inp_img_mean=(110.63666788 / 255.0, 103.16065604 / 255.0, 96.29023126 / 255.0), + inp_img_std=(38.7568578 / 255.0, 37.88248729 / 255.0, 40.02898126 / 255.0), + train_dataset_properties=None, val_dataset_properties=None, test_dataset_properties=None, + val_store_image_samples=False): + super(GASPBase, self).__init__() + + self.trg_img_width = trg_img_width + self.trg_img_height = trg_img_height + self.inp_img_width = inp_img_width + self.inp_img_height = inp_img_height + self.inp_img_mean = inp_img_mean + self.inp_img_std = inp_img_std + + self.dam_loss_weight = dam_loss_weight + self.loss_weights = loss_weights + # CAREFUL (fabawi): the following loss_weights_named are signed and inverted + self.loss_weights_named = {"bce_loss": 1/loss_weights[0], + "cc": -1/loss_weights[1], + "nss": -1/loss_weights[2], + "dam_bce_loss": 1/self.dam_loss_weight, + "loss": 1} + self.learning_rate = learning_rate + self.batch_size = batch_size + self.num_workers = num_workers + + # model and dataset mode dependent + self.in_channels = in_channels + self.modalities = modalities + self.out_channels = out_channels + self.sequence_len = sequence_len + self.exhaustive = False + + self.train_dataset_properties = train_dataset_properties + self.val_dataset_properties = val_dataset_properties + self.test_dataset_properties = test_dataset_properties + + self.val_store_image_samples = val_store_image_samples + + def forward(self, modules): + raise NotImplementedError("This is the base GASP class and cannot be inherited directly") + + def loss(self, logits, y): + bce_loss = cross_entropy_loss(logits, y["transformed_salmap"], self.loss_weights[0]) + cc_loss = cc_score(logits, y["transformed_salmap"], self.loss_weights[1]) + nss_loss = nss_score(logits, y["transformed_fixmap"], self.loss_weights[2]) + # the nss and cc are losses here still, but renamed for name matching with the fixed logging values + return {"bce_loss": bce_loss, "cc": cc_loss, "nss": nss_loss, + "loss": sum((bce_loss, cc_loss, nss_loss))} + # "loss": bce_loss} + + def training_step(self, train_batch, batch_idx): + x, y = train_batch + logits = self.forward(x) + + if len(logits) > 1 and logits[1] is not None: + # calculate dam loss if model supports it + if self.exhaustive: + dam_logit_tgt = "seq_transformed_salmap" + else: + dam_logit_tgt = "transformed_salmap" + dam_logits = torch.nn.functional.binary_cross_entropy_with_logits(logits[1], y[dam_logit_tgt]) * self.dam_loss_weight + else: + dam_logits = torch.tensor(0) + sal_logits = logits[0] + losses = self.loss(sal_logits, y) + losses.update(dam_bce_loss=dam_logits) + logs = {f'train_{k}': (v * self.loss_weights_named[k])/self.batch_size for k, v in losses.items()} + return {'loss': losses['loss'] + dam_logits, 'log': logs} + + def validation_step(self, val_batch, batch_idx): + x, y = val_batch + logits = self.forward(x) + sal_logits = logits[0] + if self.val_store_image_samples: + self.log_val_images(x, logits, y) + losses = self.loss(sal_logits, y) + logs = {f'val_{k}': (v * self.loss_weights_named[k])/self.batch_size for k, v in losses.items()} + return logs + + # TODO (fabawi): if the training is too slow and you don't care about logs, remove this + def training_epoch_end(self, outputs): + avg_logs = {} + for log_key in outputs[0]['log'].keys(): + self.log(f'avg_{log_key}', torch.stack([output['log'][log_key] for output in outputs]).mean()) + # Uncomment the lines below for older lightning versions + #avg_logs[f'avg_{log_key}'] = torch.stack([output['log'][log_key] for output in outputs]).mean() + #return{'avg_train_loss': avg_logs['avg_train_loss'], 'log': avg_logs} + + def validation_epoch_end(self, outputs): + # called at the end of the validation epoch + # outputs is an array with what you returned in validation_step for each batch + # outputs = [{'loss': batch_0_loss}, {'loss': batch_1_loss}, ..., {'loss': batch_n_loss}] + avg_logs = {} + for log_key in outputs[0].keys(): + self.log(f'avg_{log_key}', torch.stack([output[log_key] for output in outputs]).mean()) + # Uncomment the lines below for older lightning versions + #avg_logs[f'avg_{log_key}'] = torch.stack([output[log_key] for output in outputs]).mean() + #return {'avg_val_loss': avg_logs['avg_val_loss'], 'log': avg_logs} + + def prepare_data(self): + + # dataset properties + if self.train_dataset_properties is None: + self.train_dataset_properties = {"csv_file": "datasets/processed/train.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + if self.val_dataset_properties is None: + self.val_dataset_properties = {"csv_file": "datasets/processed/validation.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + if self.test_dataset_properties is None: + self.test_dataset_properties = {"csv_file": "datasets/processed/test.csv", + "video_dir": "datasets/processed/Grouped_frames", + "inp_img_names_list": ["det_transformed_dave", "det_transformed_esr9", + "det_transformed_vidgaze", + "det_transformed_gaze360"], + "gt_img_names_list": ["transformed_salmap", "transformed_fixmap"]} + + # transforms for images + input_img_transform = transforms.Compose([ + transforms.Resize((self.inp_img_height, self.inp_img_width)), + transforms.ToTensor(), + transforms.Normalize(self.inp_img_mean, self.inp_img_std), + ]) + gt_img_transform = transforms.Compose([ + transforms.Resize((self.trg_img_height, self.trg_img_width)), + transforms.ToTensor() + ]) + + self.train_dataset_properties.update(sequence_len=self.sequence_len, gt_img_transform=gt_img_transform, + inp_img_transform=input_img_transform, exhaustive=self.exhaustive) + self.train_dataset = GASPDataset(**self.train_dataset_properties) + + self.val_dataset_properties.update(sequence_len=self.sequence_len, gt_img_transform=gt_img_transform, + inp_img_transform=input_img_transform, exhaustive=self.exhaustive) + self.val_dataset = GASPDataset(**self.val_dataset_properties) + + self.test_dataset_properties.update(sequence_len=self.sequence_len, gt_img_transform=gt_img_transform, + inp_img_transform=input_img_transform, exhaustive=self.exhaustive) + self.test_dataset = GASPDataset(**self.test_dataset_properties) + + # self.train_dataset, self.val_dataset = random_split(self.train_dataset, [55000, 5000]) + # update the batch size based on the train dataset if experimenting on small dummy data + if len(self.train_dataset) < self.batch_size: + self.batch_size = len(self.train_dataset) + + def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) + + def val_dataloader(self): + return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers) + + def test_dataloader(self): + return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + # optimizer = torch.optim.SGD(self.parameters(), lr=1e-1) + return optimizer + + # TODO (fabawi): this should be a callback instead and not part of the model, but will have to do for now. + def log_val_images(self, x, logits, y): + if isinstance(self.logger, CometLogger): + # for 1 channel plotting + from PIL import Image + import numpy as np + from matplotlib import cm + + random_batch_item_idx = np.random.randint(0, max(x.shape[0]-1, 1)) if self.batch_size > 1 else 0 + all_imgs = [] + + + all_imgs.append([]) + # im.show() + + # the input + for seq_idx in range(0, max(self.sequence_len, 1)): + input_imgs = [] + if self.exhaustive: + DEBUG_y = np.squeeze(y["seq_transformed_salmap"].cpu()[random_batch_item_idx, seq_idx, ::].data.numpy()) + input_imgs.append(Image.fromarray(np.uint8(cm.jet(DEBUG_y) * 255)).resize((self.inp_img_width, self.inp_img_height))) + else: + input_imgs.append(Image.new(INP_IMG_MODE, (self.inp_img_width, self.inp_img_height))) + for mod_idx in range(0, self.modalities*self.in_channels, self.in_channels): + if self.sequence_len > 1: + x_mod = x[:,seq_idx, ::] + else: + x_mod = x + x_mod = UnNormalize(self.inp_img_mean, self.inp_img_std)(x_mod[random_batch_item_idx, mod_idx:mod_idx+self.in_channels,::]) + DEBUG_x = np.squeeze(x_mod.cpu().data.numpy()) + DEBUG_x = np.moveaxis(DEBUG_x, 0, -1) # change 0:2 to the input channels of the modality -> first modality 0:3 + print(DEBUG_x.shape) + im = Image.fromarray((DEBUG_x*255).astype(np.uint8)) + input_imgs.append(im) + + if seq_idx != max(self.sequence_len, 1) - 1: + input_imgs.append(Image.new(INP_IMG_MODE, (self.inp_img_width, self.inp_img_height))) + all_imgs[-1].extend(input_imgs) + all_imgs.append([]) + # im.show() + del all_imgs[-1] + + # the gt + DEBUG_y = np.squeeze(y["transformed_salmap"].cpu()[random_batch_item_idx, ::].data.numpy()) + im_gt = Image.fromarray(np.uint8(cm.jet(DEBUG_y) * 255)).resize((self.inp_img_width, self.inp_img_height)) + all_imgs[-1][0] = im_gt + + # the model output + logits_s = logits[0].cpu() + logits_s = logits_s[random_batch_item_idx, ::] + # logits_s_norm = torch.sigmoid(logits_s) + logits_s_norm = logits_s - torch.min(logits_s) + logits_s_norm /= torch.max(logits_s_norm) + DEBUG_y = np.squeeze(logits_s_norm.data.numpy()) + DEBUG_y = np.uint8(cm.jet(DEBUG_y) * 255) + im_pred = Image.fromarray(DEBUG_y).resize((self.inp_img_width, self.inp_img_height)) + all_imgs[-1].append(im_pred) + # im.show() + + # collate images + widths, heights = zip(*(z.size for i in all_imgs for z in i)) + + total_width = sum(widths) // max(self.sequence_len, 1) + max_height = max(heights) + total_height = max_height * max(self.sequence_len, 1) + + new_im = Image.new('RGB', (total_width, total_height)) + + y_offset = 0 + for im_y_idx in range(0, max(self.sequence_len, 1)): + x_offset = 0 + for im_x_idx in range(0, len(all_imgs[0])): + new_im.paste(all_imgs[im_y_idx][im_x_idx], (x_offset, y_offset)) + x_offset += all_imgs[im_y_idx][im_x_idx].size[0] + y_offset += all_imgs[im_y_idx][0].size[1] + + # log on comet_ml + self.logger.experiment.log_image(new_im, 'GT : INPUTS : LOGITS', step=self.current_epoch) + + # this section is for debugging only + logits_s_norm = torch.sigmoid(logits_s) # 1 + logits_s_norm = logits_s - torch.min(logits_s_norm) + logits_s_norm /= torch.max(logits_s_norm) + DEBUG_y = np.squeeze(logits_s_norm.data.numpy()) + DEBUG_y = np.uint8(cm.jet(DEBUG_y) * 255) + im_pred_sig_norm = Image.fromarray(DEBUG_y).resize((self.inp_img_width, self.inp_img_height)) + logits_s_norm = torch.sigmoid(logits_s) # 2 + DEBUG_y = np.squeeze(logits_s_norm.data.numpy()) + DEBUG_y = np.uint8(cm.jet(DEBUG_y) * 255) + im_pred_sig = Image.fromarray(DEBUG_y).resize((self.inp_img_width, self.inp_img_height)) + debug_new_im = Image.new('RGB', (im_pred.size[0] * 3, im_pred.size[1])) + debug_new_im.paste(im_pred, (0, 0)) + debug_new_im.paste(im_pred_sig, (im_pred.size[0], 0)) + debug_new_im.paste(im_pred_sig_norm, (im_pred.size[0] * 2, 0)) + self.logger.experiment.log_image(debug_new_im, 'PRED : PRED_SIGMOID: PRED_SIGMOID_NORM', + step=self.current_epoch) + + @staticmethod + def update_infer_config(log_path, checkpoint_file, train_config, infer_config, device): + # if isinstance(device, list): + # infer_config.device = "cuda:" + str(device[0]) + if isinstance(device, list): + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(device) + elif isinstance(device, str): + os.environ["CUDA_VISIBLE_DEVICES"] = device + + # update the datasplitter to include test file specified in the trainer + infer_config.datasplitter_properties.update(test_csv_file=train_config.test_dataset_properties["csv_file"]) + # update the metrics file path + infer_config.metrics_save_file = os.path.join(log_path, "metrics.csv") + # update the sampling names list. This is specific to DataSample + infer_config.sampling_properties.update( + img_names_list=train_config.test_dataset_properties["inp_img_names_list"] + + train_config.test_dataset_properties["gt_img_names_list"]) + + # update the targeted model properties + for mod_group in infer_config.model_groups: + for mod in mod_group: + if not train_config.inferer_name == mod[0]: + continue + + # try extracting the window size if it exists, otherwise, assume single frame + for w_size_name in ["w_size", "frames_len", "sequence_len"]: + if w_size_name in train_config.model_properties: + if mod[1] == -1: + mod[1] = train_config.model_properties[w_size_name] + if mod[2] == -1: + mod[2] = [mod[1] - 1] + break + if mod[1] == -1: + mod[1] = 1 + if mod[2] == -1: + mod[2] = [0] + + # infer the frames_len from sequence_len + if "frames_len" not in train_config.model_properties and "sequence_len" in train_config.model_properties: + train_config.model_properties["frames_len"] = train_config.model_properties["sequence_len"] + + # update the configuration + mod[3].update(**train_config.model_properties, + weights_file=checkpoint_file, + model_name=train_config.model_name) + + # update the input image names list. This is specific to gasp + if mod[4]["inp_img_names_list"] is None: + mod[4].update(inp_img_names_list=train_config.test_dataset_properties["inp_img_names_list"]) + + break + + return infer_config + + +#done +@ModelRegistrar.register +class GASPEncGMUALSTMConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, encoder="Conv", **kwargs): + super(GASPEncGMUALSTMConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=1, + *args, **kwargs) + + self.encoder = encoder + if encoder is not None: + self.enc = ModalityEncoder(in_channels, LATENT_CONV_C, encoder=encoder) + self.gmu = GMUConv2d(LATENT_CONV_C, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + else: + self.gmu = GMUConv2d(in_channels, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + + self.att_lstm = AttentiveLSTM(LATENT_CONV_C, LATENT_CONV_C, LATENT_CONV_C, 3, 3) + self.saliency_out = nn.Conv2d(LATENT_CONV_C, out_channels, kernel_size=1) + + def forward(self, modules): + if self.encoder is not None: + fusion = [] + for mod_idx in range(0, self.in_channels * self.modalities, self.in_channels): + fusion.append(self.enc(modules[:, mod_idx:mod_idx + self.in_channels, ::])) + + fusion = torch.cat(fusion, 1) + else: + fusion = modules + + sal, lateral = self.gmu(fusion) + sal = self.att_lstm(sal) + sal = self.saliency_out(sal) + return sal, None, lateral + + +#done +@ModelRegistrar.register +class GASPEncALSTMGMUConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, encoder="Conv", **kwargs): + super(GASPEncALSTMGMUConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=1, + *args, **kwargs) + + self.encoder = encoder + if encoder is not None: + self.enc = ModalityEncoder(in_channels, LATENT_CONV_C, encoder=encoder) + self.att_lstm = AttentiveLSTM(LATENT_CONV_C * modalities, LATENT_CONV_C * modalities, LATENT_CONV_C * modalities, 3, 3) + self.gmu = GMUConv2d(LATENT_CONV_C, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + else: + self.att_lstm = AttentiveLSTM(in_channels * modalities, in_channels * modalities, in_channels * modalities, 3, 3) + self.gmu = GMUConv2d(in_channels, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + + self.saliency_out = nn.Conv2d(LATENT_CONV_C, out_channels, kernel_size=1) + + def forward(self, modules): + if self.encoder is not None: + fusion = [] + for mod_idx in range(0, self.in_channels * self.modalities, self.in_channels): + fusion.append(self.enc(modules[:, mod_idx:mod_idx + self.in_channels, ::])) + + fusion = torch.cat(fusion, 1) + else: + fusion = modules + sal = self.att_lstm(fusion) + sal, lateral = self.gmu(sal) + sal = self.saliency_out(sal) + return sal, None, lateral # lateral is non-fusion + + +#done +@ModelRegistrar.register +class GASPEncALSTMConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, encoder="Conv", **kwargs): + super(GASPEncALSTMConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=1, + *args, **kwargs) + + self.encoder = encoder + if encoder is not None: + self.enc = ModalityEncoder(in_channels, LATENT_CONV_C, encoder=encoder) + self.att_lstm = AttentiveLSTM(LATENT_CONV_C * modalities, LATENT_CONV_C * modalities, LATENT_CONV_C * modalities, 3, 3) + self.saliency_out = nn.Conv2d(LATENT_CONV_C * modalities, out_channels, kernel_size=1) + else: + self.att_lstm = AttentiveLSTM(in_channels * modalities, in_channels * modalities, in_channels * modalities, 3, 3) + self.saliency_out = nn.Conv2d(in_channels * modalities, out_channels, kernel_size=1) + + def forward(self, modules): + if self.encoder is not None: + fusion = [] + for mod_idx in range(0, self.in_channels * self.modalities, self.in_channels): + fusion.append(self.enc(modules[:, mod_idx:mod_idx + self.in_channels, ::])) + + fusion = torch.cat(fusion, 1) + else: + fusion = modules + sal = self.att_lstm(fusion) + sal = self.saliency_out(sal) + return sal, None, None + + +#done +@ModelRegistrar.register +class GASPDAMEncGMUConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, encoder="Conv", **kwargs): + super(GASPDAMEncGMUConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=1, + *args, **kwargs) + self.dam = DAMLayer(in_channels*modalities, reduction=2) + + self.encoder = encoder + if encoder is not None: + self.enc = ModalityEncoder(in_channels, LATENT_CONV_C, encoder=encoder) + self.gmu = GMUConv2d(LATENT_CONV_C, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + else: + self.gmu = GMUConv2d(in_channels, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + self.saliency_out = nn.Conv2d(LATENT_CONV_C, out_channels, kernel_size=1) + + def forward(self, modules): + dam_sal = self.dam(modules) + modules = self.dam(modules, detached=True) + + if self.encoder is not None: + fusion = [] + for mod_idx in range(0, self.in_channels*self.modalities, self.in_channels): + fusion.append(self.enc(modules[:, mod_idx:mod_idx+self.in_channels, ::])) + + fusion = torch.cat(fusion, 1) + else: + fusion = modules # to operate without an encoder + + sal, lateral = self.gmu(fusion) + sal = self.saliency_out(sal) + return sal, dam_sal, lateral # lateral is non-fusion + +#done +@ModelRegistrar.register +class GASPDAMEncALSTMGMUConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, encoder="Conv", **kwargs): + super(GASPDAMEncALSTMGMUConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=1, + *args, **kwargs) + self.dam = DAMLayer(in_channels*modalities, reduction=2) + + self.encoder = encoder + if encoder is not None: + self.enc = ModalityEncoder(in_channels, LATENT_CONV_C, encoder=encoder) + self.att_lstm = AttentiveLSTM(LATENT_CONV_C * modalities, LATENT_CONV_C * modalities, LATENT_CONV_C * modalities, 3, 3) + self.gmu = GMUConv2d(LATENT_CONV_C, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + else: + self.att_lstm = AttentiveLSTM(in_channels * modalities, in_channels * modalities, in_channels * modalities, 3, 3) + self.gmu = GMUConv2d(in_channels, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + self.saliency_out = nn.Conv2d(LATENT_CONV_C, out_channels, kernel_size=1) + + def forward(self, modules): + + dam_sal = self.dam(modules) + modules = self.dam(modules, detached=True) + + if self.encoder is not None: + fusion = [] + for mod_idx in range(0, self.in_channels*self.modalities, self.in_channels): + fusion.append(self.enc(modules[:, mod_idx:mod_idx+self.in_channels, ::])) + + fusion = torch.cat(fusion, 1) + else: + fusion = modules # to operate without an encoder + + sal = self.att_lstm(fusion) + sal, lateral = self.gmu(sal) + sal = self.saliency_out(sal) + return sal, dam_sal, lateral # lateral is non-fusion + + +#done +@ModelRegistrar.register +class GASPDAMEncGMUALSTMConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, encoder="Conv", **kwargs): + super(GASPDAMEncGMUALSTMConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=1, + *args, **kwargs) + self.dam = DAMLayer(in_channels*modalities, reduction=2) + + self.encoder = encoder + if encoder is not None: + self.enc = ModalityEncoder(in_channels, LATENT_CONV_C, encoder=encoder) + self.gmu = GMUConv2d(LATENT_CONV_C, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + + else: + self.gmu = GMUConv2d(in_channels, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + self.att_lstm = AttentiveLSTM(LATENT_CONV_C, LATENT_CONV_C, LATENT_CONV_C, 3, 3) + self.saliency_out = nn.Conv2d(LATENT_CONV_C, out_channels, kernel_size=1) + + def forward(self, modules): + + dam_sal = self.dam(modules) + modules = self.dam(modules, detached=True) + + if self.encoder is not None: + fusion = [] + for mod_idx in range(0, self.in_channels*self.modalities, self.in_channels): + fusion.append(self.enc(modules[:, mod_idx:mod_idx+self.in_channels, ::])) + + fusion = torch.cat(fusion, 1) + else: + fusion = modules # to operate without an encoder + + sal, lateral = self.gmu(fusion) + sal = self.att_lstm(sal) + sal = self.saliency_out(sal) + return sal, dam_sal, lateral + + +#done +@ModelRegistrar.register +class GASPEncConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, encoder="Conv", **kwargs): + super(GASPEncConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=1, + *args, **kwargs) + self.encoder = encoder + if encoder is not None: + self.enc = ModalityEncoder(in_channels, LATENT_CONV_C, encoder) + self.saliency_in = nn.Conv2d(LATENT_CONV_C * modalities, LATENT_CONV_C, kernel_size=3, padding=1) + else: + self.saliency_in = nn.Conv2d(in_channels*modalities, LATENT_CONV_C, kernel_size=3, padding=1) + + self.saliency_out = nn.Conv2d(LATENT_CONV_C, out_channels, kernel_size=1) + + def forward(self, modules): + + if self.encoder is not None: + fusion = [] + for mod_idx in range(0, self.in_channels * self.modalities, self.in_channels): + fusion.append(self.enc(modules[:, mod_idx:mod_idx + self.in_channels, ::])) + + fusion = torch.cat(fusion, 1) + else: + fusion = modules + sal = self.saliency_in(fusion) + sal = self.saliency_out(sal) + return sal, None, None + + +#done +@ModelRegistrar.register +class GASPEncAddConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, encoder="Conv", **kwargs): + super(GASPEncAddConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=1, + *args, **kwargs) + self.encoder = encoder + if encoder is not None: + self.enc = ModalityEncoder(in_channels, LATENT_CONV_C, encoder) + self.saliency_in = nn.Conv2d(1, LATENT_CONV_C, kernel_size=3, padding=1) + else: + self.saliency_in = nn.Conv2d(1, LATENT_CONV_C, kernel_size=3, padding=1) + + self.saliency_out = nn.Conv2d(LATENT_CONV_C, out_channels, kernel_size=1) + + def forward(self, modules): + + if self.encoder is not None: + fusion = [] + for mod_idx in range(0, self.in_channels * self.modalities, self.in_channels): + fusion.append(self.enc(modules[:, mod_idx:mod_idx + self.in_channels, ::])) + + fusion = torch.cat(fusion, 1) + else: + fusion = modules + + sal = torch.sum(fusion, dim=1) + sal = torch.unsqueeze(sal, dim=1) + sal = self.saliency_in(sal) + sal = self.saliency_out(sal) + return sal, None, None + + +@ModelRegistrar.register +class GASPSEEncConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, encoder="Conv", **kwargs): + super(GASPSEEncConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=1, + *args, **kwargs) + + self.se = SELayer(in_channels*modalities, reduction=2) + + self.encoder = encoder + if encoder is not None: + self.enc = ModalityEncoder(in_channels, LATENT_CONV_C, encoder) + self.saliency_in = nn.Conv2d(LATENT_CONV_C * modalities, LATENT_CONV_C, kernel_size=3, padding=1) + else: + self.saliency_in = nn.Conv2d(in_channels * modalities, LATENT_CONV_C, kernel_size=3, padding=1) + + self.saliency_out = nn.Conv2d(LATENT_CONV_C, out_channels, kernel_size=1) + + def forward(self, modules): + modules = self.se(modules) + if self.encoder is not None: + fusion = [] + for mod_idx in range(0, self.in_channels * self.modalities, self.in_channels): + fusion.append(self.enc(modules[:, mod_idx:mod_idx + self.in_channels, ::])) + + fusion = torch.cat(fusion, 1) + else: + fusion = modules + sal = fusion + sal = self.saliency_in(sal) + sal = self.saliency_out(sal) + return sal, None, None + + +#done +@ModelRegistrar.register +class SequenceGASPEncALSTMConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, sequence_len=16, sequence_norm=False, encoder="Conv", **kwargs): + super(SequenceGASPEncALSTMConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=sequence_len, + *args, **kwargs) + + # model and dataset mode dependent + assert sequence_len > 1, "Sequence length must be greater than 1" + self.sequence_len = sequence_len + self.exhaustive = False + + self.encoder = encoder + if encoder is not None: + self.enc_list = nn.ModuleList([ModalityEncoder(in_channels, LATENT_CONV_C, encoder=encoder) for _ in range(modalities)]) + self.att_lstm = SequenceAttentiveLSTM(LATENT_CONV_C * modalities, + LATENT_CONV_C * modalities, + LATENT_CONV_C * modalities, + 3, 3, + sequence_len=sequence_len, + sequence_norm=sequence_norm) + self.saliency_out = nn.Conv2d(LATENT_CONV_C * modalities, out_channels, kernel_size=1) + + else: + self.att_lstm = SequenceAttentiveLSTM(in_channels * modalities, + in_channels * modalities, + in_channels * modalities, + 3, 3, + sequence_len=sequence_len, + sequence_norm=sequence_norm) + self.saliency_out = nn.Conv2d(in_channels*modalities, out_channels, kernel_size=1) + + def forward(self, modules): + if self.encoder is not None: + fusion = [] + for seq_idx in range(0, self.sequence_len): + mod_fusion=[] + for mod_idx, mod_step in enumerate(range(0, self.in_channels*self.modalities, self.in_channels)): + mod_fusion.append(self.enc_list[mod_idx](modules[:, seq_idx, mod_step:mod_step+self.in_channels, ::])) + mod_fusion = torch.cat(mod_fusion, 1) + fusion.append(mod_fusion) + fusion = torch.stack(fusion, 1) + else: + fusion = modules # to operate without an encoder + + + sal = self.att_lstm(fusion) + sal = self.saliency_out(sal) + return sal, None, None + + +#done +@ModelRegistrar.register +class SequenceGASPEncALSTMGMUConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, sequence_len=16, sequence_norm=False, encoder="Conv", **kwargs): + super(SequenceGASPEncALSTMGMUConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=sequence_len, + *args, **kwargs) + + # model and dataset mode dependent + assert sequence_len > 1, "Sequence length must be greater than 1" + self.sequence_len = sequence_len + self.exhaustive = False + self.encoder = encoder + if encoder is not None: + self.enc_list = nn.ModuleList([ModalityEncoder(in_channels, LATENT_CONV_C, encoder=encoder) for _ in range(modalities)]) + self.att_lstm = SequenceAttentiveLSTM(LATENT_CONV_C * modalities, + LATENT_CONV_C * modalities, + LATENT_CONV_C * modalities, + 3, 3, + sequence_len=sequence_len, + sequence_norm=sequence_norm) + self.gmu = GMUConv2d(LATENT_CONV_C, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + else: + self.att_lstm = SequenceAttentiveLSTM(in_channels * modalities, + in_channels * modalities, + in_channels * modalities, + 3, 3, + sequence_len=sequence_len, + sequence_norm=sequence_norm) + self.gmu = GMUConv2d(in_channels, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + self.saliency_out = nn.Conv2d(LATENT_CONV_C, out_channels, kernel_size=1) + + def forward(self, modules): + if self.encoder is not None: + fusion = [] + for seq_idx in range(0, self.sequence_len): + mod_fusion=[] + for mod_idx, mod_step in enumerate(range(0, self.in_channels*self.modalities, self.in_channels)): + mod_fusion.append(self.enc_list[mod_idx](modules[:, seq_idx, mod_step:mod_step+self.in_channels, ::])) + mod_fusion = torch.cat(mod_fusion, 1) + fusion.append(mod_fusion) + fusion = torch.stack(fusion, 1) + else: + fusion = modules # to operate without an encoder + + sal = self.att_lstm(fusion) + sal, lateral = self.gmu(sal) + sal = self.saliency_out(sal) + return sal, None, lateral # lateral is non-fusion + + +#done +@ModelRegistrar.register +class SequenceGASPEncGMUALSTMConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, sequence_len=16, sequence_norm=False, encoder="Conv", **kwargs): + super(SequenceGASPEncGMUALSTMConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=sequence_len, + *args, **kwargs) + + # model and dataset mode dependent + assert sequence_len > 1, "Sequence length must be greater than 1" + self.sequence_len = sequence_len + self.exhaustive = False + self.encoder = encoder + if encoder is not None: + self.enc_list = nn.ModuleList([ModalityEncoder(in_channels, LATENT_CONV_C, encoder=encoder) for _ in range(modalities)]) + self.gmu = GMUConv2d(LATENT_CONV_C, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + else: + self.gmu = GMUConv2d(in_channels, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + + self.att_lstm = SequenceAttentiveLSTM(LATENT_CONV_C, LATENT_CONV_C, LATENT_CONV_C, 3, 3, + sequence_len=sequence_len, + sequence_norm=sequence_norm) + self.saliency_out = nn.Conv2d(LATENT_CONV_C, out_channels, kernel_size=1) + + def forward(self, modules): + if self.encoder is not None: + fusion = [] + for seq_idx in range(0, self.sequence_len): + mod_fusion=[] + for mod_idx, mod_step in enumerate(range(0, self.in_channels*self.modalities, self.in_channels)): + mod_fusion.append(self.enc_list[mod_idx](modules[:, seq_idx, mod_step:mod_step+self.in_channels, ::])) + mod_fusion = torch.cat(mod_fusion, 1) + mod_fusion, lateral = self.gmu(mod_fusion) + fusion.append(mod_fusion) + fusion = torch.stack(fusion, 1) + else: + fusion = [] # to operate without an encoder + for seq_idx in range(0, self.sequence_len): + mod_fusion, lateral = self.gmu(fusion[:, seq_idx, ::]) + fusion.append(mod_fusion) + fusion = torch.stack(fusion, 1) + sal = self.att_lstm(fusion) + sal = self.saliency_out(sal) + return sal, None, lateral # lateral is only the last + + +#done (not tested) +@ModelRegistrar.register +class SequenceGASPDAMEncGMUALSTMConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, sequence_len=16, sequence_norm=False, encoder="Conv", **kwargs): + super(SequenceGASPDAMEncGMUALSTMConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=sequence_len, + *args, **kwargs) + + # model and dataset mode dependent + assert sequence_len > 1, "Sequence length must be greater than 1" + assert encoder is not None, "Encoder must not be None when running sequential DAM" + self.sequence_len = sequence_len + self.exhaustive = True + self.encoder = encoder + self.dam = DAMLayer(in_channels * modalities, reduction=2) + self.enc_list = nn.ModuleList([ModalityEncoder(in_channels, LATENT_CONV_C, encoder=encoder) for _ in range(modalities)]) + self.gmu = GMUConv2d(LATENT_CONV_C, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + self.att_lstm = SequenceAttentiveLSTM(LATENT_CONV_C, LATENT_CONV_C, LATENT_CONV_C, 3, 3, + sequence_len=sequence_len, + sequence_norm=sequence_norm) + self.saliency_out = nn.Conv2d(LATENT_CONV_C, out_channels, kernel_size=1) + + def forward(self, modules): + + fusion = [] + dam_sals = [] + for seq_idx in range(0, self.sequence_len): + dam_sal = self.dam(modules[:, seq_idx, ::].clone()) + dam_sals.append(dam_sal) + modules[:, seq_idx, ::] = self.dam(modules[:, seq_idx, ::].clone(), detached=True) + mod_fusion=[] + for mod_idx, mod_step in enumerate(range(0, self.in_channels*self.modalities, self.in_channels)): + mod_fusion.append(self.enc_list[mod_idx](modules[:, seq_idx, mod_step:mod_step+self.in_channels, ::].clone())) + mod_fusion = torch.cat(mod_fusion, 1) + mod_fusion, lateral = self.gmu(mod_fusion) + fusion.append(mod_fusion) + fusion = torch.stack(fusion, 1) + dam_sals = torch.stack(dam_sals, 1) + + sal = self.att_lstm(fusion) + sal = self.saliency_out(sal) + return sal, dam_sals, lateral # lateral is only the last + + +#done (not tested) +@ModelRegistrar.register +class SequenceGASPDAMEncALSTMGMUConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, sequence_len=16, sequence_norm=False, encoder="Conv", **kwargs): + super(SequenceGASPDAMEncALSTMGMUConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=sequence_len, + *args, **kwargs) + + # model and dataset mode dependent + assert sequence_len > 1, "Sequence length must be greater than 1" + assert encoder is not None, "Encoder must not be None when running sequential DAM" + self.sequence_len = sequence_len + self.exhaustive = True + self.encoder = encoder + self.dam = DAMLayer(in_channels * modalities, reduction=2) + self.enc_list = nn.ModuleList([ModalityEncoder(in_channels, LATENT_CONV_C, encoder=encoder) for _ in range(modalities)]) + self.att_lstm = SequenceAttentiveLSTM(LATENT_CONV_C * modalities, LATENT_CONV_C * modalities, LATENT_CONV_C * modalities, 3, 3, + sequence_len=sequence_len, + sequence_norm=sequence_norm) + self.gmu = GMUConv2d(LATENT_CONV_C, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + self.saliency_out = nn.Conv2d(LATENT_CONV_C, out_channels, kernel_size=1) + + def forward(self, modules): + + fusion = [] + dam_sals = [] + for seq_idx in range(0, self.sequence_len): + dam_sal = self.dam(modules[:, seq_idx, ::].clone()) + dam_sals.append(dam_sal) + modules[:, seq_idx, ::] = self.dam(modules[:, seq_idx, ::].clone(), detached=True) + mod_fusion=[] + for mod_idx, mod_step in enumerate(range(0, self.in_channels*self.modalities, self.in_channels)): + mod_fusion.append(self.enc_list[mod_idx](modules[:, seq_idx, mod_step:mod_step+self.in_channels, ::].clone())) + mod_fusion = torch.cat(mod_fusion, 1) + fusion.append(mod_fusion) + fusion = torch.stack(fusion, 1) + dam_sals = torch.stack(dam_sals, 1) + + sal = self.att_lstm(fusion) + sal, lateral = self.gmu(sal) + sal = self.saliency_out(sal) + return sal, dam_sals, lateral # lateral is non-fusion + + +#done +@ModelRegistrar.register +class GASPEncGMUConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, encoder="Conv", **kwargs): + super(GASPEncGMUConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=1, + *args, **kwargs) + self.encoder = encoder + if encoder is not None: + self.enc = ModalityEncoder(in_channels, LATENT_CONV_C, encoder) + self.gmu = GMUConv2d(LATENT_CONV_C, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + else: + self.gmu = GMUConv2d(in_channels, LATENT_CONV_C, modalities, kernel_size=3, padding=1) + + self.saliency_out = nn.Conv2d(LATENT_CONV_C, out_channels, kernel_size=1) + + def forward(self, modules): + if self.encoder is not None: + fusion = [] + for mod_idx in range(0, self.in_channels*self.modalities, self.in_channels): + fusion.append(self.enc(modules[:, mod_idx:mod_idx+self.in_channels, ::])) + fusion = torch.cat(fusion, 1) + else: + fusion = modules # to operate without an encoder + sal, lateral = self.gmu(fusion) + sal = self.saliency_out(sal) + return sal, None, lateral + + +@ModelRegistrar.register +class SequenceGASPEncRGMUConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, encoder="Conv", sequence_len=16, **kwargs): + super(SequenceGASPEncRGMUConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=sequence_len, + *args, **kwargs) + + # model and dataset mode dependent + assert sequence_len > 1, "Sequence length must be greater than 1" + self.sequence_len = sequence_len + self.exhaustive = False + self.encoder = encoder + if encoder is not None: + self.enc_list = nn.ModuleList( + [ModalityEncoder(in_channels, LATENT_CONV_C, encoder=encoder) for _ in range(modalities)]) + self.gmu = RGMUConv2d(LATENT_CONV_C, LATENT_CONV_C, kernel_size=3, modalities=modalities, + input_size=(self.trg_img_height, self.trg_img_width), padding=1) + else: + self.gmu = RGMUConv2d(in_channels, LATENT_CONV_C, kernel_size=3, modalities=modalities, + input_size=(self.inp_img_height, self.inp_img_width), padding=1) + + self.saliency_out = nn.Conv2d(LATENT_CONV_C, out_channels, kernel_size=1) + + def forward(self, modules): + if self.encoder is not None: + fusion = [] + for seq_idx in range(0, self.sequence_len): + mod_fusion = [] + for mod_idx, mod_step in enumerate(range(0, self.in_channels * self.modalities, self.in_channels)): + mod_fusion.append( + self.enc_list[mod_idx](modules[:, seq_idx, mod_step:mod_step + self.in_channels, ::])) + mod_fusion = torch.cat(mod_fusion, 1) + fusion.append(mod_fusion) + fusion = torch.stack(fusion, 1) + else: + fusion = modules + h_l, z_l = self.gmu.initialize_lateral_state() + lateral = (h_l.to("cuda"), z_l.to("cuda")) + sal, lateral = self.gmu(fusion, lateral) + sal = self.saliency_out(sal) + return sal, None, lateral # lateral is all + + +@ModelRegistrar.register +class SequenceGASPDAMEncRGMUConv(GASPBase): + def __init__(self, *args, in_channels=3, modalities=4, out_channels=1, encoder="Conv", sequence_len=16, **kwargs): + super(SequenceGASPDAMEncRGMUConv, self).__init__(in_channels=in_channels, + modalities=modalities, + out_channels=out_channels, + sequence_len=sequence_len, + *args, **kwargs) + + # model and dataset mode dependent + assert sequence_len > 1, "Sequence length must be greater than 1" + assert encoder is not None, "Encoder must not be None when running sequential DAM" + self.sequence_len = sequence_len + self.exhaustive = True + self.encoder = encoder + self.dam = DAMLayer(in_channels * modalities, reduction=2) + self.enc_list = nn.ModuleList( + [ModalityEncoder(in_channels, LATENT_CONV_C, encoder=encoder) for _ in range(modalities)]) + self.gmu = RGMUConv2d(LATENT_CONV_C, LATENT_CONV_C, kernel_size=3, modalities=modalities, + input_size=(self.trg_img_height, self.trg_img_width), padding=1) + self.saliency_out = nn.Conv2d(LATENT_CONV_C, out_channels, kernel_size=1) + + def forward(self, modules): + fusion = [] + dam_sals = [] + for seq_idx in range(0, self.sequence_len): + dam_sal = self.dam(modules[:, seq_idx, ::].clone()) + dam_sals.append(dam_sal) + modules[:, seq_idx, ::] = self.dam(modules[:, seq_idx, ::].clone(), detached=True) + mod_fusion = [] + for mod_idx, mod_step in enumerate(range(0, self.in_channels * self.modalities, self.in_channels)): + mod_fusion.append( + self.enc_list[mod_idx](modules[:, seq_idx, mod_step:mod_step + self.in_channels, ::].clone())) + mod_fusion = torch.cat(mod_fusion, 1) + fusion.append(mod_fusion) + fusion = torch.stack(fusion, 1) + dam_sals = torch.stack(dam_sals, 1) + + h_l, z_l = self.gmu.initialize_lateral_state() + lateral = (h_l.to("cuda"), z_l.to("cuda")) + sal, lateral = self.gmu(fusion, lateral) + sal = self.saliency_out(sal) + return sal, dam_sals, lateral # lateral is all diff --git a/gazenet/models/saliency_prediction/losses.py b/gazenet/models/saliency_prediction/losses.py new file mode 100644 index 0000000..164da47 --- /dev/null +++ b/gazenet/models/saliency_prediction/losses.py @@ -0,0 +1,98 @@ +""" +code from: https://github.com/atsiami/STAViS/blob/master/models/sal_losses.py +""" + +import numpy as np +import torch +import torch.nn.functional as F + + +def logit(x): + return np.log(x / (1 - x + 1e-08) + 1e-08) + + +def sigmoid_np(x): + return 1 / (1 + np.exp(-x)) + + +def cc_score(x, y, weights, batch_average=False, reduce=True): + x = x.squeeze(1) + x = torch.sigmoid(x) + y = y.squeeze(1) + mean_x = torch.mean(torch.mean(x, 1, keepdim=True), 2, keepdim=True) + mean_y = torch.mean(torch.mean(y, 1, keepdim=True), 2, keepdim=True) + xm = x.sub(mean_x) + ym = y.sub(mean_y) + r_num = torch.sum(torch.sum(torch.mul(xm, ym), 1, keepdim=True), 2, keepdim=True) + r_den_x = torch.sum(torch.sum(torch.mul(xm, xm), 1, keepdim=True), 2, keepdim=True) + r_den_y = torch.sum(torch.sum(torch.mul(ym, ym), 1, keepdim=True), 2, keepdim=True) + np.asscalar( + np.finfo(np.float32).eps) + r_val = torch.div(r_num, torch.sqrt(torch.mul(r_den_x, r_den_y))) + r_val = torch.mul(r_val.squeeze(), weights) + if batch_average: + r_val = -torch.sum(r_val) / torch.sum(weights) + else: + if reduce: + r_val = -torch.sum(r_val) + else: + r_val = -r_val + return r_val + + +def nss_score(x, y, weights, batch_average=False, reduce=True): + x = x.squeeze(1) + x = torch.sigmoid(x) + y = y.squeeze(1) + y = torch.gt(y, 0.0).float() + + mean_x = torch.mean(torch.mean(x, 1, keepdim=True), 2, keepdim=True) + std_x = torch.sqrt(torch.mean(torch.mean(torch.pow(torch.sub(x, mean_x), 2), 1, keepdim=True), 2, keepdim=True)) + x_norm = torch.div(torch.sub(x, mean_x), std_x) + r_num = torch.sum(torch.sum(torch.mul(x_norm, y), 1, keepdim=True), 2, keepdim=True) + r_den = torch.sum(torch.sum(y, 1, keepdim=True), 2, keepdim=True) + r_val = torch.div(r_num, r_den + np.asscalar(np.finfo(np.float32).eps)) + r_val = torch.mul(r_val.squeeze(), weights) + if batch_average: + r_val = -torch.sum(r_val) / torch.sum(weights) + else: + if reduce: + r_val = -torch.sum(r_val) + else: + r_val = -r_val + return r_val + + +def batch_image_sum(x): + x = torch.sum(torch.sum(x, 1, keepdim=True), 2, keepdim=True) + return x + + +def batch_image_mean(x): + x = torch.mean(torch.mean(x, 1, keepdim=True), 2, keepdim=True) + return x + + +def cross_entropy_loss(pred, target, weights, batch_average=False, reduce=True): + batch_size = pred.size(0) + output = pred.view(batch_size, -1) + label = target.view(batch_size, -1) + + label = label / 255 + final_loss = F.binary_cross_entropy_with_logits(output, label, reduction="none").sum(1) + final_loss = final_loss * weights + + if reduce: + final_loss = torch.sum(final_loss) + if batch_average: + final_loss /= torch.sum(weights) + + return final_loss + + +def kld_loss(pred, target, weights, batch_average=False, reduce=True): + loss = F.kl_div(pred, target, reduction='none') + if reduce: + loss = loss.sum(-1).sum(-1).sum(-1) + if batch_average: + loss /= torch.sum(weights) + return loss \ No newline at end of file diff --git a/gazenet/models/saliency_prediction/metrics.py b/gazenet/models/saliency_prediction/metrics.py new file mode 100644 index 0000000..c0910ba --- /dev/null +++ b/gazenet/models/saliency_prediction/metrics.py @@ -0,0 +1,414 @@ +""" +code from: https://github.com/tarunsharma1/saliency_metrics/blob/master/salience_metrics.py +""" + +import random +import math + +import numpy as np +import pandas as pd +import cv2 + +from gazenet.utils.registrar import * + +def normalize_map(s_map): + # normalize the salience map (as done in MIT code) + norm_s_map = (s_map - np.min(s_map)) / (np.max(s_map) - np.min(s_map)) + return norm_s_map + + +def discretize_gt(gt): + import warnings + # warnings.warn('can improve the way GT is discretized') + gt[gt > 0] = 255 + return gt / 255 + + +@MetricsRegistrar.register +class SaliencyPredictionMetrics(object): + def __init__(self, save_file="logs/metrics/salpred.csv", dataset_name="", video_name="", + metrics_list=["sim", "aucj", "aucs", "aucb", "nss", "cc", "kld", "ifg"], map_key="frame_detections_gasp"): + + self.save_file = save_file + os.makedirs(os.path.dirname(save_file), exist_ok=True) + self.metrics_list = metrics_list + self.map_key = map_key + self._dataset_name = dataset_name + self._video_name = video_name + + if os.path.exists(save_file): + self.scores = pd.read_csv(save_file, header=0) + else: + self.scores = pd.DataFrame(columns=["video_id", "dataset", "frames_len"] + metrics_list) + self.accumulator = {metric: [] for metric in metrics_list} + + # NOTE: other_accumulator (gate_scores for example) cannot be loaded from a file, and will replace any previous runs + self.other_scores = {} + self.other_accumulator = {} + + def set_new_name(self, vid_name): + collated_metrics = self.accumulate_metrics() + self._video_name = vid_name + self.accumulator = {metric: [] for metric in self.metrics_list} + self.other_accumulator = {} + return collated_metrics + + def save(self): + self.scores.to_csv(self.save_file, index=False) + for o_score_name, o_scores in self.other_scores.items(): + o_scores.to_csv(self.save_file.replace(".csv", "_" + o_score_name + ".csv"), index=False) + + def accumulate_metrics(self, intermed_save=True): + # accumulate metrics + collated_metrics = {} + collated_metrics["video_id"] = self._video_name + collated_metrics["dataset"] = self._dataset_name + collated_metrics["frames_len"] = 0 + + for metric_name, metric_vals in self.accumulator.items(): + collated_metrics["frames_len"] = max(collated_metrics["frames_len"], len(metric_vals)) + if metric_vals: + collated_metrics[metric_name] = np.nanmean(np.array(metric_vals)) + if collated_metrics: + self.scores = self.scores.append(collated_metrics, ignore_index=True) + + # accumulate other scores + other_scores = {o_score_name: {} for o_score_name in self.other_scores.keys()} + for o_score_name in other_scores.keys(): + other_scores[o_score_name]["video_id"] = self._video_name + other_scores[o_score_name]["dataset"] = self._dataset_name + other_scores[o_score_name]["frames_len"] = 0 + + other_scores[o_score_name]["frames_len"] = max(other_scores[o_score_name]["frames_len"], + len(self.other_accumulator[o_score_name])) + if self.other_accumulator[o_score_name]: + o_score_collated = np.nanmean(np.array(self.other_accumulator[o_score_name]), axis=0) + for o_score_idx, o_score in enumerate(o_score_collated.tolist()): + if o_score: + other_scores[o_score_name][str(o_score_idx)] = o_score + + self.other_scores[o_score_name] = self.other_scores[o_score_name].append(other_scores[o_score_name], ignore_index=True) + + + # save to file after every video + if intermed_save: + self.save() + return collated_metrics, other_scores + + def add_metrics(self, returns, models, mapping): + metrics_args = {} + eval_frame_id = 0 + baseline_imgs = [] + for idx_model, model_data in enumerate(models): + for i, frame_dict in enumerate(returns[2 + idx_model][4]): + # get the image frames + for img_name in returns[2 + idx_model][1][i].keys(): + if img_name in mapping.values(): + img = returns[2 + idx_model][1][i][img_name] + metrics_args[list(mapping.keys())[list(mapping.values()).index(img_name)]] = img + eval_frame_id = frame_dict["frame_info"]["frame_id"] + # get the scores info from the annotations if specified: scores should be a vector e.g. the gate scores + if "scores_info" in mapping.keys(): + for score_name in mapping["scores_info"]: + try: + if score_name in frame_dict[self.map_key]: + if score_name in self.other_accumulator: + self.other_accumulator[score_name].append(frame_dict[][score_name][0][0]) + else: + column_list = [str(idx) for idx in range(frame_dict[self.map_key][score_name][0][0].shape[0])] + if not score_name in self.other_scores: + self.other_scores[score_name] = pd.DataFrame(columns=["video_id", "dataset", "frames_len"] + column_list) + self.other_accumulator[score_name] = [frame_dict[self.map_key][score_name][0][0]] + except: + pass + + # extract all the frames besides the evaluation frame for creating the baseline map if needed + if "gt_baseline" in mapping.keys(): + baseline_name = mapping["gt_baseline"] + if "/" in baseline_name: + metrics_args["gt_baseline"] = cv2.imread(baseline_name) + else: + info_list = returns[0]["info_list"] + for i, info in enumerate(info_list): + if info["frame_info"]["frame_id"] != eval_frame_id: + baseline_img = returns[0]["grouped_video_frames_list"][i][baseline_name] + if baseline_img is not None: + baseline_imgs.append(baseline_img) + if baseline_imgs: + baseline_imgs = np.nanmean(np.array(baseline_imgs), axis=0).astype(np.uint8) + metrics_args["gt_baseline"] = baseline_imgs + metrics = self.compute_metrics(**metrics_args) + if metrics is not None: + for metric_name, metric_val in metrics.items(): + self.accumulator[metric_name].append(metric_val) + return metrics + + def compute_metrics(self, pred_salmap, gt_fixmap, gt_salmap, gt_baseline=None): + if pred_salmap is None or gt_fixmap is None or gt_salmap is None: + return None + else: + try: + pred_salmap = cv2.cvtColor(pred_salmap.copy(), cv2.COLOR_BGR2GRAY) + except cv2.error: + pass + try: + gt_fixmap = cv2.cvtColor(gt_fixmap.copy(), cv2.COLOR_BGR2GRAY) + except cv2.error: + pass + try: + gt_salmap = cv2.cvtColor(gt_salmap.copy(), cv2.COLOR_BGR2GRAY) + except cv2.error: + pass + if gt_baseline is not None: + try: + gt_baseline = cv2.cvtColor(gt_baseline.copy(), cv2.COLOR_BGR2GRAY) + gt_baseline = cv2.resize(gt_baseline, gt_salmap.shape) + except cv2.error: + gt_baseline = None + + metrics = {} + pred_salmap_minmax_norm = normalize_map(pred_salmap) + + if "aucj" in self.metrics_list: + metrics["aucj"] = self.auc_judd(pred_salmap_minmax_norm, gt_fixmap) + + if "aucb" in self.metrics_list: + metrics["aucb"] = self.auc_borji(pred_salmap_minmax_norm, gt_fixmap) + + if "aucs" in self.metrics_list: + if gt_baseline is not None: + metrics["aucs"] = self.auc_shuff(pred_salmap_minmax_norm, gt_fixmap, gt_baseline) + else: + metrics["aucs"] = np.nan + + if "nss" in self.metrics_list: + metrics["nss"] = self.nss(pred_salmap, gt_fixmap) + + if "ifg" in self.metrics_list: + if gt_baseline is not None: + metrics["ifg"] = self.infogain(pred_salmap_minmax_norm, gt_fixmap, gt_baseline) + else: + metrics["ifg"] = np.nan + + # continous gts + if "sim" in self.metrics_list: + metrics["sim"] = self.similarity(pred_salmap, gt_salmap) + + if "cc" in self.metrics_list: + metrics["cc"] = self.cc(pred_salmap, gt_salmap) + + if "kld" in self.metrics_list: + metrics["kld"] = self.kldiv(pred_salmap, gt_salmap) + + return metrics + + @staticmethod + def similarity(s_map, gt): + # here gt is not discretized + s_map = normalize_map(s_map) + gt = normalize_map(gt) + s_map = s_map / (np.sum(s_map) * 1.0) + gt = gt / (np.sum(gt) * 1.0) + x, y = np.where(gt > 0) + sim = 0.0 + for i in zip(x, y): + sim = sim + min(gt[i[0], i[1]], s_map[i[0], i[1]]) + return sim + + + @staticmethod + def auc_judd(s_map, gt): + # ground truth is discrete, s_map is continous and normalized + gt = discretize_gt(gt) + # thresholds are calculated from the salience map, only at places where fixations are present + thresholds = [] + for i in range(0, gt.shape[0]): + for k in range(0, gt.shape[1]): + if gt[i][k] > 0: + thresholds.append(s_map[i][k]) + + num_fixations = np.sum(gt) + # num fixations is no. of salience map values at gt >0 + + thresholds = sorted(set(thresholds)) + + # fp_list = [] + # tp_list = [] + area = [] + area.append((0.0, 0.0)) + for thresh in thresholds: + # in the salience map, keep only those pixels with values above threshold + temp = np.zeros(s_map.shape) + temp[s_map >= thresh] = 1.0 + if np.max(gt) != 1.0: + return np.nan + if np.max(s_map) != 1.0: + return np.nan + + num_overlap = np.where(np.add(temp, gt) == 2)[0].shape[0] + tp = num_overlap / (num_fixations * 1.0) + + # total number of pixels > threshold - number of pixels that overlap with gt / total number of non fixated pixels + # this becomes nan when gt is full of fixations..this won't happen + fp = (np.sum(temp) - num_overlap) / ((np.shape(gt)[0] * np.shape(gt)[1]) - num_fixations) + + area.append((round(tp, 4), round(fp, 4))) + + area.append((1.0, 1.0)) + # tp_list.append(1.0) + # fp_list.append(1.0) + # print tp_list + area.sort(key=lambda x: x[0]) + tp_list = [x[0] for x in area] + fp_list = [x[1] for x in area] + return np.trapz(np.array(tp_list), np.array(fp_list)) + + @staticmethod + def auc_shuff(s_map, gt, other_map, n_splits=100, stepsize=0.1): + + # If there are no fixations to predict, return NaN + if np.sum(gt) == 0: + return np.nan + + # normalize saliency map + # s_map = normalize_map(s_map) + + S = s_map.flatten() + F = gt.flatten() + Oth = other_map.flatten() + + Sth = S[F > 0] # sal map values at fixation locations + Nfixations = len(Sth) + + # for each fixation, sample Nsplits values from the sal map at locations + # specified by other_map + + ind = np.where(Oth > 0)[0] # find fixation locations on other images + + Nfixations_oth = min(Nfixations, len(ind)) + randfix = np.full((Nfixations_oth, n_splits), np.nan) + + for i in range(n_splits): + # randomize choice of fixation locations + randind = np.random.permutation(ind.copy()) + # sal map values at random fixation locations of other random images + randfix[:, i] = S[randind[:Nfixations_oth]] + + # calculate AUC per random split (set of random locations) + auc = np.full(n_splits, np.nan) + for s in range(n_splits): + + curfix = randfix[:, s] + + allthreshes = np.flip(np.arange(0, max(np.max(Sth), np.max(curfix)), stepsize)) + tp = np.zeros(len(allthreshes) + 2) + fp = np.zeros(len(allthreshes) + 2) + tp[-1] = 1 + fp[-1] = 1 + + for i in range(len(allthreshes)): + thresh = allthreshes[i] + tp[i + 1] = np.sum(Sth >= thresh) / Nfixations + fp[i + 1] = np.sum(curfix >= thresh) / Nfixations_oth + + auc[s] = np.trapz(np.array(tp), np.array(fp)) + + return np.mean(auc) + + @staticmethod + def auc_borji(s_map, gt, splits=100, stepsize=0.1): + gt = discretize_gt(gt) + num_fixations = np.sum(gt).astype(np.int) + + num_pixels = s_map.shape[0] * s_map.shape[1] + random_numbers = [] + for i in range(0, splits): + temp_list = [] + for k in range(0, num_fixations): + temp_list.append(np.random.randint(num_pixels)) + random_numbers.append(temp_list) + + aucs = [] + # for each split, calculate auc + for i in random_numbers: + r_sal_map = [] + for k in i: + r_sal_map.append(s_map[k % s_map.shape[0] - 1, k // s_map.shape[0]]) + # in these values, we need to find thresholds and calculate auc + thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] + + r_sal_map = np.array(r_sal_map) + + # once threshs are got + thresholds = sorted(set(thresholds)) + area = [] + area.append((0.0, 0.0)) + for thresh in thresholds: + # in the salience map, keep only those pixels with values above threshold + temp = np.zeros(s_map.shape) + temp[s_map >= thresh] = 1.0 + num_overlap = np.where(np.add(temp, gt) == 2)[0].shape[0] + tp = num_overlap / (num_fixations * 1.0) + + # fp = (np.sum(temp) - num_overlap)/((np.shape(gt)[0] * np.shape(gt)[1]) - num_fixations) + # number of values in r_sal_map, above the threshold, divided by num of random locations = num of fixations + fp = len(np.where(r_sal_map > thresh)[0]) / (num_fixations * 1.0) + + area.append((round(tp, 4), round(fp, 4))) + + area.append((1.0, 1.0)) + area.sort(key=lambda x: x[0]) + tp_list = [x[0] for x in area] + fp_list = [x[1] for x in area] + + aucs.append(np.trapz(np.array(tp_list), np.array(fp_list))) + + return np.mean(aucs) + + @staticmethod + def nss(s_map, gt): + gt = discretize_gt(gt) + s_map_std_norm = (s_map - np.mean(s_map)) / np.std(s_map) + + x, y = np.where(gt == 1) + temp = [] + for i in zip(x, y): + temp.append(s_map_std_norm[i[0], i[1]]) + return np.mean(temp) + + @staticmethod + def infogain(s_map, gt, baseline_map): + gt = discretize_gt(gt) + # assuming s_map and baseline_map are normalized + eps = 2.2204e-16 + + s_map = s_map / (np.sum(s_map) * 1.0) + baseline_map = baseline_map / (np.sum(baseline_map) * 1.0) + + # for all places where gt=1, calculate info gain + temp = [] + x, y = np.where(gt == 1) + for i in zip(x, y): + temp.append(np.log2(eps + s_map[i[0], i[1]]) - np.log2(eps + baseline_map[i[0], i[1]])) + + return np.mean(temp) + + @staticmethod + def cc(s_map, gt): + s_map_norm = (s_map - np.mean(s_map)) / np.std(s_map) + gt_norm = (gt - np.mean(gt)) / np.std(gt) + a = s_map_norm + b = gt_norm + r = (a * b).sum() / math.sqrt((a * a).sum() * (b * b).sum()) + return r + + @staticmethod + def kldiv(s_map, gt): + s_map = s_map / (np.sum(s_map) * 1.0) + gt = gt / (np.sum(gt) * 1.0) + eps = 2.2204e-16 + return np.sum(gt * np.log(eps + gt / (s_map + eps))) + + + + diff --git a/gazenet/models/saliency_prediction/stavis/__init__.py b/gazenet/models/saliency_prediction/stavis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/saliency_prediction/stavis/checkpoints/pretrained_stavis_orig/download_model.sh b/gazenet/models/saliency_prediction/stavis/checkpoints/pretrained_stavis_orig/download_model.sh new file mode 100644 index 0000000..94cc247 --- /dev/null +++ b/gazenet/models/saliency_prediction/stavis/checkpoints/pretrained_stavis_orig/download_model.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +wget -O pretrained_models.tar.gz http://cvsp.cs.ntua.gr/research/stavis/data/pretrained_models.tar.gz +tar -xzf pretrained_models.tar.gz +rm pretrained_models.tar.gz +mv pretrained_models/* . +rm -rf pretrained_models/ \ No newline at end of file diff --git a/gazenet/models/saliency_prediction/stavis/generator.py b/gazenet/models/saliency_prediction/stavis/generator.py new file mode 100644 index 0000000..09d81f3 --- /dev/null +++ b/gazenet/models/saliency_prediction/stavis/generator.py @@ -0,0 +1,60 @@ +import torch +import torchvision.transforms.functional as F +import numpy as np +import librosa as sf +from PIL import Image +import cv2 + +def normalize_data(data): + data_min = np.min(data) + data_max = np.max(data) + data_norm = np.clip((data - data_min) * + (255.0 / (data_max - data_min)), + 0, 255).astype(np.uint8) + return data_norm + +def create_data_packet(in_data, frame_number, frames_len=16): + in_data = np.array(in_data) + n_frame = in_data.shape[0] + # if the frame number is larger, we just use the last sound one heard + frame_number = min(frame_number, n_frame) + starting_frame = frame_number - frames_len + 1 + # ensure we do not have any negative video_frames_list + starting_frame = max(0, starting_frame) + data_pack = in_data + # data_pack = in_data[starting_frame:frame_number+1, :] + return data_pack, frames_len#frame_number + + +def get_wav_features(features, frame_number, frames_len=16): + + audio_data, valid_frame_number = create_data_packet(features, frame_number, frames_len=frames_len) + return torch.from_numpy(audio_data).float().view(1,1,-1), valid_frame_number + + +def load_video_frames(frames_list, last_frame_idx, valid_frame_idx, img_mean, img_std, img_width, img_height, frames_len=16): + # load video video_frames_list, process them and return a suitable tensor + frame_number = min(last_frame_idx, valid_frame_idx) + start_frame_number = frame_number - frames_len + 1 + start_frame_number = max(0, start_frame_number) + frames_list_idx = [f for f in range(start_frame_number, frame_number)] + if len(frames_list_idx) < frames_len: + nsh = frames_len - len(frames_list_idx) + frames_list_idx = np.concatenate((np.tile(frames_list_idx[0], (nsh)), frames_list_idx), axis=0) + frames = [] + for i in range(len(frames_list_idx)): + # TODO (fabawi): loading the first frame on failure is not ideal. Find a better way + try: + img = cv2.resize(frames_list[frames_list_idx[i]].copy(), (img_width, img_height)) + except: + try: + img = cv2.resize(frames_list[frames_list_idx[0]].copy(), (img_width, img_height)) + except: + img = np.zeros((img_height, img_width, 3), dtype=np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + # img = img.convert('RGB') + img = F.to_tensor(img) + img = F.normalize(img, img_mean, img_std) + frames.append(img) + frames = torch.stack(frames, dim=0) + return frames.permute(1, 0, 2, 3) diff --git a/gazenet/models/saliency_prediction/stavis/infer.py b/gazenet/models/saliency_prediction/stavis/infer.py new file mode 100644 index 0000000..ba38500 --- /dev/null +++ b/gazenet/models/saliency_prediction/stavis/infer.py @@ -0,0 +1,156 @@ + +import re +import os + +import torch +import numpy as np +import torch.backends.cudnn as cudnn + +from gazenet.utils.registrar import * +from gazenet.utils.sample_processors import InferenceSampleProcessor + +import gazenet.models.saliency_prediction.stavis.model as stavis_model +from gazenet.models.saliency_prediction.stavis.generator import load_video_frames, get_wav_features, normalize_data + +MODEL_PATHS = { + "stavis_audvis": os.path.join("gazenet", "models", "saliency_prediction", "stavis", "checkpoints", "pretrained_stavis_orig", "stavis_audiovisual", "audiovisual_split1_save_60.pth"), + "stavis_vis": os.path.join("gazenet", "models", "saliency_prediction", "stavis", "checkpoints", "pretrained_stavis_orig", "stavis_visual_only", "visual_split1_save_60.pth")} + +INP_IMG_WIDTH = 112 +INP_IMG_HEIGHT = 112 +INP_IMG_MEAN = (110.63666788 / 255.0, 103.16065604 / 255.0, 96.29023126 / 255.0) +INP_IMG_STD = (38.7568578 / 255.0, 37.88248729 / 255.0, 40.02898126 / 255.0) +# IMG_MEAN = [0,0,0] +# IMG_STD = [1,1,1] +# AUD_MEAN = [114.7748 / 255.0, 107.7354 / 255.0, 99.4750 / 255.0] +FRAMES_LEN = 16 + + +@InferenceRegistrar.register +class STAViSInference(InferenceSampleProcessor): + + def __init__(self, weights_file=None, w_size=16, audiovisual=False, + frames_len=FRAMES_LEN, + inp_img_width=INP_IMG_WIDTH, inp_img_height=INP_IMG_HEIGHT, inp_img_mean=INP_IMG_MEAN, inp_img_std=INP_IMG_STD, + device="cuda:0", width=None, height=None, **kwargs): + self.short_name = "stavis" + self._device = device + + self.frames_len = frames_len + self.inp_img_width = inp_img_width + self.inp_img_height = inp_img_height + self.inp_img_mean = inp_img_mean + self.inp_img_std = inp_img_std + + if weights_file is None: + if audiovisual: + weights_file = MODEL_PATHS['stavis_audvis'] + else: + weights_file = MODEL_PATHS['stavis_vis'] + + super().__init__(width=width, height=height, w_size=w_size, **kwargs) + + # load the model + self.model = stavis_model.resnet50(shortcut_type="B", sample_size=inp_img_width, sample_duration=frames_len, audiovisual=audiovisual) + self.model.load_state_dict(self._load_state_dict_(weights_file, device), strict=False) + print("STAViS model loaded from", weights_file) + self.model = self.model.to(device) + # cudnn.benchmarks = True + # TODO (fabawi): restore eval for proper batch normalization adjustment + # self.model.eval() + + @staticmethod + def _load_state_dict_(filepath, device): + if os.path.isfile(filepath): + # print("=> loading checkpoint '{}'".format(filepath)) + checkpoint = torch.load(filepath, map_location=torch.device(device)) + + pattern = re.compile(r'module+\.*') + state_dict = checkpoint['state_dict'] + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = re.sub('module.', '', key) + state_dict[new_key] = state_dict[key] + del state_dict[key] + return state_dict + + def infer_frame(self, grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, info_list, properties_list, + video_frames_list, hann_audio_frames, valid_audio_frames_len=None, source_frames_idxs=None, **kwargs): + if valid_audio_frames_len is None: + valid_audio_frames_len = self.frames_len + audio_data = hann_audio_frames.to(self._device) + audio_data = torch.unsqueeze(audio_data, 0) + frames_idxs = range(len(grouped_video_frames_list)) if source_frames_idxs is None else source_frames_idxs + for f_idx, frame_id in enumerate(frames_idxs): + info = {"frame_detections_" + self.short_name: { + "saliency_maps": [], "video_saliency_maps": [], "audio_saliency_maps": [], # detected + }} + video_frames_tensor = load_video_frames(video_frames_list[:frame_id+1], + frame_id+1, + valid_audio_frames_len, + img_width=self.inp_img_width, img_height=self.inp_img_height, + img_mean=self.inp_img_mean, img_std=self.inp_img_std, + frames_len=self.frames_len) + with torch.no_grad(): + video_frames = video_frames_tensor.to(self._device) + video_frames = torch.unsqueeze(video_frames, 0) + prediction = self.model(video_frames, audio_data) + + prediction_l = prediction["sal"][-1] + prediction_l = torch.sigmoid(prediction_l) + saliency = prediction_l.cpu().data.numpy() + saliency = np.squeeze(saliency) + saliency = normalize_data(saliency) + info["frame_detections_" + self.short_name]["saliency_maps"].append((saliency, -1)) + info_list[frame_id].update(**info) + + kept_data = self._keep_extracted_frames_data(source_frames_idxs, grabbed_video_list, grouped_video_frames_list, + grabbed_audio_list, audio_frames_list, info_list, properties_list) + return kept_data + + def preprocess_frames(self, video_frames_list=None, hann_audio_frames=None, **kwargs): + features = super().preprocess_frames(**kwargs) + pad = features["preproc_pad_len"] + lim = features["preproc_lim_len"] + if video_frames_list is not None: + video_frames_list = list(video_frames_list) + features["video_frames_list"] = video_frames_list[:lim] + [video_frames_list[lim]] * pad + if hann_audio_frames is not None: + features["hann_audio_frames"], features["valid_audio_frames_len"] = \ + get_wav_features(list(hann_audio_frames), self.frames_len, frames_len=self.frames_len) + return features + + def annotate_frame(self, input_data, plotter, + show_det_saliency_map=True, + enable_transform_overlays=True, + color_map=None, + **kwargs): + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties = input_data + + properties = {**properties, "show_det_saliency_map": (show_det_saliency_map, "toggle", (True, False))} + + grouped_video_frames = {**grouped_video_frames, + "PLOT": grouped_video_frames["PLOT"] + [["det_source_" + self.short_name, + "det_transformed_" + self.short_name]], + "det_source_" + self.short_name: grouped_video_frames["captured"], + "det_transformed_" + self.short_name: grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"])} + + for saliency_map_name, frame_name in zip(["saliency_maps"],[""]): + if grabbed_video: + if show_det_saliency_map: + saliency_map = info["frame_detections_" + self.short_name][saliency_map_name][0][0] + frame_transformed = plotter.plot_color_map(saliency_map, color_map=color_map) + if enable_transform_overlays: + frame_transformed = plotter.plot_alpha_overlay(grouped_video_frames["det_transformed_" + + frame_name + self.short_name], + frame_transformed, alpha=0.4) + else: + frame_transformed = plotter.resize(frame_transformed, + height=grouped_video_frames["det_transformed_" + frame_name + + self.short_name].shape[0], + width=grouped_video_frames["det_transformed_" + frame_name + + self.short_name].shape[1]) + grouped_video_frames["det_transformed_" + frame_name + self.short_name] = frame_transformed + return grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties diff --git a/gazenet/models/saliency_prediction/stavis/model.py b/gazenet/models/saliency_prediction/stavis/model.py new file mode 100644 index 0000000..d1d774b --- /dev/null +++ b/gazenet/models/saliency_prediction/stavis/model.py @@ -0,0 +1,308 @@ +from functools import partial + +import torch +import torch.nn as nn +from torch.nn import functional as F +import numpy as np + +from gazenet.models.shared_components.resnet3d import model as resnet3d +from gazenet.models.shared_components.soundnet8 import model as soundnet8 + +TRAINING = False + + +class AVModule(nn.Module): + + def __init__(self, rgb_nfilters, audio_nfilters, img_size, temp_size, hidden_layers): + + super(AVModule, self).__init__() + + self.rgb_nfilters = rgb_nfilters + self.audio_nfilters = audio_nfilters + self.hidden_layers = hidden_layers + self.out_layers = 64 + self.img_size = img_size + self.avgpool_rgb = nn.AvgPool3d((temp_size, 1, 1), stride=1) + # Make the layers numbers equal + self.relu = nn.ReLU() + self.affine_rgb = nn.Linear(rgb_nfilters, hidden_layers) + self.affine_audio = nn.Linear(audio_nfilters, hidden_layers) + self.w_a_rgb = nn.Bilinear(hidden_layers, hidden_layers, self.out_layers, bias=True) + self.upscale_ = nn.Upsample(scale_factor=8, mode='bilinear') + + + def forward(self, rgb, audio, crop_h, crop_w): + + self.crop_w = crop_w + self.crop_h = crop_h + dgb = rgb[:,:,rgb.shape[2]//2-1:rgb.shape[2]//2+1,:,:] + rgb = self.avgpool_rgb(dgb).squeeze(2) + rgb = rgb.permute(0, 2, 3, 1) + rgb = rgb.view(rgb.size(0), -1, self.rgb_nfilters) + rgb = self.affine_rgb(rgb) + rgb = self.relu(rgb) + audio1 = self.affine_audio(audio[0].squeeze(-1).squeeze(-1)) + audio1 = self.relu(audio1) + + a_rgb_B = self.w_a_rgb(rgb.contiguous(), audio1.unsqueeze(1).expand(-1, self.img_size[0] * self.img_size[1], -1).contiguous()) + sal_bilin = a_rgb_B + sal_bilin = sal_bilin.view(-1, self.img_size[0], self.img_size[1], self.out_layers) + sal_bilin = sal_bilin.permute(0, 3, 1, 2) + sal_bilin = center_crop(self.upscale_(sal_bilin), self.crop_h, self.crop_w) + + return sal_bilin + +def center_crop(x, height, width): + crop_h = torch.FloatTensor([x.size()[2]]).sub(height).div(-2) + crop_w = torch.FloatTensor([x.size()[3]]).sub(width).div(-2) + + # fixed indexing for PyTorch 0.4 + return F.pad(x, [int(crop_w.ceil()[0]), int(crop_w.floor()[0]), int(crop_h.ceil()[0]), int(crop_h.floor()[0])]) + + +class DSAMScoreDSN(nn.Module): + + def __init__(self, prev_layer, prev_nfilters, prev_nsamples): + + super(DSAMScoreDSN, self).__init__() + i = prev_layer + self.avgpool = nn.AvgPool3d((prev_nsamples, 1, 1), stride=1) + # Make the layers of the preparation step + self.side_prep = nn.Conv2d(prev_nfilters, 16, kernel_size=3, padding=1) + # Make the layers of the score_dsn step + self.score_dsn = nn.Conv2d(16, 1, kernel_size=1, padding=0) + self.upscale_ = nn.ConvTranspose2d(1, 1, kernel_size=2 ** (1 + i), stride=2 ** i, bias=False) + self.upscale = nn.ConvTranspose2d(16, 16, kernel_size=2 ** (1 + i), stride=2 ** i, bias=False) + + def forward(self, x, crop_h, crop_w): + + self.crop_h = crop_h + self.crop_w = crop_w + x = self.avgpool(x).squeeze(2) + side_temp = self.side_prep(x) + side = center_crop(self.upscale(side_temp), self.crop_h, self.crop_w) + side_out_tmp = self.score_dsn(side_temp) + side_out = center_crop(self.upscale_(side_out_tmp), self.crop_h, self.crop_w) + return side, side_out, side_out_tmp + + +def upsample_filt(size): + factor = (size + 1) // 2 + if size % 2 == 1: + center = factor - 1 + else: + center = factor - 0.5 + og = np.ogrid[:size, :size] + return (1 - abs(og[0] - center) / factor) * \ + (1 - abs(og[1] - center) / factor) + +def spatial_softmax(x): + x = torch.exp(x) + sum_batch = torch.sum(torch.sum(x, 2, keepdim=True), 3, keepdim=True) + x_soft = torch.div(x,sum_batch) + return x_soft + +# set parameters s.t. deconvolutional layers compute bilinear interpolation +# this is for deconvolution without groups +def interp_surgery(lay): + m, k, h, w = lay.weight.data.size() + if m != k: + print('input + output channels need to be the same') + raise ValueError + if h != w: + print('filters need to be square') + raise ValueError + filt = upsample_filt(h) + + for i in range(m): + lay.weight[i, i, :, :].data.copy_(torch.from_numpy(filt)) + + return lay.weight.data + + +class ResNet(nn.Module): + + def __init__(self, + block, + layers, + sample_size, + sample_duration, + shortcut_type='B', + audiovisual=True): + + self.audiovisual = audiovisual + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv3d( + 3, + 64, + kernel_size=7, + stride=(1, 2, 2), + padding=(3, 3, 3), + bias=False) + self.bn1 = nn.BatchNorm3d(64, momentum=0.1 if TRAINING else 0.0) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) + self.layer2 = self._make_layer( + block, 128, layers[1], shortcut_type, stride=2) + self.layer3 = self._make_layer( + block, 256, layers[2], shortcut_type, stride=2) + + score_dsn = nn.modules.ModuleList() + in_channels_dsn = [64, + 64 * block.expansion, + 128 * block.expansion, + 256 *block.expansion] + temp_size_prev = [sample_duration, + int(sample_duration / 2), + int(sample_duration / 4), + int(sample_duration /8)] + temp_img_size_prev = [int(sample_size / 2), + int(sample_size / 4), + int(sample_size / 8), + int(sample_size / 16)] + for i in range(1,5): + score_dsn.append(DSAMScoreDSN(i, in_channels_dsn[i-1], temp_size_prev[i-1])) + self.score_dsn = score_dsn + + self.fuse = nn.Conv2d(64, 1, kernel_size=1, padding=0) + + + if audiovisual: + self.fuseav = nn.Conv2d(128, 1, kernel_size=1, padding=0) + + self.soundnet8 = nn.Sequential( + soundnet8.SoundNet(momentum=0.1 if TRAINING else 0.0, reverse=True), + nn.MaxPool2d((1, 2))) + + self.fusion3 = AVModule(in_channels_dsn[2], + 1024, + [temp_img_size_prev[2], temp_img_size_prev[2]], + temp_size_prev[3], + 128) + + self.fuseav.bias.data = torch.tensor([-6.0]) + + for m in self.modules(): + if isinstance(m, nn.Conv3d): + m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') + elif isinstance(m, nn.BatchNorm3d): + m.weight.data.fill_(1) + m.bias.data.zero_() + if isinstance(m, nn.Conv2d): + m.weight.data.normal_(0, 0.001) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.ConvTranspose2d): + m.weight.data.zero_() + m.weight.data = interp_surgery(m) + if isinstance(m ,nn.Linear): + m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Bilinear): + m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + m.bias.data.zero_() + + self.fuse.bias.data = torch.tensor([-6.0]) + for i in range(0, 4): + self.score_dsn[i].score_dsn.bias.data = torch.tensor([-6.0]) + + + def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + if shortcut_type == 'A': + downsample = partial( + resnet3d.downsample_basic_block, + planes=planes * block.expansion, + stride=stride) + else: + downsample = nn.Sequential( + nn.Conv3d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm3d(planes * block.expansion, momentum=0.1 if TRAINING else 0.0)) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, training=TRAINING)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, training=TRAINING)) + + return nn.Sequential(*layers) + + def forward(self, x, aud): + + if self.audiovisual: + aud = self.soundnet8(aud) + aud = [aud] + + crop_h, crop_w = int(x.size()[-2]), int(x.size()[-1]) + side = [] + side_out = [] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + (tmp, tmp_, att_tmp) = self.score_dsn[0](x, crop_h, crop_w) + + att = spatial_softmax(att_tmp) + att = att.unsqueeze(1) + side.append(tmp) + side_out.append(tmp_) + x =torch.mul(1+att, x) + x = self.maxpool(x) + + x = self.layer1(x) + + (tmp, tmp_, att_tmp) = self.score_dsn[1](x, crop_h, crop_w) + + att = spatial_softmax(att_tmp) + att = att.unsqueeze(1) + side.append(tmp) + side_out.append(tmp_) + x = torch.mul(1+att, x) + x = self.layer2(x) + + if self.audiovisual: + y = self.fusion3(x, aud, crop_h, crop_w) + (tmp, tmp_, att_tmp) = self.score_dsn[2](x, crop_h, crop_w) + + att = spatial_softmax(att_tmp) + att = att.unsqueeze(1) + side.append(tmp) + side_out.append(tmp_) + x = torch.mul(1+att, x) + x = self.layer3(x) + + (tmp, tmp_, att_tmp) = self.score_dsn[3](x, crop_h, crop_w) + + att = spatial_softmax(att_tmp) + att = att.unsqueeze(1) + side.append(tmp) + side_out.append(tmp_) + + out = torch.cat(side[:], dim=1) + + if self.audiovisual: + appendy = torch.cat((out, y), dim=1) + x_out = self.fuseav(appendy) + side_out = [] + else: + x_out = self.fuse(out) + side_out.append(x_out) + + x_out = {'sal': side_out} + + return x_out + +def resnet50(**kwargs): + """Constructs a ResNet-50 model. + """ + model = ResNet(resnet3d.Bottleneck, [3, 4, 6, 3], **kwargs) + return model \ No newline at end of file diff --git a/gazenet/models/saliency_prediction/tased/__init__.py b/gazenet/models/saliency_prediction/tased/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/saliency_prediction/tased/checkpoints/pretrained_tased_orig/download_model.sh b/gazenet/models/saliency_prediction/tased/checkpoints/pretrained_tased_orig/download_model.sh new file mode 100644 index 0000000..f3dca6d --- /dev/null +++ b/gazenet/models/saliency_prediction/tased/checkpoints/pretrained_tased_orig/download_model.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +wget --load-cookies /tmp/cookies.txt "https://drive.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://drive.google.com/uc?export=download&id=1pn_ioHdeUzBcX7FBTP8S0f_Ebxpp1Hlf' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1pn_ioHdeUzBcX7FBTP8S0f_Ebxpp1Hlf" -O model.pt && rm -rf /tmp/cookies.txt \ No newline at end of file diff --git a/gazenet/models/saliency_prediction/tased/generator.py b/gazenet/models/saliency_prediction/tased/generator.py new file mode 100644 index 0000000..ab477a1 --- /dev/null +++ b/gazenet/models/saliency_prediction/tased/generator.py @@ -0,0 +1,33 @@ +import torch +import torchvision.transforms.functional as F +import numpy as np +import cv2 + + +def load_video_frames(frames_list, last_frame_idx, img_width, img_height, frames_len=16): + # load video video_frames_list, process them and return a suitable tensor + frame_number = last_frame_idx + start_frame_number = frame_number - frames_len + 1 + start_frame_number = max(0, start_frame_number) + frames_list_idx = [f for f in range(start_frame_number, frame_number)] + if len(frames_list_idx) < frames_len: + nsh = frames_len - len(frames_list_idx) + frames_list_idx = np.concatenate((np.tile(frames_list_idx[0], (nsh)), frames_list_idx), axis=0) + frames = [] + for i in range(len(frames_list_idx)): + # TODO (fabawi): loading the first frame on failure is not ideal. Find a better way + try: + img = cv2.resize(frames_list[frames_list_idx[i]].copy(), (img_width, img_height)) + except: + try: + img = cv2.resize(frames_list[frames_list_idx[0]].copy(), (img_width, img_height)) + except: + img = np.zeros((img_height, img_width, 3), dtype=np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose((2, 0, 1))).float().mul_(2.).sub_(255).div(255) + # img = F.to_tensor(img) + frames.append(img) + frames = torch.stack(frames, dim=0) + # frames = frames.mul_(2.).sub_(255).div(255) + return frames.permute(1, 0, 2, 3) + diff --git a/gazenet/models/saliency_prediction/tased/infer.py b/gazenet/models/saliency_prediction/tased/infer.py new file mode 100644 index 0000000..a364c8c --- /dev/null +++ b/gazenet/models/saliency_prediction/tased/infer.py @@ -0,0 +1,136 @@ +import re +import sys +import os + +import numpy as np +import cv2 +import torch +import torch.backends.cudnn as cudnn +from scipy.ndimage.filters import gaussian_filter + +from gazenet.utils.registrar import * +from gazenet.models.saliency_prediction.tased.generator import load_video_frames +from gazenet.models.saliency_prediction.tased.model import TASED_v2 +from gazenet.utils.sample_processors import InferenceSampleProcessor + + +MODEL_PATHS = { + "tased": os.path.join("gazenet", "models", "saliency_prediction", "tased", "checkpoints", "pretrained_tased_orig", "model.pt")} + +INP_IMG_WIDTH = 384 +INP_IMG_HEIGHT = 224 +FRAMES_LEN = 32 + + +@InferenceRegistrar.register +class TASEDInference(InferenceSampleProcessor): + + def __init__(self, weights_file=MODEL_PATHS['tased'], w_size=32, + frames_len=FRAMES_LEN, inp_img_width=INP_IMG_WIDTH, inp_img_height=INP_IMG_HEIGHT, + device="cuda:0", width=None, height=None, **kwargs): + super().__init__(width=width, height=height, w_size=w_size, **kwargs) + self.short_name = "tased" + self._device = device + + self.frames_len = frames_len + self.inp_img_width = inp_img_width + self.inp_img_height = inp_img_height + + # load the model + self.model = TASED_v2() + self.model = self._load_model_(weights_file, self.model) + print("TASED model loaded from", weights_file) + self.model = self.model.to(device) + cudnn.benchmark = False + self.model.eval() + + @staticmethod + def _load_model_(filepath, model): + if os.path.isfile(filepath): + weight_dict = torch.load(filepath) + model_dict = model.state_dict() + for name, param in weight_dict.items(): + if 'module' in name: + name = '.'.join(name.split('.')[1:]) + if name in model_dict: + if param.size() == model_dict[name].size(): + model_dict[name].copy_(param) + else: + print(' size? ' + name, param.size(), model_dict[name].size()) + else: + print(' name? ' + name) + + return model + + def infer_frame(self, grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, info_list, properties_list, + video_frames_list, source_frames_idxs=None, **kwargs): + + frames_idxs = range(len(grouped_video_frames_list)) if source_frames_idxs is None else source_frames_idxs + for f_idx, frame_id in enumerate(frames_idxs): + info = {"frame_detections_" + self.short_name: { + "saliency_maps": [], # detected + }} + video_frames_tensor = load_video_frames(video_frames_list[:frame_id+1], + frame_id+1, + img_width=self.inp_img_width, img_height=self.inp_img_height, + frames_len=self.frames_len) + video_frames = video_frames_tensor.to(self._device) + video_frames = torch.unsqueeze(video_frames, 0) + with torch.no_grad(): + final_prediction = self.model(video_frames) + # get the visual feature maps + for prediction, prediction_name in zip([final_prediction],["saliency_maps"]): + saliency = prediction.cpu().data[0].numpy() + # saliency = (saliency*255.).astype(np.int)/255. + saliency = gaussian_filter(saliency, sigma=7) + saliency = saliency/np.max(saliency) + info["frame_detections_" + self.short_name][prediction_name].append((saliency, -1)) + info_list[frame_id].update(**info) + + kept_data = self._keep_extracted_frames_data(source_frames_idxs, grabbed_video_list, grouped_video_frames_list, + grabbed_audio_list, audio_frames_list, info_list, properties_list) + return kept_data + + + def preprocess_frames(self, video_frames_list=None, **kwargs): + features = super().preprocess_frames(**kwargs) + pad = features["preproc_pad_len"] + lim = features["preproc_lim_len"] + if video_frames_list is not None: + video_frames_list = list(video_frames_list) + features["video_frames_list"] = video_frames_list[:lim] + [video_frames_list[lim]] * pad + return features + + def annotate_frame(self, input_data, plotter, + show_det_saliency_map=True, + enable_transform_overlays=True, + color_map=None, + **kwargs): + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties = input_data + + properties = {**properties, "show_det_saliency_map": (show_det_saliency_map, "toggle", (True, False))} + + grouped_video_frames = {**grouped_video_frames, + "PLOT": grouped_video_frames["PLOT"] + [["det_source_" + self.short_name, + "det_transformed_" + self.short_name]], + "det_source_" + self.short_name: grouped_video_frames["captured"], + "det_transformed_" + self.short_name: grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"])} + + for saliency_map_name, frame_name in zip(["saliency_maps"],[""]): + if grabbed_video: + if show_det_saliency_map: + saliency_map = info["frame_detections_" + self.short_name][saliency_map_name][0][0] + frame_transformed = plotter.plot_color_map(np.uint8(255 * saliency_map), color_map=color_map) + if enable_transform_overlays: + frame_transformed = plotter.plot_alpha_overlay(grouped_video_frames["det_transformed_" + + frame_name + self.short_name], + frame_transformed, alpha=0.4) + else: + frame_transformed = plotter.resize(frame_transformed, + height=grouped_video_frames["det_transformed_" + frame_name + + self.short_name].shape[0], + width=grouped_video_frames["det_transformed_" + frame_name + + self.short_name].shape[1]) + grouped_video_frames["det_transformed_" + frame_name + self.short_name] = frame_transformed + return grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties diff --git a/gazenet/models/saliency_prediction/tased/model.py b/gazenet/models/saliency_prediction/tased/model.py new file mode 100644 index 0000000..bfaec9b --- /dev/null +++ b/gazenet/models/saliency_prediction/tased/model.py @@ -0,0 +1,113 @@ +""" +code from: https://raw.githubusercontent.com/MichiganCOG/TASED-Net/master/model.py +""" + +import torch +from torch import nn + +from gazenet.models.shared_components.conv3d import model as conv3d + + +class TASED_v2(nn.Module): + def __init__(self): + super(TASED_v2, self).__init__() + self.base1 = nn.Sequential( + conv3d.SepConv3d(3, 64, kernel_size=7, stride=2, padding=3), + nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)), + conv3d.BasicConv3d(64, 64, kernel_size=1, stride=1), + conv3d.SepConv3d(64, 192, kernel_size=3, stride=1, padding=1), + ) + self.maxp2 = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)) + self.maxm2 = nn.MaxPool3d(kernel_size=(4,1,1), stride=(4,1,1), padding=(0,0,0)) + self.maxt2 = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1), return_indices=True) + self.base2 = nn.Sequential( + conv3d.Mixed_3b(), + conv3d.Mixed_3c(), + ) + self.maxp3 = nn.MaxPool3d(kernel_size=(3,3,3), stride=(2,2,2), padding=(1,1,1)) + self.maxm3 = nn.MaxPool3d(kernel_size=(4,1,1), stride=(4,1,1), padding=(0,0,0)) + self.maxt3 = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1), return_indices=True) + self.base3 = nn.Sequential( + conv3d.Mixed_4b(), + conv3d.Mixed_4c(), + conv3d.Mixed_4d(), + conv3d.Mixed_4e(), + conv3d.Mixed_4f(), + ) + self.maxt4 = nn.MaxPool3d(kernel_size=(2,1,1), stride=(2,1,1), padding=(0,0,0)) + self.maxp4 = nn.MaxPool3d(kernel_size=(1,2,2), stride=(1,2,2), padding=(0,0,0), return_indices=True) + self.base4 = nn.Sequential( + conv3d.Mixed_5b(), + conv3d.Mixed_5c(), + ) + self.convtsp1 = nn.Sequential( + nn.Conv3d(1024, 1024, kernel_size=1, stride=1, bias=False), + nn.BatchNorm3d(1024, eps=1e-3, momentum=0.001, affine=True), + nn.ReLU(), + + nn.ConvTranspose3d(1024, 832, kernel_size=(1,3,3), stride=1, padding=(0,1,1), bias=False), + nn.BatchNorm3d(832, eps=1e-3, momentum=0.001, affine=True), + nn.ReLU(), + ) + self.unpool1 = nn.MaxUnpool3d(kernel_size=(1,2,2), stride=(1,2,2), padding=(0,0,0)) + self.convtsp2 = nn.Sequential( + nn.ConvTranspose3d(832, 480, kernel_size=(1,3,3), stride=1, padding=(0,1,1), bias=False), + nn.BatchNorm3d(480, eps=1e-3, momentum=0.001, affine=True), + nn.ReLU(), + ) + self.unpool2 = nn.MaxUnpool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)) + self.convtsp3 = nn.Sequential( + nn.ConvTranspose3d(480, 192, kernel_size=(1,3,3), stride=1, padding=(0,1,1), bias=False), + nn.BatchNorm3d(192, eps=1e-3, momentum=0.001, affine=True), + nn.ReLU(), + ) + self.unpool3 = nn.MaxUnpool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)) + self.convtsp4 = nn.Sequential( + nn.ConvTranspose3d(192, 64, kernel_size=(1,4,4), stride=(1,2,2), padding=(0,1,1), bias=False), + nn.BatchNorm3d(64, eps=1e-3, momentum=0.001, affine=True), + nn.ReLU(), + + nn.Conv3d(64, 64, kernel_size=(2,1,1), stride=(2,1,1), bias=False), + nn.BatchNorm3d(64, eps=1e-3, momentum=0.001, affine=True), + nn.ReLU(), + + nn.ConvTranspose3d(64, 4, kernel_size=1, stride=1, bias=False), + nn.BatchNorm3d(4, eps=1e-3, momentum=0.001, affine=True), + nn.ReLU(), + + nn.Conv3d(4, 4, kernel_size=(2,1,1), stride=(2,1,1), bias=False), + nn.BatchNorm3d(4, eps=1e-3, momentum=0.001, affine=True), + nn.ReLU(), + + nn.ConvTranspose3d(4, 4, kernel_size=(1,4,4), stride=(1,2,2), padding=(0,1,1), bias=False), + nn.Conv3d(4, 1, kernel_size=1, stride=1, bias=True), + nn.Sigmoid(), + ) + + def forward(self, x): + y3 = self.base1(x) + y = self.maxp2(y3) + y3 = self.maxm2(y3) + _, i2 = self.maxt2(y3) + y2 = self.base2(y) + y = self.maxp3(y2) + y2 = self.maxm3(y2) + _, i1 = self.maxt3(y2) + y1 = self.base3(y) + y = self.maxt4(y1) + y, i0 = self.maxp4(y) + y0 = self.base4(y) + + z = self.convtsp1(y0) + z = self.unpool1(z, i0) + z = self.convtsp2(z) + z = self.unpool2(z, i1, y2.size()) + z = self.convtsp3(z) + z = self.unpool3(z, i2, y3.size()) + z = self.convtsp4(z) + z = z.view(z.size(0), z.size(3), z.size(4)) + + return z + + + diff --git a/gazenet/models/saliency_prediction/unisal/__init__.py b/gazenet/models/saliency_prediction/unisal/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/saliency_prediction/unisal/checkpoints/pretrained_unisal_orig/download_model.sh b/gazenet/models/saliency_prediction/unisal/checkpoints/pretrained_unisal_orig/download_model.sh new file mode 100644 index 0000000..277560c --- /dev/null +++ b/gazenet/models/saliency_prediction/unisal/checkpoints/pretrained_unisal_orig/download_model.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +wget -O model.pth https://github.com/rdroste/unisal/raw/master/training_runs/pretrained_unisal/weights_best.pth \ No newline at end of file diff --git a/gazenet/models/saliency_prediction/unisal/generator.py b/gazenet/models/saliency_prediction/unisal/generator.py new file mode 100644 index 0000000..d8c6810 --- /dev/null +++ b/gazenet/models/saliency_prediction/unisal/generator.py @@ -0,0 +1,49 @@ + +import torch +import torchvision.transforms.functional as F +import numpy as np +import cv2 + + +def load_video_frames(frames_list, last_frame_idx, img_mean, img_std, img_width, img_height, frames_len=16): + # load video video_frames_list, process them and return a suitable tensor + frame_number = last_frame_idx + start_frame_number = frame_number - frames_len + 1 + start_frame_number = max(0, start_frame_number) + frames_list_idx = [f for f in range(start_frame_number, frame_number)] + if len(frames_list_idx) < frames_len: + nsh = frames_len - len(frames_list_idx) + frames_list_idx = np.concatenate((np.tile(frames_list_idx[0], (nsh)), frames_list_idx), axis=0) + frames = [] + for i in range(len(frames_list_idx)): + # TODO (fabawi): loading the first frame on failure is not ideal. Find a better way + try: + img = cv2.resize(frames_list[frames_list_idx[i]].copy(), (img_width, img_height)) + except: + try: + img = cv2.resize(frames_list[frames_list_idx[0]].copy(), (img_width, img_height)) + except: + img = np.zeros((img_height, img_width, 3), dtype=np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = F.to_tensor(img) + img = F.normalize(img, img_mean, img_std) + frames.append(img) + frames = torch.stack(frames, dim=0) + return frames # .permute(1, 0, 2, 3) + + +def smooth_sequence(seq, method): + shape = seq.shape + + seq = seq.reshape(shape[1], np.prod(shape[-2:])) + if method[:3] == 'med': + kernel_size = int(method[3:]) + ks2 = kernel_size // 2 + smoothed = np.zeros_like(seq) + for idx in range(seq.shape[0]): + smoothed[idx, :] = np.median(seq[max(0, idx - ks2):min(seq.shape[0], idx + ks2 + 1), :], axis=0) + seq = smoothed.reshape(shape) + else: + raise NotImplementedError + + return seq \ No newline at end of file diff --git a/gazenet/models/saliency_prediction/unisal/infer.py b/gazenet/models/saliency_prediction/unisal/infer.py new file mode 100644 index 0000000..84fac8e --- /dev/null +++ b/gazenet/models/saliency_prediction/unisal/infer.py @@ -0,0 +1,144 @@ +import random +import cv2 +import re +import os + +import torch +import numpy as np + +from gazenet.utils.registrar import * +from gazenet.models.saliency_prediction.unisal.generator import load_video_frames +from gazenet.models.saliency_prediction.unisal.model import UNISAL +from gazenet.utils.sample_processors import InferenceSampleProcessor + +from gazenet.models.saliency_prediction.unisal.generator import smooth_sequence + + +MODEL_PATHS = { + "unisal": os.path.join("gazenet", "models", "saliency_prediction", "unisal", "checkpoints", "pretrained_unisal_orig", "model.pth")} + +INP_IMG_WIDTH = 384 +INP_IMG_HEIGHT = 288 +TRG_IMG_WIDTH = 640 +TRG_IMG_HEIGHT = 480 +INP_IMG_MEAN = (0.485, 0.456, 0.406) +INP_IMG_STD = (0.229, 0.224, 0.225) +FRAMES_LEN = 12 + + +@InferenceRegistrar.register +class UNISALInference(InferenceSampleProcessor): + + def __init__(self, weights_file=MODEL_PATHS['unisal'], w_size=12, + frames_len=FRAMES_LEN, trg_img_width=TRG_IMG_WIDTH, trg_img_height=TRG_IMG_HEIGHT, + inp_img_width=INP_IMG_WIDTH, inp_img_height=INP_IMG_HEIGHT, inp_img_mean=INP_IMG_MEAN, inp_img_std=INP_IMG_STD, + device="cuda:0", width=None, height=None, **kwargs): + super().__init__(width=width, height=height, w_size=w_size, **kwargs) + self.short_name = "unisal" + self._device = device + + self.frames_len = frames_len + self.trg_img_width = trg_img_width + self.trg_img_height = trg_img_height + self.inp_img_width = inp_img_width + self.inp_img_height = inp_img_height + self.inp_img_mean = inp_img_mean + self.inp_img_std = inp_img_std + + # load the model + self.model = UNISAL() + self.model.load_state_dict(torch.load(weights_file, map_location=torch.device(device))) + print("UNISAL model loaded from", weights_file) + self.model = self.model.to(device) + self.model.eval() + + def infer_frame(self, grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, + info_list, properties_list, + video_frames_list, source_frames_idxs=None, smooth_method="med41", **kwargs): + + frames_idxs = range(len(grouped_video_frames_list)) if source_frames_idxs is None else source_frames_idxs + h0 = [None] + model_kwargs = { + 'source': ("DHF1K"), #"eval", + 'target_size': (self.trg_img_height, self.trg_img_width)} + + for f_idx, frame_id in enumerate(frames_idxs): + info = {"frame_detections_" + self.short_name: { + "saliency_maps": [], # detected + }} + video_frames_tensor = load_video_frames(video_frames_list[:frame_id + 1], + frame_id + 1, + img_width=self.inp_img_width, img_height=self.inp_img_height, + img_mean=self.inp_img_mean, img_std=self.inp_img_std, + frames_len=self.frames_len) + video_frames = video_frames_tensor.to(self._device) + video_frames = torch.unsqueeze(video_frames, 0) + with torch.no_grad(): + final_prediction, _ = self.model( # final_prediction, h0 = self.model( + video_frames, h0=h0, return_hidden=True, + **model_kwargs) + + + # get the visual feature maps + for prediction, prediction_name in zip([final_prediction], ["saliency_maps"]): + saliency = prediction.cpu() + if smooth_method is not None: + saliency = saliency.numpy() + saliency = smooth_sequence(saliency, smooth_method) + saliency = torch.from_numpy(saliency).float() + + # for _, smap in enumerate(torch.unbind(saliency, dim=1)): + smap = saliency[:, frame_id, ::] + smap = smap.exp() + smap = torch.squeeze(smap) + smap = smap.data.cpu().numpy() + smap = smap / np.amax(smap) + info["frame_detections_" + self.short_name][prediction_name].append((smap, -1)) + info_list[frame_id].update(**info) + + kept_data = self._keep_extracted_frames_data(source_frames_idxs, grabbed_video_list, grouped_video_frames_list, + grabbed_audio_list, audio_frames_list, info_list, properties_list) + return kept_data + + def preprocess_frames(self, video_frames_list=None, **kwargs): + features = super().preprocess_frames(**kwargs) + pad = features["preproc_pad_len"] + lim = features["preproc_lim_len"] + if video_frames_list is not None: + video_frames_list = list(video_frames_list) + features["video_frames_list"] = video_frames_list[:lim] + [video_frames_list[lim]] * pad + return features + + def annotate_frame(self, input_data, plotter, + show_det_saliency_map=True, + enable_transform_overlays=True, + color_map=None, + **kwargs): + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties = input_data + + properties = {**properties, "show_det_saliency_map": (show_det_saliency_map, "toggle", (True, False))} + + grouped_video_frames = {**grouped_video_frames, + "PLOT": grouped_video_frames["PLOT"] + [["det_source_" + self.short_name, + "det_transformed_" + self.short_name]], + "det_source_" + self.short_name: grouped_video_frames["captured"], + "det_transformed_" + self.short_name: grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"])} + + for saliency_map_name, frame_name in zip(["saliency_maps"], [""]): + if grabbed_video: + if show_det_saliency_map: + saliency_map = info["frame_detections_" + self.short_name][saliency_map_name][0][0] + frame_transformed = plotter.plot_color_map(np.uint8(255 * saliency_map), color_map=color_map) + if enable_transform_overlays: + frame_transformed = plotter.plot_alpha_overlay(grouped_video_frames["det_transformed_" + + frame_name + self.short_name], + frame_transformed, alpha=0.4) + else: + frame_transformed = plotter.resize(frame_transformed, + height=grouped_video_frames["det_transformed_" + frame_name + + self.short_name].shape[0], + width=grouped_video_frames["det_transformed_" + frame_name + + self.short_name].shape[1]) + grouped_video_frames["det_transformed_" + frame_name + self.short_name] = frame_transformed + return grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties diff --git a/gazenet/models/saliency_prediction/unisal/model.py b/gazenet/models/saliency_prediction/unisal/model.py new file mode 100644 index 0000000..6f79c88 --- /dev/null +++ b/gazenet/models/saliency_prediction/unisal/model.py @@ -0,0 +1,509 @@ +from collections import OrderedDict +import pprint +from functools import partial +from itertools import product + +import torch +from torch import nn +import torch.nn.functional as F + +from gazenet.models.shared_components.convgru.model import ConvGRU +from gazenet.models.shared_components.mobilenetv2.model import MobileNetV2, InvertedResidual + + +def get_model(): + """Return the model class""" + return UNISAL + + +class BaseModel(nn.Module): + """Abstract model class with functionality to save and load weights""" + + def forward(self, *input): + raise NotImplementedError + + def save_weights(self, directory, name): + torch.save(self.state_dict(), directory / f'weights_{name}.pth') + + def load_weights(self, directory, name): + self.load_state_dict(torch.load(directory / f'weights_{name}.pth')) + + def load_best_weights(self, directory): + self.load_state_dict(torch.load(directory / f'weights_best.pth')) + + def load_epoch_checkpoint(self, directory, epoch): + """Load state_dict from a Trainer checkpoint at a specific epoch""" + chkpnt = torch.load(directory / f"chkpnt_epoch{epoch:04d}.pth") + self.load_state_dict(chkpnt['model_state_dict']) + + def load_checkpoint(self, file): + """Load state_dict from a specific Trainer checkpoint""" + """Load """ + chkpnt = torch.load(file) + self.load_state_dict(chkpnt['model_state_dict']) + + def load_last_chkpnt(self, directory): + """Load state_dict from the last Trainer checkpoint""" + last_chkpnt = sorted(list(directory.glob('chkpnt_epoch*.pth')))[-1] + self.load_checkpoint(last_chkpnt) + + +# Set default backbone CNN kwargs +default_cnn_cfg = { + 'widen_factor': 1., 'pretrained': True, 'input_channel': 32, + 'last_channel': 1280} + +# Set default RNN kwargs +default_rnn_cfg = { + 'kernel_size': (3, 3), 'gate_ksize': (3, 3), + 'dropout': (False, True, False), 'drop_prob': (0.2, 0.2, 0.2), + 'mobile': True, +} + + +class DomainBatchNorm2d(nn.Module): + """ + Domain-specific 2D BatchNorm module. + Stores a BN module for a given list of sources. + During the forward pass, select the BN module based on self.this_source. + """ + + def __init__(self, num_features, sources, momenta=None, **kwargs): + """ + num_features: Number of channels + sources: List of sources + momenta: List of BatchNorm momenta corresponding to the sources. + Default is 0.1 for each source. + kwargs: Other BatchNorm kwargs + """ + super().__init__() + self.sources = sources + + # Process momenta input + if momenta is None: + momenta = [0.1] * len(sources) + self.momenta = momenta + if 'momentum' in kwargs: + del kwargs['momentum'] + + # Instantiate the BN modules + for src, mnt in zip(sources, self.momenta): + self.__setattr__(f"bn_{src}", nn.BatchNorm2d( + num_features, momentum=mnt, **kwargs)) + + # Prepare the self.this_source attribute that will be updated at runtime + # by the model + self.this_source = None + + def forward(self, x): + return self.__getattr__(f"bn_{self.this_source}")(x) + + +class UNISAL(BaseModel): + """ + UNISAL model. See paper for more information. + Arguments: + rnn_input_channels: Number of channels of the RNN input. + rnn_hidden_channels: Number of channels of the RNN hidden state. + cnn_cfg: Dictionary with kwargs for the backbone CNN. + rnn_cfg: Dictionary with kwargs for the RNN. + res_rnn: Whether to add the RNN features with a residual connection. + bypass_rnn: Whether to bypass the RNN for static inputs. + Requires res_rnn. + drop_probs: Dropout probabilities for + [backbone CNN outputs, Skip-2x and Skip-4x]. + gaussian_init: Method to initialize the learned Gaussian parameters. + If "manual", 16 pre-defined Gaussians are initialized. + n_gaussians: Number of Gaussians if gaussian_init is "random". + smoothing_ksize: Size of the Smoothing kernel. + bn_momentum: Momentum of the BatchNorm running estimates for dynamic + batches. + static_bn_momentum: Momentum of the BatchNorm running estimates for + static batches. + sources: List of datasets. + ds_bn: Domain-specific BatchNorm (DSBN). + ds_adaptation: Domain-specific Adaptation. + ds_smoothing: Domain-specific Smoothing. + ds_gaussians: Domain-specific Gaussian prior maps. + verbose: Verbosity level. + """ + + def __init__(self, + rnn_input_channels=256, rnn_hidden_channels=256, + cnn_cfg=None, + rnn_cfg=None, + res_rnn=True, + bypass_rnn=True, + drop_probs=(0.0, 0.6, 0.6), + gaussian_init='manual', + n_gaussians=16, + smoothing_ksize=41, + bn_momentum=0.01, + static_bn_momentum=0.1, + sources=('DHF1K', 'Hollywood', 'UCFSports', 'SALICON'), + ds_bn=True, + ds_adaptation=True, + ds_smoothing=True, + ds_gaussians=True, + verbose=1, + ): + super().__init__() + + # Check inputs + assert(gaussian_init in ('random', 'manual')) + # Bypass-RNN requires residual RNN connection + if bypass_rnn: + assert res_rnn + + # Manual Gaussian initialization generates 16 Gaussians + if n_gaussians > 0 and gaussian_init == 'manual': + n_gaussians = 16 + + self.rnn_input_channels = rnn_input_channels + self.rnn_hidden_channels = rnn_hidden_channels + this_cnn_cfg = default_cnn_cfg.copy() + this_cnn_cfg.update(cnn_cfg or {}) + self.cnn_cfg = this_cnn_cfg + this_rnn_cfg = default_rnn_cfg.copy() + this_rnn_cfg.update(rnn_cfg or {}) + self.rnn_cfg = this_rnn_cfg + self.bypass_rnn = bypass_rnn + self.res_rnn = res_rnn + self.drop_probs = drop_probs + self.gaussian_init = gaussian_init + self.n_gaussians = n_gaussians + self.smoothing_ksize = smoothing_ksize + self.bn_momentum = bn_momentum + self.sources = sources + self.ds_bn = ds_bn + self.static_bn_momentum = static_bn_momentum + self.ds_adaptation = ds_adaptation + self.ds_smoothing = ds_smoothing + self.ds_gaussians = ds_gaussians + self.verbose = verbose + + # Initialize backbone CNN + self.cnn = MobileNetV2(**self.cnn_cfg) + + # Initialize Post-CNN module with optional dropout + post_cnn = [ + ('inv_res', InvertedResidual( + self.cnn.out_channels + n_gaussians, + rnn_input_channels, 1, 1, bn_momentum=bn_momentum, + )) + ] + if self.drop_probs[0] > 0: + post_cnn.insert(0, ( + 'dropout', nn.Dropout2d(self.drop_probs[0], inplace=False) + )) + self.post_cnn = nn.Sequential(OrderedDict(post_cnn)) + + # Initialize Bypass-RNN if training on dynamic data + if sources != ('SALICON',) or not self.bypass_rnn: + self.rnn = ConvGRU( + rnn_input_channels, + hidden_channels=[rnn_hidden_channels], + batchnorm=self.get_bn_module, + **self.rnn_cfg) + self.post_rnn = self.conv_1x1_bn( + rnn_hidden_channels, rnn_input_channels) + + # Initialize first upsampling module US1 + self.upsampling_1 = nn.Sequential(OrderedDict([ + ('us1', self.upsampling(2)), + ])) + + # Number of channels at the 2x scale + channels_2x = 128 + + # Initialize Skip-2x module + self.skip_2x = self.make_skip_connection( + self.cnn.feat_2x_channels, channels_2x, 2, self.drop_probs[1]) + + # Initialize second upsampling module US2 + self.upsampling_2 = nn.Sequential(OrderedDict([ + ('inv_res', InvertedResidual( + rnn_input_channels + channels_2x, + channels_2x, 1, 2, batchnorm=self.get_bn_module)), + ('us2', self.upsampling(2)), + ])) + + # Number of channels at the 4x scale + channels_4x = 64 + + # Initialize Skip-4x module + self.skip_4x = self.make_skip_connection( + self.cnn.feat_4x_channels, channels_4x, 2, self.drop_probs[2]) + + # Initialize Post-US2 module + self.post_upsampling_2= nn.Sequential(OrderedDict([ + ('inv_res', InvertedResidual( + channels_2x + channels_4x, channels_4x, 1, 2, + batchnorm=self.get_bn_module)), + ])) + + # Initialize domain-specific modules + for source_str in self.sources: + source_str = f'_{source_str}'.lower() + + # Initialize learned Gaussian priors parameters + if n_gaussians > 0: + self.set_gaussians(source_str) + + # Initialize Adaptation + self.__setattr__( + 'adaptation' + (source_str if self.ds_adaptation else ''), + nn.Sequential(*[ + nn.Conv2d(channels_4x, 1, 1, bias=True) + ])) + + # Initialize Smoothing + smoothing = nn.Conv2d( + 1, 1, kernel_size=smoothing_ksize, padding=0, bias=False) + with torch.no_grad(): + gaussian = self._make_gaussian_maps( + smoothing.weight.data, + torch.Tensor([[[0.5, -2]] * 2]) + ) + gaussian /= gaussian.sum() + smoothing.weight.data = gaussian + self.__setattr__( + 'smoothing' + (source_str if self.ds_smoothing else ''), + smoothing) + + if self.verbose > 1: + pprint.pprint(self.asdict(), width=1) + + @property + def this_source(self): + """Return current source for domain-specific BatchNorm.""" + return self._this_source + + @this_source.setter + def this_source(self, source): + """Set current source for domain-specific BatchNorm.""" + for module in self.modules(): + if isinstance(module, DomainBatchNorm2d): + module.this_source = source + self._this_source = source + + def get_bn_module(self, num_features, **kwargs): + """Return BatchNorm class (domain-specific or domain-invariant).""" + momenta = [self.bn_momentum if src != 'SALICON' + else self.static_bn_momentum for src in self.sources] + if self.ds_bn: + return DomainBatchNorm2d( + num_features, self.sources, momenta=momenta, **kwargs) + else: + return nn.BatchNorm2d(num_features, **kwargs) + + # @staticmethod + def upsampling(self, factor): + """Return upsampling module.""" + return nn.Sequential(*[ + nn.Upsample( + scale_factor=factor, mode='bilinear', align_corners=False), + ]) + + def set_gaussians(self, source_str, prefix='coarse_'): + """Set Gaussian parameters.""" + suffix = source_str if self.ds_gaussians else '' + self.__setattr__( + prefix + 'gaussians' + suffix, + self._initialize_gaussians(self.n_gaussians)) + + def _initialize_gaussians(self, n_gaussians): + """ + Return initialized Gaussian parameters. + Dimensions: [idx, y/x, mu/logstd]. + """ + if self.gaussian_init == 'manual': + gaussians = torch.Tensor([ + list(product([0.25, 0.5, 0.75], repeat=2)) + + [(0.5, 0.25), (0.5, 0.5), (0.5, 0.75)] + + [(0.25, 0.5), (0.5, 0.5), (0.75, 0.5)] + + [(0.5, 0.5)], + [(-1.5, -1.5)] * 9 + [(0, -1.5)] * 3 + [(-1.5, 0)] * 3 + + [(0, 0)], + ]).permute(1, 2, 0) + + elif self.gaussian_init == 'random': + with torch.no_grad(): + gaussians = torch.stack([ + torch.randn( + n_gaussians, 2, dtype=torch.float) * .1 + 0.5, + torch.randn( + n_gaussians, 2, dtype=torch.float) * .2 - 1,], + dim=2) + + else: + raise NotImplementedError + + gaussians = nn.Parameter(gaussians, requires_grad=True) + return gaussians + + @staticmethod + def _make_gaussian_maps(x, gaussians, size=None, scaling=6.): + """Construct prior maps from Gaussian parameters.""" + if size is None: + size = x.shape[-2:] + bs = x.shape[0] + else: + size = [size] * 2 + bs = 1 + dtype = x.dtype + device = x.device + + gaussian_maps = [] + map_template = torch.ones(*size, dtype=dtype, device=device) + meshgrids = torch.meshgrid( + [torch.linspace(0, 1, size[0], dtype=dtype, device=device), + torch.linspace(0, 1, size[1], dtype=dtype, device=device),]) + + for gaussian_idx, yx_mu_logstd in enumerate(torch.unbind(gaussians)): + map = map_template.clone() + for mu_logstd, mgrid in zip(yx_mu_logstd, meshgrids): + mu = mu_logstd[0] + std = torch.exp(mu_logstd[1]) + map *= torch.exp(-((mgrid - mu) / std) ** 2 / 2) + + map *= scaling + gaussian_maps.append(map) + + gaussian_maps = torch.stack(gaussian_maps) + gaussian_maps = gaussian_maps.unsqueeze(0).expand(bs, -1, -1, -1) + return gaussian_maps + + def _get_gaussian_maps(self, x, source_str, prefix='coarse_', **kwargs): + """Return the constructed Gaussian prior maps.""" + suffix = source_str if self.ds_gaussians else '' + gaussians = self.__getattr__(prefix + "gaussians" + suffix) + gaussian_maps = self._make_gaussian_maps(x, gaussians, **kwargs) + return gaussian_maps + + # @classmethod + def make_skip_connection(self, input_channels, output_channels, expand_ratio, p, + inplace=False): + """Return skip connection module.""" + hidden_channels = round(input_channels * expand_ratio) + return nn.Sequential(OrderedDict([ + ('expansion', self.conv_1x1_bn( + input_channels, hidden_channels)), + ('dropout', nn.Dropout2d(p, inplace=inplace)), + ('reduction', nn.Sequential(*[ + nn.Conv2d(hidden_channels, output_channels, 1), + self.get_bn_module(output_channels), + ])), + ])) + + # @staticmethod + def conv_1x1_bn(self, inp, oup): + """Return pointwise convolution with BatchNorm and ReLU6.""" + return nn.Sequential( + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + self.get_bn_module(oup), + nn.ReLU6(inplace=True) + ) + + def forward(self, x, target_size=None, h0=None, return_hidden=False, + source='DHF1K', static=None): + """ + Forward pass. + Arguments: + x: Input batch of dimensions [batch, time, channel, h, w]. + target_size: (height, width) of the resized output. + h0: Initial hidden state. + return_hidden: Return [prediction, hidden_state]. + source: Data source of current batch. Must be in self.sources. + static: Whether the current input is static. If None, this is + inferred from the input dimensions or self.sources. + """ + if target_size is None: + target_size = x.shape[-2:] + + # Set the current source for the domain-specific BatchNorm modules + self.this_source = source + + # Prepare other parameters + source_str = f'_{source.lower()}' + if static is None: + static = x.shape[1] == 1 or self.sources == ('SALICON',) + + # Compute backbone CNN features and concatenate with Gaussian prior maps + feat_seq_1x = [] + feat_seq_2x = [] + feat_seq_4x = [] + for t, img in enumerate(torch.unbind(x, dim=1)): + im_feat_1x, im_feat_2x, im_feat_4x = self.cnn(img) + + im_feat_2x = self.skip_2x(im_feat_2x) + im_feat_4x = self.skip_4x(im_feat_4x) + + if self.n_gaussians > 0: + gaussian_maps = self._get_gaussian_maps(im_feat_1x, source_str) + im_feat_1x = torch.cat((im_feat_1x, gaussian_maps), dim=1) + + im_feat_1x = self.post_cnn(im_feat_1x) + feat_seq_1x.append(im_feat_1x) + feat_seq_2x.append(im_feat_2x) + feat_seq_4x.append(im_feat_4x) + + feat_seq_1x = torch.stack(feat_seq_1x, dim=1) + + # Bypass-RNN + hidden, rnn_feat_seq, rnn_feat = (None,) * 3 + if not (static and self.bypass_rnn): + rnn_feat_seq, hidden = self.rnn(feat_seq_1x, hidden=h0) + + # Decoder + output_seq = [] + for idx, im_feat in enumerate( + torch.unbind(feat_seq_1x, dim=1)): + + if not (static and self.bypass_rnn): + rnn_feat = rnn_feat_seq[:, idx, ...] + rnn_feat = self.post_rnn(rnn_feat) + if self.res_rnn: + im_feat = im_feat + rnn_feat + else: + im_feat = rnn_feat + + im_feat = self.upsampling_1(im_feat) + im_feat = torch.cat((im_feat, feat_seq_2x[idx]), dim=1) + im_feat = self.upsampling_2(im_feat) + im_feat = torch.cat((im_feat, feat_seq_4x[idx]), dim=1) + im_feat = self.post_upsampling_2(im_feat) + + im_feat = self.__getattr__( + 'adaptation' + (source_str if self.ds_adaptation else ''))( + im_feat) + + im_feat = F.interpolate( + im_feat, size=x.shape[-2:], mode='nearest') + + im_feat = F.pad(im_feat, [self.smoothing_ksize // 2] * 4, + mode='replicate') + im_feat = self.__getattr__( + 'smoothing' + (source_str if self.ds_smoothing else ''))( + im_feat) + + im_feat = F.interpolate( + im_feat, size=target_size, mode='bilinear', align_corners=False) + + im_feat = log_softmax(im_feat) + output_seq.append(im_feat) + output_seq = torch.stack(output_seq, dim=1) + + outputs = [output_seq] + if return_hidden: + outputs.append(hidden) + if len(outputs) == 1: + return outputs[0] + return outputs + + +def log_softmax(x): + x_size = x.size() + x = x.view(x.size(0), -1) + x = F.log_softmax(x, dim=1) + return x.view(x_size) \ No newline at end of file diff --git a/gazenet/models/shared_components/__init__.py b/gazenet/models/shared_components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/shared_components/attentive_convlstm/__init__.py b/gazenet/models/shared_components/attentive_convlstm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/shared_components/attentive_convlstm/model.py b/gazenet/models/shared_components/attentive_convlstm/model.py new file mode 100644 index 0000000..4c85e7b --- /dev/null +++ b/gazenet/models/shared_components/attentive_convlstm/model.py @@ -0,0 +1,112 @@ +import torch.nn as nn + +nb_timestep = 4 + +# https://github.com/PanoAsh/Saliency-Attentive-Model-Pytorch/blob/master/main.py +class AttentiveLSTM(nn.Module): + + def __init__(self, nb_features_in, nb_features_out, nb_features_att, nb_rows, nb_cols): + super(AttentiveLSTM, self).__init__() + + # define the fundamantal parameters + self.nb_features_in = nb_features_in + self.nb_features_out = nb_features_out + self.nb_features_att = nb_features_att + self.nb_rows = nb_rows + self.nb_cols = nb_cols + + # define convs + self.W_a = nn.Conv2d(in_channels=self.nb_features_att, out_channels=self.nb_features_att, + kernel_size=self.nb_rows, stride=1, padding=1, dilation=1, groups=1, bias=True) + self.U_a = nn.Conv2d(in_channels=self.nb_features_in, out_channels=self.nb_features_att, + kernel_size=self.nb_rows, stride=1, padding=1, dilation=1, groups=1, bias=True) + self.V_a = nn.Conv2d(in_channels=self.nb_features_att, out_channels=1, + kernel_size=self.nb_rows, stride=1, padding=1, dilation=1, groups=1, bias=False) + + self.W_i = nn.Conv2d(in_channels=self.nb_features_in, out_channels=self.nb_features_out, + kernel_size=self.nb_rows, stride=1, padding=1, dilation=1, groups=1, bias=True) + self.U_i = nn.Conv2d(in_channels=self.nb_features_out, out_channels=self.nb_features_out, + kernel_size=self.nb_rows, stride=1, padding=1, dilation=1, groups=1, bias=True) + + self.W_f = nn.Conv2d(in_channels=self.nb_features_in, out_channels=self.nb_features_out, + kernel_size=self.nb_rows, stride=1, padding=1, dilation=1, groups=1, bias=True) + self.U_f = nn.Conv2d(in_channels=self.nb_features_out, out_channels=self.nb_features_out, + kernel_size=self.nb_rows, stride=1, padding=1, dilation=1, groups=1, bias=True) + + self.W_c = nn.Conv2d(in_channels=self.nb_features_in, out_channels=self.nb_features_out, + kernel_size=self.nb_rows, stride=1, padding=1, dilation=1, groups=1, bias=True) + self.U_c = nn.Conv2d(in_channels=self.nb_features_out, out_channels=self.nb_features_out, + kernel_size=self.nb_rows, stride=1, padding=1, dilation=1, groups=1, bias=True) + + self.W_o = nn.Conv2d(in_channels=self.nb_features_in, out_channels=self.nb_features_out, + kernel_size=self.nb_rows, stride=1, padding=1, dilation=1, groups=1, bias=True) + self.U_o = nn.Conv2d(in_channels=self.nb_features_out, out_channels=self.nb_features_out, + kernel_size=self.nb_rows, stride=1, padding=1, dilation=1, groups=1, bias=True) + + # define activations + self.tanh = nn.Tanh() + self.sigmoid = nn.Sigmoid() + self.softmax = nn.Softmax(dim=-1) + + # define number of temporal steps + self.nb_ts = nb_timestep + + def forward(self, x): + # get the current cell memory and hidden state + h_curr, c_curr = x, x + + for i in range(self.nb_ts): + + # the attentive model + my_Z = self.V_a(self.tanh(self.W_a(h_curr) + self.U_a(x))) + my_A = self.softmax(my_Z) + AM_cL = my_A * x + + # the convLSTM model + my_I = self.sigmoid(self.W_i(AM_cL) + self.U_i(h_curr)) + my_F = self.sigmoid(self.W_f(AM_cL) + self.U_f(h_curr)) + my_O = self.sigmoid(self.W_o(AM_cL) + self.U_o(h_curr)) + my_G = self.tanh(self.W_c(AM_cL) + self.U_c(h_curr)) + c_next = my_G * my_I + my_F * c_curr + h_next = self.tanh(c_next) * my_O + + c_curr = c_next + h_curr = h_next + + return h_curr + + +class SequenceAttentiveLSTM(AttentiveLSTM): + def __init__(self, *args, sequence_len=2, sequence_norm=True, **kwargs): + super().__init__(*args, **kwargs) + + if sequence_norm: + self.sequence_norm = nn.BatchNorm3d(sequence_len) + # self.sequence_len = sequence_len + else: + self.sequence_norm = lambda x : x + # self.sequence_len = None + + def forward(self, x): + x = self.sequence_norm(x) + # get the current cell memory and hidden state + h_curr, c_curr = x[:,0], x[:,0] + + for i in range(x.shape[1]): # for i in range(self.sequence_len): + # the attentive model + my_Z = self.V_a(self.tanh(self.W_a(h_curr) + self.U_a(x[:,i]))) + my_A = self.softmax(my_Z) + AM_cL = my_A * x[:,i] + + # the convLSTM model + my_I = self.sigmoid(self.W_i(AM_cL) + self.U_i(h_curr)) + my_F = self.sigmoid(self.W_f(AM_cL) + self.U_f(h_curr)) + my_O = self.sigmoid(self.W_o(AM_cL) + self.U_o(h_curr)) + my_G = self.tanh(self.W_c(AM_cL) + self.U_c(h_curr)) + c_next = my_G * my_I + my_F * c_curr + h_next = self.tanh(c_next) * my_O + + c_curr = c_next + h_curr = h_next + + return h_curr diff --git a/gazenet/models/shared_components/conv3d/__init__.py b/gazenet/models/shared_components/conv3d/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/shared_components/conv3d/model.py b/gazenet/models/shared_components/conv3d/model.py new file mode 100644 index 0000000..ba4491a --- /dev/null +++ b/gazenet/models/shared_components/conv3d/model.py @@ -0,0 +1,299 @@ +import torch +from torch import nn + + +class BasicConv3d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv3d, self).__init__() + self.conv = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class SepConv3d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(SepConv3d, self).__init__() + self.conv_s = nn.Conv3d(in_planes, out_planes, kernel_size=(1,kernel_size,kernel_size), stride=(1,stride,stride), padding=(0,padding,padding), bias=False) + self.bn_s = nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True) + self.relu_s = nn.ReLU() + + self.conv_t = nn.Conv3d(out_planes, out_planes, kernel_size=(kernel_size,1,1), stride=(stride,1,1), padding=(padding,0,0), bias=False) + self.bn_t = nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True) + self.relu_t = nn.ReLU() + + def forward(self, x): + x = self.conv_s(x) + x = self.bn_s(x) + x = self.relu_s(x) + + x = self.conv_t(x) + x = self.bn_t(x) + x = self.relu_t(x) + return x + + +class Mixed_3b(nn.Module): + def __init__(self): + super(Mixed_3b, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv3d(192, 64, kernel_size=1, stride=1), + ) + self.branch1 = nn.Sequential( + BasicConv3d(192, 96, kernel_size=1, stride=1), + SepConv3d(96, 128, kernel_size=3, stride=1, padding=1), + ) + self.branch2 = nn.Sequential( + BasicConv3d(192, 16, kernel_size=1, stride=1), + SepConv3d(16, 32, kernel_size=3, stride=1, padding=1), + ) + self.branch3 = nn.Sequential( + nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1), + BasicConv3d(192, 32, kernel_size=1, stride=1), + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + + return out + + +class Mixed_3c(nn.Module): + def __init__(self): + super(Mixed_3c, self).__init__() + self.branch0 = nn.Sequential( + BasicConv3d(256, 128, kernel_size=1, stride=1), + ) + self.branch1 = nn.Sequential( + BasicConv3d(256, 128, kernel_size=1, stride=1), + SepConv3d(128, 192, kernel_size=3, stride=1, padding=1), + ) + self.branch2 = nn.Sequential( + BasicConv3d(256, 32, kernel_size=1, stride=1), + SepConv3d(32, 96, kernel_size=3, stride=1, padding=1), + ) + self.branch3 = nn.Sequential( + nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1), + BasicConv3d(256, 64, kernel_size=1, stride=1), + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Mixed_4b(nn.Module): + def __init__(self): + super(Mixed_4b, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv3d(480, 192, kernel_size=1, stride=1), + ) + self.branch1 = nn.Sequential( + BasicConv3d(480, 96, kernel_size=1, stride=1), + SepConv3d(96, 208, kernel_size=3, stride=1, padding=1), + ) + self.branch2 = nn.Sequential( + BasicConv3d(480, 16, kernel_size=1, stride=1), + SepConv3d(16, 48, kernel_size=3, stride=1, padding=1), + ) + self.branch3 = nn.Sequential( + nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1), + BasicConv3d(480, 64, kernel_size=1, stride=1), + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Mixed_4c(nn.Module): + def __init__(self): + super(Mixed_4c, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv3d(512, 160, kernel_size=1, stride=1), + ) + self.branch1 = nn.Sequential( + BasicConv3d(512, 112, kernel_size=1, stride=1), + SepConv3d(112, 224, kernel_size=3, stride=1, padding=1), + ) + self.branch2 = nn.Sequential( + BasicConv3d(512, 24, kernel_size=1, stride=1), + SepConv3d(24, 64, kernel_size=3, stride=1, padding=1), + ) + self.branch3 = nn.Sequential( + nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1), + BasicConv3d(512, 64, kernel_size=1, stride=1), + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Mixed_4d(nn.Module): + def __init__(self): + super(Mixed_4d, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv3d(512, 128, kernel_size=1, stride=1), + ) + self.branch1 = nn.Sequential( + BasicConv3d(512, 128, kernel_size=1, stride=1), + SepConv3d(128, 256, kernel_size=3, stride=1, padding=1), + ) + self.branch2 = nn.Sequential( + BasicConv3d(512, 24, kernel_size=1, stride=1), + SepConv3d(24, 64, kernel_size=3, stride=1, padding=1), + ) + self.branch3 = nn.Sequential( + nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1), + BasicConv3d(512, 64, kernel_size=1, stride=1), + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Mixed_4e(nn.Module): + def __init__(self): + super(Mixed_4e, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv3d(512, 112, kernel_size=1, stride=1), + ) + self.branch1 = nn.Sequential( + BasicConv3d(512, 144, kernel_size=1, stride=1), + SepConv3d(144, 288, kernel_size=3, stride=1, padding=1), + ) + self.branch2 = nn.Sequential( + BasicConv3d(512, 32, kernel_size=1, stride=1), + SepConv3d(32, 64, kernel_size=3, stride=1, padding=1), + ) + self.branch3 = nn.Sequential( + nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1), + BasicConv3d(512, 64, kernel_size=1, stride=1), + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Mixed_4f(nn.Module): + def __init__(self): + super(Mixed_4f, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv3d(528, 256, kernel_size=1, stride=1), + ) + self.branch1 = nn.Sequential( + BasicConv3d(528, 160, kernel_size=1, stride=1), + SepConv3d(160, 320, kernel_size=3, stride=1, padding=1), + ) + self.branch2 = nn.Sequential( + BasicConv3d(528, 32, kernel_size=1, stride=1), + SepConv3d(32, 128, kernel_size=3, stride=1, padding=1), + ) + self.branch3 = nn.Sequential( + nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1), + BasicConv3d(528, 128, kernel_size=1, stride=1), + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Mixed_5b(nn.Module): + def __init__(self): + super(Mixed_5b, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv3d(832, 256, kernel_size=1, stride=1), + ) + self.branch1 = nn.Sequential( + BasicConv3d(832, 160, kernel_size=1, stride=1), + SepConv3d(160, 320, kernel_size=3, stride=1, padding=1), + ) + self.branch2 = nn.Sequential( + BasicConv3d(832, 32, kernel_size=1, stride=1), + SepConv3d(32, 128, kernel_size=3, stride=1, padding=1), + ) + self.branch3 = nn.Sequential( + nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1), + BasicConv3d(832, 128, kernel_size=1, stride=1), + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Mixed_5c(nn.Module): + def __init__(self): + super(Mixed_5c, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv3d(832, 384, kernel_size=1, stride=1), + ) + self.branch1 = nn.Sequential( + BasicConv3d(832, 192, kernel_size=1, stride=1), + SepConv3d(192, 384, kernel_size=3, stride=1, padding=1), + ) + self.branch2 = nn.Sequential( + BasicConv3d(832, 48, kernel_size=1, stride=1), + SepConv3d(48, 128, kernel_size=3, stride=1, padding=1), + ) + self.branch3 = nn.Sequential( + nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1), + BasicConv3d(832, 128, kernel_size=1, stride=1), + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out \ No newline at end of file diff --git a/gazenet/models/shared_components/convgru/__init__.py b/gazenet/models/shared_components/convgru/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/shared_components/convgru/model.py b/gazenet/models/shared_components/convgru/model.py new file mode 100644 index 0000000..f894f9a --- /dev/null +++ b/gazenet/models/shared_components/convgru/model.py @@ -0,0 +1,371 @@ +import math +from collections import OrderedDict + +import torch +from torch.distributions.bernoulli import Bernoulli +import torch.nn as nn +import torch.nn.functional as f +from torch.nn.parameter import Parameter +from torch.nn import init + +# Inspired by: +# https://github.com/jacobkimmel/pytorch_convgru +# https://gist.github.com/halochou/acbd669af86ecb8f988325084ba7a749 + + +class ConvGRUCell(nn.Module): + """ + Generate a convolutional GRU cell. + Arguments: + input_ch: Number of channels of the input. + hidden_ch: Number of channels of hidden state. + kernel_size (tuple): Kernel size of the U and W operations. + gate_ksize (tuple): Kernel size for the gates. + bias: Add bias term to layers. + norm: Normalization method. 'batch', 'instance' or ''. + norm_momentum: BatchNorm momentum. + affine_norm: Affine BatchNorm. + batchnorm: External function that accepts a number of channels and + returns a BatchNorm module (for DSBN). Overwrites norm and + norm_momentum. + drop_prob: Tuple of dropout probabilities for input, recurrent and + output dropout. + do_mode: If 'recurrent', the variational dropout is used, dropping out + the same channels at every time step. If 'naive', different channels + are dropped at each time step. + r_bias, z_bias: Bias initialization for r and z gates. + mobile: If True, MobileNet-style convolutions are used. + """ + + def __init__(self, input_ch, hidden_ch, kernel_size, gate_ksize=(1, 1), + bias=True, norm='', norm_momentum=0.1, affine_norm=True, + batchnorm=None, gain=1, drop_prob=(0., 0., 0.), + do_mode='recurrent', r_bias=0., z_bias=0., mobile=False, + **kwargs): + super().__init__() + + self.input_ch = input_ch + self.hidden_ch = hidden_ch + self.kernel_size = kernel_size + self.gate_ksize = gate_ksize + self.mobile = mobile + self.kwargs = {'init': 'xavier_uniform_'} + self.kwargs.update(kwargs) + + # Process normalization arguments + self.norm = norm + self.norm_momentum = norm_momentum + self.affine_norm = affine_norm + self.batchnorm = batchnorm + self.norm_kwargs = None + if self.batchnorm is not None: + self.norm = 'batch' + elif self.norm: + self.norm_kwargs = { + 'affine': self.affine_norm, 'track_running_stats': True, + 'momentum': self.norm_momentum} + + # Prepare normalization modules + if self.norm: + self.norm_r_x = self.get_norm_module(self.hidden_ch) + self.norm_r_h = self.get_norm_module(self.hidden_ch) + self.norm_z_x = self.get_norm_module(self.hidden_ch) + self.norm_z_h = self.get_norm_module(self.hidden_ch) + self.norm_out_x = self.get_norm_module(self.hidden_ch) + self.norm_out_h = self.get_norm_module(self.hidden_ch) + + # Prepare dropout + self.drop_prob = drop_prob + self.do_mode = do_mode + if self.do_mode == 'recurrent': + # Prepare dropout masks if using recurrent dropout + for idx, mask in self.yield_drop_masks(): + self.register_buffer(self.mask_name(idx), mask) + elif self.do_mode != 'naive': + raise ValueError('Unknown dropout mode ', self.do_mode) + + # Instantiate the main weight matrices + self.w_r = self._conv2d(self.input_ch, self.gate_ksize, bias=False) + self.u_r = self._conv2d(self.hidden_ch, self.gate_ksize, bias=False) + self.w_z = self._conv2d(self.input_ch, self.gate_ksize, bias=False) + self.u_z = self._conv2d(self.hidden_ch, self.gate_ksize, bias=False) + self.w = self._conv2d(self.input_ch, self.kernel_size, bias=False) + self.u = self._conv2d(self.hidden_ch, self.gate_ksize, bias=False) + + # Instantiate the optional biases and affine paramters + self.bias = bias + self.r_bias = r_bias + self.z_bias = z_bias + if self.bias or self.affine_norm: + self.b_r = Parameter(torch.Tensor(self.hidden_ch, 1, 1)) + self.b_z = Parameter(torch.Tensor(self.hidden_ch, 1, 1)) + self.b_h = Parameter(torch.Tensor(self.hidden_ch, 1, 1)) + if self.affine_norm: + self.a_r_x = Parameter(torch.Tensor(self.hidden_ch, 1, 1)) + self.a_r_h = Parameter(torch.Tensor(self.hidden_ch, 1, 1)) + self.a_z_x = Parameter(torch.Tensor(self.hidden_ch, 1, 1)) + self.a_z_h = Parameter(torch.Tensor(self.hidden_ch, 1, 1)) + self.a_h_x = Parameter(torch.Tensor(self.hidden_ch, 1, 1)) + self.a_h_h = Parameter(torch.Tensor(self.hidden_ch, 1, 1)) + + self.gain = gain + self.set_weights() + + def set_weights(self): + """Initialize the parameters""" + def gain_from_ksize(ksize): + n = ksize[0] * ksize[1] * self.hidden_ch + return math.sqrt(2. / n) + with torch.no_grad(): + if not self.mobile: + if self.gain < 0: + gain_1 = gain_from_ksize(self.kernel_size) + gain_2 = gain_from_ksize(self.gate_ksize) + else: + gain_1 = gain_2 = self.gain + init_fn = getattr(init, self.kwargs['init']) + init_fn(self.w_r.weight, gain=gain_2) + init_fn(self.u_r.weight, gain=gain_2) + init_fn(self.w_z.weight, gain=gain_2) + init_fn(self.u_z.weight, gain=gain_2) + init_fn(self.w.weight, gain=gain_1) + init_fn(self.u.weight, gain=gain_2) + if self.bias or self.affine_norm: + self.b_r.data.fill_(self.r_bias) + self.b_z.data.fill_(self.z_bias) + self.b_h.data.zero_() + if self.affine_norm: + self.a_r_x.data.fill_(1) + self.a_r_h.data.fill_(1) + self.a_z_x.data.fill_(1) + self.a_z_h.data.fill_(1) + self.a_h_x.data.fill_(1) + self.a_h_h.data.fill_(1) + + def forward(self, x, h_tm1): + # Initialize hidden state if necessary + if h_tm1 is None: + h_tm1 = self._init_hidden(x, cuda=x.is_cuda) + + # Compute gate components + r_x = self.w_r(self.apply_dropout(x, 0, 0)) + r_h = self.u_r(self.apply_dropout(h_tm1, 1, 0)) + z_x = self.w_z(self.apply_dropout(x, 0, 1)) + z_h = self.u_z(self.apply_dropout(h_tm1, 1, 1)) + h_x = self.w(self.apply_dropout(x, 0, 2)) + h_h = self.u(self.apply_dropout(h_tm1, 1, 2)) + + if self.norm: + # Apply normalization + r_x = self.norm_r_x(r_x) + r_h = self.norm_r_h(r_h) + z_x = self.norm_z_x(z_x) + z_h = self.norm_z_h(z_h) + h_x = self.norm_out_x(h_x) + h_h = self.norm_out_h(h_h) + + if self.affine_norm: + # Apply affine transformation + r_x = r_x * self.a_r_x + r_h = r_h * self.a_r_h + z_x = z_x * self.a_z_x + z_h = z_h * self.a_z_h + h_x = h_x * self.a_h_x + h_h = h_h * self.a_h_h + + # Compute gates with optinal bias + if self.bias or self.affine_norm: + r = torch.sigmoid(r_x + r_h + self.b_r) + z = torch.sigmoid(z_x + z_h + self.b_z) + else: + r = torch.sigmoid(r_x + r_h) + z = torch.sigmoid(z_x + z_h) + + # Compute new hidden state + if self.bias or self.affine_norm: + h = torch.tanh(h_x + r * h_h + self.b_h) + else: + h = torch.tanh(h_x + r * h_h) + h = (1 - z) * h_tm1 + z * h + + # Optionally apply output dropout + y = self.apply_dropout(h, 2, 0) + + return y, h + + @staticmethod + def mask_name(idx): + return 'drop_mask_{}'.format(idx) + + def set_drop_masks(self): + """Set the dropout masks for the current sequence""" + for idx, mask in self.yield_drop_masks(): + setattr(self, self.mask_name(idx), mask) + + def yield_drop_masks(self): + """Iterator over recurrent dropout masks""" + n_masks = (3, 3, 1) + n_channels = (self.input_ch, self.hidden_ch, self.hidden_ch) + for idx, p in enumerate(self.drop_prob): + if p > 0: + yield (idx, self.generate_do_mask( + p, n_masks[idx], n_channels[idx])) + + @staticmethod + def generate_do_mask(p, n, ch): + """Generate a dropout mask for recurrent dropout""" + with torch.no_grad(): + mask = Bernoulli(torch.full((n, ch), 1 - p)).sample() / (1 - p) + mask = mask.requires_grad_(False).cuda() + return mask + + def apply_dropout(self, x, idx, sub_idx): + """Apply recurrent or naive dropout""" + if self.training and self.drop_prob[idx] > 0 and idx != 2: + if self.do_mode == 'recurrent': + x = x.clone() * torch.reshape( + getattr(self, self.mask_name(idx)) + [sub_idx, :], (1, -1, 1, 1)) + elif self.do_mode == 'naive': + x = f.dropout2d( + x, self.drop_prob[idx], self.training, inplace=False) + else: + x = x.clone() + return x + + def get_norm_module(self, channels): + """Return normalization module instance""" + norm_module = None + if self.batchnorm is not None: + norm_module = self.batchnorm(channels) + elif self.norm == 'instance': + norm_module = nn.InstanceNorm2d(channels, **self.norm_kwargs) + elif self.norm == 'batch': + norm_module = nn.BatchNorm2d(channels, **self.norm_kwargs) + return norm_module + + def _conv2d(self, in_channels, kernel_size, bias=True): + """ + Return convolutional layer. + Supports standard convolutions and MobileNet-style convolutions. + """ + padding = tuple(k_size // 2 for k_size in kernel_size) + if not self.mobile or kernel_size == (1, 1): + return nn.Conv2d(in_channels, self.hidden_ch, kernel_size, + padding=padding, bias=bias) + else: + return nn.Sequential(OrderedDict([ + ('conv_dw', nn.Conv2d( + in_channels, in_channels, kernel_size=kernel_size, + padding=padding, groups=in_channels, bias=False)), + ('sep_bn', self.get_norm_module(in_channels)), + ('sep_relu', nn.ReLU6()), + ('conv_sep', nn.Conv2d( + in_channels, self.hidden_ch, 1, bias=bias)), + ])) + + def _init_hidden(self, input_, cuda=True): + """Initialize the hidden state""" + batch_size, _, height, width = input_.data.size() + prev_state = torch.zeros( + batch_size, self.hidden_ch, height, width) + if cuda: + prev_state = prev_state.cuda() + return prev_state + + +class ConvGRU(nn.Module): + + def __init__(self, input_channels=None, hidden_channels=None, + kernel_size=(3, 3), gate_ksize=(1, 1), + dropout=(False, False, False), drop_prob=(0.5, 0.5, 0.5), + **kwargs): + """ + Generates a multi-layer convolutional GRU. + Preserves spatial dimensions across cells, only altering depth. + Arguments: + input_channels: Number of channels of the input. + hidden_channels (list): List of hidden channels for each layer. + kernel_size (tuple): Kernel size of the U and W operations. + gate_ksize (tuple): Kernel size for the gates. + dropout: Tuple of Booleans for input, recurrent and output dropout. + drop_prob: Tuple of dropout probabilities for each selected dropout. + kwargs: Additional parameters for the cGRU cells. + """ + + super().__init__() + + kernel_size = tuple(kernel_size) + gate_ksize = tuple(gate_ksize) + dropout = tuple(dropout) + drop_prob = tuple(drop_prob) + + assert len(hidden_channels) > 0 + self.input_channels = [input_channels] + hidden_channels + self.hidden_channels = hidden_channels + self.num_layers = len(hidden_channels) + self._check_kernel_size_consistency(kernel_size) + self._check_kernel_size_consistency(gate_ksize) + self.kernel_size = self._extend_for_multilayer(kernel_size) + self.gate_ksize = self._extend_for_multilayer(gate_ksize) + self.dropout = self._extend_for_multilayer(dropout) + drop_prob = self._extend_for_multilayer(drop_prob) + self.drop_prob = [tuple(dp_ if do_ else 0. for dp_, do_ in zip(dp, do)) + for dp, do in zip(drop_prob, self.dropout)] + self.kwargs = kwargs + + cell_list = [] + for idx in range(self.num_layers): + if idx < self.num_layers - 1: + # Switch output dropout off for hidden layers. + # Otherwise it would confict with input dropout. + this_drop_prob = self.drop_prob[idx][:2] + (0.,) + else: + this_drop_prob = self.drop_prob[idx] + cell_list.append(ConvGRUCell( + self.input_channels[idx], self.hidden_channels[idx], + self.kernel_size[idx], drop_prob=this_drop_prob, + gate_ksize=self.gate_ksize[idx], **kwargs)) + self.cell_list = nn.ModuleList(cell_list) + + def forward(self, input_tensor, hidden=None): + """ + Args: + input_tensor: + 5-D Tensor of shape (b, t, c, h, w) + hidden: + optional initial hiddens state + Returns: + outputs + """ + if not hidden: + hidden = [None] * self.num_layers + + outputs = [] + iterator = torch.unbind(input_tensor, dim=1) + + for t, x in enumerate(iterator): + for layer_idx in range(self.num_layers): + if self.cell_list[layer_idx].do_mode == 'recurrent'\ + and t == 0: + self.cell_list[layer_idx].set_drop_masks() + (x, h) = self.cell_list[layer_idx](x, hidden[layer_idx]) + hidden[layer_idx] = h.clone() + outputs.append(x.clone()) + outputs = torch.stack(outputs, dim=1) + + return outputs, hidden + + @staticmethod + def _check_kernel_size_consistency(kernel_size): + if not (isinstance(kernel_size, tuple) or + (isinstance(kernel_size, list) and + all([isinstance(elem, tuple) for elem in kernel_size]))): + raise ValueError('`kernel_size` must be tuple or list of tuples') + + def _extend_for_multilayer(self, param): + if not isinstance(param, list): + param = [param] * self.num_layers + else: + assert(len(param) == self.num_layers) + return param \ No newline at end of file diff --git a/gazenet/models/shared_components/gmu/__init__.py b/gazenet/models/shared_components/gmu/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/shared_components/gmu/model.py b/gazenet/models/shared_components/gmu/model.py new file mode 100644 index 0000000..6b4dff3 --- /dev/null +++ b/gazenet/models/shared_components/gmu/model.py @@ -0,0 +1,1243 @@ +# -*- coding: utf-8 -*- +"""This module implements the Gated Multimodal Units in PyTorch + +Currently there are two versions: +Two versions, the general GMU and the simplified, bimodal unit +are described in Arevalo et al., Gated multimodal networks, 2020 +(https://link.springer.com/article/10.1007/s00521-019-04559-1) + +The published code of the authors contains an implementation +of the bimodal version in the Theano framework Bricks. +However, this version is a bit restrictive. It constraints +the input size with the hidden size. +See https://github.com/johnarevalo/gmu-mmimdb/blob/master/model.py + +The general GMU and the bimodal version with tied gates +will be implemented here as GMU and GBU. +Now, there is also the GMU Conv2d version in here. +""" + +import torch + + +class GMU(torch.nn.Module): + """Gated Multimodal Unit, a hidden unit in a neural network that learns + to combine the representation of different modalities into a single one + via gates (similar to LSTM). + + h generally refers to the hidden state (i.e. modality information, this is + the naming scheme chosen by the original GMU authors, but I do not like it + that much), while + z generally refers to the gates. + """ + + def __init__( + self, + in_features, + out_features, + modalities, + activation=torch.tanh, + gate_activation=torch.sigmoid, + hidden_weight_init=lambda x: torch.nn.init.uniform_(x, -0.01, 0.01), + gate_weight_init=lambda x: torch.nn.init.uniform_(x, -0.01, 0.01), + gate_hidden_interaction=lambda x, y: x * y, + gate_transformation=None, + bias=True, + ): + """Init function. + + Args: + in_features (int): vector length of a single modality + out_features (int): number of (hidden) units / output features + modalities (int): number of modalities + activation (torch func): activation function for the modalities + gate_activation (torch func): activation function for the gate + hidden_weight_init (torch init func): init method for the neuronal + weights + gate_weight_init (torch init func): init method for the gate weights + gate_hidden_interaction (lambda func): how does h and z interact + with another. Could be linear or non-linear (e.g. x * (1+y)) + gate_transformation (lambda func): processes the gate + activations before they interact with the hidden state, + e.g. normalise / gain control them by + lambda x: x / torch.sum(x, 1, keepdim=True) + bias (bool): should the computation contain a + bias (not specified in the original paper) + + """ + + super(GMU, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.modalities = modalities + self.gates = modalities + self.activation = activation + self.gate_activation = gate_activation + self.hidden_weight_init = hidden_weight_init + self.gate_weight_init = gate_weight_init + self.hidden_bias_init = lambda x: torch.nn.init.uniform_(x, -0.01, 0.01) + self.gate_bias_init = lambda x: torch.nn.init.uniform_(x, -0.01, 0.01) + self.gate_hidden_interaction = gate_hidden_interaction + self.gate_transformation = gate_transformation + self.W_h = self.initialize_hidden_weights() + self.W_z = self.initialize_gate_weights() + self.register_bias(bias) + + def register_bias(self, bias): + """ register biases """ + if bias: + self.hidden_bias = self.initialize_hidden_bias() + self.gate_bias = self.initialize_gate_bias() + else: + self.register_parameter("hidden_bias", None) + self.register_parameter("gate_bias", None) + + def initialize_hidden_bias(self): + """Initializes hidden weight parameters + + Returns: + torch.nn.Parameter + """ + + b = torch.nn.Parameter(torch.empty((1, self.modalities, self.out_features))) + self.hidden_bias_init(b) + return b + + def initialize_gate_bias(self): + """Initializes hidden weight parameters + + Returns: + torch.nn.Parameter + """ + + b = torch.nn.Parameter(torch.empty((1, self.gates, self.out_features))) + self.gate_bias_init(b) + return b + + def initialize_hidden_weights(self): + """Initializes hidden weight parameters + + Returns: + torch.nn.Parameter + + """ + # each neuron only receives the information of its associated modality + W = torch.nn.Parameter( + torch.empty((1, self.modalities, self.in_features, self.out_features)) + ) + return self.hidden_weight_init(W) + + def initialize_gate_weights(self): + """Initializes gate weight parameters + + Returns: + torch.nn.Parameter + + """ + # each gate gets the information of all modalities + W = torch.nn.Parameter( + torch.empty( + ( + self.modalities * self.in_features, + self.gates * self.out_features, + ) + ) + ) + return self.gate_weight_init(W) + + @staticmethod + def check_input(inputs): + """Checks if the input is already a Torch tensor, + if it is a list or tuple (hopefully one of the two), + stack them into a tensor + + Args: + inputs: input to the layer/cell + + Returns: + Torch tensor of size (N,C,self.in_features) + + """ + + if not isinstance(inputs, torch.Tensor): + inputs = torch.stack(inputs, 1) + return inputs + + def get_modality_activation(self, inputs): + """Processes the the modality information separately with a set of weights + + Args: + inputs: input to the layer/cell + + Returns: + Torch tensor of size (N,self.modalities,self.out_features) + + """ + h = torch.sum(inputs.unsqueeze(-1) * self.W_h, -2) + if self.hidden_bias is not None: + h += self.hidden_bias + h = self.activation(h) + return h + + def get_gate_activation(self, inputs): + """Processes the modality information separately with a set of weights + + Args: + inputs: input to the layer/cell + + Returns: + Torch tensor of size (N,self.gates,self.out_features) + + """ + + z = torch.matmul(inputs.view(-1, self.in_features * self.modalities), self.W_z) + if self.gate_bias is not None: + z = z.view(-1, self.gates, self.out_features) + self.gate_bias + z = self.gate_activation(z) + return z + + def forward(self, inputs): + """Calculates the output of the unit + + Args: + inputs (torch.Tensors): consisting of + multiple modalities as torch.Tensors in the form NCH. + N is batch size, C is the modalities and H the length + of the modality vectors. + + Returns: + A tuple of torch.Tensor of size (N, self.out_features) + + """ + + inputs = self.check_input(inputs) + h = self.get_modality_activation(inputs) + z = self.get_gate_activation(inputs) + if self.gate_transformation is not None: + z = self.gate_transformation(z) + return torch.sum(self.gate_hidden_interaction(h, z), 1), (h, z) + + +class GBU(GMU): + """Gated Bimodal Unit, a hidden unit in a neural network that learns + to combine the representation of two modalities into a single one + via a single gate. See GMU for more general information. + + h generally refers to the hidden state, while + z generally refers to the gate. + + Note: Since this is a specialised subclass of the GMU, most of the + general behaviour is handled in the GMU class + """ + + def __init__( + self, + in_features, + out_features, + activation=None, + gate_activation=None, + hidden_weight_init=None, + gate_weight_init=None, + gate_hidden_interaction=None, + gate_transformation=lambda x: torch.cat((x, (1 - x)), 1), + bias=True, + ): + """Init function. + + Args: + out_features (int): number of (hidden) units / output features + in_features (int): vector length of a single modality + activation (torch func, optional): activation function for the + modalities + gate_activation (torch func, optional): activation function for the + gate + hidden_weight_init (torch init func, optional): init method for the + neuronal weights + gate_weight_init (torch init func, optional): init method for the + gate weights + gate_hidden_interaction (lambda func): how does h and z interact + with another. Could be linear or non-linear (e.g. x * (1+y)) + gate_transformation (lambda func): processes the gate + activations before they interact with the hidden state, + here in the bimodal case, it just concats the activations + with the complementary probabilities + + Notes: + # TODO hidden bias inheritance + """ + + super(GBU, self).__init__( + in_features=in_features, out_features=out_features, modalities=2 + ) + if activation: + self.activation = activation + if gate_activation: + self.gate_activation = gate_activation + if hidden_weight_init: + self.hidden_weight_init = hidden_weight_init + if gate_weight_init: + self.gate_weight_init = gate_weight_init + if gate_hidden_interaction: + self.gate_hidden_interaction = gate_hidden_interaction + self.gate_transformation = gate_transformation + self.gates = 1 + self.W_h = self.initialize_hidden_weights() + self.W_z = self.initialize_gate_weights() + self.register_bias(bias) + + +class RGMU(GMU): + """Recurrent Gated Multimodal Unit, a hidden unit in a neural network that + learns to combine the representation of several modalities into a single one + incorporating recurrent activation over time. See GMU for more general + information. + + h generally refers to the hidden state, while + z generally refers to the gate. + h_l are the lateral information from the hidden state, while + z_l are the lateral information from the gate, i.e. the activations + from the last timestep. + + Note: Since this is a specialised subclass of the GMU, most of the + general behaviour is handled in the GMU class + + """ + + def __init__( + self, + in_features, + out_features, + modalities, + recurrent_modalities=True, + recurrent_gates=True, + activation=None, + gate_activation=None, + hidden_weight_init=None, + lateral_hidden_weight_init=None, + gate_weight_init=None, + lateral_gate_weight_init=None, + gate_hidden_interaction=None, + gate_transformation=None, + batch_first=True, + bias=True, + return_sequences=False, + ): + """Init function. + + Args: + out_features (int): number of (hidden) units / output features + in_features (int): vector length of a single modality + modalities (int): number of modalities + recurrent_modalities (bool, optional): if modality activation should + incorporate recurrent information + recurrent_gates (bool, optional): if gate activations should + incorporate recurrent information + activation (torch func, optional): activation function for the + modalities + gate_activation (torch func, optional): activation function for the + gate + hidden_weight_init (torch init func, optional): init method for the + neuronal weights + lateral_hidden_weight_init (torch init func, optional): init method + for the recurrent neural weights + gate_weight_init (torch init func, optional): init method for the + gate weights + lateral_gate_weight_init (torch init func, optional): init method + for the recurrent gate weights + gate_hidden_interaction (lambda func): how does h and z interact + with another. Could be linear or non-linear (e.g. x * (1+y)) + gate_transformation (lambda func): processes the gate + activations before they interact with the hidden state + batch_first (bool): use batch, sequence, feature instead of + sequence, batch, feature + default: True + bias (bool): tbd #todo + return_sequences (bool): if true returns all hidden states + from the intermediate time steps (as a list). The keras/tf + behaviour was the inspiration for that. + + """ + + super(RGMU, self).__init__( + in_features=in_features, + out_features=out_features, + modalities=modalities, + ) + if activation is not None: + self.activation = activation + if gate_activation is not None: + self.gate_activation = gate_activation + if hidden_weight_init is not None: + self.hidden_weight_init = hidden_weight_init + if lateral_hidden_weight_init is None: + self.lateral_hidden_weight_init = hidden_weight_init + else: + self.lateral_hidden_weight_init = lateral_hidden_weight_init + if gate_weight_init is not None: + self.gate_weight_init = gate_weight_init + if lateral_gate_weight_init is None: + self.lateral_gate_weight_init = gate_weight_init + else: + self.lateral_gate_weight_init = lateral_gate_weight_init + if gate_hidden_interaction is not None: + self.gate_hidden_interaction = gate_hidden_interaction + self.gate_transformation = gate_transformation + self.recurrent_modalities = recurrent_modalities + self.recurrent_gates = recurrent_gates + self.W_h = self.initialize_hidden_weights() + self.W_h_l = self.initialize_lateral_hidden_weights() + self.W_z = self.initialize_gate_weights() + self.W_z_l = self.initialize_lateral_gate_weights() + self.batch_first = batch_first + self.register_bias(bias) + self.register_recurrent_bias(bias) + self.return_sequences = return_sequences + + def register_recurrent_bias(self, bias): + """ register recurrent biases """ + if bias: + self.recurrent_hidden_bias = self.initialize_hidden_bias() + self.recurrent_gate_bias = self.initialize_gate_bias() + else: + self.register_parameter("recurrent_hidden_bias", None) + self.register_parameter("recurrent_gate_bias", None) + + def initialize_lateral_state(self, batch_size=1): + """Initializes lateral state for the first forward pass + + Returns: + a tuple of torch.Tensors + + """ + + h_l = torch.zeros((batch_size, self.modalities, self.out_features), device=self.W_h.device) + z_l = torch.zeros((batch_size, self.gates, self.out_features), device=self.W_h.device) + return h_l, z_l + + def initialize_lateral_hidden_weights(self): + """Initializes lateral hidden weights + + Returns: + torch.nn.Parameter + + """ + W = torch.nn.Parameter(torch.empty((self.modalities, self.out_features))) + if self.lateral_hidden_weight_init is not None: + self.lateral_hiden_weight_init(W) + return W + + def initialize_lateral_gate_weights(self): + """Initializes lateral gate weight parameters + + Returns: + torch.nn.Parameter + + """ + W = torch.nn.Parameter(torch.empty((self.gates, self.out_features))) + if self.lateral_gate_weight_init is not None: + self.lateral_gate_weight_init(W) + return W + + def get_recurrent_modality_activation(self, inputs, h_l): + """Processes the the modality information separately with a set of + weights and the weighted recurrent information from the last timestep + + Args: + inputs (Torch.Tensor): input to the layer/cell + h_l (Torch.Tensor): activations of the last timestep + + Returns: + Torch tensor of size (N,self.modalities,self.out_features) + + """ + + h = torch.sum(inputs.unsqueeze(-1) * self.W_h, -2) + self.W_h_l * h_l + if self.recurrent_hidden_bias is not None: + h += self.hidden_bias + self.recurrent_hidden_bias + return self.activation(h) + + def get_recurrent_gate_activation(self, inputs, z_l): + """Processes the gate information separately with a set of weights + and the weighted recurrent information from the last timestep + + Args: + inputs (Torch.Tensor): input to the layer/cell + z_l (Torch.Tensor): activations of the last timestep + + Returns: + Torch tensor of size (N,self.modalities,self.out_features) + + """ + + z = ( + torch.matmul(inputs.view(-1, self.in_features * self.modalities), self.W_z) + + (self.W_z_l.unsqueeze(0) * z_l).view(-1, self.gates * self.out_features) + ).view(-1, self.gates, self.out_features) + if self.recurrent_gate_bias is not None: + z += self.gate_bias + self.recurrent_gate_bias + z = self.gate_activation(z) + if self.gate_transformation is not None: + z = self.gate_transformation(z) + return z + + def step(self, inputs, lateral): + """Calculates the output of one timestep, depending on which of the + parts are recurrent, either modalities, gates or both + + Args: + inputs (torch.Tensors): consisting of + multiple modalities as torch.Tensors in the form NCH. + N is batch size, C is the modalities and H the length + of the modality vectors. + lateral (tuple of torch.Tensors): tuple consisting of both, + recurrent modality activations and recurrent gate + activations + + Returns: + A tuple of (torch.Tensor of size (N, self.out_features) and + a tuple of (modality and gate activations)). + """ + + inputs = self.check_input(inputs) + h_l, z_l = lateral + if self.recurrent_modalities: + h = self.get_recurrent_modality_activation(inputs, h_l) + else: + h = self.get_modality_activation(inputs) + + if self.recurrent_gates: + z = self.get_recurrent_gate_activation(inputs, z_l) + else: + z = self.get_gate_activation(inputs) + + return torch.sum(self.gate_hidden_interaction(h, z), 1), (h, z) + + def forward(self, inputs, lateral=None): + """Applies the layer computation to the whole sequence + + Args: + inputs (torch.Tensors): consisting of + multiple modalities as torch.Tensors in the form + if batch_first: NSCH. + N is batch size, S is sequence, C is the modalities and H the + length + of the modality vectors + else: SNCH + lateral (tuple of torch.Tensors): tuple consisting of both, + recurrent modality activations and recurrent gate + activations, if none is supplied, the lateral is intialized as zeros + + Returns: + A tuple of (torch.Tensor of size (N, self.out_features) and + a tuple of (modality (N,modalities,self.out_features) and gate + activations (N,gates,self.out_features)). + If return_sequences, then we follow the batch_first approach, + where the dimensions are N, sequences, self.out_feautures. + The lateral tuples will simply be in a list (for now). + """ + if lateral is None: + lateral = self.initialize_lateral_state() + + if self.return_sequences: + output_sequences = [] + lateral_sequences = [] + + if self.batch_first: + for i in range(inputs.shape[1]): + output, lateral = self.step(inputs[:, i], lateral) + if self.return_sequences: + output_sequences.append(output) + lateral_sequences.append(lateral) + else: + for data in inputs: + output, lateral = self.step(data, lateral) + if self.return_sequences: + output_sequences.append(output) + lateral_sequences.append(lateral) + if self.return_sequences: + return torch.stack(output_sequences, 1), lateral_sequences + else: + return output, lateral + + +class GMUConv2d(torch.nn.Module): + """Gated Multimodal Unit, a hidden unit in a neural network that learns + to combine the representation of different modalities into a single one + via gates (similar to LSTM). + Here, a specialised version is used that takes as input feature maps, + or general 2d input, convolves over these maps and subsequently, + outputs feature maps. + The only real difference to the non-conv versions is that the states + and values of the units are feature maps and not scalars. + + h generally refers to the hidden state, while + z generally refers to the gates. + """ + + def __init__( + self, + in_channels, + out_channels, + modalities, + kernel_size, + stride=1, + padding=0, + dilation=1, + activation=torch.tanh, + gate_activation=torch.sigmoid, + hidden_weight_init=lambda x: torch.nn.init.uniform_(x, -0.01, 0.01), + gate_weight_init=lambda x: torch.nn.init.uniform_(x, -0.01, 0.01), + gate_hidden_interaction=lambda x, y: x * y, + gate_transformation=None, + bias=True, + ): + """Init function. + + Args: + in_channels (int): number of input channels of each modality + out_channels (int): number of (hidden) units / output feature maps + modalities (int): number of modalities + activation (torch func): activation function for the modalities + gate_activation (torch func): activation function for the gate + weight_init (torch init func): init method for the neuronal weights + gate_weight_init (torch init func): init method for the gate weights + gate_hidden_interaction (lambda func): how does h and z interact + with another. Could be linear or non-linear (e.g. x * (1+y)) + gate_transformation (lambda func): processes the gate + activations before they interact with the hidden state, + e.g. normalise / gain control them by + lambda x: x / torch.sum(x, 1, keepdim=True) + + Note: at the moment, the input feature maps have to be streamlined in + the channel dimension. + i.e. they all have to have the same number of channels. + """ + + super(GMUConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.modalities = modalities + self.gates = modalities + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.activation = activation + self.gate_activation = gate_activation + self.hidden_weight_init = hidden_weight_init + self.gate_weight_init = gate_weight_init + self.hidden_bias_init = lambda x: torch.nn.init.uniform_( + x, -0.01, 0.01 + ) # make it as keyword? + self.gate_bias_init = lambda x: torch.nn.init.uniform_(x, -0.01, 0.01) + self.gate_hidden_interaction = gate_hidden_interaction + self.gate_transformation = gate_transformation + self.W_h = self.initialize_hidden_weights() + self.W_z = self.initialize_gate_weights() + self.register_bias(bias) + + def register_bias(self, bias): + if bias: + self.hidden_bias = self.initialize_hidden_bias() + self.gate_bias = self.initialize_gate_bias() + else: + self.register_parameter("hidden_bias", None) + self.register_parameter("gate_bias", None) + + def initialize_hidden_bias(self): + """ + + Returns: + torch.nn.Parameter + """ + b = torch.nn.Parameter(torch.empty((self.modalities * self.out_channels))) + self.hidden_bias_init(b) + return b + + def initialize_gate_bias(self): + """ + + Returns: + torch.nn.Parameter + """ + b = torch.nn.Parameter(torch.empty((self.gates * self.out_channels))) + self.gate_bias_init(b) + return b + + def initialize_gate_weights(self): + """Initializes gate weight/kernel parameters + + Returns: + torch.nn.Parameter + + """ + # each gate gets the information of all modalities + # one gate per modality + W = torch.nn.Parameter( + torch.empty( + ( + self.gates * self.out_channels, + self.modalities * self.in_channels, + self.kernel_size, + self.kernel_size, + ) + ) + ) + if self.gate_weight_init is not None: + self.gate_weight_init(W) + return W + + def initialize_hidden_weights(self): + """Initializes hidden weight/kernel parameters + + Returns: + torch.nn.Parameter + + """ + # each neuron only receives the information of its associated modality + W = torch.nn.Parameter( + torch.empty( + ( + self.modalities * self.out_channels, + self.in_channels, + self.kernel_size, + self.kernel_size, + ) + ) + ) + if self.hidden_weight_init is not None: + self.hidden_weight_init(W) + return W + + def get_modality_activation(self, inputs): + """Processes the modality information separately with a set of weights + + Notes: + The groups parameter is a bit poorly documented. It works as follows: https://mc.ai/how-groups-work-in-pytorch-convolutions/ + + Args: + inputs: input feature map to the layer/cell + + Returns: + Torch tensor of size (N,self.modalities,self.out_channels, *h, *w) + The *height and *weight are determined by the input size and the + use of padding, + dilation, stride etc. + + """ + + h = self.activation( + torch.nn.functional.conv2d( + inputs, + self.W_h, + self.hidden_bias, + self.stride, + self.padding, + self.dilation, + self.modalities, + ) + ) + return h.view(h.shape[0], self.modalities, -1, h.shape[-2], h.shape[-1]) + + def get_gate_activation(self, inputs): + """Processes the modality information with a set of weights (modalities are not treated separately but together) + + Args: + inputs: input feature map to the layer/cell + + Returns: + Torch tensor of size (N,self.gates,self.out_channels, *h, *w) + The *height and *weight are determined by the input size and the + use of padding, + dilation, stride etc. + """ + + z = self.gate_activation( + torch.nn.functional.conv2d( + inputs, + self.W_z, + self.gate_bias, + self.stride, + self.padding, + self.dilation, + 1, + ) + ) + z = z.view(z.shape[0], self.gates, -1, z.shape[-2], z.shape[-1]) + if self.gate_transformation is not None: + z = self.gate_transformation(z) + return z + + def forward(self, inputs): + """Calculates the output of the unit + + Args: + inputs (tuple of torch.Tensors): input tuple consisting of + multiple modalities as torch.Tensors in the form NCH. + N is batch size, C is the modalities (as in stacked on top of each other, even if they have multiple channels each) and HW are the sizes of the feature map + + Returns: + torch.Tensor of size (N, out_channels, *h, *w) + The *height and *weight are determined by the input size and the + use of padding, + dilation, stride etc. + + """ + + inputs = GMU.check_input(inputs) + h = self.get_modality_activation(inputs) + z = self.get_gate_activation(inputs) + return torch.sum(self.gate_hidden_interaction(h, z), 1), (h, z) + + +class GBUConv2d(GMUConv2d): + """Gated Multimodal Unit, a hidden unit in a neural network that learns + to combine the representation of different modalities into a single one + via gates (similar to LSTM). + Here, a specialised version is used that takes as input feature maps, + or general 2d input, convolves over these maps and subsequently, + outputs feature maps. + The only real difference to the non-conv versions is that the states + and values of the units are feature maps and not scalars. + GBU here indicates that only two modalities are possible for input + and only one gate is used. + + h generally refers to the hidden state, while + z generally refers to the gates. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=None, + padding=None, + dilation=None, + activation=None, + gate_activation=None, + hidden_weight_init=None, + gate_weight_init=None, + gate_hidden_interaction=None, + gate_transformation=lambda x: torch.cat((x, (1 - x)), 1), + bias=True, + ): + """Init function. + + Args: + in_channels (int): number of input channels of each modality + out_channels (int): number of (hidden) units / output feature maps + modalities (int): number of modalities + activation (torch func): activation function for the modalities + gate_activation (torch func): activation function for the gate + weight_init (torch init func): init method for the neuronal weights + gate_weight_init (torch init func): init method for the gate weights + gate_hidden_interaction (lambda func): how does h and z interact + with another. Could be linear or non-linear (e.g. x * (1+y)) + gate_transformation (lambda func): processes the gate + activations before they interact with the hidden state, + here in the bimodal case, it just concats the activations + with the complementary probabilites + + + Note: at the moment, the input feature maps have to be streamlined in + the channel dimension. + i.e. they all have to have the same number of channels. + """ + + super(GBUConv2d, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + modalities=2, + kernel_size=kernel_size, + ) + if stride is not None: + self.stride = stride + if padding is not None: + self.padding = padding + if dilation is not None: + self.dilation = dilation + if activation is not None: + self.activation = activation + if gate_activation is not None: + self.gate_activation = gate_activation + if hidden_weight_init is not None: + self.hidden_weight_init = hidden_weight_init + if gate_weight_init is not None: + self.gate_weight_init = gate_weight_init + if gate_hidden_interaction is not None: + self.gate_hidden_interaction = gate_hidden_interaction + if gate_transformation is not None: + self.gate_transformation = gate_transformation + self.gates = 1 + self.W_h = self.initialize_hidden_weights() + self.W_z = self.initialize_gate_weights() + self.register_bias(bias) + + +class RGMUConv2d(GMUConv2d): + """Recurrent Gated Multimodal Unit, a hidden unit in a neural network that + learns to combine the representation of different modalities into a + single one via gates (similar to LSTM). + Here, a specialised version is used that takes as input feature maps, + or general 2d input, convolves over these maps and subsequently, + outputs feature maps. + The only real difference to the non-conv versions is that the states + and values of the units are feature maps and not scalars. + Recurrent means that the either the gates or the modalities, or both, + incorporate information from prior timesteps in there processing. + + h generally refers to the hidden state, while + z generally refers to the gates. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + modalities, + input_size, + recurrent_modalities=True, + recurrent_gates=True, + stride=None, + padding=None, + dilation=None, + activation=None, + gate_activation=None, + hidden_weight_init=None, + lateral_hidden_weight_init=None, + gate_weight_init=None, + lateral_gate_weight_init=None, + gate_hidden_interaction=None, + gate_transformation=None, + batch_first=True, + return_sequences=False, + bias=True, + device="cuda:0", + ): + """Init function. + + Args: + in_channels (int): number of input channels of each modality + out_channels (int): number of (hidden) units / output feature maps + modalities (int): number of modalities + input_size (list or tuple): height and width of the input + recurrent_modalities (bool, optional): if modality activation should + incorporate recurrent information + recurrent_gates (bool, optional): if gate activations should + incorporate recurrent information + activation (torch func): activation function for the modalities + gate_activation (torch func): activation function for the gate + weight_init (torch init func): init method for the neuronal weights + lateral_hidden_weight_init (torch init func, optional): init method + for the recurrent neural weights + gate_weight_init (torch init func): init method for the gate weights + lateral_gate_weight_init (torch init func, optional): init method + for the recurrent gate weights + gate_hidden_interaction (lambda func): how does h and z interact + with another. Could be linear or non-linear (e.g. x * (1+y)) + gate_transformation (lambda func): processes the gate + activations before they interact with the hidden state, + here in the bimodal case, it just concats the activations + with the complementary probabilites + device (string): gpu or cpu device + + + Note: at the moment, the input feature maps have to be streamlined in + the channel dimension. + i.e. they all have to have the same number of channels. + """ + + super(RGMUConv2d, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + modalities=modalities, + kernel_size=kernel_size, + ) + self.device = device + self.height = input_size[0] + self.width = input_size[1] + if stride is not None: + self.stride = stride + if padding is not None: + self.padding = padding + if dilation is not None: + self.dilation = dilation + if activation is not None: + self.activation = activation + if gate_activation is not None: + self.gate_activation = gate_activation + if hidden_weight_init is not None: + self.hidden_weight_init = hidden_weight_init + if lateral_hidden_weight_init is not None: + self.lateral_hidden_weight_init = lateral_hidden_weight_init + else: + self.lateral_hidden_weight_init = self.hidden_weight_init + if gate_weight_init is not None: + self.gate_weight_init = gate_weight_init + if lateral_gate_weight_init is not None: + self.lateral_gate_weight_init = lateral_gate_weight_init + else: + self.lateral_gate_weight_init = self.gate_weight_init + if gate_hidden_interaction is not None: + self.gate_hidden_interaction = gate_hidden_interaction + if gate_transformation is not None: + self.gate_transformation = gate_transformation + self.gates = modalities + self.recurrent_modalities = recurrent_modalities + self.recurrent_gates = recurrent_gates + self.return_sequences = return_sequences + self.batch_first = batch_first + self.W_h = self.initialize_hidden_weights() + self.W_h_l = self.initialize_lateral_hidden_weights() + self.W_z = self.initialize_gate_weights() + self.W_z_l = self.initialize_lateral_gate_weights() + self.register_bias(bias) + self.register_recurrent_bias(bias) + + def register_recurrent_bias(self, bias): + if bias: + self.recurrent_hidden_bias = self.initialize_hidden_bias() + self.recurrent_gate_bias = self.initialize_gate_bias() + else: + self.register_parameter("recurrent_hidden_bias", None) + self.register_parameter("recurrent_gate_bias", None) + + def initialize_lateral_state(self, batch_size=1): + """ Todo: Docstring""" + h_l = torch.zeros( + ( + batch_size, + self.modalities * self.out_channels, + (self.height - self.kernel_size + self.padding * 2) // self.stride + 1, + (self.width - self.kernel_size + self.padding * 2) // self.stride + 1, + ) + ) + z_l = torch.zeros( + ( + batch_size, + self.gates * self.out_channels, + self.height - (self.kernel_size - 1) + self.padding * 2, + self.width - (self.kernel_size - 1) + self.padding * 2, + ) + ) + return h_l.to(self.device), z_l.to(self.device) + + def initialize_lateral_gate_weights(self): + """Initializes gate weight/kernel parameters + + Returns: + torch.nn.Parameter + + """ + # the recurrent processing takes as input the output of the gate + # processing, i.e. one feature map per gate, per RGMUCell + W = torch.nn.Parameter( + torch.empty( + ( + self.gates * self.out_channels, + 1, + self.kernel_size, + self.kernel_size, + ) + ) + ) + if self.gate_weight_init is not None: + self.gate_weight_init(W) + return W + + def initialize_lateral_hidden_weights(self): + """Initializes hidden weight/kernel parameters + + Returns: + torch.nn.Parameter + + """ + # as input we receive the output of the modality processing, + # i.e. one feature map per modality, per RGMUCell + W = torch.nn.Parameter( + torch.empty( + ( + self.modalities * self.out_channels, + 1, + self.kernel_size, + self.kernel_size, + ) + ) + ) + if self.hidden_weight_init is not None: + self.hidden_weight_init(W) + return W + + def get_recurrent_modality_activation(self, inputs, h_l): + """Processes the the modality information separately with a set of + weights and the weighted recurrent information from the last timestep + + Args: + inputs (Torch.Tensor): input to the layer/cell + h_l (Torch.Tensor): activations of the last timestep + + Returns: + Torch tensor of size (N,self.modalities,self.out_channels, *h, *w) + + """ + + h = self.activation( + torch.nn.functional.conv2d( + inputs, + self.W_h, + self.hidden_bias, + self.stride, + self.padding, + self.dilation, + self.modalities, + ) + + torch.nn.functional.conv2d( + h_l, + self.W_h_l, + self.recurrent_hidden_bias, + self.stride, + (self.kernel_size - 1) // 2, + self.dilation, + self.modalities * self.out_channels, + ) + ) + return h.view(h.shape[0], self.modalities, -1, h.shape[-2], h.shape[-1]) + + def get_recurrent_gate_activation(self, inputs, z_l): + """Processes the gate information separately with a set of weights + and the weighted recurrent information from the last timestep + + Args: + inputs (Torch.Tensor): input to the layer/cell + z_l (Torch.Tensor): activations of the last timestep + + Returns: + Torch tensor of size (N,self.gates,self.out_channels, *h, *w) + + """ + + z = self.gate_activation( + torch.nn.functional.conv2d( + inputs, + self.W_z, + self.gate_bias, + self.stride, + self.padding, + self.dilation, + 1, + ) + + torch.nn.functional.conv2d( + z_l, + self.W_z_l, + self.recurrent_gate_bias, + self.stride, + (self.kernel_size - 1) // 2, + self.dilation, + self.gates * self.out_channels, + ) + ) + z = z.view(z.shape[0], self.gates, -1, z.shape[-2], z.shape[-1]) + if self.gate_transformation is not None: + z = self.gate_transformation(z) + return z + + def step(self, inputs, lateral): + """ Copy from RGMU but adapt to 2D""" + + h_l, z_l = lateral + if self.recurrent_modalities: + h = self.get_recurrent_modality_activation(inputs, h_l) + else: + h = self.get_modality_activation(inputs) + + if self.recurrent_gates: + z = self.get_recurrent_gate_activation(inputs, z_l) + else: + z = self.get_gate_activation(inputs) + + return torch.sum(self.gate_hidden_interaction(h, z), 1), ( + h.view( + h.shape[0], + self.modalities * self.out_channels, + h.shape[-2], + h.shape[-1], + ), + z.view( + z.shape[0], + self.gates * self.out_channels, + z.shape[-2], + z.shape[-1], + ), + ) + + def forward(self, inputs, lateral=None): + """# TODO adapt docstring + + Args: + inputs (torch.Tensors): consisting of + multiple modalities as torch.Tensors in the form + if batch_first: NSCH. + N is batch size, S is sequence, C is the modalities and H the + length + of the modality vectors + else: SNCH + lateral (tuple of torch.Tensors): tuple consisting of both, + recurrent modality activations and recurrent gate + activations, if none is supplied, the lateral is intialized as zeros + + Returns: + A tuple of (torch.Tensor of size (N, self.out_features) and + a tuple of (modality (N,modalities,self.out_features) and gate + activations (N,gates,self.out_features)). + If return_sequences, then we follow the batch_first approach, + where the dimensions are N, sequences, self.out_feautures. + The lateral tuples will simply be in a list (for now). + """ + if lateral is None: + lateral = self.initialize_lateral_state() + # So if you run the code on GPU, it leads to errors. + + if self.return_sequences: + output_sequences = [] + lateral_sequences = [] + + if self.batch_first: + for i in range(inputs.shape[1]): + output, lateral = self.step(inputs[:, i], lateral) + if self.return_sequences: + output_sequences.append(output) + lateral_sequences.append(lateral) + else: + for data in inputs: + output, lateral = self.step(data, lateral) + if self.return_sequences: + output_sequences.append(output) + lateral_sequences.append(lateral) + if self.return_sequences: + return torch.stack(output_sequences, 1), lateral_sequences + else: + return output, lateral + + +if __name__ == "__main__": + # import numpy as np + + # np.random.seed(1337) + # from torch.utils.data import Dataset + # import multimodal as mm + + rgmuconv_in1_out2_mod3 = RGMUConv2d( + in_channels=1, + out_channels=2, + modalities=3, + kernel_size=3, + input_size=[5, 5], + ) + + input2d_5x5_c1_mod3_len4 = torch.ones((8, 4, 3, 5, 5)) + lat = rgmuconv_in1_out2_mod3.initialize_lateral_state() + h, z = lat + + rgmuconv_in1_out2_mod3(input2d_5x5_c1_mod3_len4, lat) diff --git a/gazenet/models/shared_components/gradcam/__init__.py b/gazenet/models/shared_components/gradcam/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/shared_components/gradcam/model.py b/gazenet/models/shared_components/gradcam/model.py new file mode 100644 index 0000000..ab22fe3 --- /dev/null +++ b/gazenet/models/shared_components/gradcam/model.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Implementation of ESR-9 (Siqueira et al., 2020) trained on AffectNet (Mollahosseini et al., 2017) for emotion +and affect perception. + + +Reference: + Siqueira, H., Magg, S. and Wermter, S., 2020. Efficient Facial Feature Learning with Wide Ensemble-based + Convolutional Neural Networks. Proceedings of the Thirty-Fourth AAAI Conference on Artificial Intelligence + (AAAI-20), pages 1–1, New York, USA. + + Mollahosseini, A., Hasani, B. and Mahoor, M.H., 2017. AffectNet: A database for facial expression, valence, + and arousal computing in the wild. IEEE Transactions on Affective Computing, 10(1), pp.18-31. +""" + +__author__ = "Henrique Siqueira" +__email__ = "siqueira.hc@outlook.com" +__license__ = "MIT license" +__version__ = "1.0" + +# External libraries +import torch +from torch.autograd import Variable +import time + + +class GradCAM: + """ + Implementation of the Grad-CAM visualization algorithm (Selvaraju et al., 2017). + Generates saliency maps with respect to discrete emotion labels (the second last, fully-connected + layer of ESR-9 (Siqueira et al., 2020)). + + Reference: + Selvaraju, R.R., Cogswell, M., Das, A., Vedantam, R., Parikh, D. and Batra, D., 2017. + Grad-cam: Visual explanations from deep networks via gradient-based localization. + In Proceedings of the IEEE international conference on computer vision (pp. 618-626). + + Siqueira, H., Magg, S. and Wermter, S., 2020. Efficient Facial Feature Learning with Wide Ensemble-based + Convolutional Neural Networks. Proceedings of the Thirty-Fourth AAAI Conference on Artificial Intelligence + (AAAI-20), pages 1–1, New York, USA. + """ + + # def __init__(self, esr_base, esr_branch_to_last_conv_layer, esr_branch_from_last_conv_layer_to_emotion_output): + def __init__(self, esr, device): + self._zero_grad = esr.zero_grad + self._esr_base = esr.base + self._esr_branch_to_last_conv_layer = [] + self._esr_branch_from_last_conv_layer_to_emotion_output = [] + for branch in esr.convolutional_branches: + self._esr_branch_to_last_conv_layer.append(branch.forward_to_last_conv_layer) + self._esr_branch_from_last_conv_layer_to_emotion_output.append( + branch.forward_from_last_conv_layer_to_output_layer) + + self._gradients = None + self._device = device + + def __call__(self, x, i): + # Clear gradients + self._gradients = [] + + # Forward activations to the last convolutional layer + feature_maps = self._esr_base(x) + feature_maps = self._esr_branch_to_last_conv_layer[i](feature_maps) + + # Saves gradients + feature_maps.register_hook(self.set_gradients) + + # Forward feature maps to the discrete emotion output layer (the second last, fully-connected layer) + output_activations = self._esr_branch_from_last_conv_layer_to_emotion_output[i](feature_maps) + + return feature_maps, output_activations + + def set_gradients(self, grads): + self._gradients.append(grads) + + def get_mean_gradients(self): + return self._gradients[0].mean(3).mean(2)[0] + + def grad_cam(self, x, list_y): + # cumm_saliency_map = None + list_saliency_maps = [] + + for i, y in enumerate(list_y): + # Set gradients to zero + self._zero_grad() + + # Forward phase + feature_maps, output_activations = self(x, i) + feature_maps = feature_maps[0] + + # TODO (fabawi): seperate the backward and forward to maintain keep_graph=True when training just like + # https://github.com/jacobgil/pytorch-grad-cam/blob/master/gradcam.py + # Backward the activation of the neuron associated to + # the predicted emotion to the last convolutional layer + one_hot = torch.zeros(output_activations.size()) + one_hot[0][y] = 1 + one_hot = Variable(one_hot, requires_grad=True).to(self._device) + one_hot = torch.sum(one_hot * output_activations) + + # Back-propagate activations + one_hot.backward(retain_graph=False) + + # Get mean gradient for every convolutional filter + grad_cam_weights = self.get_mean_gradients() + + # Computes saliency map as a weighted sum of feature maps and mean gradient + saliency_map = torch.zeros(feature_maps.size()[1:]).to(self._device) + for j, w in enumerate(grad_cam_weights): + saliency_map += w * feature_maps[j, :, :] + + # Normalize saliency maps + # saliency_map = torch.clamp(saliency_map, min=0) + saliency_map -= torch.min(saliency_map) + # saliency_map /= torch.max(saliency_map) + + list_saliency_maps.append(saliency_map) + + # Return the list of normalized saliency maps + return list_saliency_maps + + +''' + print(saliency_map) + + if power == 0: + if cumm_saliency_map is None: + cumm_saliency_map = torch.ones(feature_maps.size()[1:]).to(self._device) + cumm_saliency_map = torch.mul(cumm_saliency_map, saliency_map) + else: + if cumm_saliency_map is None: + cumm_saliency_map = torch.zeros(feature_maps.size()[1:]).to(self._device) + cumm_saliency_map = torch.stack((cumm_saliency_map, saliency_map.pow(power))).sum(dim=0) + + time.sleep(2) + feature_maps = None + output_activations = None + one_hot = None + grad_cam_weights = None + saliency_map = None + torch.cuda.empty_cache() + print("here 3", i) + + if power == 0: + cumm_saliency_map = torch.pow(cumm_saliency_map, 1 / len(list_y)) + else: + cumm_saliency_map /= len(list_y) + cumm_saliency_map = torch.pow(cumm_saliency_map, 1 / power) + + cumm_saliency_map /= torch.max(cumm_saliency_map) + + # Return the list of normalized saliency maps + return cumm_saliency_map''' \ No newline at end of file diff --git a/gazenet/models/shared_components/mobilenetv2/__init__.py b/gazenet/models/shared_components/mobilenetv2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/shared_components/mobilenetv2/model.py b/gazenet/models/shared_components/mobilenetv2/model.py new file mode 100644 index 0000000..28ee80b --- /dev/null +++ b/gazenet/models/shared_components/mobilenetv2/model.py @@ -0,0 +1,190 @@ +import math +from pathlib import Path + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + +MODEL_URLS = {"mobilenetv2": "https://github.com/rdroste/unisal/raw/master/unisal/models/weights/mobilenet_v2.pth.tar"} + +# Source: https://github.com/tonylins/pytorch-mobilenet-v2 + + +def conv_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + +def conv_1x1_bn(inp, oup): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU6(inplace=True) + ) + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio, omit_stride=False, + no_res_connect=False, dropout=0., bn_momentum=0.1, + batchnorm=None): + super().__init__() + self.out_channels = oup + self.stride = stride + self.omit_stride = omit_stride + self.use_res_connect = not no_res_connect and\ + self.stride == 1 and inp == oup + self.dropout = dropout + actual_stride = self.stride if not self.omit_stride else 1 + if batchnorm is None: + def batchnorm(num_features): + return nn.BatchNorm2d(num_features, momentum=bn_momentum) + + assert actual_stride in [1, 2] + + hidden_dim = round(inp * expand_ratio) + if expand_ratio == 1: + modules = [ + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, actual_stride, 1, + groups=hidden_dim, bias=False), + batchnorm(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + batchnorm(oup), + ] + if self.dropout > 0: + modules.append(nn.Dropout2d(self.dropout)) + self.conv = nn.Sequential(*modules) + else: + modules = [ + # pw + nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), + batchnorm(hidden_dim), + nn.ReLU6(inplace=True), + # dw + nn.Conv2d(hidden_dim, hidden_dim, 3, actual_stride, 1, + groups=hidden_dim, bias=False), + batchnorm(hidden_dim), + nn.ReLU6(inplace=True), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + batchnorm(oup), + ] + if self.dropout > 0: + modules.insert(3, nn.Dropout2d(self.dropout)) + self.conv = nn.Sequential(*modules) + self._initialize_weights() + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + +class MobileNetV2(nn.Module): + def __init__(self, widen_factor=1., pretrained=True, + last_channel=None, input_channel=32): + super().__init__() + self.widen_factor = widen_factor + self.pretrained = pretrained + self.last_channel = last_channel + self.input_channel = input_channel + + block = InvertedResidual + interverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # building first layer + input_channel = int(self.input_channel * widen_factor) + self.features = [conv_bn(3, input_channel, 2)] + # building inverted residual blocks + for t, c, n, s in interverted_residual_setting: + output_channel = int(c * widen_factor) + for i in range(n): + if i == 0: + self.features.append(block( + input_channel, output_channel, s, expand_ratio=t, + omit_stride=True)) + else: + self.features.append(block( + input_channel, output_channel, 1, expand_ratio=t)) + input_channel = output_channel + # building last several layers + if self.last_channel is not None: + output_channel = int(self.last_channel * widen_factor)\ + if widen_factor > 1.0 else self.last_channel + self.features.append(conv_1x1_bn(input_channel, output_channel)) + # make it nn.Sequential + self.features = nn.Sequential(*self.features) + self.out_channels = output_channel + self.feat_1x_channels = int( + interverted_residual_setting[-1][1] * widen_factor) + self.feat_2x_channels = int( + interverted_residual_setting[-2][1] * widen_factor) + self.feat_4x_channels = int( + interverted_residual_setting[-4][1] * widen_factor) + self.feat_8x_channels = int( + interverted_residual_setting[-5][1] * widen_factor) + + if self.pretrained: + + self.load_state_dict(model_zoo.load_url(MODEL_URLS['mobilenetv2']), strict = False) + else: + self._initialize_weights() + + def forward(self, x): + # x = self.features(x) + feat_2x, feat_4x, feat_8x = None, None, None + for idx, module in enumerate(self.features._modules.values()): + x = module(x) + if idx == 7: + feat_4x = x.clone() + elif idx == 14: + feat_2x = x.clone() + if idx > 0 and hasattr(module, 'stride') and module.stride != 1: + x = x[..., ::2, ::2] + + return x, feat_2x, feat_4x + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() diff --git a/gazenet/models/shared_components/resnet/__init__.py b/gazenet/models/shared_components/resnet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/shared_components/resnet/model.py b/gazenet/models/shared_components/resnet/model.py new file mode 100644 index 0000000..feed92a --- /dev/null +++ b/gazenet/models/shared_components/resnet/model.py @@ -0,0 +1,270 @@ +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152'] + +MODEL_URLS = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + # self.avgpool = nn.AdaptativeAvgPool((1,1), stride=1) + self.fc1 = nn.Linear(512 * block.expansion, 1000) + self.fc2 = nn.Linear(1000, 3) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + feat_D = self.layer2(x) + x = self.layer3(feat_D) + x = self.layer4(x) + # print('Size at output',x.size()) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + # x = nn.Dropout()(x) + x = nn.ReLU()(self.fc1(x)) + x = self.fc2(x) + + return x + + +class ResNetCAM(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 64 + super(ResNetCAM, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc1 = nn.Linear(512 * block.expansion, 1000) + self.fc2 = nn.Linear(1000, 3) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x2 = self.layer2(x) + x2 = self.layer3(x2) + x2 = self.layer4(x2) + return x, x2 + + +def resnetCAM(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNetCAM(BasicBlock, [2, 2, 2, 2], **kwargs) + return model + + +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(MODEL_URLS['resnet18']), strict=False) + return model + + +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(MODEL_URLS['resnet34'])) + return model + + +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(MODEL_URLS['resnet50']), strict=False) + return model + + +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(MODEL_URLS['resnet101'])) + return model + + +def resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(MODEL_URLS['resnet152'])) + return model diff --git a/gazenet/models/shared_components/resnet3d/__init__.py b/gazenet/models/shared_components/resnet3d/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/shared_components/resnet3d/model.py b/gazenet/models/shared_components/resnet3d/model.py new file mode 100644 index 0000000..ced693f --- /dev/null +++ b/gazenet/models/shared_components/resnet3d/model.py @@ -0,0 +1,268 @@ +# +# 3D-ResNet implementation +# provided by Kensho Hara +# introduced in +# Kensho Hara, Hirokatsu Kataoka, and Yutaka Satoh, +# "Can Spatiotemporal 3D CNNs Retrace the History of 2D CNNs and ImageNet?", +# Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 6546-6555, 2018. +# + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import math +from functools import partial, reduce + +__all__ = ['ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet200'] + + +class LambdaBase(nn.Sequential): + def __init__(self, fn, *args): + super(LambdaBase, self).__init__(*args) + self.lambda_func = fn + + def forward_prepare(self, input): + output = [] + for module in self._modules.values(): + output.append(module(input)) + return output if output else input + +class Lambda(LambdaBase): + def forward(self, input): + return self.lambda_func(self.forward_prepare(input)) + +class LambdaMap(LambdaBase): + def forward(self, input): + return list(map(self.lambda_func,self.forward_prepare(input))) + +class LambdaReduce(LambdaBase): + def forward(self, input): + return reduce(self.lambda_func,self.forward_prepare(input)) + + +def conv3x3x3(in_planes, out_planes, stride=1): + # 3x3x3 convolution with padding + return nn.Conv3d(in_planes, out_planes, kernel_size=3, + stride=stride, padding=1, bias=False) + + +def downsample_basic_block(x, planes, stride): + out = F.avg_pool3d(x, kernel_size=1, stride=stride) + zero_pads = torch.Tensor(out.size(0), planes - out.size(1), + out.size(2), out.size(3), + out.size(4)).zero_() + if isinstance(out.data, torch.cuda.FloatTensor): + zero_pads = zero_pads.cuda() + + out = Variable(torch.cat([out.data, zero_pads], dim=1)) + + return out + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, training=True): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm3d(planes, momentum=0.1 if training else 0.0) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3x3(planes, planes) + self.bn2 = nn.BatchNorm3d(planes, momentum=0.1 if training else 0.0) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, training=True): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm3d(planes, momentum=0.1 if training else 0.0) + self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm3d(planes, momentum=0.1 if training else 0.0) + self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm3d(planes * 4, momentum=0.1 if training else 0.0) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, sample_size, sample_duration, shortcut_type='B', num_classes=400, last_fc=True, last_pool=True): + self.last_fc = last_fc + self.last_pool = last_pool + + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2), + padding=(3, 3, 3), bias=False) + self.bn1 = nn.BatchNorm3d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) + self.layer2 = self._make_layer(block, 128, layers[1], shortcut_type, stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], shortcut_type, stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], shortcut_type, stride=2) + last_duration = math.ceil(sample_duration / 16) + last_size = math.ceil(sample_size / 32) + self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size), stride=1) + self.fc_new = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv3d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm3d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + if shortcut_type == 'A': + downsample = partial(downsample_basic_block, + planes=planes * block.expansion, + stride=stride) + else: + downsample = nn.Sequential( + nn.Conv3d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm3d(planes * block.expansion) + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + if self.last_pool: + x = self.avgpool(x) + + if self.last_fc: + x = x.view(x.size(0), -1) + x = self.fc_new(x) + + return x + + +def get_fine_tuning_parameters(model, ft_begin_index): + if ft_begin_index == 0: + return model.parameters() + + ft_module_names = [] + for i in range(ft_begin_index, 5): + ft_module_names.append('layer{}'.format(ft_begin_index)) + ft_module_names.append('fc') + + parameters = [] + for k, v in model.named_parameters(): + for ft_module in ft_module_names: + if ft_module in k: + parameters.append({'params': v}) + break + else: + parameters.append({'params': v, 'lr': 0.0}) + + return parameters + + +def resnet10(**kwargs): + """Constructs a ResNet-10 model. + """ + model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) + return model + +def resnet18(**kwargs): + """Constructs a ResNet-18 model. + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + return model + +def resnet34(**kwargs): + """Constructs a ResNet-34 model. + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + return model + +def resnet50(**kwargs): + """Constructs a ResNet-50 model. + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + return model + +def resnet101(**kwargs): + """Constructs a ResNet-101 model. + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + return model + +def resnet152(**kwargs): + """Constructs a ResNet-152 model. + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + return model + +def resnet200(**kwargs): + """Constructs a ResNet-200 model. + """ + model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs) + return model \ No newline at end of file diff --git a/gazenet/models/shared_components/soundnet8/__init__.py b/gazenet/models/shared_components/soundnet8/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/shared_components/soundnet8/model.py b/gazenet/models/shared_components/soundnet8/model.py new file mode 100644 index 0000000..121ce14 --- /dev/null +++ b/gazenet/models/shared_components/soundnet8/model.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn + + +class SoundNet(nn.Module): + + def __init__(self, momentum=0.1, reverse=False): + super(SoundNet, self).__init__() + self.reverse = reverse + + self.conv1 = nn.Conv2d(1, 16, kernel_size=self._rev(64, 1), stride=self._rev(2, 1), + padding=self._rev(32, 0)) + self.batchnorm1 = nn.BatchNorm2d(16, eps=1e-5, momentum=momentum) + self.relu1 = nn.ReLU(True) + self.maxpool1 = nn.MaxPool2d(self._rev(8, 1), stride=self._rev(8, 1)) + + self.conv2 = nn.Conv2d(16, 32, kernel_size=self._rev(32, 1), stride=self._rev(2, 1), + padding=self._rev(16, 0)) + self.batchnorm2 = nn.BatchNorm2d(32, eps=1e-5, momentum=momentum) + self.relu2 = nn.ReLU(True) + self.maxpool2 = nn.MaxPool2d(self._rev(8, 1), stride=self._rev(8, 1)) + + self.conv3 = nn.Conv2d(32, 64, kernel_size=self._rev(16, 1), stride=self._rev(2, 1), + padding=self._rev(8, 0)) + self.batchnorm3 = nn.BatchNorm2d(64, eps=1e-5, momentum=momentum) + self.relu3 = nn.ReLU(True) + + self.conv4 = nn.Conv2d(64, 128, kernel_size=self._rev(8, 1), stride=self._rev(2, 1), + padding=self._rev(4, 0)) + self.batchnorm4 = nn.BatchNorm2d(128, eps=1e-5, momentum=momentum) + self.relu4 = nn.ReLU(True) + + self.conv5 = nn.Conv2d(128, 256, kernel_size=self._rev(4, 1), stride=self._rev(2, 1), + padding=self._rev(2, 0)) + self.batchnorm5 = nn.BatchNorm2d(256, eps=1e-5, momentum=momentum) + self.relu5 = nn.ReLU(True) + self.maxpool5 = nn.MaxPool2d(self._rev(4, 1), stride=self._rev(4, 1)) + + self.conv6 = nn.Conv2d(256, 512, kernel_size=self._rev(4, 1), stride=self._rev(2, 1), + padding=self._rev(2, 0)) + self.batchnorm6 = nn.BatchNorm2d(512, eps=1e-5, momentum=momentum) + self.relu6 = nn.ReLU(True) + + self.conv7 = nn.Conv2d(512, 1024, kernel_size=self._rev(4, 1), stride=self._rev(2, 1), + padding=self._rev(2, 0)) + self.batchnorm7 = nn.BatchNorm2d(1024, eps=1e-5, momentum=momentum) + self.relu7 = nn.ReLU(True) + + self.conv8_objs = nn.Conv2d(1024, 1000, kernel_size=(8, 1), + stride=(2, 1)) + self.conv8_scns = nn.Conv2d(1024, 401, kernel_size=(8, 1), + stride=(2, 1)) + + def forward(self, waveform): + x = self.conv1(waveform) + x = self.batchnorm1(x) + x = self.relu1(x) + x = self.maxpool1(x) + + x = self.conv2(x) + x = self.batchnorm2(x) + x = self.relu2(x) + x = self.maxpool2(x) + + x = self.conv3(x) + x = self.batchnorm3(x) + x = self.relu3(x) + + x = self.conv4(x) + x = self.batchnorm4(x) + x = self.relu4(x) + + x = self.conv5(x) + x = self.batchnorm5(x) + x = self.relu5(x) + x = self.maxpool5(x) + + x = self.conv6(x) + x = self.batchnorm6(x) + x = self.relu6(x) + + x = self.conv7(x) + x = self.batchnorm7(x) + x = self.relu7(x) + + return x + + def _rev(self, *tup): + if self.reverse: + new_tup = tup[::-1] + return new_tup + else: + return tup diff --git a/gazenet/models/shared_components/squeezeexcitation/__init__.py b/gazenet/models/shared_components/squeezeexcitation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/shared_components/squeezeexcitation/model.py b/gazenet/models/shared_components/squeezeexcitation/model.py new file mode 100644 index 0000000..2aaf235 --- /dev/null +++ b/gazenet/models/shared_components/squeezeexcitation/model.py @@ -0,0 +1,20 @@ +import torch +from torch import nn + +# code from: https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py +class SELayer(nn.Module): + def __init__(self, channel, reduction=2): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y.expand_as(x) \ No newline at end of file diff --git a/gazenet/models/shared_components/transformer/__init__.py b/gazenet/models/shared_components/transformer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/models/shared_components/transformer/model.py b/gazenet/models/shared_components/transformer/model.py new file mode 100644 index 0000000..e4dda17 --- /dev/null +++ b/gazenet/models/shared_components/transformer/model.py @@ -0,0 +1,71 @@ +import math +import torch +from torch import nn + + +class PositionalEncoding(nn.Module): + + def __init__(self, feat_size, dropout=0.1, max_len=4): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, feat_size) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, feat_size, 2).float() * (-math.log(10000.0) / feat_size)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, x): + # print(x.shape, self.pe.shape) + x = x + self.pe + # return self.dropout(x) + return x + + +class Transformer(nn.Module): + def __init__(self, feat_size, hidden_size=256, nhead=4, num_encoder_layers=3, max_len=4, num_decoder_layers=-1, + num_queries=4, spatial_dim=-1): + super(Transformer, self).__init__() + self.pos_encoder = PositionalEncoding(feat_size, max_len=max_len) + encoder_layers = nn.TransformerEncoderLayer(feat_size, nhead, hidden_size) + + self.spatial_dim = spatial_dim + if self.spatial_dim != -1: + transformer_encoder_spatial_layers = nn.TransformerEncoderLayer(spatial_dim, nhead, hidden_size) + self.transformer_encoder_spatial = nn.TransformerEncoder(transformer_encoder_spatial_layers, + num_encoder_layers) + + self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers) + self.use_decoder = (num_decoder_layers != -1) + + if self.use_decoder: + decoder_layers = nn.TransformerDecoderLayer(hidden_size, nhead, hidden_size) + self.transformer_decoder = nn.TransformerDecoder(decoder_layers, num_decoder_layers, + norm=nn.LayerNorm(hidden_size)) + self.tgt_pos = nn.Embedding(num_queries, hidden_size).weight + assert self.tgt_pos.requires_grad == True + + def forward(self, embeddings, idx): + ''' embeddings: CxBxCh*H*W ''' + # print(embeddings.shape) + batch_size = embeddings.size(1) + + if self.spatial_dim != -1: + embeddings = embeddings.permute((2, 1, 0)) + embeddings = self.transformer_encoder_spatial(embeddings) + embeddings = embeddings.permute((2, 1, 0)) + + x = self.pos_encoder(embeddings) + x = self.transformer_encoder(x) + if self.use_decoder: + if idx != -1: + tgt_pos = self.tgt_pos[idx].unsqueeze(0) + # print(tgt_pos.size()) + tgt_pos = tgt_pos.unsqueeze(1).repeat(1, batch_size, 1) + else: + tgt_pos = self.tgt_pos.unsqueeze(1).repeat(1, batch_size, 1) + tgt = torch.zeros_like(tgt_pos) + x = self.transformer_decoder(tgt + tgt_pos, x) + return x diff --git a/gazenet/readers/__init__.py b/gazenet/readers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/readers/coutrot.py b/gazenet/readers/coutrot.py new file mode 100644 index 0000000..8d71f21 --- /dev/null +++ b/gazenet/readers/coutrot.py @@ -0,0 +1,228 @@ +""" +Class for reading and decoding the Coutrot1 [1] and Coutrot2 [2] datasets + + +[1] Coutrot, A., & Guyader, N. (2014). + How saliency, faces, and sound influence gaze in dynamic social scenes. + Journal of vision, 14(8), 5-5. + +[2] Coutrot, A., & Guyader, N. (2015, August). + An efficient audiovisual saliency model to infer eye positions when looking at conversations. + In 2015 23rd European Signal Processing Conference (EUSIPCO) (pp. 1531-1535). IEEE. +""" + +import os + +import numpy as np +import scipy.io as sio +from tqdm import tqdm + +from gazenet.utils.registrar import * +from gazenet.utils.helpers import extract_thumbnail_from_video, check_audio_in_video +from gazenet.utils.sample_processors import SampleReader, SampleProcessor + + +@ReaderRegistrar.register +class CoutrotSampleReader(SampleReader): + def __init__(self, video_dir, annotations_file, + database_type, # database_type = 'Coutrot_Database1'| 'Coutrot_Database2' + auditory_condition, video_format="avi", + extract_thumbnails=True, + pickle_file=None, mode=None, **kwargs): + self.short_name = "coutrot" + self.video_dir = video_dir + self.annotations_file = annotations_file + self.database_type = database_type + self.auditory_condition = auditory_condition + self.video_format = video_format + self.extract_thumbnails = extract_thumbnails + + super().__init__(pickle_file=pickle_file, mode=mode, **kwargs) + + def read_raw(self): + # single annotations file in matlab format + annotations = sio.loadmat(self.annotations_file) + # annotations['Coutrot_Database1'][0][0][x] Auditory condition -> clips in red on webpage are actually excluded + # annotations['Coutrot_Database1']['OriginalSounds'][0][0]['clip_1'][0][0][0][0]['data'][1][2][3] -> [1]:x(0),y(1),[2]: video_frames_list, [3]: participents? + + for video_name in tqdm(sorted(os.listdir(self.video_dir)), desc="Samples Read"): + if video_name.endswith("." + self.video_format): + id = video_name.replace("." + self.video_format, "") + + try: + # annotation assembly + annotation = annotations[self.database_type][self.auditory_condition][0][0][id][0][0][0] + self.samples.append({"id": id, + "audio_name": os.path.join(self.video_dir, video_name), + "video_name": os.path.join(self.video_dir, video_name), + "video_fps": annotation['info'][0]['fps'][0][0][0][0], + "video_width": annotation['info'][0]['vidwidth'][0][0][0][0], + "video_height": annotation['info'][0]['vidheight'][0][0][0][0], + "video_thumbnail": extract_thumbnail_from_video( + os.path.join(self.video_dir, video_name)) if self.extract_thumbnails else None, + "len_frames": annotation['info'][0]['nframe'][0][0][0][0], + "has_audio": check_audio_in_video(os.path.join(self.video_dir, video_name)), + "annotation_name": os.path.join(self.database_type, self.auditory_condition, id), + "annotations": {"xyp": annotation['data'][0]} + }) + self.video_id_to_sample_idx[id] = len(self.samples) - 1 + self.len_frames += self.samples[-1]["len_frames"] + except: + print("Error: Access non-existent annotation " + id) + + @staticmethod + def dataset_info(): + return {"summary": "TODO", + "name": "Coutrot Dataset", + "link": "TODO"} + +@ReaderRegistrar.register +class Coutrot1SampleReader(CoutrotSampleReader): + def __init__(self, video_dir="datasets/ave/database1/ERB3_Stimuli", + annotations_file="datasets/ave/database1/coutrot_database1.mat", + database_type='Coutrot_Database1', auditory_condition='OriginalSounds', + pickle_file="temp/coutrot1.pkl", mode=None, **kwargs): + super().__init__(video_dir=video_dir, annotations_file=annotations_file, + database_type=database_type, auditory_condition=auditory_condition, + pickle_file=pickle_file, mode=mode, **kwargs) + self.short_name = "coutrot1" + + @staticmethod + def dataset_info(): + return {"summary": "TODO", + "name": "Coutrot Dataset1 (Coutrot et al.)", + "link": "TODO"} + +@ReaderRegistrar.register +class Coutrot2SampleReader(CoutrotSampleReader): + def __init__(self, video_dir="datasets/ave/database2/ERB4_Stimuli", + annotations_file="datasets/ave/database2/coutrot_database2.mat", + database_type='Coutrot_Database2', auditory_condition='AudioVisual', + pickle_file="temp/coutrot2.pkl", mode=None, **kwargs): + super().__init__(video_dir=video_dir, annotations_file=annotations_file, + database_type=database_type, auditory_condition=auditory_condition, + pickle_file=pickle_file, mode=mode, **kwargs) + self.short_name = "coutrot2" + + @staticmethod + def dataset_info(): + return {"summary": "TODO", + "name": "Coutrot Dataset2 (Coutrot et al.)", + "link": "TODO"} + +@SampleRegistrar.register +class CoutrotSample(SampleProcessor): + def __init__(self, reader, index=-1, frame_index=0, width=640, height=480, **kwargs): + assert isinstance(reader, CoutrotSampleReader) + self.short_name = reader.short_name + self.reader = reader + self.index = index + + if frame_index > 0: + self.goto_frame(frame_index) + super().__init__(width=width, height=height, **kwargs) + next(self) + + def __next__(self): + with self.read_lock: + self.index += 1 + self.index %= len(self.reader.samples) + curr_metadata = self.reader.samples[self.index] + self.load(curr_metadata) + return curr_metadata + + def __len__(self): + return len(self.reader) + + def next(self): + return next(self) + + def goto(self, name, by_index=True): + if by_index: + index = name + else: + index = self.reader.video_id_to_sample_idx[name] + + with self.read_lock: + self.index = index + curr_metadata = self.reader.samples[self.index] + self.load(curr_metadata) + return curr_metadata + + def annotate_frame(self, input_data, plotter, + show_saliency_map=False, + show_fixation_locations=False, + participant=None, # None means all participants will be plotted + enable_transform_overlays=True, + color_map=None, + **kwargs): + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, _ = input_data + + properties = {"show_saliency_map": (show_saliency_map, "toggle", (True, False)), + "show_fixation_locations": (show_fixation_locations, "toggle", (True, False))} + + info = {**info, "frame_annotations": { + "eye_fixation_points": [], + "eye_fixation_participants": [] + }} + # info["frame_info"]["dataset_name"] = self.reader.short_name + # info["frame_info"]["video_id"] = self.reader.samples[self.index]["id"] + # info["frame_info"]["frame_height"] = self.reader.samples[self.index]["video_height"] + # info["frame_info"]["frame_width"] = self.reader.samples[self.index]["video_width"] + + grouped_video_frames = {**grouped_video_frames, + "PLOT": [["captured", "transformed_salmap", "transformed_fixmap"]], + "transformed_salmap": grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"]), + "transformed_fixmap": grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"])} + + try: + frame_index = self.frame_index() + video_frame_salmap = grouped_video_frames["transformed_salmap"] + video_frame_fixmap = grouped_video_frames["transformed_fixmap"] + if grabbed_video: + + ann = self.reader.samples[self.index]["annotations"] + if participant is None: + fixation_participants = ann["xyp"][2, frame_index - 1, :] + fixation_annotations = np.vstack((ann["xyp"][0, frame_index - 1, :], + # empirically introduced a vertical shift of 20 pixels + ann["xyp"][1, frame_index - 1, :] - 20, + np.ones_like((ann["xyp"][0, frame_index - 1, :])) + # no fixation amplitude + )).transpose() + else: + fixation_participants = ann["xyp"][2, frame_index - 1, participant] + fixation_annotations = np.vstack((ann["xyp"][0, frame_index - 1, participant], + # empirically introduced a vertical shift of 20 pixels + ann["xyp"][1, frame_index - 1, participant] - 20, + np.ones_like((ann["xyp"][0, frame_index - 1, participant])) + # no fixation amplitude + )).transpose() + + info["frame_annotations"]["eye_fixation_participants"].append(fixation_participants) + info["frame_annotations"]["eye_fixation_points"].append(fixation_annotations) + if show_saliency_map: + video_frame_salmap = plotter.plot_fixations_density_map(video_frame_salmap, fixation_annotations, + xy_std=(20, 20), + color_map=color_map, + alpha=0.4 if enable_transform_overlays else 1.0) + if show_fixation_locations: + video_frame_fixmap = plotter.plot_fixations_locations(video_frame_fixmap, fixation_annotations, radius=1) + + grouped_video_frames["transformed_salmap"] = video_frame_salmap + grouped_video_frames["transformed_fixmap"] = video_frame_fixmap + + except: + pass + + return grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties + + def get_participant_frame_range(self, participant_id): + raise NotImplementedError + + +if __name__ == "__main__": + reader1 = Coutrot1SampleReader(mode="w") + reader2 = Coutrot2SampleReader(mode="w") diff --git a/gazenet/readers/diem.py b/gazenet/readers/diem.py new file mode 100644 index 0000000..2d42c46 --- /dev/null +++ b/gazenet/readers/diem.py @@ -0,0 +1,208 @@ +""" +Class for reading and decoding the DIEM [1] dataset + +[1] Mital, P. K., Smith, T. J., Hill, R. L., & Henderson, J. M. (2011). + Clustering of gaze during dynamic scene viewing is predicted by motion. + Cognitive computation, 3(1), 5-24. +""" + +import os +from glob import glob + +import numpy as np +import pandas as pd +from tqdm import tqdm + +from gazenet.utils.registrar import * +from gazenet.utils.helpers import extract_thumbnail_from_video, extract_width_height_from_video +from gazenet.utils.sample_processors import SampleReader, SampleProcessor + + +@ReaderRegistrar.register +class DIEMSampleReader(SampleReader): + def __init__(self, video_audio_annotations_dir="datasets/ave/diem", + video_format="mp4", audio_format="wav", annotation_format="txt", annotation_columns=None, + extract_thumbnails=True, + pickle_file="temp/diem.pkl", mode=None, **kwargs): + self.short_name = "diem" + self.video_audio_annotations_dir = video_audio_annotations_dir + self.video_format = video_format + self.audio_format = audio_format + self.annotation_format = annotation_format + self.annotation_columns = annotation_columns + self.extract_thumbnails = extract_thumbnails + + super().__init__(pickle_file=pickle_file, mode=mode, **kwargs) + + def read_raw(self): + if self.annotation_columns is None: + self.annotation_columns = ["frame", "left_x", "left_y", "left_dil", "left_event", "right_x", "right_y", + "right_dil", "right_event"] + ids = [dI for dI in sorted(os.listdir(self.video_audio_annotations_dir)) if os.path.isdir(os.path.join(self.video_audio_annotations_dir, dI))] + + for id in tqdm(ids, desc="Samples Read"): + video_dir = os.path.join(self.video_audio_annotations_dir, id, "video") + video_name = id + "." + self.video_format + audio_dir = os.path.join(self.video_audio_annotations_dir, id, "audio") + audio_name = id + "." + self.audio_format + annotations_dir = os.path.join(self.video_audio_annotations_dir, id, "event_data") + annotations_name = "*" + id + "." + self.annotation_format + # [frame] [left_x] [left_y] [left_dil] [left_event] [right_x] [right_y] [right_dil] [right_event] + # Frames are 30 video_frames_list per second, indexed at 1; x,y are screen coordinates; dil represents pupil dilation; and the event flag represents: + # -1 = Error/dropped frame + # 0 = Blink + # 1 = Fixation + # 2 = Saccade + + # read the annotations in all the files (participants?) + annotations = [] + for part_id, annotation_path in enumerate(sorted(glob(os.path.join(annotations_dir, annotations_name)))): + annotation_name = os.path.basename(annotation_path) + part_name = annotation_name.replace("_" + id + "." + self.annotation_format, "") + # annotation assembly + annotation = pd.read_csv(annotation_path, sep="\t", names=self.annotation_columns) + annotation["participant"] = part_name + annotation["participant_id"] = part_id + annotations.append(annotation) + video_width_height = extract_width_height_from_video(os.path.join(video_dir, video_name)) + self.samples.append({"id": id, + "audio_name": os.path.join(audio_dir, audio_name), + "video_name": os.path.join(video_dir, video_name), + "video_width": video_width_height[0], + "video_height": video_width_height[1], + "video_thumbnail": extract_thumbnail_from_video( + os.path.join(video_dir, video_name)) if self.extract_thumbnails else None, + "annotation_name": os.path.join(annotations_dir, annotations_name), + "has_audio": True, + "annotations": pd.concat(annotations, axis=0, ignore_index=True, sort=False)}) + self.video_id_to_sample_idx[id] = len(self.samples) - 1 + self.samples[-1].update({"len_frames": int(self.samples[-1]["annotations"]["frame"].max())}) + self.len_frames += self.samples[-1]["len_frames"] + + @staticmethod + def dataset_info(): + return {"summary": "TODO", + "name": "DIEM Dataset (Mital et al.)", + "link": "TODO"} + + +@SampleRegistrar.register +class DIEMSample(SampleProcessor): + def __init__(self, reader, index=-1, frame_index=0, width=640, height=480, **kwargs): + assert isinstance(reader, DIEMSampleReader) + self.short_name = reader.short_name + self.reader = reader + self.index = index + + if frame_index > 0: + self.goto_frame(frame_index) + super().__init__(width=width, height=height, **kwargs) + + next(self) + + def __next__(self): + with self.read_lock: + self.index += 1 + self.index %= len(self.reader.samples) + curr_metadata = self.reader.samples[self.index] + self.load(curr_metadata) + return curr_metadata + + def __len__(self): + return len(self.reader) + + def next(self): + return next(self) + + def goto(self, name, by_index=True): + if by_index: + index = name + else: + index = self.reader.video_id_to_sample_idx[name] + + with self.read_lock: + self.index = index + curr_metadata = self.reader.samples[self.index] + self.load(curr_metadata) + return curr_metadata + + def annotate_frame(self, input_data, plotter, + show_saliency_map=False, + show_fixation_locations=False, + eye="left", + participant=None, # None means all participants will be plotted + enable_transform_overlays=True, + color_map=None, + **kwargs): + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, _ = input_data + + properties = {"show_saliency_map": (show_saliency_map, "toggle", (True, False)), + "show_fixation_locations": (show_fixation_locations, "toggle", (True, False))} + + info = {**info, "frame_annotations": { + "eye_fixation_points": [], + "eye_fixation_participants": [] + }} + # info["frame_info"]["dataset_name"] = self.reader.short_name + # info["frame_info"]["video_id"] = self.reader.samples[self.index]["id"] + # info["frame_info"]["frame_height"] = self.reader.samples[self.index]["video_height"] + # info["frame_info"]["frame_width"] = self.reader.samples[self.index]["video_width"] + + grouped_video_frames = {**grouped_video_frames, + "PLOT": [["captured", "transformed_salmap", "transformed_fixmap"]], + "transformed_salmap": grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"]), + "transformed_fixmap": grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"])} + + try: + frame_index = self.frame_index() + video_frame_salmap = grouped_video_frames["transformed_salmap"] + video_frame_fixmap = grouped_video_frames["transformed_fixmap"] + + if grabbed_video: + ann = self.reader.samples[self.index]["annotations"] + if participant is not None: + ann = ann.loc[ann["participant_id"] == participant] + + if eye == "left": + fixation_annotations = ann.loc[(ann["frame"] == frame_index) & (ann["left_event"] == 1)][ + ["participant_id", "left_x", "left_y", "left_dil", "left_event"]] + elif eye == "right": + fixation_annotations = ann.loc[(ann["frame"] == frame_index) & (ann["right_event"] == 1)][ + ["participant_id", "right_x", "right_y", "right_dil"]] + else: + raise NotImplementedError("Implement both eyes at once") + + fixation_participants = fixation_annotations.iloc[:, [0]].values + fixation_annotations = np.hstack((fixation_annotations.iloc[:, [1]].values, + # empirically introduced a vertical shift of 100 pixels + fixation_annotations.iloc[:, [2]].values - 100, + fixation_annotations.iloc[:, [3]].values)) + + # fixation_annotations = np.squeeze(fixation_annotations, axis=-1) + info["frame_annotations"]["eye_fixation_participants"].append(fixation_participants) + info["frame_annotations"]["eye_fixation_points"].append(fixation_annotations) + if fixation_annotations.shape[0] != 0: + if show_saliency_map: + video_frame_salmap = plotter.plot_fixations_density_map(video_frame_salmap, fixation_annotations, + xy_std=(60, 60), + color_map=color_map, + alpha=0.4 if enable_transform_overlays else 1.0) + if show_fixation_locations: + video_frame_fixmap = plotter.plot_fixations_locations(video_frame_fixmap, fixation_annotations, + radius=1) + + grouped_video_frames["transformed_salmap"] = video_frame_salmap + grouped_video_frames["transformed_fixmap"] = video_frame_fixmap + except: + pass + + return grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties + + def get_participant_frame_range(self, participant_id): + raise NotImplementedError + + +if __name__ == "__main__": + reader = DIEMSampleReader(mode="w") diff --git a/gazenet/readers/findwho.py b/gazenet/readers/findwho.py new file mode 100644 index 0000000..2573c90 --- /dev/null +++ b/gazenet/readers/findwho.py @@ -0,0 +1,200 @@ +""" +Class for reading and decoding the Find who to look at [1] dataset + + +[1] Xu, M., Liu, Y., Hu, R., & He, F. (2018). + Find who to look at: Turning from action to saliency. + IEEE Transactions on Image Processing, 7(9), 4529-4544. IEEE. + +""" + +import os + +import numpy as np +import pandas as pd +import scipy.io as sio +from tqdm import tqdm + +from gazenet.utils.registrar import * +from gazenet.utils.helpers import extract_thumbnail_from_video, check_audio_in_video, aggregate_frame_ranges +from gazenet.utils.sample_processors import SampleReader, SampleProcessor + + +@ReaderRegistrar.register +class FindWhoSampleReader(SampleReader): + def __init__(self, + video_dir="datasets/findwho/raw_videos", + annotations_file="datasets/findwho/video_database.mat", + database_type='video_database', + video_format="mp4", + extract_thumbnails=True, + pickle_file="temp/findwho.pkl", mode=None, **kwargs): + self.short_name = "findwho" + self.video_dir = video_dir + self.annotations_file = annotations_file + self.database_type = database_type + self.video_format = video_format + self.extract_thumbnails = extract_thumbnails + + super().__init__(pickle_file=pickle_file, mode=mode, **kwargs) + + def read_raw(self): + # single annotations file in matlab format + annotations = sio.loadmat(self.annotations_file) + # annotations['video_data'][0][0][x] fixdata fields -> SubjectIndex,VideoIndex,Timestamp,Duration,x,y + annotation_columns = [annotations[self.database_type]['fixdata_fields'][0][0][0][i].tolist()[0].replace(' ', '') + for i in range(len(annotations[self.database_type]['fixdata_fields'][0][0][0]))] + annotation_fixations = pd.DataFrame(data=annotations[self.database_type]['fixdata'][0][0], + columns=annotation_columns, index=None) + for video_name in tqdm(sorted(os.listdir(self.video_dir)), desc="Samples Read"): + if video_name.endswith("." + self.video_format): + id = video_name.replace("." + self.video_format, "") + + try: + # annotation assembly + id_idx = int(id) - 1 + video_annotations = annotation_fixations.loc[(annotation_fixations["VideoIndex"] == id_idx+1)] + + self.samples.append({"id": id, + "audio_name": os.path.join(self.video_dir, video_name), + "video_name": os.path.join(self.video_dir, video_name), + "video_fps": annotations[self.database_type]['videos_info'][0][0]['framerate_fps'][0][0][id_idx][0], + "video_width": annotations[self.database_type]['videos_info'][0][0]['size'][0][0][id_idx][0], + "video_height": annotations[self.database_type]['videos_info'][0][0]['size'][0][0][id_idx][1], + "video_thumbnail": extract_thumbnail_from_video( + os.path.join(self.video_dir, video_name)) if self.extract_thumbnails else None, + "len_frames": annotations[self.database_type]['videos_info'][0][0]['frames'][0][0][id_idx][0], + "has_audio": check_audio_in_video(os.path.join(self.video_dir, video_name)), + "annotation_name": os.path.join(self.database_type, id), + # TODO (fabawi): this should only contain the annotations of this video. Check out diem + "annotations": video_annotations, + "participant_frames": {str(participant_id): aggregate_frame_ranges(video_annotations.loc[ + (video_annotations["SubjectIndex"] == participant_id)][["Timestamp", "GazeDuration"]].values.tolist()) + for participant_id in video_annotations["SubjectIndex"].unique().tolist()} + }) + self.video_id_to_sample_idx[id] = len(self.samples) - 1 + self.len_frames += self.samples[-1]["len_frames"] + except: + print("Error: Access non-existent annotation " + id) + + @staticmethod + def dataset_info(): + return {"summary": "TODO", + "name": "'Find Who to look at' Dataset", + "link": "TODO"} + + +@SampleRegistrar.register +class FindWhoSample(SampleProcessor): + def __init__(self, reader, index=-1, frame_index=0, width=640, height=480, **kwargs): + assert isinstance(reader, FindWhoSampleReader) + self.short_name = reader.short_name + self.reader = reader + self.index = index + + if frame_index > 0: + self.goto_frame(frame_index) + super().__init__(width=width, height=height, **kwargs) + next(self) + + def __next__(self): + with self.read_lock: + self.index += 1 + self.index %= len(self.reader.samples) + curr_metadata = self.reader.samples[self.index] + self.load(curr_metadata) + return curr_metadata + + def __len__(self): + return len(self.reader) + + def next(self): + return next(self) + + def goto(self, name, by_index=True): + if by_index: + index = name + else: + index = self.reader.video_id_to_sample_idx[name] + + with self.read_lock: + self.index = index + curr_metadata = self.reader.samples[self.index] + self.load(curr_metadata) + return curr_metadata + + def annotate_frame(self, input_data, plotter, + show_saliency_map=False, + show_fixation_locations=False, + participant=None, # None means all participants will be plotted + enable_transform_overlays=True, + color_map=None, + **kwargs): + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, _ = input_data + + properties = {"show_saliency_map": (show_saliency_map, "toggle", (True, False)), + "show_fixation_locations": (show_fixation_locations, "toggle", (True, False))} + + info = {**info, "frame_annotations": { + "eye_fixation_points": [], + "eye_fixation_participants": [] + }} + # info["frame_info"]["dataset_name"] = self.reader.short_name + # info["frame_info"]["video_id"] = self.reader.samples[self.index]["id"] + # info["frame_info"]["frame_height"] = self.reader.samples[self.index]["video_height"] + # info["frame_info"]["frame_width"] = self.reader.samples[self.index]["video_width"] + + grouped_video_frames = {**grouped_video_frames, + "PLOT": [["captured", "transformed_salmap", "transformed_fixmap"]], + "transformed_salmap": grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"]), + "transformed_fixmap": grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"])} + + try: + frame_index = self.frame_index() + video_frame_salmap = grouped_video_frames["transformed_salmap"] + video_frame_fixmap = grouped_video_frames["transformed_fixmap"] + + if grabbed_video: + fps = self.reader.samples[self.index]["video_fps"] + ann = self.reader.samples[self.index]["annotations"] + if participant is not None: + ann = ann.loc[ann["SubjectIndex"] == participant] + + fixation_annotations = ann.loc[(ann["Timestamp"]/1000 <= frame_index/fps) & + (ann["Timestamp"]/1000 + ann["GazeDuration"]/1000 >= frame_index/fps)][ + ["SubjectIndex", "FixationPointX", "FixationPointY", "GazeDuration"]] + + fixation_participants = fixation_annotations.iloc[:, [0]].values + fixation_annotations = np.hstack((fixation_annotations.iloc[:, [1]].values, + fixation_annotations.iloc[:, [2]].values, + fixation_annotations.iloc[:, [3]].values)) + + # fixation_annotations = np.squeeze(fixation_annotations, axis=-1) + info["frame_annotations"]["eye_fixation_participants"].append(fixation_participants) + info["frame_annotations"]["eye_fixation_points"].append(fixation_annotations) + if fixation_annotations.shape[0] != 0: + if show_saliency_map: + video_frame_salmap = plotter.plot_fixations_density_map(video_frame_salmap, + fixation_annotations, + xy_std=(66, 66), + color_map=color_map, + alpha=0.4 if enable_transform_overlays else 1.0) + if show_fixation_locations: + video_frame_fixmap = plotter.plot_fixations_locations(video_frame_fixmap, fixation_annotations, + radius=1) + + grouped_video_frames["transformed_salmap"] = video_frame_salmap + grouped_video_frames["transformed_fixmap"] = video_frame_fixmap + except: + pass + + return grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties + + def get_participant_frame_range(self, participant_id): + raise NotImplementedError + + +if __name__ == "__main__": + reader = FindWhoSampleReader(mode="w") \ No newline at end of file diff --git a/gazenet/readers/summe_etmd.py b/gazenet/readers/summe_etmd.py new file mode 100644 index 0000000..88f8ee1 --- /dev/null +++ b/gazenet/readers/summe_etmd.py @@ -0,0 +1,206 @@ +""" +Class for reading and decoding the AVEyetracking[3] dataset composed of the ETMD [1] and SumMe [2] datasets + +[1] Gygli, M., Grabner, H., Riemenschneider, H., & Van Gool, L. (2014). + Creating summaries from user videos. + In Proceedings of the European Conference on Computer Vision, 2014, pp. 505 - 520. + +[2] Koutras, P., & Maragos, P. (2015). + Perceptually based spatio-temporal computational framework for visual saliency estimation. + Signal Processing: Image Commununication, 2015, pp. 15 - 31 + +[3] Tsiami, A., Koutras, P., Katsamanis, A., Vatakis, A., & Maragos, P. (2019). + Signal Processing: Image Communication, 2019, pp. 186 - 200 +""" + +import os + +import numpy as np +import scipy.io as sio +from tqdm import tqdm + +from gazenet.utils.registrar import * +from gazenet.utils.helpers import extract_thumbnail_from_video, extract_width_height_from_video, check_audio_in_video +from gazenet.utils.sample_processors import SampleReader, SampleProcessor + + +@ReaderRegistrar.register +class SumMeETMDSampleReader(SampleReader): + def __init__(self, video_dir="datasets/aveyetracking/SumMe_ETMD/video", + audio_dir="datasets/aveyetracking/SumMe_ETMD/audio_mono", + annotations_file="datasets/aveyetracking/SumMe_ETMD/eyetracking/all_videos.mat", + audio_format="wav", video_format="mp4", + extract_thumbnails=True, + pickle_file="temp/summeetmd.pkl", mode=None, **kwargs): + self.short_name = "summeetmd" + self.video_dir = video_dir + self.audio_dir = audio_dir + self.annotations_file = annotations_file + self.audio_format = audio_format + self.video_format = video_format + self.extract_thumbnails = extract_thumbnails + + super().__init__(pickle_file=pickle_file, mode=mode, **kwargs) + + def read_raw(self): + # single annotations file in matlab format + annotations = sio.loadmat(self.annotations_file) + # annotations['eye_data_all'][clip_name][0][0][1][2][3] -> [1]:x(0),y(1), [2]:video_frames_list, [3]:participants{10} + + for video_name in tqdm(sorted(os.listdir(self.video_dir)), desc="Samples Read"): + if video_name.endswith("." + self.video_format): + id = video_name.replace("." + self.video_format, "") + audio_full_name = os.path.join(self.audio_dir, id + "." + self.audio_format) + if os.path.isfile(audio_full_name): + has_audio = True + else: + audio_full_name = os.path.join(self.video_dir, video_name) + has_audio = check_audio_in_video(os.path.join(self.video_dir, video_name)) + try: + # annotation assembly + annotation = annotations['eye_data_all'][id][0][0] + + video_width_height = extract_width_height_from_video(os.path.join(self.video_dir, video_name)) + + annotation[0, :, :] *= video_width_height[0] + annotation[1, :, :] *= video_width_height[1] + annotation = annotation.astype(np.uint32) + self.samples.append({"id": id, + "audio_name": audio_full_name, + "video_name": os.path.join(self.video_dir, video_name), + "video_width": video_width_height[0], + "video_height": video_width_height[1], + "video_thumbnail": extract_thumbnail_from_video( + os.path.join(self.video_dir, video_name)) if self.extract_thumbnails else None, + "len_frames": len(annotation[0]), + "has_audio": has_audio, + "annotation_name": id, + "annotations": {"xyp": annotation} + }) + self.video_id_to_sample_idx[id] = len(self.samples) - 1 + self.len_frames += self.samples[-1]["len_frames"] + except: + print("Error: Access non-existent annotation " + id) + + @staticmethod + def dataset_info(): + return {"summary": "TODO", + "name": "SumMe (Koutras et al.) & ETMD (Gygli et al.) Datasets", + "link": "TODO"} + + +@SampleRegistrar.register +class SumMeETMDSample(SampleProcessor): + def __init__(self, reader, index=-1, frame_index=0, width=640, height=480, **kwargs): + assert isinstance(reader, SumMeETMDSampleReader) + self.short_name = reader.short_name + self.reader = reader + self.index = index + + if frame_index > 0: + self.goto_frame(frame_index) + super().__init__(width=width, height=height, **kwargs) + next(self) + + def __next__(self): + with self.read_lock: + self.index += 1 + self.index %= len(self.reader.samples) + curr_metadata = self.reader.samples[self.index] + self.load(curr_metadata) + return curr_metadata + + def __len__(self): + return len(self.reader) + + def next(self): + return next(self) + + def goto(self, name, by_index=True): + if by_index: + index = name + else: + index = self.reader.video_id_to_sample_idx[name] + + with self.read_lock: + self.index = index + curr_metadata = self.reader.samples[self.index] + self.load(curr_metadata) + return curr_metadata + + def annotate_frame(self, input_data, plotter, + show_saliency_map=False, + show_fixation_locations=False, + participant=None, # None means all participants will be plotted + enable_transform_overlays=True, + color_map=None, + **kwargs): + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, _ = input_data + + properties = {"show_saliency_map": (show_saliency_map, "toggle", (True, False)), + "show_fixation_locations": (show_fixation_locations, "toggle", (True, False))} + + info = {**info, "frame_annotations": { + "eye_fixation_points": [], + "eye_fixation_participants": [] + }} + # info["frame_info"]["dataset_name"] = self.reader.short_name + # info["frame_info"]["video_id"] = self.reader.samples[self.index]["id"] + # info["frame_info"]["frame_height"] = self.reader.samples[self.index]["video_height"] + # info["frame_info"]["frame_width"] = self.reader.samples[self.index]["video_width"] + + grouped_video_frames = {**grouped_video_frames, + "PLOT": [["captured", "transformed_salmap", "transformed_fixmap"]], + "transformed_salmap": grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"]), + "transformed_fixmap": grouped_video_frames["captured"] + if enable_transform_overlays else np.zeros_like(grouped_video_frames["captured"])} + + try: + frame_index = self.frame_index() + video_frame_salmap = grouped_video_frames["transformed_salmap"] + video_frame_fixmap = grouped_video_frames["transformed_fixmap"] + if grabbed_video: + + ann = self.reader.samples[self.index]["annotations"] + if participant is None: + fixation_participants = ann["xyp"][2, frame_index - 1, :] + fixation_annotations = np.vstack((ann["xyp"][0, frame_index - 1, :], + # empirically introduced a vertical shift of 20 pixels + ann["xyp"][1, frame_index - 1, :], + np.ones_like((ann["xyp"][0, frame_index - 1, :])) + # no fixation amplitude + )).transpose() + else: + fixation_participants = ann["xyp"][2, frame_index - 1, participant] + fixation_annotations = np.vstack((ann["xyp"][0, frame_index - 1, participant], + ann["xyp"][1, frame_index - 1, participant], + np.ones_like((ann["xyp"][0, frame_index - 1, participant])) + # no fixation amplitude + )).transpose() + + info["frame_annotations"]["eye_fixation_participants"].append(fixation_participants) + info["frame_annotations"]["eye_fixation_points"].append(fixation_annotations) + if show_saliency_map: + video_frame_salmap = plotter.plot_fixations_density_map(video_frame_salmap, fixation_annotations, + xy_std=(20, 20), + color_map=color_map, + alpha=0.4 if enable_transform_overlays else 1.0) + if show_fixation_locations: + video_frame_fixmap = plotter.plot_fixations_locations(video_frame_fixmap, fixation_annotations, + radius=1) + + grouped_video_frames["transformed_salmap"] = video_frame_salmap + grouped_video_frames["transformed_fixmap"] = video_frame_fixmap + + except: + pass + + return grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties + + def get_participant_frame_range(self, participant_id): + raise NotImplementedError + + +if __name__ == "__main__": + reader = SumMeETMDSampleReader(mode="w") diff --git a/gazenet/readers/visualization/__init__.py b/gazenet/readers/visualization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/gazenet/readers/visualization/assets/base.css b/gazenet/readers/visualization/assets/base.css new file mode 100644 index 0000000..87301a3 --- /dev/null +++ b/gazenet/readers/visualization/assets/base.css @@ -0,0 +1,434 @@ +/* Table of contents +–––––––––––––––––––––––––––––––––––––––––––––––––– +- Plotly.js +- Grid +- Base Styles +- Typography +- Links +- Buttons +- Forms +- Lists +- Code +- Tables +- Spacing +- Utilities +- Clearing +- Media Queries +*/ + +/* PLotly.js +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +/* plotly.js's modebar's z-index is 1001 by default + * https://github.com/plotly/plotly.js/blob/7e4d8ab164258f6bd48be56589dacd9bdd7fded2/src/css/_modebar.scss#L5 + * In case a dropdown is above the graph, the dropdown's options + * will be rendered below the modebar + * Increase the select option's z-index + */ + +/* This was actually not quite right - + dropdowns were overlapping each other (edited October 26) + +.Select { + z-index: 1002; +}*/ + +/* Grid +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +.container { + position: relative; + width: 100%; + /*max-width: 960px;*/ + /*margin: 0 auto;*/ + padding: 0 20px; + box-sizing: border-box; } +.column, +.columns { + width: 100%; + float: left; + box-sizing: border-box; } + +/* For devices larger than 400px */ +@media (min-width: 400px) { + .container { + width: 100%; + padding: 0; } +} + +/* For devices larger than 550px */ +@media (min-width: 550px) { + .container { + width: 100%; } + .column, + .columns { + margin-left: 0; } + .column:first-child, + .columns:first-child { + margin-left: 0; } + + .one.column, + .one.columns { width: 4.66666666667%; } + .two.columns { width: 13.3333333333%; } + .three.columns { width: 22%; } + .four.columns { width: 32.6666666667%; } + .five.columns { width: 39.3333333333%; } + .six.columns { width: 48%; } + .seven.columns { width: 56.6666666667%; } + .eight.columns { width: 65.3333333333%; } + .nine.columns { width: 74.0%; } + .ten.columns { width: 82.6666666667%; } + .eleven.columns { width: 91.3333333333%; } + .twelve.columns { width: 100%; margin-left: 0; } + + .one-third.column { width: 30.6666666667%; } + .two-thirds.column { width: 65.3333333333%; } + + .one-half.column { width: 48%; } + + /* Offsets */ + .offset-by-one.column, + .offset-by-one.columns { margin-left: 8.66666666667%; } + .offset-by-two.column, + .offset-by-two.columns { margin-left: 17.3333333333%; } + .offset-by-three.column, + .offset-by-three.columns { margin-left: 26%; } + .offset-by-four.column, + .offset-by-four.columns { margin-left: 34.6666666667%; } + .offset-by-five.column, + .offset-by-five.columns { margin-left: 43.3333333333%; } + .offset-by-six.column, + .offset-by-six.columns { margin-left: 52%; } + .offset-by-seven.column, + .offset-by-seven.columns { margin-left: 60.6666666667%; } + .offset-by-eight.column, + .offset-by-eight.columns { margin-left: 69.3333333333%; } + .offset-by-nine.column, + .offset-by-nine.columns { margin-left: 78.0%; } + .offset-by-ten.column, + .offset-by-ten.columns { margin-left: 86.6666666667%; } + .offset-by-eleven.column, + .offset-by-eleven.columns { margin-left: 95.3333333333%; } + + .offset-by-one-third.column, + .offset-by-one-third.columns { margin-left: 34.6666666667%; } + .offset-by-two-thirds.column, + .offset-by-two-thirds.columns { margin-left: 69.3333333333%; } + + .offset-by-one-half.column, + .offset-by-one-half.columns { margin-left: 52%; } + +} + + +/* Base Styles +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +/* NOTE +html is set to 62.5% so that all the REM measurements throughout Skeleton +are based on 10px sizing. So basically 1.5rem = 15px :) */ +html { + font-size: 62.5%; } +body { + font-size: 1.5em; /* currently ems cause chrome bug misinterpreting rems on body element */ + line-height: 1.6; + font-weight: 400; + font-family: "Open Sans", "HelveticaNeue", "Helvetica Neue", Helvetica, Arial, sans-serif; + color: rgb(50, 50, 50); } + + +/* Typography +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +h1, h2, h3, h4, h5, h6 { + margin-top: 0; + margin-bottom: 0; + font-weight: 300; } +h1 { font-size: 4.5rem; line-height: 1.2; letter-spacing: -.1rem; margin-bottom: 2rem; } +h2 { font-size: 3.6rem; line-height: 1.25; letter-spacing: -.1rem; margin-bottom: 1.8rem; margin-top: 1.8rem;} +h3 { font-size: 3.0rem; line-height: 1.3; letter-spacing: -.1rem; margin-bottom: 1.5rem; margin-top: 1.5rem;} +h4 { font-size: 2.6rem; line-height: 1.35; letter-spacing: -.08rem; margin-bottom: 1.2rem; margin-top: 1.2rem;} +h5 { font-size: 2.2rem; line-height: 1.5; letter-spacing: -.05rem; margin-bottom: 0.6rem; margin-top: 0.6rem;} +h6 { font-size: 2.0rem; line-height: 1.6; letter-spacing: 0; margin-bottom: 0.75rem; margin-top: 0.75rem;} + +p { + margin-top: 0; } + + +/* Blockquotes +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +blockquote { + border-left: 4px lightgrey solid; + padding-left: 1rem; + margin-top: 2rem; + margin-bottom: 2rem; + margin-left: 0rem; +} + + +/* Links +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +a { + color: #1EAEDB; + text-decoration: underline; + cursor: pointer;} +a:hover { + color: #0FA0CE; } + + +/* Buttons +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +.button, +button, +input[type="submit"], +input[type="reset"], +input[type="button"] { + display: inline-block; + height: 38px; + padding: 0 30px; + color: #555; + text-align: center; + font-size: 11px; + font-weight: 600; + line-height: 38px; + letter-spacing: .1rem; + text-transform: uppercase; + text-decoration: none; + white-space: nowrap; + background-color: transparent; + border-radius: 4px; + border: 1px solid #bbb; + cursor: pointer; + box-sizing: border-box; } +.button:hover, +button:hover, +input[type="submit"]:hover, +input[type="reset"]:hover, +input[type="button"]:hover, +.button:focus, +button:focus, +input[type="submit"]:focus, +input[type="reset"]:focus, +input[type="button"]:focus { + color: #333; + border-color: #888; + outline: 0; } +.button.button-primary, +button.button-primary, +input[type="submit"].button-primary, +input[type="reset"].button-primary, +input[type="button"].button-primary { + color: #FFF; + background-color: #33C3F0; + border-color: #33C3F0; } +.button.button-primary:hover, +button.button-primary:hover, +input[type="submit"].button-primary:hover, +input[type="reset"].button-primary:hover, +input[type="button"].button-primary:hover, +.button.button-primary:focus, +button.button-primary:focus, +input[type="submit"].button-primary:focus, +input[type="reset"].button-primary:focus, +input[type="button"].button-primary:focus { + color: #FFF; + background-color: #1EAEDB; + border-color: #1EAEDB; } + + +/* Forms +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +input[type="email"], +input[type="number"], +input[type="search"], +input[type="text"], +input[type="tel"], +input[type="url"], +input[type="password"], +textarea, +select { + height: 38px; + padding: 6px 10px; /* The 6px vertically centers text on FF, ignored by Webkit */ + background-color: #fff; + border: 1px solid #D1D1D1; + border-radius: 4px; + box-shadow: none; + box-sizing: border-box; + font-family: inherit; + font-size: inherit; /*https://stackoverflow.com/questions/6080413/why-doesnt-input-inherit-the-font-from-body*/} +/* Removes awkward default styles on some inputs for iOS */ +input[type="email"], +input[type="number"], +input[type="search"], +input[type="text"], +input[type="tel"], +input[type="url"], +input[type="password"], +textarea { + -webkit-appearance: none; + -moz-appearance: none; + appearance: none; } +textarea { + min-height: 65px; + padding-top: 6px; + padding-bottom: 6px; } +input[type="email"]:focus, +input[type="number"]:focus, +input[type="search"]:focus, +input[type="text"]:focus, +input[type="tel"]:focus, +input[type="url"]:focus, +input[type="password"]:focus, +textarea:focus, +select:focus { + border: 1px solid #33C3F0; + outline: 0; } +label, +legend { + display: block; + margin-bottom: 0px; } +fieldset { + padding: 0; + border-width: 0; } +input[type="checkbox"], +input[type="radio"] { + display: inline; } +label > .label-body { + display: inline-block; + margin-left: .5rem; + font-weight: normal; } + + +/* Lists +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +ul { + list-style: circle inside; } +ol { + list-style: decimal inside; } +ol, ul { + padding-left: 0; + margin-top: 0; } +ul ul, +ul ol, +ol ol, +ol ul { + margin: 1.5rem 0 1.5rem 3rem; + font-size: 90%; } +li { + margin-bottom: 1rem; } + + +/* Tables +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +table { + border-collapse: collapse; +} +th, +td { + padding: 12px 15px; + text-align: left; + border-bottom: 1px solid #E1E1E1; } +th:first-child, +td:first-child { + padding-left: 0; } +th:last-child, +td:last-child { + padding-right: 0; } + + +/* Spacing +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +button, +.button { + margin-bottom: 0rem; } +input, +textarea, +select, +fieldset { + margin-bottom: 0rem; } +pre, +dl, +figure, +table, +form { + margin-bottom: 0rem; } +p, +ul, +ol { + margin-bottom: 0.75rem; } + +/* Utilities +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +.u-full-width { + width: 100%; + box-sizing: border-box; } +.u-max-full-width { + max-width: 100%; + box-sizing: border-box; } +.u-pull-right { + float: right; } +.u-pull-left { + float: left; } + + +/* Misc +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +hr { + margin-top: 3rem; + margin-bottom: 3.5rem; + border-width: 0; + border-top: 1px solid #E1E1E1; } + + +/* Clearing +–––––––––––––––––––––––––––––––––––––––––––––––––– */ + +/* Self Clearing Goodness */ +.container:after, +.row:after, +.u-cf { + content: ""; + display: table; + clear: both; } + + +/* Media Queries +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +/* +Note: The best way to structure the use of media queries is to create the queries +near the relevant code. For example, if you wanted to change the styles for buttons +on small devices, paste the mobile query code up in the buttons section and style it +there. +*/ + + +/* Larger than mobile */ +@media (min-width: 300px) { + #logo-mobile { + margin: 0 auto; + } +} + +/* Larger than phablet (also point when grid becomes active) */ +@media (min-width: 550px) { + #logo-mobile { + margin: 0 auto; + } +} + +/* Larger than tablet */ +@media (min-width: 750px) { + #logo-mobile { + margin: 0 auto; + } +} + +/* Larger than desktop */ +@media (min-width: 1000px) { + #logo-mobile { + margin: 0 auto; + } +} + +/* Larger than Desktop HD */ +@media (min-width: 1200px) { + #logo-mobile { + margin: 0 auto; + } +} \ No newline at end of file diff --git a/gazenet/readers/visualization/assets/internal.css b/gazenet/readers/visualization/assets/internal.css new file mode 100644 index 0000000..f152520 --- /dev/null +++ b/gazenet/readers/visualization/assets/internal.css @@ -0,0 +1,399 @@ +@import url("https://fonts.googleapis.com/css?family=Roboto|Raleway"); + +/* Remove Undo +–––––––––––––––––––––––––––––––––––––––––––––––––– */ +._dash-undo-redo { + display: none; +} +body { + margin: 0 !important; + color: #606060 !important; + font-family: "Raleway", sans-serif; + background-color: #f9f9f9 !important; +} + +.page-content { + display: flex; +} + + +#top-bar { + background-color: #fa4f56; + height: 0px; /**fabawi: changed this to 0 from 5px**/ +} + +#left-side-column { + display: flex; + flex-direction: column; + flex: 1; + height: 100vh; + background-color: #f2f2f2; + overflow-y: scroll; + margin-left: 0px; + float: left; + justify-content: flex-start; + align-items: center; + padding: 1rem 2rem; + +} + +#right-side-column { + height: calc(100vh - 5px); + overflow-y: inherit; + margin-left: 1%; + display: flex; + flex: 1; + flex-direction: column; +} + +h4 { + font-family: "Roboto", sans-serif; + font-weight: 400; +} + +p { + font-size: 16px; + font-weight: 300; +} + +.modal { + display: block; /*Hidden by default */ + position: fixed; /* Stay in place */ + z-index: 1000; /* Sit on top */ + left: 0; + top: 0; + width: 100vw; /* Full width */ + height: 100vh; /* Full height */ + overflow: auto; /* Enable scroll if needed */ + background-color: rgb(0, 0, 0); /* Fallback color */ + background-color: rgba(0, 0, 0, 0.4); /* Black w/ opacity */ +} + +.markdown-container { + width: 60vw; + margin: 10% auto; + padding: 10px 15px; + background-color: #f9f9f9; + border-radius: 10px; +} + +.close-container { + display: inline-block; + width: 100%; + margin: 0; + padding: 0; +} + +.markdown-text { + padding: 0px 10px; +} + +.closeButton { + padding: 0 15px; + font-weight: normal; + float: right; + font-size: 1.2rem; + border: none; + height: 100%; +} + +.closeButton:hover { + color: red !important; +} + +#header-section { + width: 100%; + margin-top: 2%; +} + +.button:focus { + color: #ffffff; + border-color: #bbb; +} + +.play-button { + border-top: 10px solid white; + border-bottom: 10px solid white; + border-left: 20px solid black; + height: 0px; +} + +#learn-more-button { + background-color: #fa4f56; + color: #ffffff; + font-size: 13px; + font-weight: 500; + padding: 5px 20px; + text-transform: none; + line-height: 30px; + font-family: "Raleway", sans-serif; +} + +.video-outer-container { + width: 90%; + margin-top: 2%; + margin-bottom: 2%; + min-width: 500px; + display: inline-block; +} + +.video-container { + width: 100%; + padding-bottom: 1%; + position: relative; + +} + +.video-container > div, .video-container > div > div { + position: absolute; + width:100%; + height:100%; + top:0; + left:0; + bottom:0; + right:0; +} + +.control-section { + width: 100%; + padding: 5px; + display: flex; + flex-flow: column nowrap; + flex-shrink: 0; + margin-bottom: 2%; +} + +.control-element { + font-size: 15px; + padding: 10px; + display: flex; + flex-flow: row nowrap; +} + +.control-element > div:nth-child(1) { + width: 40%; +} + +.control-element > div:nth-child(2) { + width: 60%; +} + +.video-control-section { + width: 100%; + flex-flow: column nowrap; + flex-shrink: 0; +} + +.video-control-element { + font-size: 15px; + display: flex; + flex-flow: row nowrap; +} + +.video-control-element > div:nth-child(1) { + padding-top: 0px; + width: 1%; +} +.video-control-element > div:nth-child(2) { + padding-top: 0px; + width: 1%; +} +.video-control-element > div:nth-child(3) { + width: 98%; +} + +.Select-value { + background-color: #f2f2f2; +} +.Select-control { + width: 100% !important; +} + +.Select-menu-outer { + position: relative; +} + +.has-value.Select--single > .Select-control .Select-value .Select-value-label, +.has-value.is-pseudo-focused.Select--single + > .Select-control + .Select-value + .Select-value-label { + color: #606060; +} + +.rc-slider-track { + background-color: #fa4f56; +} + +.rc-slider-dot-active, +.rc-slider-dot, +.rc-slider-handle { + border-color: #fa4f56; +} + +.rc-slider-handle:hover { + border-color: #fa4f56; +} + +.img-container { + display: flex; + justify-content: flex-end; + height: 35px; + margin: 10px; +} + +#logo-web { + height: 100%; + min-height: 30px; + width: auto; + margin: 2px; +} + +#logo-mobile { + display: none; +} + +.plot-title { + margin-left: 5%; + margin-bottom: 0; + font-weight: 500; +} + +#heatmap-confidence { + height: 45vh; + width: 100%; +} + +#pie-object-count { + height: 40vh; + width: 100%; +} + +#bar-score-graph { + height: 55vh; +} + +/* + ##Device = Most of the Smartphones Mobiles, Ipad (Portrait) + */ + +@media only screen and (max-width: 750px), + screen and (min-width: 768px) and (max-width: 1024px) { + .container { + padding: 0; + width: 100%; + height: 100%; + } + + .columns { + width: 100% !important; + float: left; + } + + h4 { + font-size: 25px; + } + + #header-section { + text-align: center; + } + + body { + font-size: 12px; + } + + p { + font-size: 12px; + } + + #learn-more-button { + font-size: 11px; + font-weight: 400; + padding: 5px 20px; + text-transform: none; + } + + .video-outer-container { + min-width: 80vw; + margin-top: 1.5rem; + margin-bottom: 0.1rem; + } + + .control-element { + font-size: 12px; + flex-flow: column nowrap; + padding: 10px 0px; + height: 100%; + } + + .control-element > div { + width: 100% !important; + height: 100%; + } + + .control-element > div:nth-child(1) { + text-align: center; + } + + /*#slider-minimum-confidence-threshold {*/ + /* margin-bottom: 10px;*/ + /*}*/ + + #div-visual-mode, + #div-detection-mode { + text-align: center; + } + + .annotation-text { + font-size: 0.8rem !important; + } + + .img-container { + display: none; + } + + #left-side-column { + height: auto; + margin-left: 0; + padding: 0; + } + + #right-side-column { + justify-content: center; + margin-right: 0; + padding: 1rem; + overflow-y: inherit; + height: 100%; + } + + #logo-mobile { + display: inline-flex; + height: 30px; + align-self: flex-start; + } + + h5 { + font-size: 1.5rem; + } +} + +#left-side-column::-webkit-scrollbar { + width: 5px; + } + + /* Track */ +#left-side-column::-webkit-scrollbar-track { + box-shadow: inset 0 0 2px grey; + border-radius: 10px; + } + + /* Handle */ +#left-side-column::-webkit-scrollbar-thumb { + background: rgb(7, 7, 7); + border-radius: 10px; + } + + /* Handle on hover */ +#left-side-column::-webkit-scrollbar-thumb:hover { + background: #009eb3; + } \ No newline at end of file diff --git a/gazenet/readers/visualization/visualize_local.py b/gazenet/readers/visualization/visualize_local.py new file mode 100644 index 0000000..7f30926 --- /dev/null +++ b/gazenet/readers/visualization/visualize_local.py @@ -0,0 +1,85 @@ +import time +import queue +import copy +import threading + +import cv2 +import numpy as np +import sounddevice as sd + +from gazenet.utils.registrar import * +from gazenet.utils.annotation_plotter import OpenCV +import gazenet.utils.sample_processors as sp + + +read_lock = threading.Lock() + + +def view(plotter, video, n_frames=5, enable_audio=True, video_properties={}): + video.start(plotter=plotter, **video_properties) + # audio_streamer = sd.play(sample_streamer.audio_cap["audio"][0] + # [sample_streamer.audio_cap["video_frames_list"][sample_streamer.audio_cap["curr_frame"]]:], + # sample_streamer.audio_cap["audio"][1]) + i = 0 + buff_audio_frames = [] + while i < n_frames: + ts = time.time() + # if i < 100: + # sample_streamer.play() + # elif i < 200: + # sample_streamer.pause() + # elif i == 200: + # sample_streamer.goto_frame(0) + # sample_streamer.pause() + # else: + # sample_streamer.play() + video.play() + + grabbed_video, video_frame, grabbed_audio, audio_frames, _, _ = video.read() + if enable_audio: + if len(buff_audio_frames) >= video.audio_cap.get(sp.AUCAP_PROP_SAMPLE_RATE): + with read_lock: + new_buff_audio_frames = buff_audio_frames.copy() + buff_audio_frames = [] + # sd.wait() + sd.play(np.array(new_buff_audio_frames), video.audio_cap.get(sp.AUCAP_PROP_SAMPLE_RATE)) + if grabbed_audio: + with read_lock: + buff_audio_frames.extend(audio_frames) + if grabbed_video: + cv2.imshow('Frame', video_frame) + orig_fps = video.frames_per_sec() * 1.06 # used to be multiplied by 1.51 + td = time.time() - ts + if td < 1.0/orig_fps: + cv2.waitKey(int((1.0/orig_fps - td)*1000)) & 0xFF + else: + cv2.waitKey(1) & 0xFF + else: + break + i += 1 + video.stop() + sd.stop() + cv2.destroyAllWindows() + + +if __name__ == '__main__': + width, height = 800, 400 + play_audio = True + + # define the reader + reader = "FindWhoSampleReader" + sampler = "FindWhoSample" + # sampler_properties = {"show_saliency_map": True, "enable_transform_overlays":False, "color_map": "bone"} + sampler_properties = {"show_saliency_map": True, "participant": None} + + + SampleRegistrar.scan() + ReaderRegistrar.scan() + + plotter = OpenCV() + video_source = ReaderRegistrar.registry[reader](mode="d") + video = SampleRegistrar.registry[sampler](video_source, w_size=1, width=width, height=height, enable_audio=play_audio) + audio_streamer = None + for i in range(len(video)): + view(plotter, video, n_frames=500, enable_audio=play_audio, video_properties=sampler_properties) + next(video) \ No newline at end of file diff --git a/gazenet/readers/visualization/visualize_server.py b/gazenet/readers/visualization/visualize_server.py new file mode 100644 index 0000000..c9dd22e --- /dev/null +++ b/gazenet/readers/visualization/visualize_server.py @@ -0,0 +1,445 @@ +# TODO (fabawi): The audio is not actually streamed but played locally. Need to change that +from io import * +import base64 +import os +import time +import re + +import sounddevice as sd +from flask import Flask, Response +import numpy as np +import cv2 +import dash +import dash_core_components as dcc +import dash_bootstrap_components as dbc +import dash_html_components as html +from dash.dependencies import Input, Output, State + +from gazenet.utils.registrar import * +from gazenet.utils.helpers import encode_image +from gazenet.utils.annotation_plotter import OpenCV +from gazenet.utils.dataset_processors import DataSplitter +import gazenet.utils.sample_processors as sp + +width, height = 800, 400 + +reader = "FindWhoSampleReader" +sampler = "FindWhoSample" +play_mode = 'play' +DEBUG = False +sp.SERVER_MODE = True + +SampleRegistrar.scan() +ReaderRegistrar.scan() +dspl = DataSplitter() + +video_source = ReaderRegistrar.registry[reader](mode="d") +video = SampleRegistrar.registry[sampler](video_source, w_size=1, width=width, height=height) + + +FONT_AWESOME = "https://use.fontawesome.com/releases/v5.7.2/css/all.css" + +server = Flask(__name__) +app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP, FONT_AWESOME], server=server) + +plotter = OpenCV() + +dataset_info = video_source.dataset_info() +dataset_name = dataset_info["name"] +dataset_summary = dataset_info["summary"] +dataset_link = dataset_info["link"] + + +reset_sample = True + +# TODO (fabawi): pass initialization arguments to the preprocess for tracing the properties +preprocessed_data = video.preprocess_frames() +if preprocessed_data is not None: + extracted_data_list = video.extract_frames(**preprocessed_data) +else: + extracted_data_list = video.extract_frames() +_, _, _, _, _, properties = video.annotate_frame(input_data=next(zip(*extracted_data_list)), plotter=plotter) +video_properties = {k: v[0] for k,v in properties.items()} + +dummy_img_vid = encode_image(np.random.randint(255, size=(height,width,3),dtype=np.uint8), raw=True) +dummy_img = encode_image(np.random.randint(255, size=(height,width,3),dtype=np.uint8)) + +def gen_video(): + while True: + ts = time.time() + with video.read_lock: + video.buffer.put(play_mode) + grabbed_video, video_frame, grabbed_audio, audio_frames, _, _ = video.read() + + time.sleep(0.01) + # time.sleep(sample_streamer.frames_per_sec() * 0.00106) + if video_frame is not None: + if grabbed_audio: + sd.play(audio_frames, samplerate=video.audio_cap.get(sp.AUCAP_PROP_SAMPLE_RATE)) + video_frame = encode_image(video_frame.copy(), True) + orig_fps = video.frames_per_sec() * 1.06 # used to be multiplied by 1.51 + td = time.time() - ts + try: + if td < 1.0 / orig_fps: + time.sleep(1.0 / orig_fps - td) + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + video_frame + b'\r\n\r\n') + except: + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + dummy_img_vid + b'\r\n\r\n') + else: + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + dummy_img_vid + b'\r\n\r\n') + +@server.route('/video_feed') +def video_feed(): + time.sleep(1) + stream_img = gen_video() + return Response(stream_img, + mimetype='multipart/x-mixed-replace; boundary=frame') + +dynamic_elements = {'toggle': [], 'multilist': []} +def generate_elements(properties): + all_elements = [] + for property_name, property in properties.items(): + if property[1] == 'toggle': + all_elements.append( + html.Div( + className="control-element", + children=[ + html.Div( + children=["Switch property:"] + ), + dbc.Checklist(options=[ + {"label": property_name.replace("_", " "), "value": 1}, + ], + value=[], + id="property-"+property_name, + switch=True, + ) + ]) + ) + dynamic_elements['toggle'].append("property-"+property_name) + return all_elements + + +cards = [] +for i, sample in enumerate(video_source.samples): + audio_symbol = "\U0001F508" if sample["has_audio"] else " " + split, category = dspl.sample(sample["id"], video_source.short_name , mode="r") + card = dbc.Card( + [ + dbc.CardImg(src=encode_image(sample['video_thumbnail']), + id="video-"+str(i)+"-card-img", top=True), # style={"display": "none"} + dbc.CardBody([ + dbc.Button( + os.path.basename(sample["video_name"]) + audio_symbol, + id="video-" + str(i) + "-card-button", + color="dark", + style={'display': 'inline-block'}), + dbc.DropdownMenu( + label="", + children=[ + dbc.DropdownMenuItem("Split", header=True), + dbc.DropdownMenuItem("train", id="video-" + str(i) + "-split-train-button"), + dbc.DropdownMenuItem("val", id="video-" + str(i) + "-split-val-button"), + dbc.DropdownMenuItem("test", id="video-" + str(i) + "-split-test-button"), + dbc.DropdownMenuItem("None", id="video-" + str(i) + "-split-None-button"), + dbc.DropdownMenuItem(divider=True), + dbc.DropdownMenuItem("Category", header=True), + dbc.DropdownMenuItem("Social", id="video-" + str(i) + "-cat-Social-button"), + dbc.DropdownMenuItem("Nature", id="video-" + str(i) + "-cat-Nature-button"), + dbc.DropdownMenuItem("Other", id="video-" + str(i) + "-cat-Other-button"), + dbc.DropdownMenuItem("None", id="video-" + str(i) + "-cat-None-button"), + ], + style={'display': 'inline-block', 'float': 'right'}, + ), + html.Div([ + dbc.Badge(split if split is not None else "None", + id="video-" + str(i) + "-split-bdg", + color="success", className="mr-1"), + dbc.Badge(category if category is not None else "None", + id="video-" + str(i) + "-category-bdg", + color="info", className="mr-1"), + ], style={"display": "block"}) + + ])], + color="primary", inverse=True) + + cards.append(card) + +# Main App +app.layout = html.Div( + children=[ + dcc.Interval(id="video-player", interval=100, n_intervals=0), + html.Div(id="top-bar", className="row"), + html.Div( + className="page-content", + children=[ + html.Div( + id="left-side-column", + className="four columns", + children=[ + dbc.CardColumns(cards), + html.Img(src='', id="magic-img") + ], + ), + html.Div( + id="right-side-column", + className="eight columns", + children=[ + html.Div( + id="header-section", + children=[ + html.H4(dataset_name), + html.P( + dataset_summary + ), + dcc.Link(html.Button( + "Learn More", id="learn-more-button", n_clicks=0 + ), href=dataset_link, target='_blank'), + html.Div( + [ + dbc.Button("Save", id="open-save-dialog-button", + style={'display': 'inline-block', 'float': 'right'}), + dbc.Modal( + [ + dbc.ModalHeader("Save"), + dbc.ModalBody("Are you sure you want to save the dataset splits and categories?"), + dbc.ModalFooter([ + dbc.Button( + "Yes", id="save-split-cat-button", className="ml-auto", color="success", + ), + dbc.Button( + "No", id="close-save-dialog-button", className="ml-auto", color="dark", + ) + ] + ), + ], + backdrop=False, + zIndex=1000000, + id="save-dialog", + centered=True, + ), + ]) + ], + ), + html.Div(children=html.Hr()), + html.Div( + id="video-name", + children=[] + ), + html.Div( + className="video-outer-container", + children=[ + html.Div( + className="video-container", + children=[ + html.Img(id="video-container", + src="/video_feed"), + ] + ), + html.Div( + className="video-control-section", + children=[ + html.Div( + className="video-control-element", + children=[ + html.Button( + "", + className="fas fa-pause-circle fa-lg", + style={'border': 'none', 'height': '15px', 'padding':'0px'}, + id="video-control-button", + n_clicks=0 + ), + html.Button( + "", + className="fas fa-stop-circle fa-lg", + style={'border': 'none', 'height': '15px', 'padding':'0px'}, + id="video-stop-button", + n_clicks=0 + ), + dcc.Slider( + id="video-frame-slider", + min=20, + max=80, + value=0), + ], + ), + ], + ), + ] + ), + html.Div(id="enabled-properties", style={"display": "none"}, children=[]), + html.Div( + className="control-section", + children= generate_elements(properties) , + ), + ], + ), + + ], + ), + ] +) + +@app.callback([Output("video-stop-button", "className")], + [Input("video-stop-button", "n_clicks")], + [State("video-stop-button", "className")]) +def stop_video(n, class_name): + global play_mode + global reset_sample + if n > 0: + video.goto_frame(0) + video.stop() + video.goto(video.index) + video.set_annotation_properties(video_properties.copy()) + reset_sample = True + video.start(plotter) + play_mode = 'pause' + sd.stop() + return [class_name] + + +@app.callback([Output("video-control-button", "className")], + [Input("video-control-button", "n_clicks")], + [State("video-control-button", "className")]) +def control_video(n, class_name): + global play_mode + if n > 0: + if class_name == 'fas fa-play-circle fa-lg': + play_mode = 'play' + return ['fas fa-pause-circle fa-lg'] + else: + play_mode = 'pause' + sd.stop() + return ['fas fa-play-circle fa-lg'] + return [class_name] + + +@app.callback([Output("video-frame-slider", "value")], + [Input("video-player", "n_intervals")], + [State("video-frame-slider", "value"), + State("video-control-button", "className")]) +def start_video(n, prev_frame_idx, class_name): + global reset_sample + if n > 0 and not reset_sample and prev_frame_idx is not None: + new_frame_idx = video.frame_index() + # TODO (fabawi): Find a better way to seek + if np.abs((prev_frame_idx - new_frame_idx)) > video.frames_per_sec() + 10: + video.goto_frame(prev_frame_idx) + return [video.frame_index()] + else: + reset_sample = False + return [0] + + +@app.callback( + [Output("video-name", "children"), Output("video-frame-slider","min"), Output("video-frame-slider","max")], + [Input(f"video-{i}-card-button", "n_clicks_timestamp") for i in range(len(cards))] + [Input(f"video-{i}-card-button", "n_clicks") for i in range(len(cards))] , + [State(f"video-{i}-card-img", "src") for i in range(len(cards))] +) +def activate_video(*args): + global reset_sample + h_args = args[:len(cards)+1] + if h_args and h_args is not None: + f_h_args = list(filter(None.__ne__, h_args)) + if f_h_args: + index = h_args.index(max(f_h_args)) + (len(cards)*2) + if args[index - len(cards)] is not None and args[index - len(cards)] > 0: + video.goto_frame(0) + video.stop() + video.goto(index - (len(cards) * 2)) + video.set_annotation_properties(video_properties.copy()) + reset_sample = True + video.start(plotter) + curr_sample = video.reader.samples[video.index] + video_name = os.path.basename(curr_sample['video_name']) + time.sleep(1) + return [html.H5(video_name), 0, video.len_frames() - 2] + else: + return [html.H5("No video selected"), 1, 2] + + +# Toggle properties +@app.callback( + [Output("enabled-properties", "children")], + [Input(f"{i}", "value") for i in dynamic_elements['toggle']], + [State(f"{i}", "id") for i in dynamic_elements['toggle']] +) +def set_property(*args): + values = args[:len(dynamic_elements['toggle'])] + ids = args[len(dynamic_elements['toggle']):] + for idx, value in enumerate(values): + video_properties[ids[idx].replace('property-', '')] = True if len(value) == 1 else False + video.set_annotation_properties(video_properties.copy()) + return ["None"] + +@app.callback( + [Output(f"video-{i}-split-bdg", "children") for i in range(len(cards))], + [Input(f"video-{i}-split-train-button", "n_clicks") for i in range(len(cards))] + + [Input(f"video-{i}-split-val-button", "n_clicks") for i in range(len(cards))] + + [Input(f"video-{i}-split-test-button", "n_clicks") for i in range(len(cards))] + + [Input(f"video-{i}-split-None-button", "n_clicks") for i in range(len(cards))] , + [State(f"video-{i}-split-bdg", "children") for i in range(len(cards))] +) +def control_video_split(*args): + # we performed the button checking before in a different way, but this is the most recently recommended method + # https://dash.plotly.com/dash-html-components/button + changed_id = [p['prop_id'] for p in dash.callback_context.triggered][0] + changed_id = changed_id + outputs = [*args[-len(cards):]] + if changed_id is not None or changed_id != ".": + for split_name in ("train", "val", "test", "None"): + match = re.match(r'video-(.*)-split-' + split_name + '-button(.*)', changed_id) + if match is not None: + index = int(match.group(1)) + dspl.sample(video_source.samples[index]["id"], video_source.short_name, + fps=video_source.samples[index].get("video_fps", 30), split=split_name, mode="d") + outputs[index] = split_name + return outputs + +@app.callback( + [Output(f"video-{i}-category-bdg", "children") for i in range(len(cards))], + [Input(f"video-{i}-cat-Social-button", "n_clicks") for i in range(len(cards))] + + [Input(f"video-{i}-cat-Nature-button", "n_clicks") for i in range(len(cards))] + + [Input(f"video-{i}-cat-Other-button", "n_clicks") for i in range(len(cards))] + + [Input(f"video-{i}-cat-None-button", "n_clicks") for i in range(len(cards))] , + [State(f"video-{i}-category-bdg", "children") for i in range(len(cards))] +) +def control_video_category(*args): + # we performed the button checking before in a different way, but this is the most recently recommended method + # https://dash.plotly.com/dash-html-components/button + changed_id = [p['prop_id'] for p in dash.callback_context.triggered][0] + changed_id = changed_id + outputs = [*args[-len(cards):]] + if changed_id is not None or changed_id != ".": + for cat_name in ("Social", "Nature", "Other", "None"): + match = re.match(r'video-(.*)-cat-' + cat_name + '-button(.*)', changed_id) + if match is not None: + index = int(match.group(1)) + dspl.sample(video_source.samples[index]["id"], video_source.short_name, + fps=video_source.samples[index].get("video_fps", None), scene_type=cat_name, mode="d") + outputs[index] = cat_name + return outputs + +@app.callback( + Output("save-dialog", "is_open"), + [Input("open-save-dialog-button", "n_clicks"), Input("save-split-cat-button", "n_clicks"), Input("close-save-dialog-button", "n_clicks")], + [State("save-dialog", "is_open")], +) +def save_split_category(nopen, nsave, nclose, is_open): + if nopen or nsave or nclose: + if nsave is not None and nsave > 0: + dspl.save() + return not is_open + return is_open + + +if __name__ == '__main__': + # run on flask + # app.run_server(debug=DEBUG) + + # run on waitress + from waitress import serve + serve(server, host="0.0.0.0", port=8080) \ No newline at end of file diff --git a/gazenet/utils/__init__.py b/gazenet/utils/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/gazenet/utils/__init__.py @@ -0,0 +1 @@ + diff --git a/gazenet/utils/annotation_plotter.py b/gazenet/utils/annotation_plotter.py new file mode 100644 index 0000000..a539a76 --- /dev/null +++ b/gazenet/utils/annotation_plotter.py @@ -0,0 +1,204 @@ +import cv2 +import numpy as np + +np.seterr(divide='ignore', invalid='ignore') + +from gazenet.utils.helpers import circular_list, mp_multivariate_gaussian, conic_projection + + +class OpenCV(object): + def __init__(self, colors=None, color_maps=None): + if colors is None: + self.colors = circular_list( + [(0, 255, 0), (255, 0, 0), (0, 0, 255), (128, 0, 255), (255, 128, 0), (255, 0, 128), (0, 128, 255), + (0, 255, 128), (128, 255, 0), (255, 128, 64), (255, 64, 128), (128, 255, 64), + (128, 64, 255), (64, 128, 255), (64, 255, 128)]) + else: + self.colors = colors + if color_maps is None: + self.color_maps = {"autumn": cv2.COLORMAP_AUTUMN, + "bone": cv2.COLORMAP_BONE, + "jet": cv2.COLORMAP_JET, + "winter": cv2.COLORMAP_WINTER, + "rainbow": cv2.COLORMAP_RAINBOW, + "ocean": cv2.COLORMAP_OCEAN, + "summer": cv2.COLORMAP_SUMMER, + "spring": cv2.COLORMAP_SPRING, + "cool": cv2.COLORMAP_COOL, + "hsv": cv2.COLORMAP_HSV, + "pink": cv2.COLORMAP_PINK, + "hot": cv2.COLORMAP_HOT} + else: + self.color_maps = color_maps + + self.interpolation = {"nearest": cv2.INTER_NEAREST, + "linear": cv2.INTER_LINEAR, + "area": cv2.INTER_AREA, + "cubic": cv2.INTER_CUBIC, + "lanczos": cv2.INTER_LANCZOS4, + } + + def __prep_col_image__(self, frame=None, color_id=None): + if frame is not None: + frame = frame.copy() + if color_id is None: + color = self.colors[0] + elif isinstance(color_id, tuple): + color = color_id + else: + color = self.colors[color_id] + return frame, color + + def __prep_map_image__(self, frame=None, color_map=None): + if frame is not None: + frame = frame.copy() + if color_map is None: + color = self.color_maps["jet"] + elif isinstance(color_map, str): + color = self.color_maps[color_map] + else: + color = color_map + return frame, color + + def resize(self, frame, width=None, height=None, interpolation="nearest"): + if isinstance(interpolation, str): + interpolation = self.interpolation[interpolation] + frame = frame.copy() + frame = cv2.resize(frame, (width, height), interpolation) + return frame + + def plot_color_map(self, np_frame, color_map=None): + _, color = self.__prep_map_image__(None, color_map) + frame = cv2.applyColorMap(np_frame, colormap=color) + return frame + + def plot_point(self, frame, xy, color_id=None, radius=10, thickness=-1): + frame, color = self.__prep_col_image__(frame, color_id) + cv2.circle(frame, xy, radius, color, thickness) + return frame + + def plot_bbox(self, frame, xy_min, xy_max, color_id=None, thickness=2): + frame, color = self.__prep_col_image__(frame, color_id) + cv2.rectangle(frame, xy_min, xy_max, color, thickness) + return frame + + def plot_text(self, frame, text, xy, color_id=None, thickness=2, font_scale=0.5): + frame, color = self.__prep_col_image__(frame, color_id) + cv2.putText(frame, str(text), xy, + cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness) + return frame + + def plot_arrow(self, frame, xy_orig, xy_tgt, color_id=None, thickness=2): + frame, color = self.__prep_col_image__(frame, color_id) + cv2.arrowedLine(frame, xy_orig, xy_tgt, color, thickness) + return frame + + def plot_axis(self, frame, xy_min, xy_max, xyz, thickness=2): + frame = frame.copy() + pitch = xyz[0] + yaw = -xyz[1] + roll = xyz[2] + + xy_min = np.array(xy_min) + xy_max = np.array(xy_max) + size = np.linalg.norm(xy_max - xy_min) + xy1 = xy_min + ((xy_max - xy_min) / 2) + xy1 = xy1.astype("int32") + + tdx = xy1[0] + tdy = xy1[1] + + # X-Axis pointing to right. drawn in red + x1 = size * (np.cos(yaw) * np.cos(roll)) + tdx + y1 = size * (np.cos(pitch) * np.sin(roll) + np.cos(roll) * np.sin(pitch) * np.sin(yaw)) + tdy + + # Y-Axis | drawn in green + # v + x2 = size * (-np.cos(yaw) * np.sin(roll)) + tdx + y2 = size * (np.cos(pitch) * np.cos(roll) - np.sin(pitch) * np.sin(yaw) * np.sin(roll)) + tdy + + # Z-Axis (out of the screen) drawn in blue + x3 = size * (np.sin(yaw)) + tdx + y3 = size * (-np.cos(yaw) * np.sin(pitch)) + tdy + + cv2.line(frame, (int(tdx), int(tdy)), (int(x1), int(y1)), (0, 0, 255), thickness) + cv2.line(frame, (int(tdx), int(tdy)), (int(x2), int(y2)), (0, 255, 0), thickness) + cv2.line(frame, (int(tdx), int(tdy)), (int(x3), int(y3)), (255, 0, 0), thickness) + + return frame + + def plot_fov_mask(self, frame, xy, radius=50, thickness=-1): + frame = frame.copy() + mask = np.zeros_like(frame) + cv2.circle(mask, xy, radius, (255, 255, 255), thickness=thickness) + mask_blur = cv2.GaussianBlur(mask, (51, 51), 0) + frame = frame * (mask_blur / 255) + return frame + + def plot_conic_field(self, frame, xyz_orig, xyz_tgt, radius_orig=1, radius_tgt=10, color_map=None): + frame, color = self.__prep_map_image__(frame, color_map) + h, w, _ = frame.shape + xyz_orig = np.array(xyz_orig[:3]) + xyz_tgt = xyz_orig + np.array(xyz_tgt[:3]) * 10 + p2i = conic_projection(xyz_orig, xyz_tgt, width=w, height=h, radius_orig=radius_orig, radius_tgt=radius_tgt) + frame = self.plot_color_map(255 - np.uint8(((p2i - p2i.min()) / (p2i.max() - p2i.min())) * 255), color_map=color) + return frame + + def plot_alpha_overlay(self, frame, overlay, xy_min=None, xy_max=None, alpha=0.2, interpolation="nearest"): + frame = frame.copy() + overlay = overlay.copy() + h, w, _ = frame.shape + ho, wo, co = overlay.shape + if xy_min is None and xy_max is None: + if w != wo or h != ho: + overlay = self.resize(overlay, height=h, width=w, interpolation=interpolation) + else: + y1, y2 = max(0, xy_min[1]), min(h, xy_max[1]) + x1, x2 = max(0, xy_min[0]), min(w, xy_max[0]) + # add alpha channel + tmp = cv2.cvtColor(overlay, cv2.COLOR_BGR2GRAY) + _, a = cv2.threshold(tmp, 0, 255, cv2.THRESH_TOZERO) + b, g, r = cv2.split(overlay) + rgba = [b, g, r, a] + overlay = cv2.merge(rgba, 4) + # resize overlay to fit bbox + overlay = self.resize(overlay, width=x2 - x1, height=y2 - y1) + alpha_s = overlay[:, :, 3] / 255.0 + alpha_l = 1.0 - alpha_s + overlay_full = np.zeros_like(frame) + for c in range(0, 3): + overlay_full[y1:y2, x1:x2, c] = (alpha_s * overlay[:, :, c] + alpha_l * frame[y1:y2, x1:x2, c]) + overlay = overlay_full + + frame = cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0) + + return frame + + ####################################### Multiple points ################################################# + + def plot_fixations_density_map(self, frame, xy_fix, xy_std=(10, 10), alpha=0.2, color_map=None): + frame, color = self.__prep_map_image__(frame, color_map) + height, width = frame.shape[:2] + heatmap = mp_multivariate_gaussian(xy_fix, width=width, height=height, xy_std=xy_std) + heatmap = np.divide(heatmap, np.amax(heatmap), out=heatmap, where=np.amax(heatmap) != 0) + heatmap *= 255 + heatmap = heatmap.astype("uint8") + + overlay = self.plot_color_map(heatmap, color_map=color) + + overlay = overlay.astype("uint8") + frame = self.plot_alpha_overlay(frame, overlay, alpha=alpha) + return frame + + def plot_fixations_locations(self, frame, xy_fix, radius=10): + frame = frame.copy() + for xy in xy_fix: + try: + cv2.circle(frame, (int(xy[0]), int(xy[1])), radius, (255,255,255), -1) + except ValueError: + pass + return frame + +class MatplotLib(object): + def __init__(self): + raise NotImplementedError("Matplotlib plotting not yet supported") diff --git a/gazenet/utils/audio_features.py b/gazenet/utils/audio_features.py new file mode 100644 index 0000000..251bfaf --- /dev/null +++ b/gazenet/utils/audio_features.py @@ -0,0 +1,278 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified by Fares Abawi (abawi@informatik.uni-hamburg.de) +# ============================================================================== + + +import resampy + +import numpy as np + +from gazenet.utils.registrar import * + + +@AudioFeatureRegistrar.register +class WindowedAudioFeatures(object): + """Split the raw signal into windows and applies a hanning window without any spectral transformation""" + def __init__(self, rate=48000, + win_len=64, + **kwargs): + self.rate = rate + self.win_len = win_len + + def waveform_to_feature(self, data, rate): + # TODO (fabawi): resampling is an improvisation. Check if it actually works + data = resampy.resample(data, rate, self.rate) + audiowav = data * (2 ** -23) + # TODO (fabawi): check the hanning + audiowav = np.hanning(audiowav.shape[1]) * audiowav + return audiowav + +@AudioFeatureRegistrar.register +class MFCCAudioFeatures(object): + """Defines routines to compute mel spectrogram features from audio waveform.""" + def __init__(self, rate=16000, + win_len_sec=0.64, # Each example contains 50 10ms video_frames_list + hop_len_sec=0.02, # Defined dynamically as 1/(video_fps) + log_offset=0.010, # Offset used for stabilized log of input mel-spectrogram + stft_win_len_sec=0.025, + stft_hop_len_sec=0.010, + mel_len=64, + mel_min_hz=125, + mel_max_hz=7500, + mel_break_hz=700.0, + mel_high_q=1127.0, + **kwargs): + self.rate = rate + self.win_len_sec = win_len_sec + self.hop_len_sec = hop_len_sec + self.log_offset = log_offset + self.stft_win_len_sec = stft_win_len_sec + self.stft_hop_len_sec = stft_hop_len_sec + self.mel_len = mel_len + self.mel_min_hz = mel_min_hz + self.mel_max_hz = mel_max_hz + self.mel_break_hz = mel_break_hz + self.mel_high_q = mel_high_q + + def frame(self, data, win_len, hop_len): + """Convert array into a sequence of successive possibly overlapping video_frames_list. + + An n-dimensional array of shape (num_samples, ...) is converted into an + (n+1)-D array of shape (num_frames, window_length, ...), where each frame + starts hop_length points after the preceding one. + + This is accomplished using stride_tricks, so the original data is not + copied. However, there is no zero-padding, so any incomplete video_frames_list at the + end are not included. + + Args: + data: np.array of dimension N >= 1. + win_len: Number of samples in each frame. + hop_len: Advance (in samples) between each window. + + Returns: + (N+1)-D np.array with as many rows as there are complete video_frames_list that can be + extracted. + """ + num_samples = data.shape[0] + num_frames = 1 + int(np.floor((num_samples - win_len) / hop_len)) + shape = (num_frames, win_len) + data.shape[1:] + strides = (data.strides[0] * hop_len,) + data.strides + return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) + + def periodic_hann(self, win_len): + """Calculate a "periodic" Hann window. + + The classic Hann window is defined as a raised cosine that starts and + ends on zero, and where every value appears twice, except the middle + point for an odd-length window. Matlab calls this a "symmetric" window + and np.hanning() returns it. However, for Fourier analysis, this + actually represents just over one cycle of a period N-1 cosine, and + thus is not compactly expressed on a length-N Fourier basis. Instead, + it's better to use a raised cosine that ends just before the final + zero value - i.e. a complete cycle of a period-N cosine. Matlab + calls this a "periodic" window. This routine calculates it. + + Args: + win_len: The number of points in the returned window. + + Returns: + A 1D np.array containing the periodic hann window. + """ + return 0.5 - (0.5 * np.cos(2 * np.pi / win_len * np.arange(win_len))) + + def stft_magnitude(self, signal, fft_len, hop_len, win_len): + """Calculate the short-time Fourier transform magnitude. + + Args: + signal: 1D np.array of the input time-domain signal + fft_len: Size of the FFT to apply + hop_len: Advance (in samples) between each frame passed to FFT + win_len: Length of each block of samples to pass to FFT + + Returns: + 2D np.array where each row contains the magnitudes of the fft_length/2+1 + unique values of the FFT for the corresponding frame of input samples. + """ + frames = self.frame(signal, win_len, hop_len) + # Apply frame window to each frame. We use a periodic Hann (cosine of period + # window_length) instead of the symmetric Hann of np.hanning (period + # window_length-1). + window = self.periodic_hann(win_len) + windowed_frames = frames * window + return np.abs(np.fft.rfft(windowed_frames, int(fft_len))) + + def hertz_to_mel(self, freqs_hz): + """Convert frequencies to mel scale using HTK formula. + + Args: + freqs_hertz: Scalar or np.array of frequencies in hertz. + + Returns: + Object of same size as frequencies_hertz containing corresponding values + on the mel scale. + """ + return self.mel_high_q * np.log( + 1.0 + (freqs_hz / self.mel_break_hz)) + + def spectrogram_to_mel_matrix(self, spectro_len=129): + + """Return a matrix that can post-multiply spectrogram rows to make mel. + + Returns a np.array matrix A that can be used to post-multiply a matrix S of + spectrogram values (STFT magnitudes) arranged as video_frames_list x bins to generate a + "mel spectrogram" M of video_frames_list x mel_len. M = S A. + + The classic HTK algorithm exploits the complementarity of adjacent mel bands + to multiply each FFT bin by only one mel weight, then add it, with positive + and negative signs, to the two adjacent mel bands to which that bin + contributes. Here, by expressing this operation as a matrix multiply, we go + from num_fft multiplies per frame (plus around 2*num_fft adds) to around + num_fft^2 multiplies and adds. However, because these are all presumably + accomplished in a single call to np.dot(), it's not clear which approach is + faster in Python. The matrix multiplication has the attraction of being more + general and flexible, and much easier to read. + + Args: + spectro_len: How many bins there are in the source spectrogram + data, which is understood to be fft_size/2 + 1, i.e. the spectrogram + only contains the nonredundant FFT bins + + Returns: + An np.array with shape (num_spectrogram_bins, mel_len). + + Raises: + ValueError: if frequency edges are incorrectly ordered or out of range. + """ + nyquist_hz = self.rate / 2. + mel_len = self.mel_len + mel_min_hz = self.mel_min_hz + mel_max_hz = self.mel_max_hz + if mel_min_hz < 0.0: + raise ValueError("mel_min_hz %.1f must be >= 0" % mel_min_hz) + if mel_min_hz >= mel_max_hz: + raise ValueError("mel_min_hz %.1f >= mel_max_hz %.1f" % + (mel_min_hz, mel_max_hz)) + if mel_max_hz > nyquist_hz: + raise ValueError("mel_max_hz %.1f is greater than Nyquist %.1f" % + (mel_max_hz, nyquist_hz)) + spectro_len_hz = np.linspace(0.0, nyquist_hz, spectro_len) + spectro_len_mel = self.hertz_to_mel(spectro_len_hz) + # The i'th mel band (starting from i=1) has center frequency + # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge + # band_edges_mel[i+1]. Thus, we need mel_len + 2 values in + # the band_edges_mel arrays. + band_edges_mel = np.linspace(self.hertz_to_mel(mel_min_hz), + self.hertz_to_mel(mel_max_hz), mel_len + 2) + # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins + # of spectrogram values. + mel_weights_matrix = np.empty((spectro_len, mel_len)) + for i in range(mel_len): + lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] + # Calculate lower and upper slopes for every spectrogram bin. + # Line segments are linear in the *mel* domain, not hertz. + lower_slope = ((spectro_len_mel - lower_edge_mel) / + (center_mel - lower_edge_mel)) + upper_slope = ((upper_edge_mel - spectro_len_mel) / + (upper_edge_mel - center_mel)) + # .. then intersect them with each other and zero. + mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, + upper_slope)) + # HTK excludes the spectrogram DC bin; make sure it always gets a zero + # coefficient. + mel_weights_matrix[0, :] = 0.0 + return mel_weights_matrix + + def log_mel_spectrogram(self, data): + """Convert waveform to a log magnitude mel-frequency spectrogram. + + Args: + data: 1D np.array of waveform data + + Returns: + 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank + magnitudes for successive video_frames_list. + """ + win_len_samples = int(round(self.rate * self.stft_win_len_sec)) + hop_len_samples = int(round(self.rate * self.stft_hop_len_sec)) + fft_length = 2 ** int(np.ceil(np.log(win_len_samples) / np.log(2.0))) + spectrogram = self.stft_magnitude( + data, + fft_len=fft_length, + hop_len=hop_len_samples, + win_len=win_len_samples) + mel_spectrogram = np.dot(spectrogram, self.spectrogram_to_mel_matrix( + spectro_len=spectrogram.shape[1])) + return np.log(mel_spectrogram + self.log_offset) + + def waveform_to_feature(self, data, rate): + """Converts audio waveform into an array of examples for VGGish. + + Args: + data: np.array of either one dimension (mono) or two dimensions + (multi-channel, with the outer dimension representing channels). + Each sample is generally expected to lie in the range [-1.0, +1.0], + although this is not required. + sample_rate: Sample rate of data + + Returns: + 3-D np.array of shape [num_examples, num_frames, num_bands] which represents + a sequence of examples, each of which contains a patch of log mel + spectrogram, covering num_frames video_frames_list of audio and num_bands mel frequency + bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. + """ + data = data.flatten()[..., np.newaxis] + # convert to mono. + if len(data.shape) > 1: + data = np.mean(data, axis=1) + + # resample to the rate assumed by VGGish + if rate != self.rate: + data = resampy.resample(data, rate, self.rate) + + # compute log mel spectrogram features + log_mel = self.log_mel_spectrogram(data) + + # frame features into examples + features_sample_rate = 1.0 / self.stft_hop_len_sec + example_win_len = int(round( + self.win_len_sec * features_sample_rate)) + example_hop_len = int(round( + self.hop_len_sec * features_sample_rate)) + log_mel_examples = self.frame( + log_mel, + win_len=example_win_len, + hop_len=example_hop_len) + return log_mel_examples diff --git a/gazenet/utils/dataset_processors.py b/gazenet/utils/dataset_processors.py new file mode 100644 index 0000000..259defd --- /dev/null +++ b/gazenet/utils/dataset_processors.py @@ -0,0 +1,390 @@ +import pickle +import os + +import cv2 +import numpy as np +import pandas as pd +from tqdm import tqdm + +from gazenet.utils.registrar import * +from gazenet.utils.helpers import extract_width_height_thumbnail_from_image +from gazenet.utils.sample_processors import SampleReader, SampleProcessor, ImageCapture + + +# TODO (fabawi): support annotation reading +@ReaderRegistrar.register +class DataSampleReader(SampleReader): + def __init__(self, video_dir="datasets/processed/Grouped_frames", + annotations_dir=None, + extract_thumbnails=True, + thumbnail_image_file="captured_1.jpg", + pickle_file="temp/processed.pkl", mode=None, **kwargs): + self.short_name = "processed" + self.video_dir = video_dir + self.annotations_dir = annotations_dir + self.extract_thumbnails = extract_thumbnails + self.thumbnail_image_file = thumbnail_image_file + + super().__init__(pickle_file=pickle_file, mode=mode, **kwargs) + + def read_raw(self): + video_groups = [video_group for video_group in sorted(os.listdir(self.video_dir))] + video_names = [os.path.join(video_group, video_name) for video_group in video_groups + for video_name in sorted(os.listdir(os.path.join(self.video_dir, video_group)))] + + for video_name in tqdm(video_names, desc="Samples Read"): + id = video_name + try: + + len_frames = len([name for name in os.listdir(os.path.join(self.video_dir, video_name)) + if os.path.isdir(os.path.join(self.video_dir, video_name))]) + width, height, thumbnail = extract_width_height_thumbnail_from_image( + os.path.join(self.video_dir, video_name, "1", self.thumbnail_image_file)) + + self.samples.append({"id": id, + "audio_name": '', + "video_name": os.path.join(self.video_dir, video_name), + "video_fps": 25, # 30 + "video_width": width, + "video_height":height, + "video_thumbnail": thumbnail, + "len_frames": len_frames, + "has_audio": False, + "annotation_name": os.path.join('videogaze', id), + "annotations": {} + }) + self.video_id_to_sample_idx[id] = len(self.samples) - 1 + self.len_frames += self.samples[-1]["len_frames"] + except: + print("Error: Access non-existent annotation " + id) + + @staticmethod + def dataset_info(): + return {"summary": "TODO", + "name": "Processed Dataset", + "link": "TODO"} + + +@SampleRegistrar.register +class DataSample(SampleProcessor): + def __init__(self, reader, index=-1, frame_index=0, width=640, height=480, **kwargs): + assert isinstance(reader, DataSampleReader) + self.short_name = reader.short_name + self.reader = reader + self.index = index + + if frame_index > 0: + self.goto_frame(frame_index) + + kwargs.update(enable_audio=False) + super().__init__(width=width, height=height, + video_reader=(ImageCapture, {"extension": "jpg", + "sub_directories": True, + "image_file": "captured_1"}), **kwargs) + next(self) + + def __next__(self): + with self.read_lock: + self.index += 1 + self.index %= len(self.reader.samples) + curr_metadata = self.reader.samples[self.index] + self.load(curr_metadata) + return curr_metadata + + def __len__(self): + return len(self.reader) + + def next(self): + return next(self) + + def goto(self, name, by_index=True): + if by_index: + index = name + else: + index = self.reader.video_id_to_sample_idx[name] + + with self.read_lock: + self.index = index + curr_metadata = self.reader.samples[self.index] + self.load(curr_metadata) + return curr_metadata + + def frames_per_sec(self): + if self.video_cap is not None: + return self.reader.samples[self.index]["video_fps"] + else: + return 0 + + def annotate_frame(self, input_data, plotter, + show_gaze=False, show_gaze_label=False, img_names_list=None, + **kwargs): + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, _ = input_data + + properties = {} + + info = {**info, "frame_annotations": {}} + # info["frame_info"]["dataset_name"] = self.reader.short_name + # info["frame_info"]["video_id"] = self.reader.samples[self.index]["id"] + # info["frame_info"]["frame_height"] = self.reader.samples[self.index]["video_height"] + # info["frame_info"]["frame_width"] = self.reader.samples[self.index]["video_width"] + + grouped_video_frames = {**grouped_video_frames, + "PLOT": [["captured"]] + } + + try: + frame_index = self.frame_index() + frame_name = self.video_cap.frames[frame_index-1] + frame_dir = os.path.join(self.video_cap.directory, os.path.dirname(frame_name)) + if grabbed_video and img_names_list is not None: + for img_name in img_names_list: + try: + img = cv2.imread(os.path.join(frame_dir, img_name + "_1.jpg")) + except cv2.Error: + img = np.zeros_like(grouped_video_frames["captured"]) + + grouped_video_frames[img_name] = img + + except: + pass + + return grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties + + def get_participant_frame_range(self,participant_id): + raise NotImplementedError + + +class DataSplitter(object): + """ + Reads and writes the split (Train, validation, and Test) sets and stores the groups for training and + evaluation. The file names are stored in csv files and are not split automatically. This provides an interface + for manually adding videos to the assigned lists + """ + def __init__(self, train_csv_file="datasets/processed/train.csv", + val_csv_file="datasets/processed/validation.csv", + test_csv_file="datasets/processed/test.csv", + mode="d", **kwargs): + if (train_csv_file is None and val_csv_file is None and test_csv_file is None) or mode is None: + raise AttributeError("Specify atleast 1 csv file and/or choose a supported mode (r,w,x,d)") + + self.train_csv_file = train_csv_file + self.val_csv_file = val_csv_file + self.test_csv_file = test_csv_file + + self.mode = mode + self.columns = ["video_id", "fps", "scene_type", "dataset"] + self.samples = pd.DataFrame(columns=self.columns + ["split"]) + self.open() + + def read(self, csv_file, split): + if csv_file is not None: + if self.mode == "r": # read or append + samples = pd.read_csv(csv_file, names=self.columns, header=0) + samples["split"] = split + + self.samples = pd.concat([self.samples, samples]) + elif self.mode == "d": # dynamic: if the pickle_file exists it will be read, otherwise, a new dataset is created + if os.path.exists(csv_file): + samples = pd.read_csv(csv_file, names=self.columns, header=0) + samples["split"] = split + self.samples = pd.concat([self.samples, samples]) + + elif self.mode == "x": # safe write + if os.path.exists(csv_file): + raise FileExistsError("Read mode 'x' safely writes a file. " + "Either delete the csv_file '" + csv_file + "' or change the read mode") + + def sample(self, video_id, dataset, fps=0, scene_type=None, split=None, mode="d"): + # the mode specified here controls the sample whereas the class' mode controls the data splits on file + # the grouping is based on the video_id and dataset + if mode == "r": + match = self.samples[(self.samples["video_id"] == video_id) & (self.samples["dataset"] == dataset)] + if match.empty: + match = {"split": None, "scene_type": None} + elif mode == "d": + match = self.samples[(self.samples["video_id"] == video_id) & (self.samples["dataset"] == dataset)] + if match.empty: + match = pd.DataFrame([[video_id, fps, scene_type, dataset, split]], columns=self.columns + ["split"]) + self.samples = self.samples.append(match, ignore_index=True) + else: + if fps is not None: + self.samples.loc[(self.samples["video_id"] == video_id) & (self.samples["dataset"] == dataset), "fps"] = fps + if scene_type is not None: + self.samples.loc[(self.samples["video_id"] == video_id) & (self.samples["dataset"] == dataset), "scene_type"] = scene_type + if split is not None: + self.samples.loc[(self.samples["video_id"] == video_id) & (self.samples["dataset"] == dataset), "split"] = split + match = self.samples[(self.samples["video_id"] == video_id) & (self.samples["dataset"] == dataset)] + elif mode == "x": + match = self.samples[(self.samples["video_id"] == video_id) & (self.samples["dataset"] == dataset)] + if match.empty: + match = pd.DataFrame([[video_id, fps, scene_type, dataset, split]], columns=self.columns + ["split"]) + self.samples = self.samples.append(match, ignore_index=True) + elif mode == "w": + match = pd.DataFrame([[video_id, fps, scene_type, dataset, split]], columns=self.columns + ["split"]) + self.samples = self.samples.append(match, ignore_index=True) + + return match["split"], match["scene_type"] + + def write(self, csv_file, split): + if csv_file is not None: + if self.mode == "d" or self.mode == "w" or self.mode == "x": # read or append + for name, group in self.samples.groupby("split"): + if name == split: + group = group.drop(["split"], axis=1) + group.to_csv(csv_file, index=False) + + def open(self): + self.read(self.train_csv_file, "train") + self.read(self.val_csv_file, "val") + self.read(self.test_csv_file, "test") + + def save(self): + self.write(self.train_csv_file, "train") + self.write(self.val_csv_file, "val") + self.write(self.test_csv_file, "test") + + def close(self): + self.save() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def __len__(self): + return len(self.samples) + + +class DataWriter(object): + """ + Writes the dataset in the format supported by the DataLoader. This writer assumes the structures of + grouped_video_frames_list and info_list as resulting from annotate_frames in VideoProcessor + """ + def __init__(self, dataset_name, video_name, save_dir="datasets/processed/", + output_video_size=(640, 480), frames_per_sec=20, + write_images=True, write_videos=False, write_annotations=True): + + if dataset_name == "processed": + rename_ds = video_name.split(os.sep) + dataset_name = rename_ds[0] + video_name = rename_ds[1] + + self.save_dir = save_dir + self.dataset_name = dataset_name + self.video_name = video_name + self.output_video_size = output_video_size + self.frames_per_sec = frames_per_sec + self.write_images = write_images + self.write_videos = write_videos + self.write_annotations = write_annotations + # loading annotations not needed if they will not be written + if write_annotations: + path_to_pickle = os.path.join(self.save_dir, "Annotations", self.dataset_name, self.video_name) + if os.path.exists(path_to_pickle): + self.annotations = {} + for root_dir, dirnames, filenames in sorted(os.walk(path_to_pickle)): + for filename in filenames: + if filename.endswith(".pkl"): + self.annotations[int(filename.rstrip(".pkl"))] = pickle.load(open(os.path.join(path_to_pickle, + filename), "rb")) + else: + self.annotations = {} + else: + self.annotations = {} + self.videos = {} + + def add_detections(self, returns, models): + current_dict = {} + for idx_model, model_data in enumerate(models): + for i, frame_dict in enumerate(returns[2 + idx_model][4]): + if self.write_annotations: + current_dict = {**current_dict, **self.make_id_key(frame_dict)} + if self.write_images or self.write_videos: + for image_group in returns[2 + idx_model][1][i]["PLOT"]: + for image_name in image_group: + image = returns[2 + idx_model][1][i][image_name] + if image is not None and np.shape(image) != () and image.any(): + self.write_transformed_image(frame_id=frame_dict["frame_info"]["frame_id"], + img_array=image, img_name=image_name) + if self.write_annotations: + self.merge_into_annotations(current_dict) + + def merge_into_annotations(self, current_dict): + for key in current_dict: + if key in self.annotations and not "frame_info": + self.annotations[key] = self.deep_append(self.annotations[key], current_dict[key]) + else: + self.annotations[key] = current_dict[key] + + @staticmethod + def make_id_key(old_dict): + # quick restructuring so that frame_id is top level + new_dict = {old_dict["frame_info"]["frame_id"]: old_dict} + return new_dict + + @staticmethod + def deep_append(old_dict, new_dict): + # merges dicts of dicts (of dicts) of lists without deleting + for key in old_dict: + if isinstance(old_dict[key], dict): + if new_dict[key] and new_dict[key] is not None: + old_dict[key] = DataWriter.deep_append(old_dict[key], new_dict[key]) + else: + if new_dict[key] and new_dict[key] is not None: + old_dict[key].extend(new_dict[key]) + return old_dict + + def write_transformed_image(self, frame_id, img_array, img_name): + # save the transformed images as jpegs, see mattermost for folder structure + write_path = os.path.join(self.save_dir, "{}", self.dataset_name) + if self.write_images: + img_path = os.path.join(write_path.format("Grouped_frames"), self.video_name, str(frame_id)) + if not os.path.exists(img_path): + os.makedirs(img_path, exist_ok=True) + index = 1 + while os.path.exists(os.path.join(img_path, img_name + "_" + str(index) + ".jpg")): + index += 1 + cv2.imwrite(os.path.join(img_path, img_name + "_" + str(index) + ".jpg"), img_array) + if self.write_videos: + if not img_name in self.videos: + vid_path = os.path.join(write_path.format("Videos"), self.video_name) + if not os.path.exists(vid_path): + os.makedirs(vid_path, exist_ok=True) + video_enc = cv2.VideoWriter_fourcc(*"XVID") + self.videos[img_name] = {"writer": cv2.VideoWriter(os.path.join(vid_path, img_name + '.avi'), video_enc, + self.frames_per_sec, self.output_video_size), # 25, (1232,504)), # + "last_frame": frame_id-1} + if self.videos[img_name]["last_frame"] < frame_id: + # self.videos[img_name]["writer"].write(cv2.resize(img_array, (1232,504))) + self.videos[img_name]["writer"].write(cv2.resize(img_array, self.output_video_size)) + self.videos[img_name]["last_frame"] = frame_id + + def dump_annotations(self): + if self.write_annotations: + path_to_pickle = os.path.join(self.save_dir, "Annotations", self.dataset_name, self.video_name) + if not os.path.exists(path_to_pickle): + os.makedirs(path_to_pickle, exist_ok=True) + for frame, annotation in self.annotations.items(): + pickle.dump(annotation, open(os.path.join(path_to_pickle, str(frame) + ".pkl"), "wb"), + protocol=pickle.HIGHEST_PROTOCOL) + + def clear_annotations(self): + self.annotations = {} + + def dump_videos(self): + if self.write_videos: + for vid_name in self.videos.keys(): + self.videos[vid_name]["writer"].release() + self.videos = {} + + def set_new_name(self, vid_name, output_vid_size=None, fps=None): + if os.sep in vid_name: + rename_ds = vid_name.split(os.sep) + self.dataset_name = rename_ds[0] + self.video_name = rename_ds[1] + else: + self.video_name = vid_name + self.clear_annotations() + if output_vid_size is not None: + self.output_video_size = output_vid_size + if fps is not None: + self.frames_per_sec = fps + diff --git a/gazenet/utils/face_detectors.py b/gazenet/utils/face_detectors.py new file mode 100644 index 0000000..61b82d7 --- /dev/null +++ b/gazenet/utils/face_detectors.py @@ -0,0 +1,96 @@ +from gazenet.utils.registrar import * + + +@FaceDetectorRegistrar.register +class DlibFaceDetection(object): + def __init__(self): + import face_recognition + self.__face_recognition__ = face_recognition + + def detect_frames(self, video_frames_list, match_face=None, **kwargs): + # loading the features for tracking + if match_face is not None: + p_image = self.__face_recognition__.load_image_file("face.jpg") + p_encoding = self.__face_recognition__.face_encodings(p_image)[0] + + faces_locations = [] + for f_idx in range(len(video_frames_list)): + # detecting the person inside the image if specified + if video_frames_list[f_idx] is not None: + boxes = self.__face_recognition__.face_locations(video_frames_list[f_idx]) + if match_face is not None: + tmp_encodings = self.__face_recognition__.face_encodings(video_frames_list[f_idx]) + results = self.__face_recognition__.compare_faces(tmp_encodings, p_encoding) + for id, box in enumerate(boxes): + if results[id]: + boxes = [box] + break + faces_locations.append(boxes) + else: + faces_locations.append([]) + return faces_locations + + +@FaceDetectorRegistrar.register +class MTCNNFaceDetection(object): + def __init__(self, device="cuda:0"): + from facenet_pytorch import MTCNN + import numpy as np + self.__mtcnn__ = MTCNN(keep_all=True, device=device) + self.__np__ = np + + def detect_frames(self, video_frames_list, **kwargs): + faces_locations = [] + for f_idx in range(len(video_frames_list)): + # detect faces + boxes = [] + if video_frames_list[f_idx] is not None: + box_vals, _ = self.__mtcnn__.detect(video_frames_list[f_idx]) + if box_vals is not None: + for box in box_vals: + (left, top, right, bottom) = box.astype(self.__np__.int32).tolist() + boxes.append([top, right, bottom, left]) + faces_locations.append(boxes) + return faces_locations + + +@FaceDetectorRegistrar.register +class SFDFaceDetection: + def __init__(self, landmarks_type=1, device='cuda:0', flip_input=False, verbose=False): + import torch + import torch.backends.cudnn as cudnn + from face_alignment.detection.sfd import FaceDetector + import numpy as np + self.__torch__ = torch + self.__np__ = np + self.device = device + self.flip_input = flip_input + self.landmarks_type = landmarks_type + self.verbose = verbose + + if 'cuda' in device: + cudnn.benchmark = True + + # Get the face detector + self.face_detector = FaceDetector(device=device, verbose=verbose) + + def detect_frames(self, video_frames_list, **kwargs): + # images = self.__np__.asarray(video_frames_list)[..., ::-1] + # images = self.__np__.squeeze(images, axis=1) + # images = self.__torch__.FloatTensor(images) + images = self.__np__.moveaxis(self.__np__.stack(video_frames_list), -1, 1) + images = self.__torch__.from_numpy(images).to(device=self.device) + detected_faces = self.face_detector.detect_from_batch(images) + face_locations = [] + + for i, d in enumerate(detected_faces): + if len(d) == 0: + face_locations.append([]) + continue + boxes = [] + for b in d: + b = self.__np__.clip(b, 0, None) + x1, y1, x2, y2 = map(int, b[:-1]) + boxes.append([y1, x2, y2, x1]) + face_locations.append(boxes) + return face_locations \ No newline at end of file diff --git a/gazenet/utils/helpers.py b/gazenet/utils/helpers.py new file mode 100644 index 0000000..56d473d --- /dev/null +++ b/gazenet/utils/helpers.py @@ -0,0 +1,368 @@ +################################################ Formatting ########################################################## + +# based on: https://stackoverflow.com/a/62001539 +def flatten_dict(input_node: dict, key_: str = '', output_dict: dict = {}): + if isinstance(input_node, dict): + for key, val in input_node.items(): + new_key = f"{key_}.{key}" if key_ else f"{key}" + flatten_dict(val, new_key, output_dict) + elif isinstance(input_node, list) or isinstance(input_node, tuple): + for idx, item in enumerate(input_node): + flatten_dict(item, f"{key_}.[{idx}]", output_dict) + else: + output_dict[key_] = input_node + return output_dict + + +def dynamic_module_import(modules, globals): + import importlib + for module_name in modules: + if not module_name.endswith(".py") or module_name.endswith("__.py"): + continue + module_name = module_name[:-3] + module_name = module_name.replace("/", ".") + module = __import__(module_name, fromlist=['*']) + # importlib.import_module(module_name) + if hasattr(module, '__all__'): + all_names = module.__all__ + else: + all_names = [name for name in dir(module) if not name.startswith('_')] + globals.update({name: getattr(module, name) for name in all_names}) + + +def adjust_len(a, b): + # adjusts the len of two sorted lists + al = len(a) + bl = len(b) + if al > bl: + start = (al - bl) // 2 + end = bl + start + a = a[start:end] + if bl > al: + a, b = adjust_len(b, a) + return a, b + + +def circular_list(ls): + class CircularList(list): + def __getitem__(self, x): + import operator + if isinstance(x, slice): + return [self[x] for x in self._rangeify(x)] + + index = operator.index(x) + try: + return super().__getitem__(index % len(self)) + except ZeroDivisionError: + raise IndexError('list index out of range') + + def _rangeify(self, slice): + start, stop, step = slice.start, slice.stop, slice.step + if start is None: + start = 0 + if stop is None: + stop = len(self) + if step is None: + step = 1 + return range(start, stop, step) + return CircularList(ls) + + +def check_audio_in_video(filename): + # TODO (fabawi): slows down the reading. Consider finding an alternative + import subprocess + import re + mean_volume = subprocess.run("ffmpeg -hide_banner -i " + filename + + " -af volumedetect -vn -f null - 2>&1 | grep mean_volume", + stdout=subprocess.PIPE, shell=True).stdout.decode('utf-8') + # if mean_volume is not None: + # mean_volume = float(re.search(r'mean_volume:(.*?)dB', mean_volume).group(1)) + # else: + # mean_volume = -91.0 + # + # if mean_volume > -90.0: + # has_audio = True + # else: + # has_audio = False + if not mean_volume or '-91.0 dB' in mean_volume: + has_audio = False + else: + has_audio = True + return has_audio + + +def extract_width_height_from_video(filename): + import cv2 + vcap = cv2.VideoCapture(filename) + width = vcap.get(cv2.CAP_PROP_FRAME_WIDTH) + height = vcap.get(cv2.CAP_PROP_FRAME_HEIGHT) + vcap.release() + return int(width), int(height) + + +def extract_thumbnail_from_video(filename, thumb_width=180, thumb_height=108, threshold=1): + import cv2 + vcap = cv2.VideoCapture(filename) + res, im_ar = vcap.read() + while im_ar.mean() < threshold and res: + res, im_ar = vcap.read() + im_ar = cv2.resize(im_ar, (thumb_width, thumb_height), 0, 0, cv2.INTER_LINEAR) + vcap.release() + return im_ar + + +def extract_width_height_thumbnail_from_image(filename, thumb_width=180, thumb_height=108): + import cv2 + im = cv2.imread(filename) + height, width = im.shape[:2] + im_ar = cv2.resize(im, (thumb_width, thumb_height), 0, 0, cv2.INTER_LINEAR) + return int(width), int(height), im_ar + + +def encode_image(img, raw=False): + import cv2 + import base64 + ret, jpeg = cv2.imencode('.jpg', img) + if raw: + enc_img = jpeg.tobytes() + return enc_img + else: + enc_img = base64.b64encode(jpeg).decode('UTF-8') + return 'data:image/jpeg;base64,{}'.format(enc_img) + + +def stack_images(grouped_video_frames_list, grabbed_video_list=None, plot_override=None): + import numpy as np + import cv2 + + # resize_to_match = lambda img_src, img_tgt: cv2.resize(img_src, (img_tgt.shape[1], img_tgt.shape[0]), 0, 0, cv2.INTER_LINEAR) + resize_to_match_y = lambda img_src, img_tgt: cv2.resize(img_src, (img_src.shape[1], img_tgt.shape[0]), 0, 0, cv2.INTER_LINEAR) + resize_to_match_x = lambda img_src, img_tgt: cv2.resize(img_src, (img_tgt.shape[1], img_src.shape[0]), 0, 0, cv2.INTER_LINEAR) + + def stack_image(grouped_video_frames, grabbed_video, plot_override=None): + if not grabbed_video: + return None + rows = [] + plot_frames = grouped_video_frames["PLOT"] if plot_override is None else plot_override + for row in plot_frames: + if len(row) > 1: + rows.append(np.concatenate([ + resize_to_match_y(grouped_video_frames[row_frame], grouped_video_frames[row[0]]) for row_frame in row], axis=1)) + else: + rows.append(grouped_video_frames[row[0]]) + if len(rows) > 1: + return np.concatenate([resize_to_match_x(row, rows[0]) for row in rows], axis=0) + else: + return rows[0] + + if isinstance(grouped_video_frames_list, list): + frames_list = [] + for gv_idx, grouped_video_frames in enumerate(grouped_video_frames_list): + if grabbed_video_list is not None: + grabbed_video = grabbed_video_list[gv_idx] + else: + grabbed_video = True + frames_list.append(stack_image(grouped_video_frames, + grabbed_video=grabbed_video, + plot_override=plot_override)) + return frames_list + else: # elif isinstance(grouped_video_frames_list, dict) + return stack_image(grouped_video_frames_list, + grabbed_video=True if grabbed_video_list is None else grabbed_video_list, + plot_override=plot_override) + + +def aggregate_frame_ranges(frame_ids): + """ + Compresses a list of frames to ranges + :param frame_ids: assumes either a list of frame_ids or a list of lists [frame_id, duration] + :return: list of frame ranges + """ + import itertools + frame_ids_updated = [] + + if isinstance(frame_ids[0], list) or isinstance(frame_ids[0], tuple): + frame_ids_updated.extend(i for frame_id, frame_duration in frame_ids for i in range(frame_id, frame_id + frame_duration)) + if not frame_ids_updated: + frame_ids_updated = frame_ids + frame_ids_updated = sorted(list(set(frame_ids_updated))) + + def ranges(frame_ids_updated): + for a, b in itertools.groupby(enumerate(frame_ids_updated), lambda pair: pair[1] - pair[0]): + b = list(b) + yield b[0][1], b[-1][1] + + return list(ranges(frame_ids_updated)) + + +################################################### Math ############################################################# + +def calc_overlap_ratio(bbox, patch_size, patch_num): + """ + compute the overlaping ratio of the bbox and each patch (10x16) + """ + import numpy as np + patch_area = float(patch_size[0] * patch_size[1]) + aoi_ratio = np.zeros((1, patch_num[1], patch_num[0]), dtype=np.float32) + + tl_x, tl_y = bbox[0], bbox[1] + br_x, br_y = bbox[0] + bbox[2], bbox[1] + bbox[3] + lx, ux = tl_x // patch_size[0], br_x // patch_size[0] + ly, uy = tl_y // patch_size[1], br_y // patch_size[1] + + for x in range(lx, ux + 1): + for y in range(ly, uy + 1): + patch_tlx, patch_tly = x * patch_size[0], y * patch_size[1] + patch_brx, patch_bry = patch_tlx + patch_size[ + 0], patch_tly + patch_size[1] + + aoi_tlx = tl_x if patch_tlx < tl_x else patch_tlx + aoi_tly = tl_y if patch_tly < tl_y else patch_tly + aoi_brx = br_x if patch_brx > br_x else patch_brx + aoi_bry = br_y if patch_bry > br_y else patch_bry + + aoi_ratio[0, y, x] = max((aoi_brx - aoi_tlx), 0) * max( + (aoi_bry - aoi_tly), 0) / float(patch_area) + + return aoi_ratio + +def multi_hot_coding(bbox, patch_size, patch_num): + """ + compute the overlaping ratio of the bbox and each patch (10x16) + """ + import numpy as np + thresh = 0.5 + aoi_ratio = calc_overlap_ratio(bbox, patch_size, patch_num) + hot_ind = aoi_ratio > thresh + while hot_ind.sum() == 0: + thresh *= 0.8 + hot_ind = aoi_ratio > thresh + + aoi_ratio[hot_ind] = 1 + aoi_ratio[np.logical_not(hot_ind)] = 0 + + return aoi_ratio[0] + +def pixels_to_bounded_range(xy_pix_max, xy_peak, xy_bounds=(-1, 1)): + import numpy as np + xy_pix_max = np.array(xy_pix_max) + xy_peak = np.array(xy_peak) + xy_peak = xy_peak / xy_pix_max + xy_peak = (xy_bounds[1] - xy_bounds[0]) * xy_peak - xy_bounds[1] + return xy_peak + + +def cartesian_to_spherical(xyz): + import numpy as np + ptr = np.zeros((3,)) + xy = xyz[0] ** 2 + xyz[1] ** 2 + ptr[0] = np.arctan2(xyz[1], xyz[0]) + ptr[1] = np.arctan2(xyz[2], np.sqrt(xy)) # for elevation angle defined from XY-plane up + # ptr[1] = np.arctan2(np.sqrt(xy), xyz[2]) # for elevation angle defined from Z-axis down + ptr[2] = np.sqrt(xy + xyz[2] ** 2) + return ptr + + +def spherical_to_euler(pt): + import numpy as np + import math as m + v1 = np.array([0, 0, -1]) + v2 = np.array([np.sin(pt[1]) * np.cos(pt[0]), np.sin(pt[1]) * np.sin(pt[0]), np.cos(pt[0])]) + Z = np.cross(v1, v2) + Z /= np.sqrt(Z[0] ** 2 + Z[1] ** 2 + Z[2] ** 2) + Y = np.cross(Z, v1) + t = m.atan2(-Z[0], Z[1]) + p = m.asin(Z[0]) + psi = m.atan2(-Y[0], v1[0]) + return np.array([t, p, psi]) + + +def foveal_to_mask(xy, radius, width, height): + import numpy as np + Y, X = np.ogrid[:height, :width] + dist = np.sqrt((X - xy[0]) ** 2 + (Y - xy[1]) ** 2) + mask = dist <= radius + return mask.astype(np.float32) + + +def mp_multivariate_gaussian(entries, width, height, xy_std=(10, 10)): + import numpy as np + from joblib import Parallel, delayed + import time + + def multivariate_gaussian(x, y, xy_mean, width, height, xy_std, amplitude=64): + if np.isnan(xy_mean[0]) == False and np.isnan(xy_mean[1]) == False: + x0 = xy_mean[0] + y0 = xy_mean[1] + # now = time.time() + result = amplitude * np.exp( + -((((x - x0) ** 2) / (2 * xy_std[0] ** 2)) + (((y - y0) ** 2) / (2 * xy_std[1] ** 2)))) + # print('gaussian time:', time.time() - now) + return result + else: + return np.zeros((height, width)) + + + # std_x = np.std(xyfix[:, 0]) / 10 + # std_y = np.std(xyfix[:, 1]) / 10 + x = np.arange(0, width, 1, float) + y = np.arange(0, height, 1, float) + x, y = np.meshgrid(x, y, copy=False, sparse=True) + results = Parallel(n_jobs=4, prefer="threads")(delayed(multivariate_gaussian)(x, y, (entries[i, 0], entries[i, 1],), + width, height, xy_std, + amplitude = entries[i, 2]) + for i in range(entries.shape[0])) + result = np.sum(results, axis=0) + return result + + +# based on: https://stackoverflow.com/a/39823124/190597 (astrokeat) +def truncated_cone(xyz_orig, xyz_tgt, radius_orig, radius_tgt): + from scipy.linalg import norm + import numpy as np + # vector in direction of axis + v = xyz_tgt - xyz_orig + # find magnitude of vector + mag = norm(v) + # unit vector in direction of axis + v = v / mag + # make some vector not in the same direction as v + not_v = np.array([1, 1, 0]) + if (v == not_v).all(): + not_v = np.array([0, 1, 0]) + # make vector perpendicular to v + n1 = np.cross(v, not_v) + # print n1,'\t',norm(n1) + # normalize n1 + n1 /= norm(n1) + # make unit vector perpendicular to v and n1 + n2 = np.cross(v, n1) + # surface ranges over t from 0 to length of axis and 0 to 2*pi + n = 100 + t = np.linspace(0, mag, n) + theta = np.linspace(0, 2 * np.pi, n) + # use meshgrid to make 2d arrays + t, theta = np.meshgrid(t, theta) + r = np.linspace(radius_orig, radius_tgt, n) + # generate coordinates for surface + x, y, z = [xyz_orig[i] + v[i] * t + r * + np.sin(theta) * n1[i] + r * np.cos(theta) * n2[i] for i in [0, 1, 2]] + x = np.reshape(x, -1) + y = np.reshape(y, -1) + z = np.reshape(z, -1) + x,y,z = x.astype(np.int64), y.astype(np.int64), z.astype(np.int64) + return x, y, z + + +def conic_projection(xyz_orig, xyz_tgt, width, height, radius_orig=1, radius_tgt=10): + import numpy as np + from scipy.interpolate import griddata + from scipy import ndimage + + x, y, z = truncated_cone(xyz_orig, xyz_tgt, radius_orig, radius_tgt) + ############################### + xi = np.linspace(0, width, width) + yi = np.linspace(0, height, height) + zi = griddata((x, y), z, (xi[None, :], yi[:, None]), method="nearest") + p2i = ndimage.gaussian_filter(zi, sigma=6) + return p2i \ No newline at end of file diff --git a/gazenet/utils/registrar.py b/gazenet/utils/registrar.py new file mode 100644 index 0000000..9a5eadc --- /dev/null +++ b/gazenet/utils/registrar.py @@ -0,0 +1,197 @@ +import os +from glob import glob + +from gazenet.utils.helpers import dynamic_module_import + + +class RobotControllerRegistrar(object): + registry = {} + + @staticmethod + def register(cls): + RobotControllerRegistrar.registry[cls.__name__] = cls + return cls + + @staticmethod + def scan(): + modules = glob(os.path.join(os.path.dirname(__file__), "..", "robots", "**", "controller.py"), recursive=True) + modules = ["gazenet.robots." + module.replace(os.path.dirname(__file__) + "/../robots/", "") for module in modules] + dynamic_module_import(modules, globals()) + + +# A pytorch lightning module +class ModelRegistrar(object): + registry = {} + + @staticmethod + def register(cls): + ModelRegistrar.registry[cls.__name__] = cls + return cls + + @staticmethod + def scan(): + modules = glob(os.path.join(os.path.dirname(__file__), "..", "models", "**", "model.py"), recursive=True) + modules = ["gazenet.models." + module.replace(os.path.dirname(__file__) + "/../models/", "") for module in modules] + dynamic_module_import(modules, globals()) + + +# A pytorch lightning data module +class ModelDataRegistrar(object): + registry = {} + + @staticmethod + def register(cls): + ModelDataRegistrar.registry[cls.__name__] = cls + return cls + + @staticmethod + def scan(): + modules = glob(os.path.join(os.path.dirname(__file__), "..", "models", "**", "generator.py"), recursive=True) + modules = ["gazenet.models." + module.replace(os.path.dirname(__file__) + "/../models/", "") for module in modules] + dynamic_module_import(modules, globals()) + + +class MetricsRegistrar(object): + registry = {} + + @staticmethod + def register(cls): + MetricsRegistrar.registry[cls.__name__] = cls + return cls + + @staticmethod + def scan(): + modules = glob(os.path.join(os.path.dirname(__file__), "..", "models", "**", "metrics.py"), recursive=True) + modules = ["gazenet.models." + module.replace(os.path.dirname(__file__) + "/../models/", "") for module in modules] + dynamic_module_import(modules, globals()) + + +# An InferenceSampleProcessor inheriting class +class InferenceRegistrar(object): + registry = {} + + @staticmethod + def register(cls): + InferenceRegistrar.registry[cls.__name__] = cls + return cls + + @staticmethod + def scan(): + modules = glob(os.path.join(os.path.dirname(__file__), "..", "models", "**", "infer.py"), recursive=True) + modules = ["gazenet.models." + module.replace(os.path.dirname(__file__) + "/../models/", "") for module in modules] + dynamic_module_import(modules, globals()) + + +# A SampleReader inheriting class +class ReaderRegistrar(object): + registry = {} + + @staticmethod + def register(cls): + ReaderRegistrar.registry[cls.__name__] = cls + return cls + + @staticmethod + def scan(): + modules = glob(os.path.join(os.path.dirname(__file__), "..", "readers", "*.py"), recursive=True) + modules = ["gazenet.readers." + module.replace(os.path.dirname(__file__) + "/../readers/", "") for module in modules] + dynamic_module_import(modules, globals()) + # add the data reader as well + modules = glob(os.path.join(os.path.dirname(__file__), "dataset_processors.py"), recursive=False) + modules = ["gazenet.utils." + modules[0].replace(os.path.dirname(__file__) + "/", "")] + dynamic_module_import(modules, globals()) + + +# A SampleProcessor inheriting class +class SampleRegistrar(object): + registry = {} + + @staticmethod + def register(cls): + SampleRegistrar.registry[cls.__name__] = cls + return cls + + @staticmethod + def scan(): + modules = glob(os.path.join(os.path.dirname(__file__), "..", "readers", "*.py"), recursive=True) + modules = ["gazenet.readers." + module.replace(os.path.dirname(__file__) + "/../readers/", "") for module in modules] + dynamic_module_import(modules, globals()) + # add the data sample as well + modules = glob(os.path.join(os.path.dirname(__file__), "dataset_processors.py"), recursive=False) + modules = ["gazenet.utils." + modules[0].replace(os.path.dirname(__file__) + "/", "")] + dynamic_module_import(modules, globals()) + + +class InferenceConfigRegistrar(object): + registry = {} + + @staticmethod + def register(cls): + InferenceConfigRegistrar.registry[cls.__name__] = cls + return cls + + @staticmethod + def scan(): + modules = glob(os.path.join(os.path.dirname(__file__), "..", "configs", "*.py"), recursive=True) + modules = ["gazenet.configs." + module.replace(os.path.dirname(__file__) + "/../configs/", "") for module in modules] + dynamic_module_import(modules, globals()) + + +class TrainingConfigRegistrar(object): + registry = {} + + @staticmethod + def register(cls): + TrainingConfigRegistrar.registry[cls.__name__] = cls + return cls + + @staticmethod + def scan(): + modules = glob(os.path.join(os.path.dirname(__file__), "..", "configs", "*.py"), recursive=True) + modules = ["gazenet.configs." + module.replace(os.path.dirname(__file__) + "/../configs/", "") for module in modules] + dynamic_module_import(modules, globals()) + + +class PlotterRegistrar(object): + registry = {} + + @staticmethod + def register(cls): + PlotterRegistrar.registry[cls.__name__] = cls + return cls + + @staticmethod + def scan(): + modules = glob(os.path.join(os.path.dirname(__file__), "annotation_plotter.py"), recursive=False) + modules = ["gazenet.utils." + modules[0].replace(os.path.dirname(__file__) + "/", "")] + dynamic_module_import(modules, globals()) + + +class FaceDetectorRegistrar(object): + registry = {} + + @staticmethod + def register(cls): + FaceDetectorRegistrar.registry[cls.__name__] = cls + return cls + + @staticmethod + def scan(): + modules = glob(os.path.join(os.path.dirname(__file__), "face_detectors.py"), recursive=False) + modules = ["gazenet.utils." + modules[0].replace(os.path.dirname(__file__) + "/", "")] + dynamic_module_import(modules, globals()) + + +class AudioFeatureRegistrar(object): + registry = {} + + @staticmethod + def register(cls): + AudioFeatureRegistrar.registry[cls.__name__] = cls + return cls + + @staticmethod + def scan(): + modules = glob(os.path.join(os.path.dirname(__file__), "audio_features.py"), recursive=False) + modules = ["gazenet.utils." + modules[0].replace(os.path.dirname(__file__) + "/", "")] + dynamic_module_import(modules, globals()) \ No newline at end of file diff --git a/gazenet/utils/sample_processors.py b/gazenet/utils/sample_processors.py new file mode 100644 index 0000000..cd03b5f --- /dev/null +++ b/gazenet/utils/sample_processors.py @@ -0,0 +1,675 @@ +import threading +import pickle +import subprocess +import queue +import os +from pathlib import Path + +import cv2 +import numpy as np +import sounddevice as sd +import librosa + +from gazenet.utils.helpers import stack_images + +SERVER_MODE = False +REVIVE_RETRIES = 5 + +DEFAULT_SAMPLE_RATE = 16000 + +# audio capturer property ids +AUCAP_PROP_SAMPLE_RATE = 1001 +AUCAP_PROP_CHUNK_SIZE = 1002 +AUCAP_PROP_BUFFER_SIZE = 1003 +AUCAP_PROP_CHANNELS = 1004 +AUCAP_PROP_POS_FRAMES = 1005 +AUCAP_PROP_FRAME_COUNT = 1007 +# AUCAP_PROP_POS_MSEC = 1008 + + +class SampleReader(object): + """ + Handles the dataset reading process and provides a unified interface for + """ + def __init__(self, pickle_file, mode=None, **kwargs): + self.pickle_file = pickle_file + self.samples = [] + self.video_id_to_sample_idx = {} + self.len_frames = 0 + + if pickle_file is not None: + if mode == "r": # read + self.read() + + elif mode == "w": # write + self.read_raw() + self.write() + elif mode == "x": # safe write + if os.path.exists(pickle_file): + raise FileExistsError("Read mode 'x' safely writes a file. " + "Either delete the pickle_file or change the read mode") + self.read_raw() + self.write() + elif mode == "a": # append + self.read() + self.read_raw() + self.write() + elif mode == "d": # dynamic: if the pickle_file exists it will be read, otherwise, a new dataset is created + if os.path.exists(pickle_file): + self.read() + else: + self.read_raw() + self.write() + else: + self.read_raw() + else: + if mode is not None: + raise AttributeError("Specify the pickle_file attribute to make use of the mode") + self.read_raw() + + def write(self): + with open(self.pickle_file, 'wb') as f: + pickle.dump({"samples": self.samples, + "len_frames": self.len_frames, + "video_id_to_sample_idx": self.video_id_to_sample_idx, + "__name__": self.__class__.__name__}, f, pickle.HIGHEST_PROTOCOL) + + def read(self): + with open(self.pickle_file, "rb") as f: + data = pickle.load(f) + assert data["__name__"] == self.__class__.__name__, \ + "The pickle_file has a mismatching name. " \ + "Ensure the correct pickle_file is read" + + self.len_frames = data["len_frames"] + self.samples = data["samples"] + self.video_id_to_sample_idx = data["video_id_to_sample_idx"] + + def read_raw(self): + raise NotImplementedError("Not implemented for abstract Reader") + + def __len__(self): + return len(self.samples) + + +class ImageCapture(object): + """ + Loads all the image indices in a directory to the memory. When too many images are in a directory, use cv2.VideoCapture + instead, making sure the images follow the string format provided and are ordered sequentially + """ + def __init__(self, directory, extension="jpg", fps=1, sub_directories=False, image_file="captured_1", *args, **kwargs): + self.properties = { + cv2.CAP_PROP_POS_FRAMES: 0, + cv2.CAP_PROP_FPS: fps, + cv2.CAP_PROP_FRAME_COUNT: None, + cv2.CAP_PROP_FRAME_WIDTH: None, + cv2.CAP_PROP_FRAME_HEIGHT: None + } + self.directory = directory + self.sub_directories = True + + if sub_directories: + self.frames = [os.path.join(f, image_file+"."+extension) for f in os.listdir(directory)] + self.frames = sorted(self.frames, key=lambda x: float(x.split(os.sep, 1)[0])) + else: + self.frames = [f for f in os.listdir(directory) if f.endswith("."+extension)] + self.frames = sorted(self.frames, key=lambda x: float(x[:-(len(extension)+1)])) + self.set(cv2.CAP_PROP_FRAME_COUNT, len(self.frames)) + + file = os.path.join(self.directory, self.frames[0]) + im = cv2.imread(file) + h, w, c = im.shape + self.set(cv2.CAP_PROP_FRAME_WIDTH, w) + self.set(cv2.CAP_PROP_FRAME_HEIGHT, h) + + self.opened = True + + def read(self): + try: + frame_index = self.get(cv2.CAP_PROP_POS_FRAMES) + file = os.path.join(self.directory, self.frames[frame_index]) + im = cv2.imread(file) + self.opened = True + self.set(cv2.CAP_PROP_POS_FRAMES, frame_index+1) + return True, im + except: + self.opened = False + return False, None + + def isOpened(self): + return self.opened + + def release(self): + self.frames = [] + + def set(self, propId, value): + self.properties[propId] = value + + def get(self, propId): + return self.properties[propId] + + +class AudioCapture(object): + """ + Loads an entire audio file into memory (no buffering due to limited format support) and microphone (blocking) + """ + def __init__(self, source, buffer_size=30, rate=None, channels=1, len_frames=1): + self.properties = { + AUCAP_PROP_BUFFER_SIZE: buffer_size, + AUCAP_PROP_SAMPLE_RATE: rate, + AUCAP_PROP_CHANNELS: channels, + AUCAP_PROP_CHUNK_SIZE: None, + AUCAP_PROP_FRAME_COUNT: len_frames, + AUCAP_PROP_POS_FRAMES: 0, + } + + # capturing mode + if "." in source: # file + self.reader = self.__getfile__ + self.stream = librosa.load(source, sr=rate, duration=len_frames/buffer_size) + self.frame_indices, chunk_size = np.linspace(0, len(self.stream[0]), num=len_frames, retstep=True, endpoint=False, dtype=int) + self.set(AUCAP_PROP_CHUNK_SIZE, int(chunk_size)) + self.set(AUCAP_PROP_SAMPLE_RATE, self.stream[1]) + + else: # microphone + # TODO (fabawi): microphone reading is very choppy + self.reader = self.__getmic__ + if len_frames <= 0: + len_frames = 1 + if rate is None: + device_info = sd.query_devices(source, 'input') + rate = int(device_info['default_samplerate']) + self.set(AUCAP_PROP_SAMPLE_RATE, int(rate)) + chunk_size = int(rate * len_frames / buffer_size) + self.stream = sd.InputStream(device=source, + samplerate=rate, + channels=channels, + blocksize=chunk_size*buffer_size) + self.stream.start() + self.set(AUCAP_PROP_CHUNK_SIZE, chunk_size) + self.opened = True + self.state = -1 + self.read_lock = threading.Lock() + + def __getmic__(self): + frames = self.stream.read(self.get(AUCAP_PROP_CHUNK_SIZE)*self.get(AUCAP_PROP_BUFFER_SIZE)) + frames = np.array(np.split(frames[0], self.get(AUCAP_PROP_BUFFER_SIZE))) + return frames + pass + + def __getfile__(self): + curr_frame_idx = self.frame_indices[self.get(AUCAP_PROP_POS_FRAMES)] + frames = self.stream[0][curr_frame_idx: + curr_frame_idx + (self.get(AUCAP_PROP_BUFFER_SIZE)*self.get(AUCAP_PROP_CHUNK_SIZE))] + frames = np.array(np.split(frames, self.get(AUCAP_PROP_BUFFER_SIZE))) + return frames + + def read(self, *args, stateful=False, **kwargs): + # TODO (fabawi): being stateful causes issues with seeking. Also, the buffer size should be larger to avoid + # pauses (should be dynamic and loads the whole clip when reading from disk) + try: + if stateful: + with self.read_lock: + self.state += 1 + self.state %= self.get(AUCAP_PROP_BUFFER_SIZE) + if self.state != 0: + return False, None + + frame = self.reader() + with self.read_lock: + self.set(AUCAP_PROP_POS_FRAMES, + self.get(AUCAP_PROP_POS_FRAMES) + self.get(AUCAP_PROP_BUFFER_SIZE)) + self.opened = True + return True, frame + except: + with self.read_lock: + self.opened = False + return False, None + + def isOpened(self): + return self.opened + + def release(self): + try: + self.stream.stop() + except: + self.stream = None + with self.read_lock: + self.state = 0 + + def set(self, propId, value): + self.properties[propId] = value + + def get(self, propId): + return self.properties[propId] + + +class SampleProcessor(object): + """ + Handles video (images/audio) and processes it in a format suitable for render and display + """ + def __init__(self, width=None, height=None, enable_audio=True, + video_reader=(cv2.VideoCapture,{}), audio_reader=(AudioCapture, {}), w_size=1, **kwargs): + self.video_cap = None + self.audio_cap = None + self.width = width + self.height = height + self.enable_audio = enable_audio + self.video_reader = video_reader + self.audio_reader = audio_reader + self.w_size = w_size + self.video_out_path = '' + self.audio_out_path = '' + + self.grabbed_video, self.video_frame, self.info, self.properties = False, None, {}, {} + self.grabbed_audio, self.audio_frames = False, None + self.rt_index = 0 + self.started = False + self.buffer = queue.Queue() + self.read_lock = threading.Lock() + self.annotation_properties = {} + self.thread = None + + def load(self, metadata): + # create the names for the output files + self.video_out_path = os.path.dirname(os.path.join("gazenet", "readers", "visualization", "assets", "media", str(metadata["video_name"]))) + Path(self.video_out_path).mkdir(parents=True, exist_ok=True) + self.video_out_path = os.path.join(self.video_out_path, 'temp_vid_' + os.path.basename(str(metadata["video_name"])) + '.avi') + if self.enable_audio and metadata["has_audio"]: + self.audio_out_path = os.path.dirname(os.path.join("gazenet", "readers", "visualization", "assets", "media", str(metadata["audio_name"]))) + Path(self.audio_out_path).mkdir(parents=True, exist_ok=True) + self.audio_out_path = os.path.join(self.audio_out_path, 'temp_aud_' + os.path.basename(str(metadata["video_name"])) + '.wav') + + if self.video_reader is not None: + # setup the video capturer + if self.video_cap is not None: + self.video_cap.release() + video_properties = self.video_reader[1].copy() + self.video_cap = self.video_reader[0](metadata["video_name"], **video_properties) + + if self.enable_audio and metadata["has_audio"]: + if self.audio_reader is not None: + # setup the audio capturer + if self.audio_cap is not None: + self.audio_cap.release() + audio_properties = self.audio_reader[1].copy() + if "buffer_size" not in audio_properties: + audio_properties["buffer_size"] = int(self.video_cap.get(cv2.CAP_PROP_FPS)) + if "len_frames" not in audio_properties: + audio_properties["len_frames"] = int(self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT)) # set to total num video_frames_list + self.audio_cap = self.audio_reader[0](metadata["audio_name"], **audio_properties) + + def goto_frame(self, frame_index): + if self.video_cap is not None: + with self.read_lock: + try: + self.video_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) + except: + pass + if self.audio_cap is not None: + with self.read_lock: + try: + self.audio_cap.set(AUCAP_PROP_POS_FRAMES, int(frame_index)) + except: + pass + + def frames_per_sec(self): + # if self.enable_audio and self.audio_cap is not None: + # # return audio frame for granular precision + # return self.audio_cap.get(AUCAP_PROP_FPS) + if self.video_cap is not None: + return int(self.video_cap.get(cv2.CAP_PROP_FPS)) + else: + return 0 + + def frame_index(self): + # if self.enable_audio and self.audio_cap is not None: + # # return audio frame for granular precision + # return self.audio_cap.get(cv2.CAP_PROP_POS_FRAMES) + if self.video_cap is not None: + return int(self.video_cap.get(cv2.CAP_PROP_POS_FRAMES)) + else: + return 0 + + def len_frames(self): + # curr_sample = self.reader.samples[self.index] + # return curr_sample['len_frames'] + if self.video_cap is not None: + return int(self.video_cap.get(cv2.CAP_PROP_FRAME_COUNT)) + else: + return 0 + + def frame_size(self): + if self.video_cap is not None: + width = int(self.video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(self.video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + return width, height + else: + return 0, 0 + + def start(self, *args, **kwargs): + if self.started: + print('[!] Asynchronous video capturing has already been started.') + return None + with self.read_lock: + self.started = True + # start the video processing thread + self.thread = threading.Thread(target=self.update, kwargs=kwargs, args=(self.buffer, *args)) + self.thread.start() + return self + + def update(self, q, *args, **kwargs): + retries = REVIVE_RETRIES + while self.started: + try: + cmd = q.get(timeout=2) + if cmd == 'play': + # TODO (fabawi): this loops over all extracted grouped_video_frames. they should be returned as a list instead + preprocessed_data = self.preprocess_frames(*args, **kwargs) + if preprocessed_data is not None: + extracted_data_list = self.extract_frames(**preprocessed_data) + else: + extracted_data_list = self.extract_frames() + + for extracted_data in zip(*extracted_data_list): + if self.annotation_properties: + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties = \ + self.annotate_frame(extracted_data, *args, **self.annotation_properties) + else: + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties = \ + self.annotate_frame(extracted_data, *args, **kwargs) + + if grabbed_video: + if self.width is not None and self.height is not None: + video_frame = cv2.resize(stack_images(grouped_video_frames), (self.width, self.height)) + else: + video_frame = stack_images(grouped_video_frames) + else: + video_frame = None + + if grabbed_audio: + audio_frames = audio_frames.flatten() + else: + audio_frames = None + + with self.read_lock: + self.grabbed_video = grabbed_video + self.video_frame = video_frame + self.info = info + self.properties = properties + if self.enable_audio: + self.grabbed_audio = grabbed_audio + if self.grabbed_audio: + self.audio_frames = audio_frames + + elif cmd == 'pause': + # if self.audio_cap is not None: + # with self.read_lock: + # self.audio_cap["curr_frame"] = self.frame_index() + continue + elif cmd == 'stop': + with self.read_lock: + self.started = False + except queue.Empty: + retries -= 1 + if retries == 0: + with self.read_lock: + self.started = False + break + continue + + def read(self): + grabbed_video = self.grabbed_video + if self.video_frame is not None: + video_frame = self.video_frame.copy() + else: + video_frame = self.video_frame + info = self.info.copy() + properties = self.properties.copy() + + grabbed_audio = self.grabbed_audio + audio_frames = self.audio_frames + + return grabbed_video, video_frame, grabbed_audio, audio_frames, info, properties + + def stop(self): + if SERVER_MODE: + if self.video_cap is not None: + self.video_cap.release() + if self.audio_cap is not None: + self.audio_cap.release() + with self.read_lock: + self.started = False + if self.thread is not None: + self.thread.join() + print("Stopped the capture thread") + + def __exit__(self, exec_type, exc_value, traceback): + if self.video_cap is not None: + self.video_cap.release() + if self.audio_cap is not None: + self.audio_cap.release() + + def play(self): + with self.read_lock: + self.buffer.put('play') + + def pause(self): + with self.read_lock: + self.buffer.put('pause') + + def preprocess_frames(self, *args, **kwargs): + return None + + def postprocess_frames(self, grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, + info_list, properties_list, + keep_video=True, keep_audio=True, keep_info=True, keep_properties=True, + info_override=None, properties_override=None, plot_override=None, + keep_plot_frames_only=False, resize_frames=False, convert_plots_gray=False, + duplicate_audio_frames=False, + *args, **kwargs): + # TODO (fabawi): these may break some functionality. Make sure to externally handle None values + if not keep_video: + grouped_video_frames_list = [None] * len(grouped_video_frames_list) + if not keep_audio: + audio_frames_list = [None] * len(audio_frames_list) + if not keep_properties: + properties_list = [None] * len(properties_list) + if not keep_info: + info_list = [None] * len(info_list) + + for grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties in zip( + grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list,info_list, properties_list): + + if keep_video: + if plot_override is not None: + grouped_video_frames["PLOT"] = plot_override + if keep_plot_frames_only or resize_frames or convert_plots_gray: + del_frame_names = [] + keep_plot_names = [item for sublist in grouped_video_frames["PLOT"] for item in sublist] + ["PLOT"] + for plot_name, plot in grouped_video_frames.items(): + if keep_plot_frames_only: + if plot_name not in keep_plot_names: + grouped_video_frames[plot_name] = None + del_frame_names.append(plot_name) + if plot_name != "PLOT" and plot is not None: + if resize_frames: + if grabbed_video: + grouped_video_frames[plot_name] = cv2.resize(plot.copy(), (self.width, self.height)) + # else: + # grouped_video_frames[plot_name] = np.zeros((self.height, self.width, 3)) + if convert_plots_gray: + if grabbed_video: + grouped_video_frames[plot_name] = cv2.cvtColor(plot.copy(), cv2.COLOR_RGB2GRAY) + # else: + # grouped_video_frames[plot_name] = np.zeros((self.height, self.width, 1)) + + for del_frame_name in del_frame_names: + del grouped_video_frames[del_frame_name] + + if keep_audio: + if duplicate_audio_frames: + raise NotImplementedError + + if keep_info: + if info_override is not None: + # overrides names in list at the surface level of the dictionary + for override in info_override: + try: + del info[override] + except: + pass + if keep_properties: + if properties_override is not None: + # overrides names in list at the surface level of the dictionary + for override in properties_override: + try: + del properties[override] + except: + pass + + return grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, \ + info_list, properties_list + + def extract_frames(self, *args, extract_video=True, extract_audio=True, realtime_indexing=False, **kwargs): + grabbed_video_list = [] + grouped_video_frames_list = [] + audio_frames_list = [] + grabbed_audio_list = [] + info_list = [] + properties_list = [] + for w_idx in range(self.w_size): + if self.video_cap.isOpened() and extract_video: + grabbed_video, video_frame = self.video_cap.read() + else: + grabbed_video, video_frame = False, None + grouped_video_frames_list.append({"captured": video_frame}) + grabbed_video_list.append(grabbed_video) + if self.enable_audio and self.audio_cap.isOpened() and extract_audio: + grabbed_audio, audio_frames = self.audio_cap.read(*args, stateful=True, **kwargs) + else: + grabbed_audio, audio_frames = False, None + audio_frames_list.append(audio_frames) + grabbed_audio_list.append(grabbed_audio) + if realtime_indexing: + info_list.append({"frame_info": {"frame_id": self.rt_index}}) + with self.read_lock: + self.rt_index += 1 + else: + info_list.append({"frame_info": {"frame_id": self.frame_index()}}) + properties_list.append({}) + + return grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, info_list, properties_list + + def annotate_frames(self, input_data_list, plotter, *args, **kwargs): + grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list = [], [], [], [] + info_list, properties_list = [], [] + for extracted_data in zip(*input_data_list): + if self.annotation_properties: + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties = \ + self.annotate_frame(extracted_data, plotter, *args, **self.annotation_properties) + else: + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties = \ + self.annotate_frame(extracted_data, plotter, *args, **kwargs) + grabbed_video_list.append(grabbed_video) + grouped_video_frames_list.append(grouped_video_frames) + grabbed_audio_list.append(grabbed_audio) + audio_frames_list.append(audio_frames) + info_list.append(info) + properties_list.append(properties) + return grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, info_list, properties_list + + def annotate_frame(self, input_data, plotter, *args, **kwargs): + grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties = input_data + + grouped_video_frames = {"PLOT": [["captured"]], **grouped_video_frames} + return grabbed_video, grouped_video_frames, grabbed_audio, audio_frames, info, properties + + def set_annotation_properties(self, annotation_properties): + with self.read_lock: + self.annotation_properties = annotation_properties + + # def _write_audio_video(self): + # # write the audio to a file + # if self.audio_writer is not None: + # self.video_writer.release() + # self.audio_writer.release() + # # if self.enable_audio and self.audio_cap is not None: + # # librosa.output.write_wav(self.audio_out_path, *self.audio_cap['audio']) + # cmd = 'ffmpeg -y -i ' + \ + # self.audio_out_path + ' -r ' + \ + # str(self.frames_per_sec()) + ' -i ' + \ + # self.video_out_path + ' -filter:a aresample=async=1 -c:a flac -c:v copy ' + \ + # self.video_out_path + '.mkv' + # subprocess.call(cmd, shell=True) # "Muxing Done + # print('Muxing done') + # elif self.video_writer is not None: + # self.video_writer.release() + + +class InferenceSampleProcessor(SampleProcessor): + """ + Wraps inference classes to support integration into visualizers and include most functionality + supported by the video processor + """ + def __init__(self, width=None, height=None, w_size=1, **kwargs): + super().__init__(width=width, height=height, w_size=w_size, enable_audio=False, video_reader=None, audio_reader=None, + **kwargs) + + def infer_frame(self, *args, **kwargs): + raise NotImplementedError("Infer not defined in base class") + + def extract_frames(self, *args, **kwargs): + return self.infer_frame(*args, **kwargs) + + def preprocess_frames(self, grabbed_video_list, grouped_video_frames_list, + grabbed_audio_list, audio_frames_list, info_list, properties_list, **kwargs): + if kwargs: + features = {**kwargs} + else: + features = {} + grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, \ + info_list, properties_list = list(grabbed_video_list), list(grouped_video_frames_list), \ + list(grabbed_audio_list), list(audio_frames_list), \ + list(info_list), list(properties_list) + pad = 1 + if self.w_size > len(grabbed_video_list): + pad = self.w_size + 1 - len(grabbed_video_list) + lim = min(self.w_size - 1, len(grabbed_video_list) - 1) + features["preproc_pad_len"] = pad + features["preproc_lim_len"] = lim + features["grabbed_video_list"] = grabbed_video_list[:lim] + [grabbed_video_list[lim]] * pad + features["grouped_video_frames_list"] = grouped_video_frames_list[:lim] + [grouped_video_frames_list[lim]] * pad + + if grabbed_audio_list: + aud_lim = min(lim, len(grabbed_audio_list) - 1) + aud_pad = pad - (aud_lim-lim) + features["grabbed_audio_list"] = grabbed_audio_list[:aud_lim] + [grabbed_audio_list[aud_lim]] * aud_pad + features["audio_frames_list"] = audio_frames_list[:aud_lim] + [audio_frames_list[aud_lim]] * aud_pad + else: + features["grabbed_audio_list"] = grabbed_audio_list * pad + features["audio_frames_list"] = audio_frames_list * pad + + features["info_list"] = info_list[:lim] + [info_list[lim]] * pad + features["properties_list"] = properties_list[:lim] + [properties_list[lim]] * pad + + return features + + def _keep_extracted_frames_data(self, source_frames_idxs, grabbed_video_list, grouped_video_frames_list, + grabbed_audio_list, audio_frames_list, info_list, properties_list): + if source_frames_idxs is not None: + grabbed_video_list = [grabbed_video_list[i] for i in source_frames_idxs] + grouped_video_frames_list = [grouped_video_frames_list[i] for i in source_frames_idxs] + if grabbed_audio_list: + grabbed_audio_list = [grabbed_audio_list[i] for i in source_frames_idxs] + audio_frames_list = [audio_frames_list[min(len(audio_frames_list)-1,i)] for i in source_frames_idxs] + else: + grabbed_audio_list = [[]] * len(source_frames_idxs) + audio_frames_list = [[]] * len(source_frames_idxs) + info_list = [info_list[i] for i in source_frames_idxs] + properties_list = [properties_list[i] for i in source_frames_idxs] + return grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, \ + info_list, properties_list + else: + return grabbed_video_list, grouped_video_frames_list, grabbed_audio_list, audio_frames_list, \ + info_list, properties_list diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9c558e3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +. diff --git a/requirements_extras.txt b/requirements_extras.txt new file mode 100644 index 0000000..c8d7260 --- /dev/null +++ b/requirements_extras.txt @@ -0,0 +1,4 @@ +dash==1.11.0 +dash-bootstrap-components +ffpyplayer +# -e git+https://github.com/pytorch/accimage.git#egg=accimage diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..84e4718 --- /dev/null +++ b/setup.py @@ -0,0 +1,55 @@ +from setuptools import setup, find_packages + +setup( + name='GASP', + version='1.0.1', + packages=find_packages(), + include_package_data=True, + url='software.knowledge-technology.info#gasp', + license='MIT License', + author='Fares Abawi', + author_email='fares.abawi@uni-hamburg.de', + maintainer='Fares Abawi', + maintainer_email='fares.abawi@uni-hamburg.de', + description='Social cue integration for dynamic saliency prediction', + install_requires=['cython==0.29.1', + 'pandas==1.0.3', + 'opencv-python==4.2.0.34', + 'opencv-contrib-python==4.2.0.34', + 'matplotlib==3.2.1', + 'torch==1.7.1', + 'tqdm==4.46.0', + 'librosa==0.4.2', + 'cffi==1.14.0', + 'resampy==0.2.2', + 'sounddevice==0.3.15', + 'torchvision==0.8.2', + 'pytorch-lightning==1.3.3', + 'tensorboard==2.2.0', + 'numpy==1.18.3', + 'h5py==2.10.0', + 'scipy==1.4.1', + 'pillow==7.1.2', + 'urllib3==1.25.9', + 'numba==0.49.1', + 'scikit-image==0.17.2', + 'scikit-learn==0.22.2.post1', + 'face_recognition==1.3.0', + 'face-alignment==1.1.1', + 'facenet-pytorch==2.5.0', + 'comet_ml'], + entry_points = { + 'console_scripts': [ + 'gasp_train=gazenet.bin.train:main', + 'gasp_infer=gazenet.bin.infer:main', + 'gasp_download_manager=gazenet.bin.download_manager:main', + 'gasp_scripts=gazenet.bin.scripts:main', + ], + }, + exclude_package_data={ + "datasets": ["*.zip", "*.7z", "*.tar.gz", "*.ptb", "*.ptb.tar", "*.npy", "*.npz", "*.hd5", "*.txt", "*.jpg", "*.png", "*.gif", "*.avi", "*.mp4", "*.wav", "*.mp3"]}, + package_data={ + "": ["datasets/processed/center_bias.jpg", "datasets/processed/center_bias_bw.jpg"], + }, + data_files=[("datasets/processed/", ["datasets/processed/center_bias.jpg", "datasets/processed/center_bias_bw.jpg"])] + ) diff --git a/showcase/coutrot2_clip13_compressed.gif b/showcase/coutrot2_clip13_compressed.gif new file mode 100644 index 0000000..0e0a33e Binary files /dev/null and b/showcase/coutrot2_clip13_compressed.gif differ diff --git a/showcase/det_transformed_dave_coutrot1_clip48_compressed.gif b/showcase/det_transformed_dave_coutrot1_clip48_compressed.gif new file mode 100644 index 0000000..3646a8d Binary files /dev/null and b/showcase/det_transformed_dave_coutrot1_clip48_compressed.gif differ diff --git a/showcase/det_transformed_esr9_coutrot1_clip48_compressed.gif b/showcase/det_transformed_esr9_coutrot1_clip48_compressed.gif new file mode 100644 index 0000000..0a0f477 Binary files /dev/null and b/showcase/det_transformed_esr9_coutrot1_clip48_compressed.gif differ diff --git a/showcase/det_transformed_gaze360_coutrot1_clip48_compressed.gif b/showcase/det_transformed_gaze360_coutrot1_clip48_compressed.gif new file mode 100644 index 0000000..5e8c7ed Binary files /dev/null and b/showcase/det_transformed_gaze360_coutrot1_clip48_compressed.gif differ diff --git a/showcase/det_transformed_videogaze_coutrot1_clip48_compressed.gif b/showcase/det_transformed_videogaze_coutrot1_clip48_compressed.gif new file mode 100644 index 0000000..b721c9a Binary files /dev/null and b/showcase/det_transformed_videogaze_coutrot1_clip48_compressed.gif differ diff --git a/showcase/multimodalsaliency.png b/showcase/multimodalsaliency.png new file mode 100644 index 0000000..307271c Binary files /dev/null and b/showcase/multimodalsaliency.png differ