Skip to content

Transformer#

salt.models.transformer_v2.Attention #

Bases: torch.nn.Module

Multihead attention module.

Parameters:

Name Type Description Default
embed_dim int

Dimension of the input.

required
num_heads int

Number of attention heads.

1
attn_type str

Name of backend kernel to use.

'torch-meff'
dropout float

Dropout rate.

0.0
bias bool

Whether to include bias terms.

True
diff_attention bool

Use differential attention or not

False
depth int

Number of current attention layer

1
do_qk_norm bool

Whether to apply norm to q and k

False
do_v_norm bool

Whether to apply norm to v

False
Source code in salt/models/transformer_v2.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def __init__(
    self,
    embed_dim: int,
    num_heads: int = 1,
    attn_type: str = "torch-meff",
    dropout: float = 0.0,
    bias: bool = True,
    diff_attention: bool = False,
    depth: int = 1,
    do_qk_norm: bool = False,
    do_v_norm: bool = False,
) -> None:
    """Multihead attention module.

    Parameters
    ----------
    embed_dim : int
        Dimension of the input.
    num_heads : int
        Number of attention heads.
    attn_type : str, optional
        Name of backend kernel to use.
    dropout : float, optional
        Dropout rate.
    bias : bool, optional
        Whether to include bias terms.
    diff_attention : bool, optional
        Use differential attention or not
    depth : int, optional
        Number of current attention layer
    do_qk_norm : bool, optional
        Whether to apply norm to q and k
    do_v_norm : bool, optional
        Whether to apply norm to v
    """
    super().__init__()
    assert embed_dim % num_heads == 0, "Dim not div by the number of heads!"
    assert attn_type in ATTN_TYPES, "Invalid attention type!"

    # Attributes
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = embed_dim // num_heads
    self.dropout = dropout
    self.bias = bias
    self.attn_type = attn_type
    self.diff_attention = diff_attention
    self.depth = depth
    self.do_qk_norm = do_qk_norm
    self.do_v_norm = do_v_norm

    if self.diff_attention:
        self.head_dim = self.head_dim // 2
        self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth)
        self.lambda_q1 = nn.Parameter(
            torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)
        )
        self.lambda_k1 = nn.Parameter(
            torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)
        )
        self.lambda_q2 = nn.Parameter(
            torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)
        )
        self.lambda_k2 = nn.Parameter(
            torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)
        )

        self.subln = layernorms.RMSNorm(2 * self.head_dim)

    if self.do_qk_norm:
        self.q_norm = layernorms.RMSNorm(self.head_dim)
        self.k_norm = layernorms.RMSNorm(self.head_dim)
    if self.do_v_norm:
        self.v_norm = layernorms.RMSNorm(self.head_dim)

    # Better parallelism for self-attention when using parameters directly
    self.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))
    self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim)) if bias else None
    self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
    self.reset_parameters()
    self.set_backend(attn_type)

forward #

Attention forward pass, dispatches to the appropriate backend.

Parameters:

Name Type Description Default
x torch.Tensor

The pointcloud of shape (batch, x_len, dim).

required
kv torch.Tensor

Optional second pointcloud for cross-attn with shape (batch, kv_len, dim).

None
mask torch.BoolTensor

Mask for the pointcloud x, by default None.

None
kv_mask torch.BoolTensor

Mask the kv pointcloud, by default None.

None
attn_mask torch.BoolTensor

Full attention mask, by default None.

None
culens torch.Tensor

Cumulative lengths of the sequences in x, by default None. Only used for the flash-varlen backend.

None
maxlen int

Maximum length of a sequence in the x, by default None. Only used for the flash-varlen backend.

None

Returns:

Type Description
torch.Tensor

Output of shape (batch, x_len, dim).

Source code in salt/models/transformer_v2.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
def forward(
    self,
    x: Tensor,
    kv: Tensor | None = None,
    mask: BoolTensor | None = None,
    kv_mask: BoolTensor | None = None,
    attn_mask: BoolTensor | None = None,
    culens: Tensor | None = None,
    maxlen: int | None = None,
) -> Tensor:
    """Attention forward pass, dispatches to the appropriate backend.

    Parameters
    ----------
    x : Tensor
        The pointcloud of shape (batch, x_len, dim).
    kv : Tensor
        Optional second pointcloud for cross-attn with shape (batch, kv_len, dim).
    mask : BoolTensor, optional
        Mask for the pointcloud x, by default None.
    kv_mask : BoolTensor, optional
        Mask the kv pointcloud, by default None.
    attn_mask : BoolTensor, optional
        Full attention mask, by default None.
    culens : Tensor, optional
        Cumulative lengths of the sequences in x, by default None.
        Only used for the flash-varlen backend.
    maxlen : int, optional
        Maximum length of a sequence in the x, by default None.
        Only used for the flash-varlen backend.

    Returns
    -------
    Tensor
        Output of shape (batch, x_len, dim).
    """
    if self.attn_type == "flash-varlen":
        assert kv is None, "flash-varlen only supports self attention!"
        assert attn_mask is None, "flash-varlen does not support attention masks!"
        assert culens is not None, "flash-varlen requires culens!"
        assert maxlen is not None, "flash-varlen requires maxlen!"
        if self.diff_attention:
            return self._flash_diff_forward(x, culens, maxlen)
        return self._flash_forward(x, culens, maxlen)

    # Otherwise perform standard attention
    if self.diff_attention:
        return self._torch_diff_forward(x, kv, mask, kv_mask, attn_mask)
    return self._torch_forward(x, kv, mask, kv_mask, attn_mask)

salt.models.transformer_v2.GLU #

Bases: torch.nn.Module

Dense update with gated linear unit.

See 2002.05202.

Parameters:

Name Type Description Default
embed_dim int

Dimension of the input and output.

required
hidden_dim int | None

Dimension of the hidden layer. If None, defaults to embed_dim * 2.

None
activation str

Activation function.

'SiLU'
dropout float

Dropout rate.

0.0
bias bool

Whether to include bias in the linear layers.

True
gated bool

Whether to gate the output of the hidden layer.

False
Source code in salt/models/transformer_v2.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
def __init__(
    self,
    embed_dim: int,
    hidden_dim: int | None = None,
    activation: str = "SiLU",
    dropout: float = 0.0,
    bias: bool = True,
    gated: bool = False,
):
    """Dense update with gated linear unit.

    See [2002.05202](https://arxiv.org/abs/2002.05202).

    Parameters
    ----------
    embed_dim : int
        Dimension of the input and output.
    hidden_dim : int | None, optional
        Dimension of the hidden layer. If None, defaults to embed_dim * 2.
    activation : str, optional
        Activation function.
    dropout : float, optional
        Dropout rate.
    bias : bool, optional
        Whether to include bias in the linear layers.
    gated : bool, optional
        Whether to gate the output of the hidden layer.
    """
    super().__init__()

    if hidden_dim is None:
        hidden_dim = embed_dim * 2

    self.gated = gated
    self.embed_dim = embed_dim
    self.in_proj = nn.Linear(embed_dim, hidden_dim + hidden_dim * gated, bias=bias)
    self.out_proj = nn.Linear(hidden_dim, embed_dim, bias=bias)
    self.drop = nn.Dropout(dropout)
    self.activation = getattr(nn, activation)()

salt.models.transformer_v2.EncoderLayer #

Bases: torch.nn.Module

Encoder layer consisting of a self-attention and a feed-forward layer.

Parameters:

Name Type Description Default
embed_dim int

Dimension of the embeddings at each layer.

required
norm str

Normalization style, by default "LayerNorm".

'LayerNorm'
drop_path float

Drop path rate, by default 0.0.

0.0
ls_init float | None

Initial value for the layerscale, by default 1e-3.

None
depth int

The depth of the layer, by default 1.

1
dense_kwargs dict | None

Keyword arguments for salt.models.transformer_v2.GLU.

None
attn_kwargs dict | None

Keyword arguments for salt.models.transformer_v2.Attention.

None
norm_type str

Normalization type, can be ['pre', 'post', 'hybrid'], by default 'pre'.

'pre'
Source code in salt/models/transformer_v2.py
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
def __init__(
    self,
    embed_dim: int,
    norm: str = "LayerNorm",
    ls_init: float | None = None,
    drop_path: float = 0.0,
    depth: int = 1,
    dense_kwargs: dict | None = None,
    attn_kwargs: dict | None = None,
    norm_type: str = "pre",
) -> None:
    """Encoder layer consisting of a self-attention and a feed-forward layer.

    Parameters
    ----------
    embed_dim : int
        Dimension of the embeddings at each layer.
    norm : str, optional
        Normalization style, by default "LayerNorm".
    drop_path : float, optional
        Drop path rate, by default 0.0.
    ls_init : float | None, optional
        Initial value for the layerscale, by default 1e-3.
    depth : int, optional
        The depth of the layer, by default 1.
    dense_kwargs : dict | None, optional
        Keyword arguments for [salt.models.transformer_v2.GLU][salt.models.transformer_v2.GLU].
    attn_kwargs : dict | None, optional
        Keyword arguments for
        [salt.models.transformer_v2.Attention][salt.models.transformer_v2.Attention].
    norm_type : str, optional
        Normalization type, can be ['pre', 'post', 'hybrid'], by default 'pre'.
    """
    super().__init__()

    # Safe defaults
    if attn_kwargs is None:
        attn_kwargs = {}
    if dense_kwargs is None:
        dense_kwargs = {}

    # Attributes
    self.embed_dim = embed_dim
    self.norm_type = norm_type
    if norm_type == "hybrid":
        attn_kwargs["do_qk_norm"] = True
        attn_kwargs["do_v_norm"] = True
        residual_norm_type = "pre" if depth == 0 else "none"
        self.norm = (
            nn.Identity(embed_dim) if depth == 0 else getattr(layernorms, norm)(embed_dim)
        )
    else:
        residual_norm_type = norm_type

    # Submodules
    residual = partial(
        NormResidual,
        norm=norm,
        ls_init=ls_init,
        drop_path=drop_path,
        norm_type=residual_norm_type,
    )
    self.attn = residual(Attention(embed_dim, depth=depth, **attn_kwargs))
    self.dense = residual(GLU(embed_dim, **dense_kwargs))

salt.models.transformer_v2.TransformerV2 #

Bases: torch.nn.Module

Transformer model consisting of a stack of Transformer encoder layers.

Parameters:

Name Type Description Default
num_layers int

Number of layers.

required
embed_dim int

Dimension of the embeddings at each layer.

required
out_dim int | None

Optionally project the output to a different dimension.

None
norm str

Normalization style, by default "LayerNorm".

'LayerNorm'
attn_type str

The backend for the attention mechanism, by default "torch-flash". Provided here because the varlen backend requires pre/post processing.

'torch-math'
do_final_norm bool

Whether to apply a final normalization layer, by default True.

True
num_registers int

The number of registers to add to the END of the input sequence. Registers are randomly initialised tokens of the same dimension as any other inputs after initialiser networks. See 2309.16588.

1
drop_registers bool

If to drop the registers from the outputs

False
kwargs dict

Keyword arguments for [salt.models.transformer_v2.EncoderLayer].

{}
Source code in salt/models/transformer_v2.py
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
def __init__(
    self,
    num_layers: int,
    embed_dim: int,
    out_dim: int | None = None,
    norm: str = "LayerNorm",
    attn_type: str = "torch-math",
    do_final_norm: bool = True,
    num_registers: int = 1,
    drop_registers: bool = False,
    **kwargs,
) -> None:
    """Transformer model consisting of a stack of Transformer encoder layers.

    Parameters
    ----------
    num_layers : int
        Number of layers.
    embed_dim : int
        Dimension of the embeddings at each layer.
    out_dim : int | None, optional
        Optionally project the output to a different dimension.
    norm : str, optional
        Normalization style, by default "LayerNorm".
    attn_type : str, optional
        The backend for the attention mechanism, by default "torch-flash".
        Provided here because the varlen backend requires pre/post processing.
    do_final_norm : bool, optional
        Whether to apply a final normalization layer, by default True.
    num_registers : int, optional
        The number of registers to add to the END of the input sequence.
        Registers are randomly initialised tokens of the same dimension as
        any other inputs after initialiser networks. See 2309.16588.
    drop_registers : bool, optional
        If to drop the registers from the outputs
    kwargs : dict
        Keyword arguments for [salt.models.transformer_v2.EncoderLayer].
    """
    super().__init__()

    # Check the inputs
    if num_registers < 1:
        raise ValueError(
            "Some jets have no tracks, which causes NaNs in the attention scores. ",
            "To avoid this, set num_registers to at least 1",
        )

    # Attributes
    self.num_layers = num_layers
    self.embed_dim = embed_dim
    self.out_dim = out_dim or embed_dim
    self.do_final_norm = do_final_norm
    self.do_out_proj = out_dim is not None
    self.attn_type = attn_type
    self.num_registers = num_registers
    self.drop_registers = drop_registers

    # Submodules
    kwargs["attn_kwargs"]["attn_type"] = self.attn_type
    self.layers = torch.nn.ModuleList([
        EncoderLayer(embed_dim=embed_dim, norm=norm, depth=depth, **kwargs)
        for depth in range(num_layers)
    ])

    # Check and set the attention type
    assert self.attn_type in ATTN_TYPES, "Invalid attention type!"
    self.set_backend(self.attn_type)

    # Optional submodules
    if self.do_out_proj:
        self.out_proj = nn.Linear(self.embed_dim, out_dim)
    if self.do_final_norm:
        self.out_norm = getattr(layernorms, norm)(self.out_dim)
    if self.num_registers:
        self.registers = nn.Parameter(
            torch.normal(torch.zeros((self.num_registers, self.embed_dim)), std=1e-4)
        )
        self.register_buffer("register_mask", torch.zeros(num_registers, dtype=torch.bool))
    self.featurewise = nn.ModuleList()

Last update: May 9, 2024
Created: January 25, 2024