Skip to content

Initialisation

salt.models.InputNorm #

Bases: torch.nn.Module

Normalise inputs on the fly using a pre-computed normalisation dictionary.

Parameters:

Name Type Description Default
norm_dict pathlib.Path

Path to file containing normalisation parameters

required
variables salt.stypes.Vars

Input variables for each type of input

required
global_object str

Name of the global input object, as opposed to the constituent-level inputs

required
input_map dict[str, str]

Map names to the corresponding dataset names in the input h5 file. Set automatically by the framework.

required

Raises:

Type Description
ValueError

If norm values for an input can't be found in the normalisation dict If norm values for an input can't be found in the normalisation dict If there is a non-finite normalisation value for an input If there is a zero standard deviation for one input

Source code in salt/models/inputnorm.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(
    self,
    norm_dict: Path,
    variables: Vars,
    global_object: str,
    input_map: dict[str, str],
) -> None:
    super().__init__()
    self.variables = variables
    self.global_object = global_object
    self.NO_NORM = ["EDGE", "parameters"]
    with open(norm_dict) as f:
        self.norm_dict = yaml.safe_load(f)

    # get the keys that need to be normalised
    if input_map is None:
        input_map = {k: k for k in variables}
    keys = {input_map[k] for k in set(variables.keys())}
    keys.discard("EDGE")
    keys.discard("parameters")
    if "global" in keys:
        keys.remove("global")
        keys.add(self.global_object)

    # check we have all required keys in the normalisation dictionary
    if missing := keys - set(self.norm_dict):
        raise ValueError(
            f"Missing input types {missing} in {norm_dict}. Choose from"
            f" {self.norm_dict.keys()}."
        )

    # check we have all required variables for each input type
    for k, vs in variables.items():
        if k in self.NO_NORM:
            continue
        name = input_map[k]
        if k == "global":
            name = self.global_object
        if missing := set(vs) - set(self.norm_dict[name]):
            raise ValueError(
                f"Missing variables {missing} for {name} in {norm_dict}. Choose from"
                f" {self.norm_dict[name].keys()}."
            )

        # store normalisation parameters with the model
        means = torch.as_tensor([self.norm_dict[name][v]["mean"] for v in vs])
        stds = torch.as_tensor([self.norm_dict[name][v]["std"] for v in vs])
        self.register_buffer(f"{k}_means", means)
        self.register_buffer(f"{k}_stds", stds)

        # check normalisation parameters are ok
        if not torch.isfinite(means).all() or not torch.isfinite(stds).all():
            raise ValueError(f"Non-finite normalisation parameters for {name} in {norm_dict}.")
        if any(stds == 0):
            raise ValueError(f"Zero standard deviation for {name} in {norm_dict}.")

salt.models.InitNet #

Bases: torch.nn.Module

Initial input embedding network.

This class can handle global input concatenation and positional encoding.

Parameters:

Name Type Description Default
input_name str

Name of the input, must match the input types in the data config

required
dense_config dict

Keyword arguments for salt.models.Dense, the dense network producing the initial embedding. The input_size argument is inferred automatically by the framework

required
variables salt.stypes.Vars

Input variables used in the forward pass, set automatically by the framework

required
global_object str

Name of the global object, set automatically by the framework

required
attach_global bool

Concatenate global-level inputs with constituent-level inputs before embedding, by default True

True
pos_enc salt.models.posenc.PositionalEncoder | None

Positional encoder module to use. See salt.models.PositionalEncoder for details. By default None

None
mup bool

Whether to use the muP parametrisation (impacts initialisation), by default False

False
featurewise salt.models.FeaturewiseTransformation | None

Networks to apply featurewise transformations to inputs, set automatically by the framework. By default None

None
Source code in salt/models/initnet.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def __init__(
    self,
    input_name: str,
    dense_config: dict,
    variables: Vars,
    global_object: str,
    attach_global: bool = True,
    pos_enc: PositionalEncoder | None = None,
    mup: bool = False,
    featurewise: FeaturewiseTransformation | None = None,
):
    super().__init__()

    # set input size
    if "input_size" not in dense_config:
        dense_config["input_size"] = len(variables[input_name])
        if attach_global and input_name != "EDGE":
            dense_config["input_size"] += len(variables[global_object])
            if not featurewise:
                dense_config["input_size"] += len(variables.get("parameters", []))

    self.input_name = input_name
    self.net = Dense(**dense_config)
    self.variables = variables
    self.attach_global = attach_global
    self.global_object = global_object
    self.pos_enc = pos_enc
    self.mup = mup
    if mup:
        self.net.reset_parameters()
    self.featurewise = featurewise

salt.models.PositionalEncoder #

Bases: torch.nn.Module

Positional encoder.

Evenly share the embedding space between the different variables to be encoded. Any remaining dimensions are left as zeros.

TODO: alpha should be set for each variable

Parameters:

Name Type Description Default
variables list[str]

List of variables to apply the positional encoding to.

required
dim int

Dimension of the positional encoding.

required
alpha int

Scaling factor for the positional encoding, by default 100.

100
Source code in salt/models/posenc.py
25
26
27
28
29
30
31
32
33
def __init__(self, variables: list[str], dim: int, alpha: int = 100):
    super().__init__()
    self.variables = variables
    self.dim = dim
    self.alpha = alpha
    print(dim, len(self.variables))
    self.per_input_dim = self.dim // (2 * len(self.variables))
    self.last_dim = self.dim % (2 * len(self.variables))
    print(self.per_input_dim, self.last_dim)

Last update: January 25, 2024
Created: January 25, 2024