Skip to content

Task Heads#

salt.models.TaskBase #

Bases: torch.nn.Module, abc.ABC

Task head base class.

Tasks wrap a dense network, a loss, a label, and a weight.

Parameters:

Name Type Description Default
name str

Arbitrary name of the task, used for logging and inference.

required
input_name str

Which type of object is input to the task e.g. jet/track/flow.

required
dense_config dict

Keyword arguments for salt.models.Dense, the dense network producing the task outputs.

required
loss torch.nn.Module

Loss function applied to the dense network outputs.

required
weight float

Weight in the overall loss.

1.0
Source code in salt/models/task.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    name: str,
    input_name: str,
    dense_config: dict,
    loss: nn.Module,
    weight: float = 1.0,
):
    """Task head base class.

    Tasks wrap a dense network, a loss, a label, and a weight.

    Parameters
    ----------
    name : str
        Arbitrary name of the task, used for logging and inference.
    input_name : str
        Which type of object is input to the task e.g. jet/track/flow.
    dense_config : dict
        Keyword arguments for [`salt.models.Dense`][salt.models.Dense],
        the dense network producing the task outputs.
    loss : nn.Module
        Loss function applied to the dense network outputs.
    weight : float
        Weight in the overall loss.
    """
    super().__init__()

    self.name = name
    self.input_name = input_name
    self.net = Dense(**dense_config)
    self.loss = loss
    self.weight = weight

salt.models.ClassificationTask #

Bases: salt.models.task.TaskBase

Classification task.

Parameters:

Name Type Description Default
label str

Label name for the task

required
class_names list[str] | None

List of class names, ordered by output index. If not specified attempt to automatically determine these from the label name.

None
label_map collections.abc.Mapping | None

Remap integer labels for training (e.g. 0,4,5 -> 0,1,2).

None
sample_weight str | None

Name of a per sample weighting to apply in the loss function.

None
use_class_dict bool

If True, read class weights for the loss from the class_dict file.

False
**kwargs

Keyword arguments for salt.models.TaskBase.

{}
Source code in salt/models/task.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def __init__(
    self,
    label: str,
    class_names: list[str] | None = None,
    label_map: Mapping | None = None,
    sample_weight: str | None = None,
    use_class_dict: bool = False,
    **kwargs,
):
    """Classification task.

    Parameters
    ----------
    label : str
        Label name for the task
    class_names : list[str] | None, optional
        List of class names, ordered by output index. If not specified attempt to
        automatically determine these from the label name.
    label_map : Mapping | None, optional
        Remap integer labels for training (e.g. 0,4,5 -> 0,1,2).
    sample_weight : str | None, optional
        Name of a per sample weighting to apply in the loss function.
    use_class_dict : bool, optional
        If True, read class weights for the loss from the class_dict file.
    **kwargs
        Keyword arguments for [`salt.models.TaskBase`][salt.models.TaskBase].
    """
    super().__init__(**kwargs)
    self.label = label
    self.class_names = class_names
    self.label_map = label_map
    if self.label_map is not None and self.class_names is None:
        raise ValueError("Specify class names when using label_map.")
    if hasattr(self.loss, "ignore_index"):
        self.loss.ignore_index = -1
    self.sample_weight = sample_weight
    if self.sample_weight is not None:
        assert (
            self.loss.reduction == "none"
        ), "Sample weights only supported for reduction='none'"
    if self.class_names is None:
        self.class_names = CLASS_NAMES[self.label]
    if len(self.class_names) != self.net.output_size:
        raise ValueError(
            f"{self.name}: "
            f"Number of outputs ({self.net.output_size}) does not match "
            f"number of class names ({len(self.class_names)}). Class names: {self.class_names}"
        )
    self.use_class_dict = use_class_dict

salt.models.RegressionTaskBase #

Bases: salt.models.task.TaskBase, abc.ABC

Base class for regression tasks.

Parameters:

Name Type Description Default
targets list[str] | str

Regression target(s).

required
scaler salt.utils.scalers.RegressionTargetScaler

Functional scaler for regression targets - cannot be used with other target scaling options.

None
target_denominators list[str] | str | None

Variables to divide regression target(s) by (i.e. for regressing a ratio). - cannot be used with other target scaling options.

None
norm_params dict | None

Mean and std normalization parameters for each target, used for scaling. - cannot be used with other target scaling options.

None
custom_output_names list[str] | str | None

Name(s) of regression output(s), overwrites the standard "model name + numerator"

None
**kwargs

Keyword arguments for salt.models.TaskBase.

{}
Source code in salt/models/task.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def __init__(
    self,
    targets: list[str] | str,
    scaler: RegressionTargetScaler | None = None,
    target_denominators: list[str] | str | None = None,
    norm_params: dict | None = None,
    custom_output_names: list[str] | str | None = None,
    **kwargs,
):
    """Base class for regression tasks.

    Parameters
    ----------
    targets : list[str] | str
        Regression target(s).
    scaler : RegressionTargetScaler
        Functional scaler for regression targets
            - cannot be used with other target scaling options.
    target_denominators : list[str] | str | None, optional
        Variables to divide regression target(s) by (i.e. for regressing a ratio).
            - cannot be used with other target scaling options.
    norm_params : dict | None, optional
        Mean and std normalization parameters for each target, used for scaling.
            - cannot be used with other target scaling options.
    custom_output_names : list[str] | str | None, optional
        Name(s) of regression output(s), overwrites the standard "model name + numerator"
    **kwargs
        Keyword arguments for [`salt.models.TaskBase`][salt.models.TaskBase].
    """
    super().__init__(**kwargs)
    self.scaler = scaler
    self.targets = listify(targets)
    self.target_denominators = listify(target_denominators)
    self.custom_output_names = listify(custom_output_names)
    if norm_params:
        norm_params["mean"] = listify(norm_params["mean"])
        norm_params["std"] = listify(norm_params["std"])
    self.norm_params = norm_params

    if [scaler, target_denominators, norm_params].count(None) not in {2, 3}:
        raise ValueError("Can only use a single scaling method")

    if self.scaler:
        for target in self.targets:
            self.scaler.scale(target, torch.Tensor(1))
    if self.target_denominators and len(self.targets) != len(self.target_denominators):
        raise ValueError(
            f"{self.name}: "
            f"Number of targets ({len(self.targets)}) does not match "
            f"number of target denominators ({len(self.target_denominators)})"
        )
    if self.norm_params and len(self.norm_params["mean"]) != len(self.targets):
        raise ValueError(
            f"{self.name}: "
            f"Number of means in norm_params ({len(self.norm_params['mean'])}) does not match "
            f"number of targets ({len(self.targets)})"
        )
    if self.norm_params and len(self.norm_params["std"]) != len(self.targets):
        raise ValueError(
            f"{self.name}: "
            f"Number of stds in norm_params ({len(self.norm_params['std'])}) does not match "
            f"number of targets ({len(self.targets)})"
        )

salt.models.RegressionTask #

Bases: salt.models.task.RegressionTaskBase

Regression task.

Parameters:

Name Type Description Default
scaler

dummy text

None
**kwargs

Keyword arguments for salt.models.RegressionTaskBase.

{}
Source code in salt/models/task.py
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
def __init__(self, scaler=None, **kwargs):
    """Regression task.

    Parameters
    ----------
    scaler
        dummy text
    **kwargs
        Keyword arguments for
        [`salt.models.RegressionTaskBase`][salt.models.RegressionTaskBase].
    """
    super().__init__(**kwargs)
    if self.net.output_size != len(self.targets):
        raise ValueError(
            f"{self.name}: "
            f"Number of outputs ({self.net.output_size}) does not match "
            f"number of targets ({len(self.targets)})"
        )
    self.scaler = scaler

salt.models.GaussianRegressionTask #

Bases: salt.models.task.RegressionTaskBase

Regression task that outputs a mean and variance for each target. The loss function is the negative log likelihood of a Gaussian distribution.

Parameters:

Name Type Description Default
**kwargs

Keyword arguments for salt.models.RegressionTaskBase.

{}
Source code in salt/models/task.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
def __init__(self, **kwargs):
    """Regression task that outputs a mean and variance for each target.
    The loss function is the negative log likelihood of a Gaussian distribution.

    Parameters
    ----------
    **kwargs
        Keyword arguments for
        [`salt.models.RegressionTaskBase`][salt.models.RegressionTaskBase].
    """
    super().__init__(**kwargs)
    if self.net.output_size != 2 * len(self.targets):
        raise ValueError(
            f"{self.name}: "
            f"Number of targets ({len(self.targets)}) is not twice the "
            f"number of outputs ({self.net.output_size})"
        )

Last update: November 16, 2023
Created: October 20, 2023