Skip to content

Transformer#

salt.models.transformer_v2.Attention #

Bases: torch.nn.Module

Multihead attention module.

Parameters:

Name Type Description Default
embed_dim int

Dimension of the input.

required
num_heads int

Number of attention heads.

1
attn_type str

Name of backend kernel to use.

'torch-meff'
dropout float

Dropout rate.

0.0
bias bool

Whether to include bias terms.

True
diff_attention bool

Use differential attention or not

False
depth int

Number of current attention layer

1
Source code in salt/models/transformer_v2.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def __init__(
    self,
    embed_dim: int,
    num_heads: int = 1,
    attn_type: str = "torch-meff",
    dropout: float = 0.0,
    bias: bool = True,
    diff_attention: bool = False,
    depth: int = 1,
) -> None:
    """Multihead attention module.

    Parameters
    ----------
    embed_dim : int
        Dimension of the input.
    num_heads : int
        Number of attention heads.
    attn_type : str, optional
        Name of backend kernel to use.
    dropout : float, optional
        Dropout rate.
    bias : bool, optional
        Whether to include bias terms.
    diff_attention : bool, optional
        Use differential attention or not
    depth : int, optional
        Number of current attention layer
    """
    super().__init__()
    assert embed_dim % num_heads == 0, "Dim not div by the number of heads!"
    assert attn_type in ATTN_TYPES, "Invalid attention type!"

    # Attributes
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = embed_dim // num_heads
    self.dropout = dropout
    self.bias = bias
    self.attn_type = attn_type
    self.diff_attention = diff_attention
    self.depth = depth

    if self.diff_attention:
        self.head_dim = self.head_dim // 2
        self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth)
        self.lambda_q1 = nn.Parameter(
            torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)
        )
        self.lambda_k1 = nn.Parameter(
            torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)
        )
        self.lambda_q2 = nn.Parameter(
            torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)
        )
        self.lambda_k2 = nn.Parameter(
            torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, std=0.1)
        )

        self.subln = layernorms.RMSNorm(2 * self.head_dim)

    # Better parallelism for self-attention when using parameters directly
    self.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))
    self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim)) if bias else None
    self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
    self.reset_parameters()
    self.set_backend(attn_type)

forward #

Attention forward pass, dispatches to the appropriate backend.

Parameters:

Name Type Description Default
x torch.Tensor

The pointcloud of shape (batch, x_len, dim).

required
kv torch.Tensor

Optional second pointcloud for cross-attn with shape (batch, kv_len, dim).

None
mask torch.BoolTensor

Mask for the pointcloud x, by default None.

None
kv_mask torch.BoolTensor

Mask the kv pointcloud, by default None.

None
attn_mask torch.BoolTensor

Full attention mask, by default None.

None
culens torch.Tensor

Cumulative lengths of the sequences in x, by default None. Only used for the flash-varlen backend.

None
maxlen int

Maximum length of a sequence in the x, by default None. Only used for the flash-varlen backend.

None

Returns:

Type Description
torch.Tensor

Output of shape (batch, x_len, dim).

Source code in salt/models/transformer_v2.py
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
def forward(
    self,
    x: Tensor,
    kv: Tensor | None = None,
    mask: BoolTensor | None = None,
    kv_mask: BoolTensor | None = None,
    attn_mask: BoolTensor | None = None,
    culens: Tensor | None = None,
    maxlen: int | None = None,
) -> Tensor:
    """Attention forward pass, dispatches to the appropriate backend.

    Parameters
    ----------
    x : Tensor
        The pointcloud of shape (batch, x_len, dim).
    kv : Tensor
        Optional second pointcloud for cross-attn with shape (batch, kv_len, dim).
    mask : BoolTensor, optional
        Mask for the pointcloud x, by default None.
    kv_mask : BoolTensor, optional
        Mask the kv pointcloud, by default None.
    attn_mask : BoolTensor, optional
        Full attention mask, by default None.
    culens : Tensor, optional
        Cumulative lengths of the sequences in x, by default None.
        Only used for the flash-varlen backend.
    maxlen : int, optional
        Maximum length of a sequence in the x, by default None.
        Only used for the flash-varlen backend.

    Returns
    -------
    Tensor
        Output of shape (batch, x_len, dim).
    """
    if self.attn_type == "flash-varlen":
        assert kv is None, "flash-varlen only supports self attention!"
        assert attn_mask is None, "flash-varlen does not support attention masks!"
        assert culens is not None, "flash-varlen requires culens!"
        assert maxlen is not None, "flash-varlen requires maxlen!"
        if self.diff_attention:
            return self._flash_diff_forward(x, culens, maxlen)
        return self._flash_forward(x, culens, maxlen)

    # Otherwise perform standard attention
    if self.diff_attention:
        return self._torch_diff_forward(x, kv, mask, kv_mask, attn_mask)
    return self._torch_forward(x, kv, mask, kv_mask, attn_mask)

salt.models.transformer_v2.GLU #

Bases: torch.nn.Module

Dense update with gated linear unit.

See 2002.05202.

Parameters:

Name Type Description Default
embed_dim int

Dimension of the input and output.

required
hidden_dim int | None

Dimension of the hidden layer. If None, defaults to embed_dim * 2.

None
activation str

Activation function.

'SiLU'
dropout float

Dropout rate.

0.0
bias bool

Whether to include bias in the linear layers.

True
gated bool

Whether to gate the output of the hidden layer.

False
Source code in salt/models/transformer_v2.py
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
def __init__(
    self,
    embed_dim: int,
    hidden_dim: int | None = None,
    activation: str = "SiLU",
    dropout: float = 0.0,
    bias: bool = True,
    gated: bool = False,
):
    """Dense update with gated linear unit.

    See [2002.05202](https://arxiv.org/abs/2002.05202).

    Parameters
    ----------
    embed_dim : int
        Dimension of the input and output.
    hidden_dim : int | None, optional
        Dimension of the hidden layer. If None, defaults to embed_dim * 2.
    activation : str, optional
        Activation function.
    dropout : float, optional
        Dropout rate.
    bias : bool, optional
        Whether to include bias in the linear layers.
    gated : bool, optional
        Whether to gate the output of the hidden layer.
    """
    super().__init__()

    if hidden_dim is None:
        hidden_dim = embed_dim * 2

    self.gated = gated
    self.embed_dim = embed_dim
    self.in_proj = nn.Linear(embed_dim, hidden_dim + hidden_dim * gated, bias=bias)
    self.out_proj = nn.Linear(hidden_dim, embed_dim, bias=bias)
    self.drop = nn.Dropout(dropout)
    self.activation = getattr(nn, activation)()

salt.models.transformer_v2.EncoderLayer #

Bases: torch.nn.Module

Encoder layer consisting of a self-attention and a feed-forward layer.

Parameters:

Name Type Description Default
embed_dim int

Dimension of the embeddings at each layer.

required
norm str

Normalization style, by default "LayerNorm".

'LayerNorm'
drop_path float

Drop path rate, by default 0.0.

0.0
ls_init float | None

Initial value for the layerscale, by default 1e-3.

None
depth int

The depth of the layer, by default 1.

1
dense_kwargs dict | None

Keyword arguments for salt.models.transformer_v2.GLU.

None
attn_kwargs dict | None

Keyword arguments for salt.models.transformer_v2.Attention.

None
Source code in salt/models/transformer_v2.py
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
def __init__(
    self,
    embed_dim: int,
    norm: str = "LayerNorm",
    ls_init: float | None = None,
    drop_path: float = 0.0,
    depth: int = 1,
    dense_kwargs: dict | None = None,
    attn_kwargs: dict | None = None,
) -> None:
    """Encoder layer consisting of a self-attention and a feed-forward layer.

    Parameters
    ----------
    embed_dim : int
        Dimension of the embeddings at each layer.
    norm : str, optional
        Normalization style, by default "LayerNorm".
    drop_path : float, optional
        Drop path rate, by default 0.0.
    ls_init : float | None, optional
        Initial value for the layerscale, by default 1e-3.
    depth : int, optional
        The depth of the layer, by default 1.
    dense_kwargs : dict | None, optional
        Keyword arguments for [salt.models.transformer_v2.GLU][salt.models.transformer_v2.GLU].
    attn_kwargs : dict | None, optional
        Keyword arguments for
        [salt.models.transformer_v2.Attention][salt.models.transformer_v2.Attention].
    """
    super().__init__()

    # Safe defaults
    if attn_kwargs is None:
        attn_kwargs = {}
    if dense_kwargs is None:
        dense_kwargs = {}

    # Attributes
    self.embed_dim = embed_dim

    # Submodules
    residual = partial(PreNormResidual, norm=norm, ls_init=ls_init, drop_path=drop_path)
    self.attn = residual(Attention(embed_dim, depth=depth, **attn_kwargs))
    self.dense = residual(GLU(embed_dim, **dense_kwargs))

salt.models.transformer_v2.TransformerV2 #

Bases: torch.nn.Module

Transformer model consisting of a stack of Transformer encoder layers.

Parameters:

Name Type Description Default
num_layers int

Number of layers.

required
embed_dim int

Dimension of the embeddings at each layer.

required
out_dim int | None

Optionally project the output to a different dimension.

None
norm str

Normalization style, by default "LayerNorm".

'LayerNorm'
attn_type str

The backend for the attention mechanism, by default "torch-flash". Provided here because the varlen backend requires pre/post processing.

'torch-math'
do_final_norm bool

Whether to apply a final normalization layer, by default True.

True
num_registers int

The number of registers to add to the END of the input sequence. Registers are randomly initialised tokens of the same dimension as any other inputs after initialiser networks. See 2309.16588.

1
drop_registers bool

If to drop the registers from the outputs

False
kwargs dict

Keyword arguments for [salt.models.transformer_v2.EncoderLayer].

{}
Source code in salt/models/transformer_v2.py
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
def __init__(
    self,
    num_layers: int,
    embed_dim: int,
    out_dim: int | None = None,
    norm: str = "LayerNorm",
    attn_type: str = "torch-math",
    do_final_norm: bool = True,
    num_registers: int = 1,
    drop_registers: bool = False,
    **kwargs,
) -> None:
    """Transformer model consisting of a stack of Transformer encoder layers.

    Parameters
    ----------
    num_layers : int
        Number of layers.
    embed_dim : int
        Dimension of the embeddings at each layer.
    out_dim : int | None, optional
        Optionally project the output to a different dimension.
    norm : str, optional
        Normalization style, by default "LayerNorm".
    attn_type : str, optional
        The backend for the attention mechanism, by default "torch-flash".
        Provided here because the varlen backend requires pre/post processing.
    do_final_norm : bool, optional
        Whether to apply a final normalization layer, by default True.
    num_registers : int, optional
        The number of registers to add to the END of the input sequence.
        Registers are randomly initialised tokens of the same dimension as
        any other inputs after initialiser networks. See 2309.16588.
    drop_registers : bool, optional
        If to drop the registers from the outputs
    kwargs : dict
        Keyword arguments for [salt.models.transformer_v2.EncoderLayer].
    """
    super().__init__()

    # Check the inputs
    if num_registers < 1:
        raise ValueError(
            "Some jets have no tracks, which causes NaNs in the attention scores. ",
            "To avoid this, set num_registers to at least 1",
        )

    # Attributes
    self.num_layers = num_layers
    self.embed_dim = embed_dim
    self.out_dim = out_dim or embed_dim
    self.do_final_norm = do_final_norm
    self.do_out_proj = out_dim is not None
    self.attn_type = attn_type
    self.num_registers = num_registers
    self.drop_registers = drop_registers

    # Submodules
    kwargs["attn_kwargs"]["attn_type"] = self.attn_type
    self.layers = torch.nn.ModuleList([
        EncoderLayer(embed_dim=embed_dim, norm=norm, depth=depth, **kwargs)
        for depth in range(num_layers)
    ])

    # Check and set the attention type
    assert self.attn_type in ATTN_TYPES, "Invalid attention type!"
    self.set_backend(self.attn_type)

    # Optional submodules
    if self.do_out_proj:
        self.out_proj = nn.Linear(self.embed_dim, out_dim)
    if self.do_final_norm:
        self.out_norm = getattr(layernorms, norm)(self.out_dim)
    if self.num_registers:
        self.registers = nn.Parameter(
            torch.normal(torch.zeros((self.num_registers, self.embed_dim)), std=1e-4)
        )
        self.register_buffer("register_mask", torch.zeros(num_registers, dtype=torch.bool))
    self.featurewise = nn.ModuleList()

Last update: May 9, 2024
Created: January 25, 2024