Skip to content

Initialisation

salt.models.InputNorm #

Bases: torch.nn.Module

Normalise inputs on the fly using a pre-computed normalisation dictionary.

Parameters:

Name Type Description Default
norm_dict pathlib.Path

Path to file containing normalisation parameters

required
variables dict

Input variables for each type of input

required
global_object str

Name of the global input object, as opposed to the constituent-level inputs

required
input_map dict

Map names to the corresponding dataset names in the input h5 file. Set automatically by the framework.

required
Source code in salt/models/inputnorm.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __init__(
    self, norm_dict: Path, variables: Vars, global_object: str, input_map: dict[str, str]
) -> None:
    """Normalise inputs on the fly using a pre-computed normalisation dictionary.

    Parameters
    ----------
    norm_dict : Path
        Path to file containing normalisation parameters
    variables : dict
        Input variables for each type of input
    global_object : str
        Name of the global input object, as opposed to the constituent-level
        inputs
    input_map : dict
        Map names to the corresponding dataset names in the input h5 file.
        Set automatically by the framework.
    """
    super().__init__()
    self.variables = variables
    self.global_object = global_object
    self.NO_NORM = ["EDGE", "PARAMETERS"]
    with open(norm_dict) as f:
        self.norm_dict = yaml.safe_load(f)

    # get the keys that need to be normalised
    if input_map is None:
        input_map = {k: k for k in variables}
    keys = {input_map[k] for k in set(variables.keys())}
    keys.discard("EDGE")
    keys.discard("PARAMETERS")
    if "GLOBAL" in keys:
        keys.remove("GLOBAL")
        keys.add(self.global_object)

    # check we have all required keys in the normalisation dictionary
    if missing := keys - set(self.norm_dict):
        raise ValueError(
            f"Missing input types {missing} in {norm_dict}. Choose from"
            f" {self.norm_dict.keys()}."
        )

    # check we have all required variables for each input type
    for k, vs in variables.items():
        if k in self.NO_NORM:
            continue
        name = input_map[k]
        if k == "GLOBAL":
            name = self.global_object
        if missing := set(vs) - set(self.norm_dict[name]):
            raise ValueError(
                f"Missing variables {missing} for {name} in {norm_dict}. Choose from"
                f" {self.norm_dict[name].keys()}."
            )

        # store normalisation parameters with the model
        means = torch.as_tensor([self.norm_dict[name][v]["mean"] for v in vs])
        stds = torch.as_tensor([self.norm_dict[name][v]["std"] for v in vs])
        self.register_buffer(f"{k}_means", means)
        self.register_buffer(f"{k}_stds", stds)

        # check normalisation parameters are ok
        if not torch.isfinite(means).all() or not torch.isfinite(stds).all():
            raise ValueError(f"Non-finite normalisation parameters for {name} in {norm_dict}.")
        if any(stds == 0):
            raise ValueError(f"Zero standard deviation for {name} in {norm_dict}.")

salt.models.InitNet #

Bases: torch.nn.Module

Initial input embedding network.

This class can handle global input concatenation and positional encoding.

Parameters:

Name Type Description Default
input_name str

Name of the input, must match the input types in the data config

required
dense_config dict

Keyword arguments for salt.models.Dense, the dense network producing the initial embedding. The input_size argument is inferred automatically by the framework

required
variables salt.stypes.Vars

Input variables used in the forward pass, set automatically by the framework

required
global_object str

Name of the global object, set automatically by the framework

required
attach_global str

Concatenate global-level inputs with constituent-level inputs before embedding

True
pos_enc salt.models.posenc.PositionalEncoder

Positional encoder module to use. See salt.models.PositionalEncoder for details.

None
muP bool

Whether to use the muP parametrisation (impacts initialisation).

False
featurewise salt.models.FeaturewiseTransformation | None

Networks to apply featurewise transformations to inputs, set automatically by the framework

None
Source code in salt/models/initnet.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def __init__(
    self,
    input_name: str,
    dense_config: dict,
    variables: Vars,
    global_object: str,
    attach_global: bool = True,
    pos_enc: PositionalEncoder | None = None,
    muP: bool = False,
    featurewise: FeaturewiseTransformation | None = None,
):
    """Initial input embedding network.

    This class can handle global input concatenation and positional encoding.

    Parameters
    ----------
    input_name : str
        Name of the input, must match the input types in the data config
    dense_config : dict
        Keyword arguments for [`salt.models.Dense`][salt.models.Dense],
        the dense network producing the initial embedding. The `input_size`
        argument is inferred automatically by the framework
    variables : Vars
        Input variables used in the forward pass, set automatically by the framework
    global_object : str
        Name of the global object, set automatically by the framework
    attach_global : str, optional
        Concatenate global-level inputs with constituent-level inputs before embedding
    pos_enc : PositionalEncoder, optional
        Positional encoder module to use. See
        [`salt.models.PositionalEncoder`][salt.models.PositionalEncoder] for details.
    muP: bool, optional,
        Whether to use the muP parametrisation (impacts initialisation).
    featurewise: FeaturewiseTransformation, optional
        Networks to apply featurewise transformations to inputs, set automatically by
        the framework
    """
    super().__init__()

    # set input size
    if "input_size" not in dense_config:
        dense_config["input_size"] = len(variables[input_name])
        if attach_global and input_name != "EDGE":
            dense_config["input_size"] += len(variables[global_object])
            if not featurewise:
                dense_config["input_size"] += len(variables.get("PARAMETERS", []))

    self.input_name = input_name
    self.net = Dense(**dense_config)
    self.variables = variables
    self.attach_global = attach_global
    self.global_object = global_object
    self.pos_enc = pos_enc
    self.muP = muP
    if muP:
        self.net.reset_parameters()
    self.featurewise = featurewise

salt.models.PositionalEncoder #

Bases: torch.nn.Module

Positional encoder.

Evenly share the embedding space between the different variables to be encoded. Any remaining dimensions are left as zeros.

TODO: alpha should be set for each variable

Parameters:

Name Type Description Default
variables list[str]

List of variables to apply the positional encoding to.

required
dim int

Dimension of the positional encoding.

required
alpha int

Scaling factor for the positional encoding, by default 100.

100
Source code in salt/models/posenc.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def __init__(self, variables: list[str], dim: int, alpha: int = 100):
    """Positional encoder.

    Evenly share the embedding space between the different variables to be encoded.
    Any remaining dimensions are left as zeros.

    TODO: alpha should be set for each variable

    Parameters
    ----------
    variables : list[str]
        List of variables to apply the positional encoding to.
    dim : int
        Dimension of the positional encoding.
    alpha : int, optional
        Scaling factor for the positional encoding, by default 100.
    """
    super().__init__()
    self.variables = variables
    self.dim = dim
    self.alpha = alpha
    print(dim, len(self.variables))
    self.per_input_dim = self.dim // (2 * len(self.variables))
    self.last_dim = self.dim % (2 * len(self.variables))
    print(self.per_input_dim, self.last_dim)

Last update: January 25, 2024
Created: January 25, 2024