Skip to content

Transformer#

salt.models.transformer_v2.Attention #

Bases: torch.nn.Module

Multihead attention module with optional differential attention and norms.

Parameters:

Name Type Description Default
embed_dim int

Input (and output) embedding dimension.

required
num_heads int

Number of attention heads. The default is 1.

1
attn_type str

Backend kernel to use. One of {"torch-math", "torch-flash", "torch-meff", "flash-varlen"}. The default is "torch-meff".

'torch-meff'
dropout float

Dropout rate applied in attention. The default is 0.0.

0.0
bias bool

Whether to include bias terms in projections. The default is True.

True
diff_attention bool

Enable differential attention (splits heads in two branches). The default is False.

False
depth int

Layer depth index (used to set differential attention weights). The default is 1.

1
do_qk_norm bool

Whether to apply RMSNorm to Q and K per head. The default is False.

False
do_v_norm bool

Whether to apply RMSNorm to V per head. The default is False.

False
Source code in salt/models/transformer_v2.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
def __init__(
    self,
    embed_dim: int,
    num_heads: int = 1,
    attn_type: str = "torch-meff",
    dropout: float = 0.0,
    bias: bool = True,
    diff_attention: bool = False,
    depth: int = 1,
    do_qk_norm: bool = False,
    do_v_norm: bool = False,
) -> None:
    super().__init__()
    assert embed_dim % num_heads == 0, "Dim not div by the number of heads!"
    assert attn_type in ATTN_TYPES, "Invalid attention type!"

    # Attributes
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = embed_dim // num_heads
    self.dropout = dropout
    self.bias = bias
    self.attn_type = attn_type
    self.diff_attention = diff_attention
    self.depth = depth
    self.do_qk_norm = do_qk_norm
    self.do_v_norm = do_v_norm

    if self.diff_attention:
        self.head_dim = self.head_dim // 2
        self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth)
        self.lambda_q1 = nn.Parameter(
            torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)
        )
        self.lambda_k1 = nn.Parameter(
            torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)
        )
        self.lambda_q2 = nn.Parameter(
            torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)
        )
        self.lambda_k2 = nn.Parameter(
            torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)
        )

        self.subln = layernorms.RMSNorm(2 * self.head_dim)

    if self.do_qk_norm:
        self.q_norm = layernorms.RMSNorm(self.head_dim)
        self.k_norm = layernorms.RMSNorm(self.head_dim)
    if self.do_v_norm:
        self.v_norm = layernorms.RMSNorm(self.head_dim)

    # Better parallelism for self-attention when using parameters directly
    self.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))
    self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim)) if bias else None
    self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
    self.reset_parameters()
    self.set_backend(attn_type)

forward #

Attention forward pass, dispatching to the appropriate backend.

Parameters:

Name Type Description Default
x torch.Tensor

Query input of shape [B, L_q, D].

required
kv torch.Tensor | None

Optional key/value input of shape [B, L_kv, D] for cross-attention. If None, self-attention is used.

None
mask torch.BoolTensor | None

Padding mask for x where padded positions are True.

None
kv_mask torch.BoolTensor | None

Padding mask for kv where padded positions are True.

None
attn_mask torch.BoolTensor | None

Attention mask of shape [B, L_q, L_kv] where allowed positions are True.

None
culens torch.Tensor | None

Cumulative lengths for varlen flash. Required for attn_type="flash-varlen".

None
maxlen int | None

Maximum sequence length. Required for attn_type="flash-varlen".

None

Returns:

Type Description
torch.Tensor

Output of shape [B, L_q, D].

Source code in salt/models/transformer_v2.py
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
def forward(
    self,
    x: Tensor,
    kv: Tensor | None = None,
    mask: BoolTensor | None = None,
    kv_mask: BoolTensor | None = None,
    attn_mask: BoolTensor | None = None,
    culens: Tensor | None = None,
    maxlen: int | None = None,
) -> Tensor:
    """Attention forward pass, dispatching to the appropriate backend.

    Parameters
    ----------
    x : Tensor
        Query input of shape ``[B, L_q, D]``.
    kv : Tensor | None, optional
        Optional key/value input of shape ``[B, L_kv, D]`` for cross-attention.
        If ``None``, self-attention is used.
    mask : BoolTensor | None, optional
        Padding mask for ``x`` where padded positions are ``True``.
    kv_mask : BoolTensor | None, optional
        Padding mask for ``kv`` where padded positions are ``True``.
    attn_mask : BoolTensor | None, optional
        Attention mask of shape ``[B, L_q, L_kv]`` where allowed positions are ``True``.
    culens : Tensor | None, optional
        Cumulative lengths for varlen flash. Required for ``attn_type="flash-varlen"``.
    maxlen : int | None, optional
        Maximum sequence length. Required for ``attn_type="flash-varlen"``.

    Returns
    -------
    Tensor
        Output of shape ``[B, L_q, D]``.
    """
    if self.attn_type == "flash-varlen":
        assert kv is None, "flash-varlen only supports self attention!"
        assert attn_mask is None, "flash-varlen does not support attention masks!"
        assert culens is not None, "flash-varlen requires culens!"
        assert maxlen is not None, "flash-varlen requires maxlen!"
        if self.diff_attention:
            return self._flash_diff_forward(x, culens, maxlen)
        return self._flash_forward(x, culens, maxlen)

    if self.diff_attention:
        return self._torch_diff_forward(x, kv, mask, kv_mask, attn_mask)
    return self._torch_forward(x, kv, mask, kv_mask, attn_mask)

salt.models.transformer_v2.GLU #

Bases: torch.nn.Module

Dense update with a (gated) linear unit.

See https://arxiv.org/abs/2002.05202.

Parameters:

Name Type Description Default
embed_dim int

Input/output embedding dimension.

required
hidden_dim int | None

Hidden dimension. If None, defaults to 2 * embed_dim.

None
activation str

Name of the activation class in torch.nn (e.g., "SiLU").

'SiLU'
dropout float

Dropout probability. The default is 0.0.

0.0
bias bool

Whether to include bias terms. The default is True.

True
gated bool

If True, uses a gated branch (splits hidden in two). The default is False.

False
Source code in salt/models/transformer_v2.py
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
def __init__(
    self,
    embed_dim: int,
    hidden_dim: int | None = None,
    activation: str = "SiLU",
    dropout: float = 0.0,
    bias: bool = True,
    gated: bool = False,
):
    super().__init__()

    if hidden_dim is None:
        hidden_dim = embed_dim * 2

    self.gated = gated
    self.embed_dim = embed_dim
    self.in_proj = nn.Linear(embed_dim, hidden_dim + hidden_dim * gated, bias=bias)
    self.out_proj = nn.Linear(hidden_dim, embed_dim, bias=bias)
    self.drop = nn.Dropout(dropout)
    self.activation = getattr(nn, activation)()

salt.models.transformer_v2.EncoderLayer #

Bases: torch.nn.Module

Transformer encoder layer: self-attention + feed-forward.

Parameters:

Name Type Description Default
embed_dim int

Embedding dimension.

required
norm str

Normalization style (class name from :mod:salt.models.layernorm). The default is "LayerNorm".

'LayerNorm'
ls_init float | None

Initial LayerScale value. If None, LayerScale is disabled.

None
drop_path float

Drop-path rate. The default is 0.0.

0.0
depth int

Layer depth index, used for differential attention weighting. The default is 1.

1
dense_kwargs dict | None

Keyword args for :class:GLU.

None
attn_kwargs dict | None

Keyword args for :class:Attention.

None
norm_type str

One of {"pre", "post", "hybrid"}. The default is "pre".

'pre'
Source code in salt/models/transformer_v2.py
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
def __init__(
    self,
    embed_dim: int,
    norm: str = "LayerNorm",
    ls_init: float | None = None,
    drop_path: float = 0.0,
    depth: int = 1,
    dense_kwargs: dict | None = None,
    attn_kwargs: dict | None = None,
    norm_type: str = "pre",
) -> None:
    super().__init__()

    # Safe defaults
    if attn_kwargs is None:
        attn_kwargs = {}
    if dense_kwargs is None:
        dense_kwargs = {}

    # Attributes
    self.embed_dim = embed_dim
    self.norm_type = norm_type
    if norm_type == "hybrid":
        attn_kwargs["do_qk_norm"] = True
        attn_kwargs["do_v_norm"] = True
        residual_norm_type = "pre" if depth == 0 else "none"
        self.norm = (
            nn.Identity(embed_dim) if depth == 0 else getattr(layernorms, norm)(embed_dim)
        )
    else:
        residual_norm_type = norm_type

    # Submodules
    residual = partial(
        NormResidual,
        norm=norm,
        ls_init=ls_init,
        drop_path=drop_path,
        norm_type=residual_norm_type,
    )
    self.attn = residual(Attention(embed_dim, depth=depth, **attn_kwargs))
    self.dense = residual(GLU(embed_dim, **dense_kwargs))

salt.models.transformer_v2.TransformerV2 #

Bases: torch.nn.Module

Transformer encoder stack with optional registers and output projection.

Parameters:

Name Type Description Default
num_layers int

Number of encoder layers.

required
embed_dim int

Embedding dimension.

required
out_dim int | None

Optional output projection dimension. If None, equals embed_dim.

None
norm str

Normalization style (class name from :mod:salt.models.layernorm). The default is "LayerNorm".

'LayerNorm'
attn_type str

Attention backend, one of {"torch-math", "torch-flash", "torch-meff", "flash-varlen"}. The default is "torch-math".

'torch-math'
do_final_norm bool

Whether to apply a final normalization layer. The default is True.

True
num_registers int

Number of learned register tokens appended to the end of the sequence. The default is 1.

1
drop_registers bool

If True, registers are dropped from outputs. The default is False.

False
**kwargs typing.Any

Extra keyword arguments forwarded to :class:EncoderLayer (e.g., attn_kwargs, dense_kwargs, ls_init, etc.).

{}

Raises:

Type Description
ValueError

If num_registers < 1.

Source code in salt/models/transformer_v2.py
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
def __init__(
    self,
    num_layers: int,
    embed_dim: int,
    out_dim: int | None = None,
    norm: str = "LayerNorm",
    attn_type: str = "torch-math",
    do_final_norm: bool = True,
    num_registers: int = 1,
    drop_registers: bool = False,
    **kwargs: Any,
) -> None:
    super().__init__()

    # Check the inputs
    if num_registers < 1:
        raise ValueError(
            "Some jets have no tracks, which causes NaNs in the attention scores. "
            "To avoid this, set num_registers to at least 1",
        )

    # Attributes
    self.num_layers = num_layers
    self.embed_dim = embed_dim
    self.out_dim = out_dim or embed_dim
    self.do_final_norm = do_final_norm
    self.do_out_proj = out_dim is not None
    self.attn_type = attn_type
    self.num_registers = num_registers
    self.drop_registers = drop_registers

    # Submodules
    kwargs["attn_kwargs"]["attn_type"] = self.attn_type
    self.layers = torch.nn.ModuleList([
        EncoderLayer(embed_dim=embed_dim, norm=norm, depth=depth, **kwargs)
        for depth in range(num_layers)
    ])

    # Check and set the attention type
    assert self.attn_type in ATTN_TYPES, "Invalid attention type!"
    self.set_backend(self.attn_type)

    # Optional submodules
    if self.do_out_proj:
        self.out_proj = nn.Linear(self.embed_dim, out_dim)
    if self.do_final_norm:
        self.out_norm = getattr(layernorms, norm)(self.out_dim)
    if self.num_registers:
        self.registers = nn.Parameter(
            torch.normal(torch.zeros((self.num_registers, self.embed_dim)), std=1e-4)
        )
        self.register_buffer("register_mask", torch.zeros(num_registers, dtype=torch.bool))
    self.featurewise = nn.ModuleList()

Last update: May 9, 2024
Created: January 25, 2024