Task Heads#
salt.models.TaskBase
#
Bases: torch.nn.Module
, abc.ABC
Base class for task heads.
Tasks wrap a dense network, a loss, a target label, and a scalar weight.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name
|
str
|
Arbitrary name of the task, used for logging and inference. |
required |
input_name
|
str
|
Name of the input stream consumed by this task (e.g., |
required |
dense_config
|
dict
|
Keyword arguments for :class: |
required |
loss
|
torch.nn.Module
|
Loss function applied to the head outputs. |
required |
weight
|
float
|
Scalar multiplier for the task loss in the overall objective.
The default is |
1.0
|
Source code in salt/models/task.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
|
salt.models.ClassificationTask
#
Bases: salt.models.task.TaskBase
Multi-class or binary classification task head.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
label
|
str
|
Label name for the task. |
required |
class_names
|
list[str] | None
|
Ordered class names (index-aligned with outputs). If |
None
|
label_map
|
collections.abc.Mapping | None
|
Mapping to remap integer labels for training (e.g., {0,4,5} → {0,1,2}). |
None
|
sample_weight
|
str | None
|
Key of a per-sample weight found in |
None
|
use_class_dict
|
bool
|
If |
False
|
**kwargs
|
Forwarded to :class: |
{}
|
Raises:
Type | Description |
---|---|
ValueError
|
If a label map is provided without class names, or if the number of outputs does not match the number of classes. |
Source code in salt/models/task.py
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
|
salt.models.RegressionTaskBase
#
Bases: salt.models.task.TaskBase
, abc.ABC
Base class for regression tasks with optional target scaling.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
targets
|
list[str] | str
|
Regression target name(s). |
required |
scaler
|
salt.utils.scalers.RegressionTargetScaler | None
|
Functional scaler for targets. Mutually exclusive with other scaling options. |
None
|
target_denominators
|
list[str] | str | None
|
Denominator variable(s) for forming ratios as targets. Mutually exclusive with other scaling options. |
None
|
norm_params
|
dict | None
|
Mean/std normalization parameters for each target. Mutually exclusive
with other scaling options. Expected keys: |
None
|
custom_output_names
|
list[str] | str | None
|
Optional custom output names overriding the default. |
None
|
**kwargs
|
Forwarded to :class: |
{}
|
Raises:
Type | Description |
---|---|
ValueError
|
If multiple scaling methods are set simultaneously or if parameter counts do not match the number of targets. |
Source code in salt/models/task.py
331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 |
|
salt.models.RegressionTask
#
Bases: salt.models.task.RegressionTaskBase
Standard regression task head.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
scaler
|
salt.utils.scalers.RegressionTargetScaler | None
|
Backward-compatibility placeholder; if provided, stored on the instance. |
None
|
**kwargs
|
Forwarded to :class: |
{}
|
Raises:
Type | Description |
---|---|
ValueError
|
If the number of outputs does not match the number of regression targets. |
Source code in salt/models/task.py
482 483 484 485 486 487 488 489 490 |
|
salt.models.GaussianRegressionTask
#
Bases: salt.models.task.RegressionTaskBase
Gaussian regression task head (predicts mean and variance).
The head outputs 2 * len(targets)
values per example (means and
variances). The loss is the negative log-likelihood under a Gaussian.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs
|
typing.Any
|
Forwarded to :class: |
{}
|
Raises:
Type | Description |
---|---|
ValueError
|
If the number of outputs is not |
Source code in salt/models/task.py
631 632 633 634 635 636 637 638 |
|
Created: October 20, 2023