Models
Components#
salt.models.Dense
#
Bases: torch.nn.Module
A fully connected feed forward neural network, which can take in additional contextual information.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_size |
int
|
Input size |
required |
output_size |
int
|
Output size. If not specified this will be the same as the input size. |
None
|
hidden_layers |
list
|
Number of nodes per layer, if not specified, the network will have
a single hidden layer with size |
None
|
hidden_dim_scale |
int
|
Scale factor for the hidden layer size. |
2
|
activation |
str
|
Activation function for hidden layers. Must be a valid torch.nn activation function. |
'ReLU'
|
final_activation |
str
|
Activation function for the output layer. Must be a valid torch.nn activation function. |
None
|
dropout |
float
|
Apply dropout with the supplied probability. |
0.0
|
bias |
bool
|
Whether to use bias in the linear layers. |
True
|
context_size |
int
|
Size of the context tensor, 0 means no context information is provided. |
0
|
muP |
bool
|
Whether to use the muP parametrisation (impacts initialisation). |
False
|
Source code in salt/models/dense.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
|
Models#
salt.models.SaltModel
#
Bases: torch.nn.Module
A generic multi-modal, multi-task neural network.
This model can be used to implement a wide range of models, including DL1, DIPS, GN2 and more.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
init_nets |
list[dict]
|
Keyword arguments for one or more initialisation networks.
See |
required |
tasks |
torch.nn.ModuleList
|
Task heads, see |
required |
encoder |
torch.nn.Module
|
Main input encoder, which takes the output of the initialisation networks and produces a single embedding for each constituent. If not provided this model is essentially a DeepSets model. |
None
|
mask_decoder |
torch.nn.Module
|
Mask decoder, which takes the output of the encoder and produces a series of learned embeddings to represent object masks |
None
|
pool_net |
torch.nn.Module
|
Pooling network which computes a global representation of the object by aggregating over the constituents. If not provided, assume that the only inputs are global features (i.e. no constituents). |
None
|
merge_dict |
dict[str, list[str]] | None
|
A dictionary that lets the salt concatenate all the input representations of the inputs in list[str] and act on them in following layers (e.g. transformer or tasks) as if they are coming from one input type |
None
|
featurewise_nets |
list[dict]
|
Keyword arguments for featurewise transformation networks that perform featurewise scaling and biasing. |
None
|
Source code in salt/models/saltmodel.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
|
forward
#
Forward pass through the SaltModel
.
Don't call this method directy, instead use __call__
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inputs |
salt.stypes.Tensors
|
Dict of input tensors for each modality. Each tensor is of shape
|
required |
pad_masks |
salt.stypes.BoolTensors
|
Dict of input padding mask tensors for each modality. Each tensor is of
shape |
None
|
labels |
salt.stypes.Tensors
|
Nested dict of label tensors. The outer dict is keyed by input modality,
the inner dict is keyed by label variable name. Each tensor is of shape
|
None
|
Returns:
Name | Type | Description |
---|---|---|
preds |
salt.stypes.NestedTensors
|
Dict of model predictions for each task, separated by input modality. Tensors have varying shapes depending on the task. |
loss |
salt.stypes.Tensors
|
Dict of losses for each task, aggregated over the batch. |
Source code in salt/models/saltmodel.py
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
|
Wrappers#
salt.modelwrapper.ModelWrapper
#
Bases: lightning.LightningModule
A wrapper class for any model implemented in Salt.
This wrapper class allows is as generic as possible. It wraps
SaltModel
, but could also be used to
wrap any other model if you want to do train something that doesn't
fit into the SaltModel
architecture.
This class is responsible for containing things that are common to all salt models. These are:
- A generic forward pass, including input normalisation
- Training, validation and test steps, which include logging
- Some sanity checks on the model configuration
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
torch.nn.Module
|
Model to be wrapped |
required |
lrs_config |
collections.abc.Mapping[str, float]
|
LRS config which has to be set manually for now https://github.com/omni-us/jsonargparse/issues/170#issuecomment-1288167674 |
required |
global_object |
str
|
Name of the global input object, as opposed to the constituent-level inputs. This argument is set automatically by the framework. |
required |
norm_config |
dict
|
Keyword arguments for |
None
|
name |
str
|
Name of the model, used for logging and inference output names |
'salt'
|
muP_config |
dict | None
|
The muP configuration. |
None
|
loss_mode |
str
|
The loss mode to use. Default is "wsum" (weighted sum). Other options are - 'GLS' : arxiv.org/1904.08492 |
'wsum'
|
optimizer |
str
|
Optimizer used. Default if "AdamW" Other options are - 'lion': https://github.com/lucidrains/lion-pytorch |
'AdamW'
|
Source code in salt/modelwrapper.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
|
forward
#
Generic forward pass through any salt-compatible model.
This function performs input normalisation and then calls the self.model
's
forward pass. Don't call this method directy, instead use __call__
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inputs |
Any generic input to the model. |
required | |
pad_masks |
Input padding masks. |
None
|
|
labels |
Training targets. If not specified, assume we are running model inference (i.e. no loss computation). |
None
|
Returns:
Type | Description |
---|---|
Whatever is returned by `self.model`'s forward pass.
|
|
Source code in salt/modelwrapper.py
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
|
Created: October 20, 2023