Initialisation
salt.models.InputNorm
#
Bases: torch.nn.Module
Normalise inputs on the fly using a pre-computed normalisation dictionary.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
norm_dict |
pathlib.Path
|
Path to file containing normalisation parameters |
required |
variables |
dict
|
Input variables for each type of input |
required |
global_object |
str
|
Name of the global input object, as opposed to the constituent-level inputs |
required |
input_map |
dict
|
Map names to the corresponding dataset names in the input h5 file. Set automatically by the framework. |
required |
Source code in salt/models/inputnorm.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
|
salt.models.InitNet
#
Bases: torch.nn.Module
Initial input embedding network.
This class can handle global input concatenation and positional encoding.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_name |
str
|
Name of the input, must match the input types in the data config |
required |
dense_config |
dict
|
Keyword arguments for |
required |
variables |
salt.stypes.Vars
|
Input variables used in the forward pass, set automatically by the framework |
required |
global_object |
str
|
Name of the global object, set automatically by the framework |
required |
attach_global |
str
|
Concatenate global-level inputs with constituent-level inputs before embedding |
True
|
pos_enc |
salt.models.posenc.PositionalEncoder
|
Positional encoder module to use. See
|
None
|
muP |
bool
|
Whether to use the muP parametrisation (impacts initialisation). |
False
|
featurewise |
salt.models.FeaturewiseTransformation | None
|
Networks to apply featurewise transformations to inputs, set automatically by the framework |
None
|
Source code in salt/models/initnet.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
|
salt.models.PositionalEncoder
#
Bases: torch.nn.Module
Positional encoder.
Evenly share the embedding space between the different variables to be encoded. Any remaining dimensions are left as zeros.
TODO: alpha should be set for each variable
Parameters:
Name | Type | Description | Default |
---|---|---|---|
variables |
list[str]
|
List of variables to apply the positional encoding to. |
required |
dim |
int
|
Dimension of the positional encoding. |
required |
alpha |
int
|
Scaling factor for the positional encoding, by default 100. |
100
|
Source code in salt/models/posenc.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
|
Created: January 25, 2024