Skip to content

Transformer#

salt.models.transformer.Attention #

Bases: torch.nn.Module

Multihead attention module with optional differential attention and norms.

Parameters:

Name Type Description Default
embed_dim int

Input (and output) embedding dimension.

required
num_heads int

Number of attention heads. The default is 1.

1
attn_type str

Backend kernel to use. One of {"torch-math", "torch-flash", "torch-meff", "flash-varlen"}. The default is "torch-meff".

'torch-meff'
dropout float

Dropout rate applied in attention. The default is 0.0.

0.0
bias bool

Whether to include bias terms in projections. The default is True.

True
do_qk_norm bool

Whether to apply RMSNorm to Q and K per head. The default is False.

False
do_v_norm bool

Whether to apply RMSNorm to V per head. The default is False.

False
mup bool

Whether to use the muP parametrisation. The default is False. Impacts init and scale of dot product sqrt(head_dim) -> head_dim. Ref: https://arxiv.org/abs/2203.03466

False
Source code in salt/models/attention.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def __init__(
    self,
    embed_dim: int,
    num_heads: int = 1,
    attn_type: str = "torch-meff",
    dropout: float = 0.0,
    bias: bool = True,
    do_qk_norm: bool = False,
    do_v_norm: bool = False,
    mup: bool = False,
) -> None:
    super().__init__()
    assert embed_dim % num_heads == 0, "Dim not div by the number of heads!"
    assert attn_type in ATTN_TYPES, "Invalid attention type!"

    # Attributes
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = embed_dim // num_heads
    self.dropout = dropout
    self.bias = bias
    self.attn_type = attn_type
    self.do_qk_norm = do_qk_norm
    self.do_v_norm = do_v_norm
    self.mup = mup

    self.scale = 1 / self.head_dim if mup else 1 / math.sqrt(self.head_dim)

    if self.do_qk_norm:
        self.q_norm = layernorms.RMSNorm(self.head_dim)
        self.k_norm = layernorms.RMSNorm(self.head_dim)
    if self.do_v_norm:
        self.v_norm = layernorms.RMSNorm(self.head_dim)

    # Better parallelism for self-attention when using parameters directly
    self.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))
    self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim)) if bias else None
    self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
    self.reset_parameters()
    self.set_backend(attn_type)

forward #

Attention forward pass, dispatching to the appropriate backend.

Parameters:

Name Type Description Default
x torch.Tensor

Query input of shape [B, L_q, D].

required
kv torch.Tensor | None

Optional key/value input of shape [B, L_kv, D] for cross-attention. If None, self-attention is used.

None
mask torch.BoolTensor | None

Padding mask for x where padded positions are True.

None
kv_mask torch.BoolTensor | None

Padding mask for kv where padded positions are True.

None
attn_mask torch.BoolTensor | None

Attention mask of shape [B, L_q, L_kv] where allowed positions are True.

None
culens torch.Tensor | None

Cumulative lengths for varlen flash. Required for attn_type="flash-varlen".

None
maxlen int | None

Maximum sequence length. Required for attn_type="flash-varlen".

None

Returns:

Type Description
torch.Tensor

Output of shape [B, L_q, D].

Source code in salt/models/attention.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
def forward(
    self,
    x: Tensor,
    kv: Tensor | None = None,
    mask: BoolTensor | None = None,
    kv_mask: BoolTensor | None = None,
    attn_mask: BoolTensor | None = None,
    culens: Tensor | None = None,
    maxlen: int | None = None,
) -> Tensor:
    """Attention forward pass, dispatching to the appropriate backend.

    Parameters
    ----------
    x : Tensor
        Query input of shape ``[B, L_q, D]``.
    kv : Tensor | None, optional
        Optional key/value input of shape ``[B, L_kv, D]`` for cross-attention.
        If ``None``, self-attention is used.
    mask : BoolTensor | None, optional
        Padding mask for ``x`` where padded positions are ``True``.
    kv_mask : BoolTensor | None, optional
        Padding mask for ``kv`` where padded positions are ``True``.
    attn_mask : BoolTensor | None, optional
        Attention mask of shape ``[B, L_q, L_kv]`` where allowed positions are ``True``.
    culens : Tensor | None, optional
        Cumulative lengths for varlen flash. Required for ``attn_type="flash-varlen"``.
    maxlen : int | None, optional
        Maximum sequence length. Required for ``attn_type="flash-varlen"``.

    Returns
    -------
    Tensor
        Output of shape ``[B, L_q, D]``.
    """
    if self.attn_type == "flash-varlen":
        assert kv is None, "flash-varlen only supports self attention!"
        assert attn_mask is None, "flash-varlen does not support attention masks!"
        assert culens is not None, "flash-varlen requires culens!"
        assert maxlen is not None, "flash-varlen requires maxlen!"
        return self._flash_forward(x, culens, maxlen)

    return self._torch_forward(x, kv, mask, kv_mask, attn_mask)

salt.models.transformer.GLU #

Bases: torch.nn.Module

Dense update with a (gated) linear unit.

See https://arxiv.org/abs/2002.05202.

Parameters:

Name Type Description Default
embed_dim int

Input/output embedding dimension.

required
hidden_dim int | None

Hidden dimension. If None, defaults to 2 * embed_dim.

None
activation str

Name of the activation class in torch.nn (e.g., "SiLU").

'SiLU'
dropout float

Dropout probability. The default is 0.0.

0.0
bias bool

Whether to include bias terms. The default is True.

True
gated bool

If True, uses a gated branch (splits hidden in two). The default is False.

False
mup bool

Whether to use μP parameterization. The default is False.

False
Source code in salt/models/transformer.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def __init__(
    self,
    embed_dim: int,
    hidden_dim: int | None = None,
    activation: str = "SiLU",
    dropout: float = 0.0,
    bias: bool = True,
    gated: bool = False,
    mup: bool = False,
):
    super().__init__()
    self.mup = mup

    if hidden_dim is None:
        hidden_dim = embed_dim * 2

    self.gated = gated
    self.embed_dim = embed_dim
    self.in_proj = nn.Linear(embed_dim, hidden_dim + hidden_dim * gated, bias=bias)
    self.out_proj = nn.Linear(hidden_dim, embed_dim, bias=bias)
    self.drop = nn.Dropout(dropout)
    self.activation = getattr(nn, activation)()

    if self.mup:
        for proj in [self.in_proj, self.out_proj]:
            nn.init.normal_(proj.weight, mean=0.0, std=1.0 / (proj.weight.shape[0] ** 0.5))
            if bias:
                nn.init.zeros_(proj.bias)

salt.models.transformer.EncoderLayer #

Bases: torch.nn.Module

Transformer encoder layer: self-attention + feed-forward.

Parameters:

Name Type Description Default
embed_dim int

Embedding dimension.

required
norm str

Normalization style (class name from :mod:salt.models.layernorm). The default is "LayerNorm".

'LayerNorm'
ls_init float | None

Initial LayerScale value. If None, LayerScale is disabled.

None
drop_path float

Drop-path rate. The default is 0.0.

0.0
depth int

Layer depth index, used for differential attention weighting. The default is 1.

1
dense_kwargs dict | None

Keyword args for :class:GLU.

None
attn_kwargs dict | None

Keyword args for :class:Attention.

None
norm_type str

One of {"pre", "post", "hybrid"}. The default is "pre".

'pre'
edge_embed_dim int

Model embedding dimension for edge features. The default is 0.

0
update_edges bool

If True, edge features are updated after attention. The default is False

False
mup bool

Whether to use μP parameterization. The default is False.

False
Source code in salt/models/transformer.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
def __init__(
    self,
    embed_dim: int,
    norm: str = "LayerNorm",
    ls_init: float | None = None,
    drop_path: float = 0.0,
    depth: int = 1,
    dense_kwargs: dict | None = None,
    attn_kwargs: dict | None = None,
    norm_type: str = "pre",
    edge_embed_dim: int = 0,
    update_edges: bool = False,
    mup: bool = False,
) -> None:
    super().__init__()
    self.mup = mup
    self.update_edges = update_edges

    # Safe defaults
    if attn_kwargs is None:
        attn_kwargs = {}
    if dense_kwargs is None:
        dense_kwargs = {}

    # Attributes
    self.embed_dim = embed_dim
    self.norm_type = norm_type
    if norm_type == "hybrid":
        attn_kwargs["do_qk_norm"] = True
        attn_kwargs["do_v_norm"] = True
        residual_norm_type = "pre" if depth == 0 else "none"
        self.norm = (
            nn.Identity(embed_dim) if depth == 0 else getattr(layernorms, norm)(embed_dim)
        )
    else:
        residual_norm_type = norm_type

    if self.mup:
        attn_kwargs["mup"] = True
        dense_kwargs["mup"] = True

    # Choose attention type
    attn_class: type[Attention | EdgeAttention]
    if edge_embed_dim > 0:
        attn_class = EdgeAttention
        attn_kwargs["edge_embed_dim"] = edge_embed_dim
        attn_kwargs["update_edges"] = update_edges

        self.edge_prenorm = getattr(layernorms, norm)(edge_embed_dim)
        if self.update_edges:
            self.edge_postnorm = getattr(layernorms, norm)(edge_embed_dim)
    else:
        attn_class = Attention

    # Submodules
    residual = partial(
        NormResidual,
        norm=norm,
        ls_init=ls_init,
        drop_path=drop_path,
        norm_type=residual_norm_type,
    )
    self.attn = residual(attn_class(embed_dim, **attn_kwargs))
    self.dense = residual(GLU(embed_dim, **dense_kwargs))

salt.models.transformer.Transformer #

Bases: torch.nn.Module

Transformer encoder stack with optional registers and output projection.

Parameters:

Name Type Description Default
num_layers int

Number of encoder layers.

required
embed_dim int

Embedding dimension.

required
out_dim int | None

Optional output projection dimension. If None, equals embed_dim.

None
norm str

Normalization style (class name from :mod:salt.models.layernorm). The default is "LayerNorm".

'LayerNorm'
attn_type str

Attention backend, one of {"torch-math", "torch-flash", "torch-meff", "flash-varlen"}. The default is "torch-math".

'torch-math'
do_final_norm bool

Whether to apply a final normalization layer. The default is True.

True
num_registers int

Number of learned register tokens appended to the end of the sequence. The default is 1.

1
drop_registers bool

If True, registers are dropped from outputs. The default is False.

False
edge_embed_dim int

Model embedding dimension for edge features. The default is 0.

0
update_edges bool

If True, edge features are updated after attention. The default is False

False
mup bool

Whether to use μP parameterization. The default is False.

False
**kwargs typing.Any

Extra keyword arguments forwarded to :class:EncoderLayer (e.g., attn_kwargs, dense_kwargs, ls_init, etc.).

{}

Raises:

Type Description
ValueError

If num_registers < 1.

Source code in salt/models/transformer.py
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
def __init__(
    self,
    num_layers: int,
    embed_dim: int,
    out_dim: int | None = None,
    norm: str = "LayerNorm",
    attn_type: str = "torch-math",
    do_final_norm: bool = True,
    num_registers: int = 1,
    drop_registers: bool = False,
    edge_embed_dim: int = 0,
    update_edges: bool = False,
    mup: bool = False,
    **kwargs: Any,
) -> None:
    super().__init__()

    # Check the inputs
    if num_registers < 1:
        raise ValueError(
            "Some jets have no tracks, which causes NaNs in the attention scores. "
            "To avoid this, set num_registers to at least 1",
        )

    # Attributes
    self.num_layers = num_layers
    self.embed_dim = embed_dim
    self.out_dim = out_dim or embed_dim
    self.do_final_norm = do_final_norm
    self.do_out_proj = out_dim is not None
    self.attn_type = attn_type
    self.num_registers = num_registers
    self.drop_registers = drop_registers
    self.edge_embed_dim = edge_embed_dim
    self.update_edges = update_edges
    self.mup = mup

    if self.update_edges:
        assert edge_embed_dim > 0, "Cannot update edges with edge_embed_dim=0"

    if self.mup:
        assert _MuReadout is not None, "mup is not installed!"
        assert self.do_out_proj, (
            "Need the out_dim layer for muP, \
            as this is the last layer of the muP-part of the model"
        )

    # Set the attention type if no edge features are used
    if edge_embed_dim == 0:
        kwargs["attn_kwargs"]["attn_type"] = self.attn_type

    # Submodules
    self.layers = torch.nn.ModuleList([
        EncoderLayer(
            embed_dim=embed_dim,
            norm=norm,
            depth=depth,
            edge_embed_dim=edge_embed_dim,
            update_edges=update_edges,
            **kwargs,
        )
        for depth in range(num_layers)
    ])

    # Only set the attention type if no edge features are used
    if self.edge_embed_dim == 0:
        # Check and set the attention type
        assert self.attn_type in ATTN_TYPES, "Invalid attention type!"
        self.set_backend(self.attn_type)

    # Optional submodules
    if self.do_out_proj:
        self.out_proj = nn.Linear(self.embed_dim, self.out_dim)
        if self.mup and _MuReadout is not None:
            self.out_proj = _MuReadout(embed_dim, self.out_dim)
            self.out_proj.bias.data.zero_()
            self.out_proj.weight.data.zero_()
    if self.do_final_norm:
        self.out_norm = getattr(layernorms, norm)(self.out_dim)
    if self.num_registers:
        self.registers = nn.Parameter(
            torch.normal(torch.zeros((self.num_registers, self.embed_dim)), std=1e-4)
        )
        self.register_buffer("register_mask", torch.zeros(num_registers, dtype=torch.bool))
    self.featurewise = nn.ModuleList()

Last update: January 28, 2026
Created: January 25, 2024