Models
Components#
salt.models.Dense
#
Bases: torch.nn.Module
A fully connected feed forward neural network, which can take in additional contextual information.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_size
|
int
|
Input size |
required |
output_size
|
int | None
|
Output size. If not specified this will be the same as the input size, by default None |
None
|
hidden_layers
|
list[int] | None
|
Number of nodes per layer, if not specified, the network will have
a single hidden layer with size |
None
|
hidden_dim_scale
|
int
|
Scale factor for the hidden layer size, by default 2 |
2
|
activation
|
str
|
Activation function for hidden layers. Must be a valid torch.nn activation function. By default "ReLU" |
'ReLU'
|
final_activation
|
str | None
|
Activation function for the output layer. Must be a valid torch.nn activation function. By default None |
None
|
dropout
|
float
|
Apply dropout with the supplied probability, by default 0.0 |
0.0
|
bias
|
bool
|
Whether to use bias in the linear layers, by default True |
True
|
context_size
|
int
|
Size of the context tensor, 0 means no context information is provided, by default 0 |
0
|
mup
|
bool
|
Whether to use the muP parametrisation (impacts initialisation), by default None |
False
|
Source code in salt/models/dense.py
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
|
Models#
salt.models.SaltModel
#
Bases: torch.nn.Module
Generic multi-modal, multi-task neural network.
This model can implement a wide range of architectures such as DL1, DIPS, GN2 and more.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
init_nets
|
list[dict]
|
Keyword arguments for one or more initialisation networks. Each
initialisation network produces an initial input embedding for a
single input type. See :class: |
required |
tasks
|
torch.nn.ModuleList
|
Task heads to apply to the encoded/poolled features. See
:class: |
required |
encoder
|
torch.nn.Module | None
|
Main input encoder which takes the outputs of the init nets and
produces per-constituent embeddings. If not provided, the model is
effectively a DeepSets model. The default is |
None
|
mask_decoder
|
torch.nn.Module | None
|
Mask decoder which takes the encoder output and produces learned
embeddings to represent object masks. The default is |
None
|
pool_net
|
salt.models.Pooling | None
|
Pooling network computing a global representation by aggregating
over the constituents. If not provided, assume only global features
are present. The default is |
None
|
merge_dict
|
dict[str, list[str]] | None
|
Dictionary specifying which input types should be concatenated into
a single stream (e.g., transformer input). The default is |
None
|
featurewise_nets
|
list[dict] | None
|
Keyword arguments for featurewise transformation networks that
perform per-feature scaling and biasing. The default is |
None
|
Source code in salt/models/saltmodel.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
|
forward
#
Forward pass through the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inputs
|
salt.stypes.Tensors
|
Dict of input tensors for each modality of shape
|
required |
pad_masks
|
salt.stypes.BoolTensors | None
|
Dict of input padding mask tensors for each modality of shape
|
None
|
labels
|
salt.stypes.NestedTensors | None
|
Nested dict of label tensors. Outer dict keyed by input modality,
inner dict keyed by label variable. Each tensor has shape
|
None
|
Returns:
Name | Type | Description |
---|---|---|
preds |
salt.stypes.NestedTensors
|
Dict of model predictions for each task separated by input modality. |
loss |
salt.stypes.Tensors
|
Dict of losses for each task aggregated over the batch. |
Source code in salt/models/saltmodel.py
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
|
Wrappers#
salt.modelwrapper.ModelWrapper
#
Bases: lightning.LightningModule
A generic wrapper class for Salt-compatible models.
This class wraps SaltModel
, but can also be used to
wrap arbitrary PyTorch models for training with Lightning. It handles:
- A generic forward pass including input normalization
- Training, validation, and test steps with logging
- Sanity checks on the model configuration
- Optimizer and scheduler setup
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
torch.nn.Module
|
Model to be wrapped. |
required |
lrs_config
|
collections.abc.Mapping[str, float]
|
Learning rate schedule configuration. |
required |
global_object
|
str
|
Name of the global input object, as opposed to constituent-level inputs. |
required |
norm_config
|
dict | None
|
Keyword arguments for |
None
|
name
|
str
|
Name of the model, used for logging and inference outputs. Default is |
'salt'
|
mup_config
|
dict | None
|
Configuration for mup scaling. Default is |
None
|
loss_mode
|
str
|
Loss reduction mode. Default is |
'wsum'
|
optimizer
|
str
|
Optimizer to use. Default is |
'AdamW'
|
Source code in salt/modelwrapper.py
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
|
forward
#
Forward pass through the wrapped model with input normalization.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inputs
|
torch.Tensor | dict
|
Model inputs. |
required |
pad_masks
|
torch.Tensor | None
|
Padding masks for variable-length inputs. |
None
|
labels
|
torch.Tensor | None
|
Training targets. If not provided, inference mode is assumed. |
None
|
Returns:
Type | Description |
---|---|
typing.Any
|
Whatever is returned by the wrapped model's forward pass. |
Source code in salt/modelwrapper.py
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
|
Created: October 20, 2023