Skip to content

Models

Components#

salt.models.Dense #

Bases: torch.nn.Module

A fully connected feed forward neural network, which can take in additional contextual information.

Parameters:

Name Type Description Default
input_size int

Input size

required
output_size int | None

Output size. If not specified this will be the same as the input size, by default None

None
hidden_layers list[int] | None

Number of nodes per layer, if not specified, the network will have a single hidden layer with size input_size * hidden_dim_scale, by default None

None
hidden_dim_scale int

Scale factor for the hidden layer size, by default 2

2
activation str

Activation function for hidden layers. Must be a valid torch.nn activation function. By default "ReLU"

'ReLU'
final_activation str | None

Activation function for the output layer. Must be a valid torch.nn activation function. By default None

None
dropout float

Apply dropout with the supplied probability, by default 0.0

0.0
bias bool

Whether to use bias in the linear layers, by default True

True
context_size int

Size of the context tensor, 0 means no context information is provided, by default 0

0
mup bool

Whether to use the muP parametrisation (impacts initialisation), by default None

False
Source code in salt/models/dense.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(
    self,
    input_size: int,
    output_size: int | None = None,
    hidden_layers: list[int] | None = None,
    hidden_dim_scale: int = 2,
    activation: str = "ReLU",
    final_activation: str | None = None,
    dropout: float = 0.0,
    bias: bool = True,
    context_size: int = 0,
    mup: bool = False,
) -> None:
    super().__init__()

    if output_size is None:
        output_size = input_size
    if hidden_layers is None:
        hidden_layers = [input_size * hidden_dim_scale]

    # Save the networks input and output sizes
    self.input_size = input_size
    self.output_size = output_size
    self.context_size = context_size
    self.mup = mup

    # build nodelist
    self.node_list = [input_size + context_size, *hidden_layers, output_size]

    # input and hidden layers
    layers = []

    num_layers = len(self.node_list) - 1
    for i in range(num_layers):
        if dropout:
            layers.append(nn.Dropout(dropout))

        # linear projection
        layers.append(nn.Linear(self.node_list[i], self.node_list[i + 1], bias=bias))

        # activation for all but the final layer
        if i != num_layers - 1:
            layers.append(getattr(nn, activation)())

        # final layer: return logits by default, or activation if specified
        elif final_activation:
            layers.append(getattr(nn, final_activation)())

    # build the net
    self.net = nn.Sequential(*layers)

    if self.mup:
        self._reset_parameters()

Models#

salt.models.SaltModel #

Bases: torch.nn.Module

Generic multi-modal, multi-task neural network.

This model can implement a wide range of architectures such as DL1, DIPS, GN2 and more.

Parameters:

Name Type Description Default
init_nets list[dict]

Keyword arguments for one or more initialisation networks. Each initialisation network produces an initial input embedding for a single input type. See :class:salt.models.InitNet.

required
tasks torch.nn.ModuleList

Task heads to apply to the encoded/poolled features. See :class:salt.models.TaskBase.

required
encoder torch.nn.Module | None

Main input encoder which takes the outputs of the init nets and produces per-constituent embeddings. If not provided, the model is effectively a DeepSets model. The default is None.

None
mask_decoder torch.nn.Module | None

Mask decoder which takes the encoder output and produces learned embeddings to represent object masks. The default is None.

None
pool_net salt.models.Pooling | None

Pooling network computing a global representation by aggregating over the constituents. If not provided, assume only global features are present. The default is None.

None
merge_dict dict[str, list[str]] | None

Dictionary specifying which input types should be concatenated into a single stream (e.g., transformer input). The default is None.

None
featurewise_nets list[dict] | None

Keyword arguments for featurewise transformation networks that perform per-feature scaling and biasing. The default is None.

None
Source code in salt/models/saltmodel.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def __init__(
    self,
    init_nets: list[dict],
    tasks: nn.ModuleList,
    encoder: nn.Module | None = None,
    mask_decoder: nn.Module | None = None,
    pool_net: Pooling | None = None,
    merge_dict: dict[str, list[str]] | None = None,
    featurewise_nets: list[dict] | None = None,
):
    super().__init__()

    # init featurewise networks
    if featurewise_nets:
        self.init_featurewise(featurewise_nets, init_nets, encoder)

    self.init_nets = nn.ModuleList([InitNet(**init_net) for init_net in init_nets])
    self.tasks = tasks
    self.encoder = encoder
    self.mask_decoder = mask_decoder

    self.pool_net = pool_net
    self.merge_dict = merge_dict

    # checks for the global object only setup
    if self.pool_net is None:
        assert self.encoder is None, "pool_net must be set if encoder is set"
        assert len(self.init_nets) == 1, "pool_net must be set if more than one init_net is set"
        assert self.init_nets[0].input_name == self.init_nets[0].global_object

forward #

Forward pass through the model.

Parameters:

Name Type Description Default
inputs salt.stypes.Tensors

Dict of input tensors for each modality of shape (batch_size, num_inputs, input_size).

required
pad_masks salt.stypes.BoolTensors | None

Dict of input padding mask tensors for each modality of shape (batch_size, num_inputs). The default is None.

None
labels salt.stypes.NestedTensors | None

Nested dict of label tensors. Outer dict keyed by input modality, inner dict keyed by label variable. Each tensor has shape (batch_size, num_inputs). If None, run inference without loss computation.

None

Returns:

Name Type Description
preds salt.stypes.NestedTensors

Dict of model predictions for each task separated by input modality.

loss salt.stypes.Tensors

Dict of losses for each task aggregated over the batch.

Source code in salt/models/saltmodel.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def forward(
    self,
    inputs: Tensors,
    pad_masks: BoolTensors | None = None,
    labels: NestedTensors | None = None,
) -> tuple[NestedTensors, Tensors]:
    """Forward pass through the model.

    Parameters
    ----------
    inputs : Tensors
        Dict of input tensors for each modality of shape
        ``(batch_size, num_inputs, input_size)``.
    pad_masks : BoolTensors | None, optional
        Dict of input padding mask tensors for each modality of shape
        ``(batch_size, num_inputs)``. The default is ``None``.
    labels : NestedTensors | None, optional
        Nested dict of label tensors. Outer dict keyed by input modality,
        inner dict keyed by label variable. Each tensor has shape
        ``(batch_size, num_inputs)``. If ``None``, run inference without
        loss computation.

    Returns
    -------
    preds : NestedTensors
        Dict of model predictions for each task separated by input modality.
    loss : Tensors
        Dict of losses for each task aggregated over the batch.
    """
    # initial input projections
    xs = {}

    for init_net in self.init_nets:
        xs[init_net.input_name] = init_net(inputs)

    # handle edge features if present
    edge_x = xs.pop("EDGE", None)
    kwargs = {} if edge_x is None else {"edge_x": edge_x}

    # merge multiple streams if requested
    if isinstance(self.merge_dict, dict):
        for merge_name, merge_types in self.merge_dict.items():
            xs[merge_name] = cat([xs.pop(mt) for mt in merge_types], dim=1)
        if pad_masks is not None:
            for merge_name, merge_types in self.merge_dict.items():
                pad_masks[merge_name] = cat([pad_masks.pop(mt) for mt in merge_types], dim=1)
        for merge_name, merge_types in self.merge_dict.items():
            if labels is not None:
                labels[merge_name] = {}
                for var in labels[merge_types[0]]:
                    labels[merge_name].update({
                        var: cat([labels[mt][var] for mt in merge_types], dim=1)
                    })

    # encode
    if self.encoder:
        embed_xs = self.encoder(xs, pad_mask=pad_masks, inputs=inputs, **kwargs)
        if isinstance(embed_xs, tuple):
            embed_xs, pad_masks = embed_xs
        preds = {"embed_xs": embed_xs}
    else:
        preds = {"embed_xs": flatten_tensor_dict(xs)}

    preds, labels, loss = (
        self.mask_decoder(preds, self.tasks, pad_masks, labels)
        if self.mask_decoder
        else (preds, labels, {})
    )

    # apply featurewise transformation to global track embeddings if configured
    if hasattr(self, "featurewise_global") and self.featurewise_global:
        preds["embed_xs"] = self.featurewise_global(inputs, preds["embed_xs"])

    # pooling
    if self.pool_net:
        global_rep = self.pool_net(preds, pad_mask=pad_masks)
    else:
        global_rep = preds["embed_xs"]

    # add global features to global representation
    if (global_feats := inputs.get("global")) is not None:
        global_rep = torch.cat([global_rep, global_feats], dim=-1)
    preds["global_rep"] = global_rep

    # run tasks
    task_preds, task_loss = self.run_tasks(preds, pad_masks, labels)
    preds.update(task_preds)
    loss.update(task_loss)

    return preds, loss

Wrappers#

salt.modelwrapper.ModelWrapper #

Bases: lightning.LightningModule

A generic wrapper class for Salt-compatible models.

This class wraps SaltModel, but can also be used to wrap arbitrary PyTorch models for training with Lightning. It handles:

  • A generic forward pass including input normalization
  • Training, validation, and test steps with logging
  • Sanity checks on the model configuration
  • Optimizer and scheduler setup

Parameters:

Name Type Description Default
model torch.nn.Module

Model to be wrapped.

required
lrs_config collections.abc.Mapping[str, float]

Learning rate schedule configuration.

required
global_object str

Name of the global input object, as opposed to constituent-level inputs.

required
norm_config dict | None

Keyword arguments for salt.models.InputNorm.

None
name str

Name of the model, used for logging and inference outputs. Default is "salt".

'salt'
mup_config dict | None

Configuration for mup scaling. Default is None.

None
loss_mode str

Loss reduction mode. Default is "wsum". Other option: "GLS".

'wsum'
optimizer str

Optimizer to use. Default is "AdamW". Other option: "lion".

'AdamW'
Source code in salt/modelwrapper.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def __init__(
    self,
    model: nn.Module,
    lrs_config: Mapping[str, float],
    global_object: str,
    norm_config: dict | None = None,
    name: str = "salt",
    mup_config: dict | None = None,
    loss_mode: str = "wsum",
    optimizer: str = "AdamW",
):
    super().__init__()
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        self.save_hyperparameters(logger=False)

    self.model = model
    self.lrs_config = lrs_config
    self.global_object = global_object
    self.name = name
    self.mup = mup_config or {}
    self.last_val_batch_outs = None

    # MuP initialization if configured
    if self.mup:
        load_path = self.mup.get("shape_path")
        instantiate_mup(model, load_path)

    # propagate metadata to tasks
    self.model.global_object = self.global_object
    for task in self.model.tasks:
        task.global_object = self.global_object
        task.model_name = self.name

    # sanity checks
    check_unique(self.model.init_nets, "input_name")
    check_unique(self.model.tasks, "name")

    assert len({t.net.output_size for t in self.model.init_nets if t.input_name != "EDGE"}) == 1

    # input normalizer
    assert norm_config is not None
    self.norm = InputNorm(**norm_config)

    allowed_loss_modes = ["wsum", "GLS"]
    assert loss_mode in allowed_loss_modes, f"Loss mode must be one of {allowed_loss_modes}"
    self.loss_mode = loss_mode
    if loss_mode == "GLS":
        assert all(
            task.weight == 1.0 for task in self.model.tasks
        ), "GLS does not utilise task weights - set all weights to 1"

    allowed_optimizers = ["lion", "AdamW"]
    assert optimizer in allowed_optimizers, (
        f"Optimizer {optimizer} not implemented, " f"please choose from {allowed_optimizers}"
    )
    self.optimizer = optimizer

forward #

Forward pass through the wrapped model with input normalization.

Parameters:

Name Type Description Default
inputs torch.Tensor | dict

Model inputs.

required
pad_masks torch.Tensor | None

Padding masks for variable-length inputs.

None
labels torch.Tensor | None

Training targets. If not provided, inference mode is assumed.

None

Returns:

Type Description
typing.Any

Whatever is returned by the wrapped model's forward pass.

Source code in salt/modelwrapper.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def forward(
    self,
    inputs: torch.Tensor | dict,
    pad_masks: torch.Tensor | None = None,
    labels: torch.Tensor | None = None,
):
    """Forward pass through the wrapped model with input normalization.

    Parameters
    ----------
    inputs : torch.Tensor | dict
        Model inputs.
    pad_masks : torch.Tensor | None, optional
        Padding masks for variable-length inputs.
    labels : torch.Tensor | None, optional
        Training targets. If not provided, inference mode is assumed.

    Returns
    -------
    Any
        Whatever is returned by the wrapped model's forward pass.
    """
    x = self.norm(inputs)
    return self.model(x, pad_masks, labels)

Last update: January 25, 2024
Created: October 20, 2023