Evaluation
You can evaluate models train using salt over a test set. Test samples are loaded from structured numpy arrays stored in h5 files, as for training. After producing the evaluation file, you can make performance plots using puma.
Running the Test Loop#
To evaluate a trained model on a test file, use the salt test
command.
salt test --config logs/<timestamp>/config.yaml --data.test_file path/to/test.h5
As in the above example, you need to specify the saved config from the training run.
By default, the checkpoint with the lowest validation loss is used for training.
You can specify a different checkpoint with the --ckpt_path
argument.
When evaluating a model from a resumed training, you need to explicitly specify --ckpt_path
.
When you resume training, you specify a --ckpt_path
and this is saved with the model config.
If you then run salt test
on the resulting config without specifying a new --ckpt_path
, this same checkpoint will
we be evaluated. To instead evaluate on the desired checkpoint from the resumed training job, you should explicitly
specify --ckpt_path
again to overwrite the one that is already saved in the config.
If you still want to choose the best epoch automatically, use --ckpt_path null
.
You also need to specify a path to the test file using --data.test_file
.
This should be a prepared umami test file, and the framework should extract
the sample name and append this to the checkpint file basename.
The result is saved as an h5 file in the ckpts/
dir.
You can use --data.num_test
to set the number of samples to test on if you want to
override the default value from the training config.
Only one GPU is supported for the test loop.
When testing, only a single GPU is supported.
This is enforced by the framework, so if you try to use more than one device you will see a message
Setting --trainer.devices=1
Output files are overwritten by default.
You can use --data.test_suff
to append an additional suffix to the evaluation output file name.
Extra Evaluation Variables#
When evaluating a model, the jet and track variables included in the output file can be configured.
The variables can be configured as follows within the PredictionWriter
callback configuration in the base configuration file.
callbacks:
- class_path: salt.callbacks.Checkpoint
init_args:
monitor_loss: val_jet_classification_loss
- class_path: salt.callbacks.PredictionWriter
init_args:
write_tracks: False
extra_vars:
jets:
- pt_btagJes
- eta_btagJes
- HadronConeExclTruthLabelID
- n_tracks
- n_truth_promptLepton
tracks:
- truthOriginLabel
- truthVertexIndex
By default, only the jet quantities are evaluated to save time and space.
If you want to study the track aux task performance, you need to specify write_tracks: True
in the PredictionWriter
callback configuration.
The full API for the PredictionWriter
callback is found below.
salt.callbacks.PredictionWriter
#
Bases: lightning.Callback
Write test outputs to h5 file.
This callback will write the outputs of the model to an h5 evaluation file. The outputs
are produced by calling the run_inference
method of each task. The output file
is written to the same directory as the checkpoint file, and has the same name
as the checkpoint file, but with the suffix __test_<sample><suffix>.h5
. The file will
contain one dataset for each input type, with the same name as the input type in the test
file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
write_tracks |
bool
|
If False, skip any tasks with |
False
|
write_objects |
bool
|
If False, skip any tasks with |
False
|
half_precision |
bool
|
If true, write outputs at half precision |
False
|
object_classes |
list
|
List of flavour names with the index corresponding to the label values. This is used to construct the global object classification probability output names. |
None
|
extra_vars |
salt.stypes.Vars
|
Extra variables to write to file for each input type. If not specified for a given input type, all variables in the test file will be written. |
None
|
Source code in salt/callbacks/predictionwriter.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
|
Created: October 24, 2022