Skip to content

Evaluation

You can evaluate models train using salt over a test set. Test samples are loaded from structured numpy arrays stored in h5 files, as for training. After producing the evaluation file, you can make performance plots using puma.

Running the Test Loop#

To evaluate a trained model on a test file, use the salt test command.

salt test --config logs/<timestamp>/config.yaml --data.test_file path/to/test.h5

As in the above example, you need to specify the saved config from the training run. By default, the checkpoint with the lowest validation loss is used for training. You can specify a different checkpoint with the --ckpt_path argument.

When evaluating a model from a resumed training, you need to explicitly specify --ckpt_path.

When you resume training, you specify a --ckpt_path and this is saved with the model config. If you then run salt test on the resulting config without specifying a new --ckpt_path, this same checkpoint will we be evaluated. To instead evaluate on the desired checkpoint from the resumed training job, you should explicitly specify --ckpt_path again to overwrite the one that is already saved in the config.

If you still want to choose the best epoch automatically, use --ckpt_path null.

You also need to specify a path to the test file using --data.test_file. This should be a prepared umami test file, and the framework should extract the sample name and append this to the checkpint file basename. The result is saved as an h5 file in the ckpts/ dir.

You can use --data.num_test to set the number of samples to test on if you want to override the default value from the training config.

Only one GPU is supported for the test loop.

When testing, only a single GPU is supported. This is enforced by the framework, so if you try to use more than one device you will see a message Setting --trainer.devices=1

Output files are overwritten by default.

You can use --data.test_suff to append an additional suffix to the evaluation output file name.

Extra Evaluation Variables#

When evaluating a model, the jet and track variables included in the output file can be configured. The variables can be configured as follows within the PredictionWriter callback configuration in the base configuration file.

callbacks:
    - class_path: salt.callbacks.Checkpoint
      init_args:
        monitor_loss: val_jet_classification_loss
    - class_path: salt.callbacks.PredictionWriter
      init_args:
        write_tracks: False
        extra_vars:
          jets:
            - pt_btagJes
            - eta_btagJes
            - HadronConeExclTruthLabelID
            - n_tracks
            - n_truth_promptLepton
            tracks:
            - truthOriginLabel
            - truthVertexIndex

By default, only the jet quantities are evaluated to save time and space. If you want to study the track aux task performance, you need to specify write_tracks: True in the PredictionWriter callback configuration.

The full API for the PredictionWriter callback is found below.

salt.callbacks.PredictionWriter #

Bases: lightning.Callback

Write test outputs to h5 file.

This callback will write the outputs of the model to an h5 evaluation file. The outputs are produced by calling the run_inference method of each task. The output file is written to the same directory as the checkpoint file, and has the same name as the checkpoint file, but with the suffix __test_<sample><suffix>.h5. The file will contain one dataset for each input type, with the same name as the input type in the test file.

Parameters:

Name Type Description Default
write_tracks bool

If False, skip any tasks with "tracks" in input_name.

False
write_objects bool

If False, skip any tasks with input_name="objects" and outputs of the MaskDecoder. Default is False

False
half_precision bool

If true, write outputs at half precision

False
object_classes list

List of flavour names with the index corresponding to the label values. This is used to construct the global object classification probability output names.

None
extra_vars salt.stypes.Vars

Extra variables to write to file for each input type. If not specified for a given input type, all variables in the test file will be written.

None
Source code in salt/callbacks/predictionwriter.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def __init__(
    self,
    write_tracks: bool = False,
    write_objects: bool = False,
    half_precision: bool = False,
    object_classes: list | None = None,
    extra_vars: Vars | None = None,
) -> None:
    """Write test outputs to h5 file.

    This callback will write the outputs of the model to an h5 evaluation file. The outputs
    are produced by calling the `run_inference` method of each task. The output file
    is written to the same directory as the checkpoint file, and has the same name
    as the checkpoint file, but with the suffix `__test_<sample><suffix>.h5`. The file will
    contain one dataset for each input type, with the same name as the input type in the test
    file.

    Parameters
    ----------
    write_tracks : bool
        If False, skip any tasks with `"tracks" in input_name`.
    write_objects : bool
        If False, skip any tasks with `input_name="objects"` and outputs of the
        MaskDecoder. Default is False
    half_precision : bool
        If true, write outputs at half precision
    object_classes : list
        List of flavour names with the index corresponding to the label values. This is used
        to construct the global object classification probability output names.
    extra_vars : Vars
        Extra variables to write to file for each input type. If not specified for a given input
        type, all variables in the test file will be written.
    """
    super().__init__()
    if extra_vars is None:
        extra_vars = defaultdict(list)
    self.extra_vars = extra_vars
    self.write_tracks = write_tracks
    self.write_objects = write_objects
    self.half_precision = half_precision
    self.precision = "f2" if self.half_precision else "f4"
    self.object_classes = object_classes

Last update: December 4, 2023
Created: October 24, 2022