Skip to content

Dataloading#

salt.data.SaltDataset #

Bases: torch.utils.data.Dataset

An efficient map-style dataset for loading data from an H5 file containing structured arrays.

Parameters:

Name Type Description Default
filename str | pathlib.Path

Input h5 filepath containing structured arrays

required
norm_dict str | pathlib.Path

Path to file containing normalisation parameters

required
variables salt.stypes.Vars

Input variables used in the forward pass for each input type

required
stage str

Stage of the training process

required
num int

Number of input samples to use. If -1, use all input samples

-1
labeller_config salt.utils.configs.LabellerConfig | None

Configuration to apply relabelling on-the-fly for jet classification

None
labels salt.stypes.Vars | None

List of required labels for each input type

None
mf_config salt.utils.configs.MaskformerConfig | None

Config for Maskformer matching, by default None

None
input_map dict[str, str] | None

Map names to the corresponding dataset names in the input h5 file. If not provided, the input names will be used as the dataset names.

None
num_inputs dict | None

Truncate the number of constituent inputs to this number, to speed up training

None
non_finite_to_num bool

Convert nans and infs to zeros when loading inputs

False
global_object str

Name of the global input object, as opposed to the constituent-level inputs

'jets'
parameters dict | None

Variables used to parameterise the network, by default None.

None
selections dict[str, list[str]] | None

Selections to apply to the input data, by default None.

None
ignore_finite_checks bool

Ignoring check for non-finite inputs.

False
recover_malformed bool

Converts to invalid tracks from malformed inputs in truthOriginLabel.

False
transforms list[collections.abc.Callable] | None

Transformations to apply to the data, by default None.

None

Raises:

Type Description
ValueError

if use_labeller is set to true but the classes for relabelling are not supplied.

Source code in salt/data/datasets.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def __init__(
    self,
    filename: str | Path,
    norm_dict: str | Path,
    variables: Vars,
    stage: str,
    num: int = -1,
    labeller_config: LabellerConfig | None = None,
    labels: Vars | None = None,
    mf_config: MaskformerConfig | None = None,
    input_map: dict[str, str] | None = None,
    num_inputs: dict | None = None,
    non_finite_to_num: bool = False,
    global_object: str = "jets",
    parameters: dict | None = None,
    selections: dict[str, list[str]] | None = None,
    ignore_finite_checks: bool = False,
    recover_malformed: bool = False,
    transforms: list[Callable] | None = None,
):
    super().__init__()
    # check labels have been configured
    self.labels = labels if labels is not None else {}

    # default input mapping: use input names as dataset names
    # allow only partial maps to be provided
    if input_map is None:
        input_map = {k: k for k in variables}

    if "global" in input_map:
        input_map["global"] = global_object

    if "parameters" in input_map:
        input_map["parameters"] = global_object

    self.input_map = input_map
    self.filename = Path(filename)
    self.file = h5py.File(self.filename, "r")
    self.num_inputs = num_inputs
    self.non_finite_to_num = non_finite_to_num
    self.global_object = global_object
    self.selections = selections
    self.selectors = {}
    self.transforms = transforms
    if self.selections:
        for key, value in self.selections.items():
            self.selectors[key] = TrackSelector(Cuts.from_list(value))

    # If MaskFormer matching is enabled, extract the relevent labels
    self.mf_config = deepcopy(mf_config)
    if self.mf_config:
        self.input_map["objects"] = self.mf_config.object.name

    self.variables = variables
    self.norm_dict = norm_dict
    self.parameters = parameters
    self.stage = stage
    self.labeller_config = labeller_config
    if labeller_config and labeller_config.use_labeller:
        self.labeller = Labeller(labeller_config.class_names, labeller_config.require_labels)
    else:
        self.labeller = None
    self.rng = np.random.default_rng()

    # check that num_inputs contains valid keys
    if self.num_inputs is not None and not set(self.num_inputs).issubset(self.variables):
        raise ValueError(
            f"num_inputs keys {self.num_inputs.keys()} must be a subset of input variables"
            f" {self.variables.keys()}"
        )

    self.check_file()

    self.input_variables = variables
    assert self.input_variables is not None

    # check parameters listed in variables appear in the same order in the parameters block
    if "parameters" in self.input_variables:
        assert self.parameters is not None
        assert self.input_variables["parameters"] is not None
        assert len(self.input_variables["parameters"]) == len(self.parameters)
        for idx, param_key in enumerate(self.parameters.keys()):
            assert self.input_variables["parameters"][idx] == param_key

    # set number of objects
    self.num = self.get_num(num)
    self.ignore_finite_checks = ignore_finite_checks
    self.recover_malformed = recover_malformed

    self._is_setup = False

salt.data.SaltDataModule #

Bases: lightning.LightningDataModule

Datamodule wrapping a salt.data.SaltDataset for training, validation and testing.

This datamodule will load data from h5 files. The training, validation and test files are specified by the train_file, val_file and test_file arguments.

The arguments of this class can be set from the YAML config file or from the command line using the data key. For example, to set the batch_size from the command line, use --data.batch_size=1000.

Parameters:

Name Type Description Default
train_file str | pathlib.Path

Training file path

required
val_file str | pathlib.Path

Validation file path

required
batch_size int

Number of samples to process in each training step

required
num_workers int

Number of CPU worker processes to load batches from disk

required
num_train int

Total number of training samples

required
num_val int

Total number of validation samples

required
num_test int

Total number of testing samples

required
move_files_temp str | None

Directory to move training files to, by default None, which will result in no copying of files

None
class_dict str | None

Path to umami preprocessing scale dict file, by default None

None
test_file str | None

Test file path, by default None

None
test_suff str | None

Test file suffix, by default None

None
pin_memory bool

Pin memory for faster GPU transfer, by default True

True
config_s3 dict | None

Some parameters for the S3 access, by default None

None
**kwargs

Keyword arguments for salt.data.SaltDataset

{}
Source code in salt/data/datamodules.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def __init__(
    self,
    train_file: str | Path,
    val_file: str | Path,
    batch_size: int,
    num_workers: int,
    num_train: int,
    num_val: int,
    num_test: int,
    move_files_temp: str | None = None,
    class_dict: str | None = None,
    test_file: str | None = None,
    test_suff: str | None = None,
    pin_memory: bool = True,
    config_s3: dict | None = None,
    **kwargs,
):
    super().__init__()
    self.train_file = Path(train_file)
    self.val_file = Path(val_file)
    self.test_file = test_file
    self.test_suff = test_suff
    self.batch_size = batch_size
    self.num_workers = num_workers
    self.num_train = num_train
    self.num_val = num_val
    self.num_test = num_test
    self.class_dict = class_dict
    self.move_files_temp = move_files_temp
    self.pin_memory = pin_memory
    self.config_s3 = config_s3
    self.kwargs = kwargs

Last update: December 11, 2023
Created: October 20, 2023