Skip to content

Dataloading#

salt.data.SaltDataset #

Bases: torch.utils.data.Dataset

An efficient map-style dataset for loading data from an H5 file containing structured arrays.

Parameters:

Name Type Description Default
filename str

Input h5 filepath containing structured arrays

required
norm_dict str

Path to file containing normalisation parameters

required
variables salt.stypes.Vars

Input variables used in the forward pass for each input type

required
stage str

Stage of the training process

required
num int

Number of input samples to use. If -1, use all input samples

-1
labels salt.stypes.Vars

List of required labels for each input type

None
mf_config salt.utils.configs.MaskformerConfig

Config for Maskformer matching, by default None

None
input_map dict

Map names to the corresponding dataset names in the input h5 file. If not provided, the input names will be used as the dataset names.

None
num_inputs dict

Truncate the number of constituent inputs to this number, to speed up training

None
nan_to_num bool

Convert nans to zeros when loading inputs

False
global_object str

Name of the global input object, as opposed to the constituent-level inputs

'jets'
PARAMETERS dict | None

Variables used to parameterise the network, by default None.

None
Source code in salt/data/datasets.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def __init__(
    self,
    filename: str,
    norm_dict: str,
    variables: Vars,
    stage: str,
    num: int = -1,
    labels: Vars = None,
    mf_config: MaskformerConfig | None = None,
    input_map: dict[str, str] | None = None,
    num_inputs: dict | None = None,
    nan_to_num: bool = False,
    global_object: str = "jets",
    PARAMETERS: dict | None = None,
):
    """An efficient map-style dataset for loading data from an H5 file containing structured
    arrays.

    Parameters
    ----------
    filename : str
        Input h5 filepath containing structured arrays
    norm_dict : str
        Path to file containing normalisation parameters
    variables : Vars
        Input variables used in the forward pass for each input type
    stage : str
        Stage of the training process
    num : int, optional
        Number of input samples to use. If `-1`, use all input samples
    labels : Vars
        List of required labels for each input type
    mf_config : MaskformerConfig, optional
        Config for Maskformer matching, by default None
    input_map : dict, optional
        Map names to the corresponding dataset names in the input h5 file.
        If not provided, the input names will be used as the dataset names.
    num_inputs : dict, optional
        Truncate the number of constituent inputs to this number, to speed up training
    nan_to_num : bool, optional
        Convert nans to zeros when loading inputs
    global_object : str
        Name of the global input object, as opposed to the constituent-level
        inputs
    PARAMETERS: dict
        Variables used to parameterise the network, by default None.
    """
    super().__init__()
    # check labels have been configured
    self.labels = labels if labels is not None else {}

    # default input mapping: use input names as dataset names
    if input_map is None:
        input_map = {k: k for k in variables}

    if "GLOBAL" in input_map:
        input_map["GLOBAL"] = global_object

    if "PARAMETERS" in input_map:
        input_map["PARAMETERS"] = global_object

    self.input_map = input_map
    self.filename = filename
    self.file = h5py.File(self.filename, "r")
    self.num_inputs = num_inputs
    self.nan_to_num = nan_to_num
    self.global_object = global_object

    # If MaskFormer matching is enabled, extract the relevent labels
    self.mf_config = deepcopy(mf_config)
    if self.mf_config:
        self.input_map["objects"] = self.mf_config.object.name

    self.variables = variables
    self.norm_dict = norm_dict
    self.PARAMETERS = PARAMETERS
    self.stage = stage
    self.rng = np.random.default_rng()

    # check that num_inputs contains valid keys
    if self.num_inputs is not None and not set(self.num_inputs).issubset(self.variables):
        raise ValueError(
            f"num_inputs keys {self.num_inputs.keys()} must be a subset of input variables"
            f" {self.variables.keys()}"
        )

    self.check_file()

    self.input_variables = variables
    assert self.input_variables is not None

    # check parameters listed in variables appear in the same order in the PARAMETERS block
    if "PARAMETERS" in self.input_variables:
        assert self.PARAMETERS is not None
        assert self.input_variables["PARAMETERS"] is not None
        assert len(self.input_variables["PARAMETERS"]) == len(self.PARAMETERS)
        for idx, param_key in enumerate(self.PARAMETERS.keys()):
            assert self.input_variables["PARAMETERS"][idx] == param_key

    # setup datasets and accessor arrays
    self.dss = {}
    self.arrays = {}
    for internal, external in self.input_map.items():
        self.dss[internal] = self.file[external]
        this_vars = self.labels[internal].copy() if internal in self.labels else []
        this_vars += self.input_variables.get(internal, [])
        if internal == "EDGE":
            dtype = get_dtype_edge(self.file[external], this_vars)
        else:
            dtype = get_dtype(self.file[external], this_vars)
        self.arrays[internal] = np.array(0, dtype=dtype)
    if self.global_object not in self.dss:
        self.dss[self.global_object] = self.file[self.global_object]

    # set number of objects
    self.num = self.get_num(num)

salt.data.SaltDataModule #

Bases: lightning.LightningDataModule

Datamodule wrapping a salt.data.SaltDataset for training, validation and testing.

This datamodule will load data from h5 files. The training, validation and test files are specified by the train_file, val_file and test_file arguments.

The arguments of this class can be set from the YAML config file or from the command line using the data key. For example, to set the batch_size from the command line, use --data.batch_size=1000.

Parameters:

Name Type Description Default
train_file str

Training file path

required
val_file str

Validation file path

required
batch_size int

Number of samples to process in each training step

required
num_workers int

Number of CPU worker processes to load batches from disk

required
num_train int

Total number of training samples

required
num_val int

Total number of validation samples

required
num_test int

Total number of testing samples

required
move_files_temp str

Directory to move training files to, default is None, which will result in no copying of files

None
class_dict str

Path to umami preprocessing scale dict file

None
test_file str

Test file path, default is None

None
test_suff str

Test file suffix, default is None

None
pin_memory bool

Pin memory for faster GPU transfer, default is True

True
config_S3 dict | None

Some parameters for the S3 access

None
**kwargs

Keyword arguments for salt.data.SaltDataset

{}
Source code in salt/data/datamodules.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def __init__(
    self,
    train_file: str,
    val_file: str,
    batch_size: int,
    num_workers: int,
    num_train: int,
    num_val: int,
    num_test: int,
    move_files_temp: str | None = None,
    class_dict: str | None = None,
    test_file: str | None = None,
    test_suff: str | None = None,
    pin_memory: bool = True,
    config_S3: dict | None = None,
    **kwargs,
):
    """Datamodule wrapping a [`salt.data.SaltDataset`][salt.data.SaltDataset] for training,
    validation and testing.

    This datamodule will load data from h5 files. The training, validation and test files
    are specified by the `train_file`, `val_file` and `test_file` arguments.

    The arguments of this class can be set from the YAML config file or from the command line
    using the `data` key. For example, to set the `batch_size` from the command line, use
    `--data.batch_size=1000`.

    Parameters
    ----------
    train_file : str
        Training file path
    val_file : str
        Validation file path
    batch_size : int
        Number of samples to process in each training step
    num_workers : int
        Number of CPU worker processes to load batches from disk
    num_train : int
        Total number of training samples
    num_val : int
        Total number of validation samples
    num_test : int
        Total number of testing samples
    move_files_temp : str
        Directory to move training files to, default is None,
        which will result in no copying of files
    class_dict : str
        Path to umami preprocessing scale dict file
    test_file : str
        Test file path, default is None
    test_suff : str
        Test file suffix, default is None
    pin_memory: bool
        Pin memory for faster GPU transfer, default is True
    config_S3: dict, optional
        Some parameters for the S3 access
    **kwargs
        Keyword arguments for [`salt.data.SaltDataset`][salt.data.SaltDataset]
    """
    super().__init__()
    self.train_file = train_file
    self.val_file = val_file
    self.test_file = test_file
    self.test_suff = test_suff
    self.batch_size = batch_size
    self.num_workers = num_workers
    self.num_train = num_train
    self.num_val = num_val
    self.num_test = num_test
    self.class_dict = class_dict
    self.move_files_temp = move_files_temp
    self.pin_memory = pin_memory
    self.config_S3 = config_S3
    self.kwargs = kwargs

Last update: December 11, 2023
Created: October 20, 2023