Skip to content

Task Heads#

salt.models.TaskBase #

Bases: torch.nn.Module, abc.ABC

Task head base class.

Tasks wrap a dense network, a loss, a label, and a weight.

Parameters:

Name Type Description Default
name str

Arbitrary name of the task, used for logging and inference.

required
input_name str

Which type of object is input to the task e.g. jet/track/flow.

required
dense_config dict

Keyword arguments for salt.models.Dense, the dense network producing the task outputs.

required
loss torch.nn.Module

Loss function applied to the dense network outputs.

required
weight float

Weight in the overall loss.

1.0
Source code in salt/models/task.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    name: str,
    input_name: str,
    dense_config: dict,
    loss: nn.Module,
    weight: float = 1.0,
):
    """Task head base class.

    Tasks wrap a dense network, a loss, a label, and a weight.

    Parameters
    ----------
    name : str
        Arbitrary name of the task, used for logging and inference.
    input_name : str
        Which type of object is input to the task e.g. jet/track/flow.
    dense_config : dict
        Keyword arguments for [`salt.models.Dense`][salt.models.Dense],
        the dense network producing the task outputs.
    loss : nn.Module
        Loss function applied to the dense network outputs.
    weight : float
        Weight in the overall loss.
    """
    super().__init__()

    self.name = name
    self.input_name = input_name
    self.net = Dense(**dense_config)
    self.loss = loss
    self.weight = weight

salt.models.ClassificationTask #

Bases: salt.models.task.TaskBase

Classification task.

Parameters:

Name Type Description Default
label str

Label name for the task

required
class_names list[str] | None

List of class names, ordered by output index. If not specified attempt to automatically determine these from the label name.

None
label_map collections.abc.Mapping | None

Remap integer labels for training (e.g. 0,4,5 -> 0,1,2).

None
sample_weight str | None

Name of a per sample weighting to apply in the loss function.

None
use_class_dict bool

If True, read class weights for the loss from the class_dict file.

False
**kwargs

Keyword arguments for salt.models.TaskBase.

{}

Raises:

Type Description
ValueError

If label map is defined but corresponding class names are not.

ValueError

If the defined number of outputs does not match the defined number of classes.

Source code in salt/models/task.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def __init__(
    self,
    label: str,
    class_names: list[str] | None = None,
    label_map: Mapping | None = None,
    sample_weight: str | None = None,
    use_class_dict: bool = False,
    **kwargs,
):
    """Classification task.

    Parameters
    ----------
    label : str
        Label name for the task
    class_names : list[str] | None, optional
        List of class names, ordered by output index. If not specified attempt to
        automatically determine these from the label name.
    label_map : Mapping | None, optional
        Remap integer labels for training (e.g. 0,4,5 -> 0,1,2).
    sample_weight : str | None, optional
        Name of a per sample weighting to apply in the loss function.
    use_class_dict : bool, optional
        If True, read class weights for the loss from the class_dict file.
    **kwargs
        Keyword arguments for [`salt.models.TaskBase`][salt.models.TaskBase].

    Raises
    ------
    ValueError
        If label map is defined but corresponding class names are not.
    ValueError
        If the defined number of outputs does not match the defined number of classes.
    """
    super().__init__(**kwargs)
    self.label = label
    self.class_names = class_names
    self.label_map = label_map
    if self.label_map is not None and self.class_names is None:
        raise ValueError("Specify class names when using label_map.")
    if hasattr(self.loss, "ignore_index"):
        self.loss.ignore_index = -1
    self.sample_weight = sample_weight
    if self.sample_weight is not None:
        assert (
            self.loss.reduction == "none"
        ), "Sample weights only supported for reduction='none'"
    if self.class_names is None:
        self.class_names = CLASS_NAMES[self.label]
    if len(self.class_names) != self.net.output_size:
        raise ValueError(
            f"{self.name}: "
            f"Number of outputs ({self.net.output_size}) does not match "
            f"number of class names ({len(self.class_names)}). Class names: {self.class_names}"
        )
    self.use_class_dict = use_class_dict

salt.models.RegressionTaskBase #

Bases: salt.models.task.TaskBase, abc.ABC

Base class for regression tasks.

Parameters:

Name Type Description Default
targets list[str] | str

Regression target(s).

required
scaler salt.utils.scalers.RegressionTargetScaler

Functional scaler for regression targets - cannot be used with other target scaling options.

None
target_denominators list[str] | str | None

Variables to divide regression target(s) by (i.e. for regressing a ratio). - cannot be used with other target scaling options.

None
norm_params dict | None

Mean and std normalization parameters for each target, used for scaling. - cannot be used with other target scaling options.

None
custom_output_names list[str] | str | None

Name(s) of regression output(s), overwrites the standard "model name + numerator"

None
**kwargs

Keyword arguments for salt.models.TaskBase.

{}

Raises:

Type Description
ValueError

If multiple flags for scaling methods are defined. If number of computed parameters does not match number of regression targets.

Source code in salt/models/task.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def __init__(
    self,
    targets: list[str] | str,
    scaler: RegressionTargetScaler | None = None,
    target_denominators: list[str] | str | None = None,
    norm_params: dict | None = None,
    custom_output_names: list[str] | str | None = None,
    **kwargs,
):
    """Base class for regression tasks.

    Parameters
    ----------
    targets : list[str] | str
        Regression target(s).
    scaler : RegressionTargetScaler
        Functional scaler for regression targets
            - cannot be used with other target scaling options.
    target_denominators : list[str] | str | None, optional
        Variables to divide regression target(s) by (i.e. for regressing a ratio).
            - cannot be used with other target scaling options.
    norm_params : dict | None, optional
        Mean and std normalization parameters for each target, used for scaling.
            - cannot be used with other target scaling options.
    custom_output_names : list[str] | str | None, optional
        Name(s) of regression output(s), overwrites the standard "model name + numerator"
    **kwargs
        Keyword arguments for [`salt.models.TaskBase`][salt.models.TaskBase].

    Raises
    ------
    ValueError
        If multiple flags for scaling methods are defined.
        If number of computed parameters does not match number of regression targets.
    """
    super().__init__(**kwargs)
    self.scaler = scaler
    self.targets = listify(targets)
    self.target_denominators = listify(target_denominators)
    self.custom_output_names = listify(custom_output_names)
    if norm_params:
        norm_params["mean"] = listify(norm_params["mean"])
        norm_params["std"] = listify(norm_params["std"])
    self.norm_params = norm_params

    if [scaler, target_denominators, norm_params].count(None) not in {2, 3}:
        raise ValueError("Can only use a single scaling method")

    if self.scaler:
        for target in self.targets:
            self.scaler.scale(target, torch.Tensor(1))
    if self.target_denominators and len(self.targets) != len(self.target_denominators):
        raise ValueError(
            f"{self.name}: "
            f"Number of targets ({len(self.targets)}) does not match "
            f"number of target denominators ({len(self.target_denominators)})"
        )
    if self.norm_params and len(self.norm_params["mean"]) != len(self.targets):
        raise ValueError(
            f"{self.name}: "
            f"Number of means in norm_params ({len(self.norm_params['mean'])}) does not match "
            f"number of targets ({len(self.targets)})"
        )
    if self.norm_params and len(self.norm_params["std"]) != len(self.targets):
        raise ValueError(
            f"{self.name}: "
            f"Number of stds in norm_params ({len(self.norm_params['std'])}) does not match "
            f"number of targets ({len(self.targets)})"
        )

salt.models.RegressionTask #

Bases: salt.models.task.RegressionTaskBase

Regression task.

Parameters:

Name Type Description Default
scaler

dummy text

None
**kwargs

Keyword arguments for salt.models.RegressionTaskBase.

{}

Raises:

Type Description
ValueError

If number of outputs does not match number of regression targets.

Source code in salt/models/task.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def __init__(self, scaler=None, **kwargs):
    """Regression task.

    Parameters
    ----------
    scaler
        dummy text
    **kwargs
        Keyword arguments for
        [`salt.models.RegressionTaskBase`][salt.models.RegressionTaskBase].

    Raises
    ------
    ValueError
        If number of outputs does not match number of regression targets.
    """
    super().__init__(**kwargs)
    if self.net.output_size != len(self.targets):
        raise ValueError(
            f"{self.name}: "
            f"Number of outputs ({self.net.output_size}) does not match "
            f"number of targets ({len(self.targets)})"
        )
    self.scaler = scaler

salt.models.GaussianRegressionTask #

Bases: salt.models.task.RegressionTaskBase

Regression task that outputs a mean and variance for each target. The loss function is the negative log likelihood of a Gaussian distribution.

Parameters:

Name Type Description Default
**kwargs

Keyword arguments for salt.models.RegressionTaskBase.

{}

Raises:

Type Description
ValueError

If number of regression targets is not twice the number of class outputs.

Source code in salt/models/task.py
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
def __init__(self, **kwargs):
    """Regression task that outputs a mean and variance for each target.
    The loss function is the negative log likelihood of a Gaussian distribution.

    Parameters
    ----------
    **kwargs
        Keyword arguments for
        [`salt.models.RegressionTaskBase`][salt.models.RegressionTaskBase].

    Raises
    ------
    ValueError
        If number of regression targets is not twice the number of class outputs.
    """
    super().__init__(**kwargs)
    if self.net.output_size != 2 * len(self.targets):
        raise ValueError(
            f"{self.name}: "
            f"Number of targets ({len(self.targets)}) is not twice the "
            f"number of outputs ({self.net.output_size})"
        )

Last update: November 16, 2023
Created: October 20, 2023