ONNX Export
In order to use your trained model in Athena you need to export it to ONNX.
Model Export#
The onnx/to_onnx.py
python script handles the ONNX conversion process for you.
The script has several arguments, you can learn about them by running
to_onnx --help
At a minimum, you should specify the path to a checkpoint to convert and a name for the exported model. For example
to_onnx \
--ckpt_path logs/<timestamp>/ckpts/checkpoint.ckpt \
--name GN2vXX
If you don't specify a config path using --config
, the script will look for one in the parent dir of the specified --ckpt_path
.
The r22default
track selection is used by default
The track selection you specify must correspond to one of the options defined in trk_select_regexes
variable in
DataPrepUtilities.cxx
.
The selection you use must also match the selection applied in your training samples.
Track selection is applied when dumping using the TDD.
The current default FTAG selection is called r22default
, but you should take note of the changes described in
training-dataset-dumper!427
to make sure you are using the correct selection.
You can also optionally specify a different scale dict to the one in the training config, and a model name (by default this is salt
).
The model name is used to construct the output probability variable names in Athena.
Exporting a model trained with torch.compile()
.
If you trained your model with torch.compile()
, you need to repair your checkpoint before exporting.
You can do this by running the repair_ckpt.py
script:
repair_ckpt <path_to_checkpoint>
Athena Validation#
You may see some warnings during export, but the to_onnx
script will verify that there is a good level of compatability between the pytorch and ONNX model outputs, and that there are no nan
or 0
values in the output.
However, as a final check, you should verify the performance of your pytorch model against a version running from the TDD.
First, follow the instructions here to dump the scores of your export model. Please take note of the following considerations when comparing Athena and Python evaluated models:
- Models in Athena are evaluated with full precision inputs. Make sure to dump using the TDD at full precision (use the provided flag
--force-full-precision
). - Models evaluated in Python are limited to 40 input tracks, whereas models evaluated in Athena have no such limit.
Once you have evaluated your model using the TDD, you should use the resulting h5 file to run salt test
.
Be sure to run with --trainer.precision 32
.
Finally, you can then the compare_models
command to compare the scores of the two models.
compare_models \
--file_A tdd/output.h5 \
--tagger_A name \
--file_B salt/eval.h5 \
--tagger_B name
See compare_models.py -h
for more information.
What level of discrepancy is expected?
We usually ask for agreement within 1e-6
for the output probabilities, which is approximately floating point precision error.
If you see one or two jets with a discrepancy of 1e-5
, this is probably fine.
Common causes of more significant discrepancies are:
- Not dumping at full precision using the TDD (see above)
- Not running
salt test
with--trainer.precision 32
if you trained at lower precision. - Not writing out salt evaluation scores at full precision (see the
PredictionWriter
callback) - Enabling some runtime optimisaiton in pytorch (e.g. here)
Deploying FTAG Models in Athena#
Please see this page in the central FTAG documentation.
Viewing ONNX Model Metadata#
To view the metadata stored in an ONNX file, you can use
get_onnx_metadata path/to/model.onnx
Inside are the list of input features including normalisation values, and also the list of outputs and the model name.
A command with the same name is also available in Athena
After setting up Athena, you can also run a different get_onnx_metadata
command which has the same function.
Created: November 14, 2022