Initialisation
salt.models.InputNorm
#
Bases: torch.nn.Module
Normalise inputs on the fly using a pre-computed normalisation dictionary.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
norm_dict
|
pathlib.Path
|
Path to file containing normalisation parameters |
required |
variables
|
salt.stypes.Vars
|
Input variables for each type of input |
required |
global_object
|
str
|
Name of the global input object, as opposed to the constituent-level inputs |
required |
input_map
|
dict[str, str]
|
Map names to the corresponding dataset names in the input h5 file. Set automatically by the framework. |
required |
Raises:
Type | Description |
---|---|
ValueError
|
If norm values for an input can't be found in the normalisation dict If norm values for an input can't be found in the normalisation dict If there is a non-finite normalisation value for an input If there is a zero standard deviation for one input |
Source code in salt/models/inputnorm.py
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
|
salt.models.InitNet
#
Bases: torch.nn.Module
Initial input embedding network.
This class can handle global input concatenation and positional encoding.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input_name
|
str
|
Name of the input, must match the input types in the data config |
required |
dense_config
|
dict
|
Keyword arguments for |
required |
variables
|
salt.stypes.Vars
|
Input variables used in the forward pass, set automatically by the framework |
required |
global_object
|
str
|
Name of the global object, set automatically by the framework |
required |
attach_global
|
bool
|
Concatenate global-level inputs with constituent-level inputs before embedding, by default True |
True
|
pos_enc
|
salt.models.posenc.PositionalEncoder | None
|
Positional encoder module to use. See
|
None
|
mup
|
bool
|
Whether to use the muP parametrisation (impacts initialisation), by default False |
False
|
featurewise
|
salt.models.FeaturewiseTransformation | None
|
Networks to apply featurewise transformations to inputs, set automatically by the framework. By default None |
None
|
Source code in salt/models/initnet.py
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
|
salt.models.PositionalEncoder
#
Bases: torch.nn.Module
Positional encoder.
Evenly share the embedding space between the different variables to be encoded. Any remaining dimensions are left as zeros.
TODO: alpha should be set for each variable
Parameters:
Name | Type | Description | Default |
---|---|---|---|
variables
|
list[str]
|
List of variables to apply the positional encoding to. |
required |
dim
|
int
|
Dimension of the positional encoding. |
required |
alpha
|
int
|
Scaling factor for the positional encoding, by default 100. |
100
|
Source code in salt/models/posenc.py
25 26 27 28 29 30 31 32 33 |
|
Created: January 25, 2024