Skip to content

Models

Components#

salt.models.Dense #

Bases: torch.nn.Module

A fully connected feed forward neural network, which can take in additional contextual information.

Parameters:

Name Type Description Default
input_size int

Input size

required
output_size int

Output size. If not specified this will be the same as the input size.

None
hidden_layers list

Number of nodes per layer, if not specified, the network will have a single hidden layer with size input_size * hidden_dim_scale.

None
hidden_dim_scale int

Scale factor for the hidden layer size.

2
activation str

Activation function for hidden layers. Must be a valid torch.nn activation function.

'ReLU'
final_activation str

Activation function for the output layer. Must be a valid torch.nn activation function.

None
dropout float

Apply dropout with the supplied probability.

0.0
bias bool

Whether to use bias in the linear layers.

True
context_size int

Size of the context tensor, 0 means no context information is provided.

0
muP bool

Whether to use the muP parametrisation (impacts initialisation).

False
Source code in salt/models/dense.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def __init__(
    self,
    input_size: int,
    output_size: int | None = None,
    hidden_layers: list[int] | None = None,
    hidden_dim_scale: int = 2,
    activation: str = "ReLU",
    final_activation: str | None = None,
    dropout: float = 0.0,
    bias: bool = True,
    context_size: int = 0,
    muP: bool = False,
) -> None:
    """A fully connected feed forward neural network, which can take
    in additional contextual information.

    Parameters
    ----------
    input_size : int
        Input size
    output_size : int
        Output size. If not specified this will be the same as the input size.
    hidden_layers : list, optional
        Number of nodes per layer, if not specified, the network will have
        a single hidden layer with size `input_size * hidden_dim_scale`.
    hidden_dim_scale : int, optional
        Scale factor for the hidden layer size.
    activation : str
        Activation function for hidden layers.
        Must be a valid torch.nn activation function.
    final_activation : str, optional
        Activation function for the output layer.
        Must be a valid torch.nn activation function.
    dropout : float, optional
        Apply dropout with the supplied probability.
    bias : bool, optional
        Whether to use bias in the linear layers.
    context_size : int
        Size of the context tensor, 0 means no context information is provided.
    muP: bool, optional,
        Whether to use the muP parametrisation (impacts initialisation).
    """
    super().__init__()

    if output_size is None:
        output_size = input_size
    if hidden_layers is None:
        hidden_layers = [input_size * hidden_dim_scale]

    # Save the networks input and output sizes
    self.input_size = input_size
    self.output_size = output_size
    self.context_size = context_size
    self.muP = muP

    # build nodelist
    self.node_list = [input_size + context_size, *hidden_layers, output_size]

    # input and hidden layers
    layers = []

    num_layers = len(self.node_list) - 1
    for i in range(num_layers):
        if dropout:
            layers.append(nn.Dropout(dropout))

        # linear projection
        layers.append(nn.Linear(self.node_list[i], self.node_list[i + 1], bias=bias))

        # activation for all but the final layer
        if i != num_layers - 1:
            layers.append(getattr(nn, activation)())

        # final layer: return logits by default, or activation if specified
        elif final_activation:
            layers.append(getattr(nn, final_activation)())

    # build the net
    self.net = nn.Sequential(*layers)

    if self.muP:
        self._reset_parameters()

Models#

salt.models.SaltModel #

Bases: torch.nn.Module

A generic multi-modal, multi-task neural network.

This model can be used to implement a wide range of models, including DL1, DIPS, GN2 and more.

Parameters:

Name Type Description Default
init_nets list[dict]

Keyword arguments for one or more initialisation networks. See salt.models.InitNet. Each initialisation network produces an initial input embedding for a single input type.

required
tasks torch.nn.ModuleList

Task heads, see salt.models.TaskBase. These can be used to implement object tagging, vertexing, regression, classification, etc.

required
encoder torch.nn.Module

Main input encoder, which takes the output of the initialisation networks and produces a single embedding for each constituent. If not provided this model is essentially a DeepSets model.

None
mask_decoder torch.nn.Module

Mask decoder, which takes the output of the encoder and produces a series of learned embeddings to represent object masks

None
pool_net torch.nn.Module

Pooling network which computes a global representation of the object by aggregating over the constituents. If not provided, assume that the only inputs are global features (i.e. no constituents).

None
num_register_tokens int

Number of randomly initialised register tokens of the same length as any other input sequences after initialiser networks (e.g. tracks). See https://arxiv.org/abs/2309.16588.

0
merge_dict dict[str, list[str]] | None

A dictionary that lets the salt concatenate all the input representations of the inputs in list[str] and act on them in following layers (e.g. transformer or tasks) as if they are coming from one input type

None
featurewise_nets list[dict]

Keyword arguments for featurewise transformation networks that perform featurewise scaling and biasing.

None
Source code in salt/models/saltmodel.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def __init__(
    self,
    init_nets: list[dict],
    tasks: nn.ModuleList,
    encoder: nn.Module = None,
    mask_decoder: nn.Module = None,
    pool_net: Pooling = None,
    num_register_tokens: int = 0,
    merge_dict: dict[str, list[str]] | None = None,
    featurewise_nets: list[dict] | None = None,
):
    """A generic multi-modal, multi-task neural network.

    This model can be used to implement a wide range of models, including
    [DL1](https://ftag.docs.cern.ch/algorithms/taggers/dips/),
    [DIPS](https://ftag.docs.cern.ch/algorithms/taggers/dl1/),
    [GN2](https://ftag.docs.cern.ch/algorithms/taggers/GN2/)
    and more.

    Parameters
    ----------
    init_nets : list[dict]
        Keyword arguments for one or more initialisation networks.
        See [`salt.models.InitNet`][salt.models.InitNet].
        Each initialisation network produces an initial input embedding for
        a single input type.
    tasks : nn.ModuleList
        Task heads, see [`salt.models.TaskBase`][salt.models.TaskBase].
        These can be used to implement object tagging, vertexing, regression,
        classification, etc.
    encoder : nn.Module
        Main input encoder, which takes the output of the initialisation
        networks and produces a single embedding for each constituent.
        If not provided this model is essentially a DeepSets model.
    mask_decoder : nn.Module
        Mask decoder, which takes the output of the encoder and produces a
        series of learned embeddings to represent object masks
    pool_net : nn.Module
        Pooling network which computes a global representation of the object
        by aggregating over the constituents. If not provided, assume that
        the only inputs are global features (i.e. no constituents).
    num_register_tokens : int
        Number of randomly initialised register tokens of the same length as
        any other input sequences after initialiser networks (e.g. tracks).
        See https://arxiv.org/abs/2309.16588.
    merge_dict : dict[str, list[str]] | None
        A dictionary that lets the salt concatenate all the input
        representations of the inputs in list[str] and act on them
        in following layers (e.g. transformer or tasks) as if they
        are coming from one input type
    featurewise_nets : list[dict]
        Keyword arguments for featurewise transformation networks that perform
        featurewise scaling and biasing.
    """
    super().__init__()

    self.featurewise_nets = None
    if featurewise_nets:
        self.featurewise_nets = nn.ModuleList([
            FeaturewiseTransformation(**featurewise_net) for featurewise_net in featurewise_nets
        ])
    self.featurewise_nets_map = (
        {featurewise_net.layer: featurewise_net for featurewise_net in self.featurewise_nets}
        if self.featurewise_nets
        else {}
    )
    # if available, add featurewise net to init net config
    if "input" in self.featurewise_nets_map:
        for init_net in init_nets:
            init_net["featurewise"] = self.featurewise_nets_map["input"]

    self.init_nets = nn.ModuleList([InitNet(**init_net) for init_net in init_nets])
    self.tasks = tasks
    self.encoder = encoder
    self.mask_decoder = mask_decoder

    self.pool_net = pool_net
    self.merge_dict = merge_dict
    self.num_register_tokens = num_register_tokens

    # init register tokens
    if self.num_register_tokens and not self.encoder:
        raise ValueError("encoder must be set if num_register_tokens is set")
    if self.num_register_tokens and self.encoder:
        self.registers = torch.nn.Parameter(
            torch.normal(
                torch.zeros((self.num_register_tokens, self.encoder.embed_dim)), std=1e-4
            )
        )
        self.register_mask = torch.zeros(self.num_register_tokens, dtype=torch.bool)
        self.register_buffer("register_mask_buffer", self.register_mask)
    else:
        self.registers = None
        self.register_mask = None

    # checks for the global object only setup
    if self.pool_net is None:
        assert self.encoder is None, "pool_net must be set if encoder is set"
        assert len(self.init_nets) == 1, "pool_net must be set if more than one init_net is set"
        assert self.init_nets[0].input_name == self.init_nets[0].global_object

forward #

Forward pass through the SaltModel.

Don't call this method directy, instead use __call__.

Parameters:

Name Type Description Default
inputs salt.stypes.Tensors

Dict of input tensors for each modality. Each tensor is of shape (batch_size, num_inputs, input_size).

required
pad_masks salt.stypes.BoolTensors

Dict of input padding mask tensors for each modality. Each tensor is of shape (batch_size, num_inputs).

None
labels salt.stypes.Tensors

Nested dict of label tensors. The outer dict is keyed by input modality, the inner dict is keyed by label variable name. Each tensor is of shape (batch_size, num_inputs). If not specified, assume we are running model inference (i.e. no loss computation).

None

Returns:

Name Type Description
preds salt.stypes.NestedTensors

Dict of model predictions for each task, separated by input modality. Tensors have varying shapes depending on the task.

loss salt.stypes.Tensors

Dict of losses for each task, aggregated over the batch.

Source code in salt/models/saltmodel.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def forward(
    self,
    inputs: Tensors,
    pad_masks: BoolTensors | None = None,
    labels: NestedTensors | None = None,
) -> tuple[NestedTensors, Tensors]:
    """Forward pass through the `SaltModel`.

    Don't call this method directy, instead use `__call__`.

    Parameters
    ----------
    inputs : Tensors
        Dict of input tensors for each modality. Each tensor is of shape
        `(batch_size, num_inputs, input_size)`.
    pad_masks : BoolTensors
        Dict of input padding mask tensors for each modality. Each tensor is of
        shape `(batch_size, num_inputs)`.
    labels : Tensors, optional
        Nested dict of label tensors. The outer dict is keyed by input modality,
        the inner dict is keyed by label variable name. Each tensor is of shape
        `(batch_size, num_inputs)`. If not specified, assume we are running model
        inference (i.e. no loss computation).

    Returns
    -------
    preds : NestedTensors
        Dict of model predictions for each task, separated by input modality.
        Tensors have varying shapes depending on the task.
    loss : Tensors
        Dict of losses for each task, aggregated over the batch.
    """
    # initial input projections
    xs = {}

    for init_net in self.init_nets:
        xs[init_net.input_name] = init_net(inputs)

    if self.num_register_tokens:
        batch_size = xs[next(iter(xs))].shape[0]
        xs["REGISTERS"] = self.registers.expand(batch_size, -1, -1)
        if pad_masks:
            pad_masks["REGISTERS"] = self.register_mask_buffer.expand(batch_size, -1)

    # handle edge features if present
    edge_x = xs.pop("EDGE", None)
    kwargs = {} if edge_x is None else {"edge_x": edge_x}

    if isinstance(self.merge_dict, dict):
        for merge_name, merge_types in self.merge_dict.items():
            xs[merge_name] = cat([xs.pop(mt) for mt in merge_types], dim=1)
        if pad_masks is not None:
            for merge_name, merge_types in self.merge_dict.items():
                pad_masks[merge_name] = cat([pad_masks.pop(mt) for mt in merge_types], dim=1)
        for merge_name, merge_types in self.merge_dict.items():
            if labels is not None:
                labels[merge_name] = {}
                for var in labels[merge_types[0]]:
                    labels[merge_name].update({
                        var: cat([labels[mt][var] for mt in merge_types], dim=1)
                    })

    # Generate embedding from encoder, or by concatenating the init net outputs
    if self.encoder:
        preds = {"embed_xs": self.encoder(xs, pad_mask=pad_masks, **kwargs)}
    else:
        preds = {"embed_xs": flatten_tensor_dict(xs)}

    preds, labels, loss = (
        self.mask_decoder(preds, self.tasks, pad_masks, labels)
        if self.mask_decoder
        else (preds, labels, {})
    )

    # apply featurewise transformation to global track representations if configured
    if "global" in self.featurewise_nets_map:
        preds["embed_xs"] = self.featurewise_nets_map["global"](inputs, preds["embed_xs"])

    # pooling
    if self.pool_net:
        global_rep = self.pool_net(preds, pad_mask=pad_masks)
    else:
        global_rep = preds["embed_xs"]

    # add global features to global representation
    if (global_feats := inputs.get("GLOBAL")) is not None:
        global_rep = torch.cat([global_rep, global_feats], dim=-1)
    preds["global_rep"] = global_rep

    # run tasks
    task_preds, task_loss = self.run_tasks(preds, pad_masks, labels)
    preds.update(task_preds)
    loss.update(task_loss)

    return preds, loss

Wrappers#

salt.modelwrapper.ModelWrapper #

Bases: lightning.LightningModule

A wrapper class for any model implemented in Salt.

This wrapper class allows is as generic as possible. It wraps SaltModel, but could also be used to wrap any other model if you want to do train something that doesn't fit into the SaltModel architecture.

This class is responsible for containing things that are common to all salt models. These are:

  • A generic forward pass, including input normalisation
  • Training, validation and test steps, which include logging
  • Some sanity checks on the model configuration

Parameters:

Name Type Description Default
model torch.nn.Module

Model to be wrapped

required
lrs_config collections.abc.Mapping[str, float]

LRS config which has to be set manually for now https://github.com/omni-us/jsonargparse/issues/170#issuecomment-1288167674

required
global_object str

Name of the global input object, as opposed to the constituent-level inputs. This argument is set automatically by the framework.

required
norm_config dict

Keyword arguments for salt.models.InputNorm.

None
name str

Name of the model, used for logging and inference output names

'salt'
muP_config dict | None

The muP configuration.

None
Source code in salt/modelwrapper.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def __init__(
    self,
    model: nn.Module,
    lrs_config: Mapping[str, float],
    global_object: str,
    norm_config: dict | None = None,
    name: str = "salt",
    muP_config: dict | None = None,
):
    """A wrapper class for any model implemented in Salt.

    This wrapper class allows is as generic as possible. It wraps
    [`SaltModel`][salt.models.SaltModel], but could also be used to
    wrap any other model if you want to do train something that doesn't
    fit into the [`SaltModel`][salt.models.SaltModel] architecture.

    This class is responsible for containing things that are common to all
    salt models. These are:

    - A generic forward pass, including input normalisation
    - Training, validation and test steps, which include logging
    - Some sanity checks on the model configuration

    Parameters
    ----------
    model : nn.Module
        Model to be wrapped
    lrs_config: Mapping
        LRS config which has to be set manually for now
        https://github.com/omni-us/jsonargparse/issues/170#issuecomment-1288167674
    global_object : str
        Name of the global input object, as opposed to the constituent-level
        inputs. This argument is set automatically by the framework.
    norm_config : dict, optional
        Keyword arguments for [`salt.models.InputNorm`][salt.models.InputNorm].
    name: str, optional
        Name of the model, used for logging and inference output names
    muP_config: dict, optional
        The muP configuration.
    """
    super().__init__()
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        self.save_hyperparameters(logger=False)

    self.model = model
    self.lrs_config = lrs_config
    self.global_object = global_object
    self.name = name
    self.muP = muP_config if muP_config else {}
    self.last_val_batch_outs = None
    # Here the model should pick it up
    if self.muP:
        from salt.utils.muP_utils.configuration_muP import instantiate_mup

        load_path = None
        if "shape_path" in self.muP:
            load_path = self.muP["shape_path"]
        instantiate_mup(model, load_path)

    # all tasks should inherit the global object type
    self.model.global_object = self.global_object
    for task in self.model.tasks:
        task.global_object = self.global_object

    # ensure unique names for init_nets and tasks
    check_unique(self.model.init_nets, "input_name")
    check_unique(self.model.tasks, "name")

    # check that the model has the same output size for all init nets
    assert len({t.net.output_size for t in self.model.init_nets if t.input_name != "EDGE"}) == 1

    # create input normaliser
    assert norm_config is not None
    self.norm = InputNorm(**norm_config)

forward #

Generic forward pass through any salt-compatible model.

This function performs input normalisation and then calls the self.model's forward pass. Don't call this method directy, instead use __call__.

Parameters:

Name Type Description Default
inputs

Any generic input to the model.

required
pad_masks

Input padding masks.

None
labels

Training targets. If not specified, assume we are running model inference (i.e. no loss computation).

None

Returns:

Type Description
Whatever is returned by `self.model`'s forward pass.
Source code in salt/modelwrapper.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def forward(self, inputs, pad_masks=None, labels=None):
    """Generic forward pass through any salt-compatible model.

    This function performs input normalisation and then calls the `self.model`'s
    forward pass. Don't call this method directy, instead use `__call__`.

    Parameters
    ----------
    inputs
        Any generic input to the model.
    pad_masks
        Input padding masks.
    labels
        Training targets. If not specified, assume we are running model inference
        (i.e. no loss computation).

    Returns
    -------
    Whatever is returned by `self.model`'s forward pass.
    """
    x = self.norm(inputs)
    return self.model(x, pad_masks, labels)

Last update: January 25, 2024
Created: October 20, 2023