Intermediate coursework report can be found there.
The code was tested with python 3.8.13. The minimal set of packages is present in requirements.txt
. Installing all these can be done with:
python3 -m pip install --upgrade pip
python3 -m pip install --prefer-binary -r requirements.txt
The entry point to model training and evaluation is the script called run.py
. It has the following arguments:
argument | description |
---|---|
--checkpoint_name CHECKPOINT_NAME |
Name to identify a trained model instance with. If starting to train a new model, this will create a directory saved_models/CHECKPOINT_NAME in which the configuration and the states of the model at various epochs will be saved. |
--config CONFIG |
Configuration file to setup the model and training. This parameter is obligatory when starting to train a new model. This parameter should be omitted when resuming a previously stopped training process. Some pre-defined model configs can be found in models/configs . |
--gpu_num GPU_NUM |
GPU index to pick a single GPU on a multi-GPU machine. Provide an empty string ('' ) in case of running this code on a non-GPU machine. |
--prediction_only |
Boolean flag to switch to the model evaluation mode. Running with this option would produce evaluation plots like Fig.3 from Eur. Phys. J. C 81, 599 (2021). The plots will be saved under saved_models/CHECKPOINT_NAME/prediction_XXXXX , where XXXXX is the epoch number picked for evaluation (the latest one available). |
An example command to run model training:
python3 run.py --config models/configs/baseline.yaml --checkpoint_name test_run --gpu_num ''
An example command to run model evaluation:
python3 run_model_v4.py --checkpoint_name test_run --gpu_num '' --prediction_only
As the training goes, the model gets evaluated every save_every
epochs (as defined in the model config). The evaluation results are written in the TensorBoard format in the logs/CHECKPOINT_NAME
folder. Some simple quantities like the generator and discriminator losses are written every epoch. TensorFlow provides a tool — the TensorBoard server — to interactively monitor this information in a web browser. The TensorBoard server is included as a dependency in the requirements_minimal.txt
, so it should already be installed on your machine if you followed the instructions above.
In case you run everything on your local machine, it should be sufficient to just run:
tensorboard --logdir=logs/
This should start a server that's going to be accessible via http://localhost:6006/ in your browser locally. If you run everything on a remote machine accessed via SSH, you'll also need to make use of the SSH port forwarding to be able to acces the server on your local machine web browser. This can be done with the -L <LOCAL_PORT>:<HOST>:<PORT>
SSH option, which forwards all local connections to the <LOCAL_PORT>
to the <HOST>:<PORT>
from the remote machine. For example, you can make an SSH connection to your REMOTE_MACHINE
with:
ssh -L 4321:localhost:6006 username@REMOTE_MACHINE
after which opening http://localhost:4321/ in your browser locally would be forwarded through SSH as if you opened http://localhost:6006/ on REMOTE_MACHINE
. The port 6006 is the default TensorBoard port, but it can be configured to be any other using the --port
argument of the tensorboard
.
Once you configure (if necessary) the port forwarding, start the TensorBoard server and access it through the web browser, you should see a page with two tabs: SCALARS
and IMAGES
. The SCALARS
tab contains the generator and discriminator losses, along with a quantity called chi2
. This chi2
quantity is a sum of squared discrepancy-over-error terms, where discrepancies are calculated between the data and model prediction for the upper and lower bands in each bin of profiles like Fig.3 from Eur. Phys. J. C 81, 599 (2021) (excluding the amplitude profiles). The chi2
quantity is not technically a chi-squared due to the correlations between different terms, but it does reflect the overall agreement of the model (the lower chi2
the better). The IMAGES
tab should contain validation histograms and profiles and example responses generated.