This directory contains the code to train the RRWNet model.
Configuration options and hyperparameters can be found in config.py
.
The data can be downloaded from the following link:
Place the data in the _Data/
directory under train/
.
To train the model, run the following commands:
# Activate the virtual environment
source ../venv/bin/activate
# Train the model
python3 train.py --dataset RITE-train --model RRWNet
The available datasets for training are RITE-train
and HRF-Karlsson-w1024
, while the available models are RRWNet
, RRWNetAll
, RRUNet
, WNet
, and UNet
. See the paper for more details.
Training logs and weights will be saved under the __training/
directory.
Once the model is trained, the predictions can be generated using the following command.
python3 get_predictions.py -p <path_to_the_trained_model> -i <path_to_the_images>
The predictions will be saved under the tests_predictions/
directory in the path specified by the -p
flag.
To evaluate the predictions, see eval/README.md.