Task Heads#
salt.models.TaskBase
#
Bases: torch.nn.Module
, abc.ABC
Task head base class.
Tasks wrap a dense network, a loss, a label, and a weight.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str
|
Arbitrary name of the task, used for logging and inference. |
required |
input_name |
str
|
Which type of object is input to the task e.g. jet/track/flow. |
required |
dense_config |
dict
|
Keyword arguments for |
required |
loss |
torch.nn.Module
|
Loss function applied to the dense network outputs. |
required |
weight |
float
|
Weight in the overall loss. |
1.0
|
Source code in salt/models/task.py
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
|
salt.models.ClassificationTask
#
Bases: salt.models.task.TaskBase
Classification task.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
label |
str
|
Label name for the task |
required |
class_names |
list[str] | None
|
List of class names, ordered by output index. If not specified attempt to automatically determine these from the label name. |
None
|
label_map |
collections.abc.Mapping | None
|
Remap integer labels for training (e.g. 0,4,5 -> 0,1,2). |
None
|
sample_weight |
str | None
|
Name of a per sample weighting to apply in the loss function. |
None
|
use_class_dict |
bool
|
If True, read class weights for the loss from the class_dict file. |
False
|
**kwargs |
Keyword arguments for |
{}
|
Source code in salt/models/task.py
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
|
salt.models.RegressionTaskBase
#
Bases: salt.models.task.TaskBase
, abc.ABC
Base class for regression tasks.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
targets |
list[str] | str
|
Regression target(s). |
required |
scaler |
salt.utils.scalers.RegressionTargetScaler
|
Functional scaler for regression targets - cannot be used with other target scaling options. |
None
|
target_denominators |
list[str] | str | None
|
Variables to divide regression target(s) by (i.e. for regressing a ratio). - cannot be used with other target scaling options. |
None
|
norm_params |
dict | None
|
Mean and std normalization parameters for each target, used for scaling. - cannot be used with other target scaling options.. |
None
|
**kwargs |
Keyword arguments for |
{}
|
Source code in salt/models/task.py
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
|
salt.models.RegressionTask
#
Bases: salt.models.task.RegressionTaskBase
Regression task.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
scaler |
dummy text |
None
|
|
**kwargs |
Keyword arguments for
|
{}
|
Source code in salt/models/task.py
294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 |
|
salt.models.GaussianRegressionTask
#
Bases: salt.models.task.RegressionTaskBase
Regression task that outputs a mean and variance for each target. The loss function is the negative log likelihood of a Gaussian distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
**kwargs |
Keyword arguments for
|
{}
|
Source code in salt/models/task.py
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 |
|
Created: October 20, 2023