Skip to content

Transformer#

salt.models.transformer_v2.Attention #

Bases: torch.nn.Module

Multihead attention module.

Parameters:

Name Type Description Default
embed_dim int

Dimension of the input.

required
num_heads int

Number of attention heads.

1
attn_type str

Name of backend kernel to use.

'torch-meff'
dropout float

Dropout rate.

0.0
bias bool

Whether to include bias terms.

True
Source code in salt/models/transformer_v2.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def __init__(
    self,
    embed_dim: int,
    num_heads: int = 1,
    attn_type: str = "torch-meff",
    dropout: float = 0.0,
    bias: bool = True,
) -> None:
    """Multihead attention module.

    Parameters
    ----------
    embed_dim : int
        Dimension of the input.
    num_heads : int
        Number of attention heads.
    attn_type : str, optional
        Name of backend kernel to use.
    dropout : float, optional
        Dropout rate.
    bias : bool, optional
        Whether to include bias terms.
    """
    super().__init__()
    assert embed_dim % num_heads == 0, "Dim not div by the number of heads!"
    assert attn_type in {
        "torch-flash",
        "torch-math",
        "torch-meff",
        "flash-varlen",
    }, "Invalid attention type!"

    # Attributes
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = embed_dim // num_heads
    self.dropout = dropout
    self.bias = bias

    # Better parallelism for self-attention when using parameters directly
    self.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))
    self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim)) if bias else None
    self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
    self.reset_parameters()
    self.set_backend(attn_type)

forward #

Attention forward pass.

Parameters:

Name Type Description Default
x torch.Tensor

The pointcloud of shape (batch, x_len, dim).

required
kv torch.Tensor

Optional second pointcloud for cross-attn with shape (batch, kv_len, dim).

None
mask torch.BoolTensor

Mask for the pointcloud x, by default None.

None
kv_mask torch.BoolTensor

Mask the kv pointcloud, by default None.

None
attn_mask torch.BoolTensor

Full attention mask, by default None.

None
culens torch.Tensor

Cumulative lengths of the sequences in x, by default None. Only used for the flash-varlen backend.

None
maxlen int

Maximum length of a sequence in the x, by default None. Only used for the flash-varlen backend.

None

Returns:

Type Description
torch.Tensor

Output of shape (batch, x_len, dim).

Source code in salt/models/transformer_v2.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def forward(
    self,
    x: Tensor,
    kv: Tensor | None = None,
    mask: BoolTensor | None = None,
    kv_mask: BoolTensor | None = None,
    attn_mask: BoolTensor | None = None,
    culens: Tensor | None = None,
    maxlen: int | None = None,
) -> Tensor:
    """Attention forward pass.

    Parameters
    ----------
    x : Tensor
        The pointcloud of shape (batch, x_len, dim).
    kv : Tensor
        Optional second pointcloud for cross-attn with shape (batch, kv_len, dim).
    mask : BoolTensor, optional
        Mask for the pointcloud x, by default None.
    kv_mask : BoolTensor, optional
        Mask the kv pointcloud, by default None.
    attn_mask : BoolTensor, optional
        Full attention mask, by default None.
    culens : Tensor, optional
        Cumulative lengths of the sequences in x, by default None.
        Only used for the flash-varlen backend.
    maxlen : int, optional
        Maximum length of a sequence in the x, by default None.
        Only used for the flash-varlen backend.

    Returns
    -------
    Tensor
        Output of shape (batch, x_len, dim).
    """
    # the varlen attention backend is called at the begining (different args)
    if self.attn_type == "flash-varlen":
        assert kv is None, "flash-varlen only supports self attention!"
        assert attn_mask is None, "flash-varlen does not support attention masks!"
        assert culens is not None, "flash-varlen requires culens!"
        assert maxlen is not None, "flash-varlen requires maxlen!"
        return self._varlen_attention(x, culens, maxlen)

    # Otherwise perform standard attention
    B, S, D = x.shape

    # input projections -> B, S, D
    q, k, v = projection_packed(x, kv, self.in_proj_weight, self.in_proj_bias)

    # transform tensors to (B, Nh, S, Hd)
    shape = (B, -1, self.num_heads, self.head_dim)  # Dont use S for cross attn
    q, k, v = (t.view(shape).transpose(1, 2).contiguous() for t in (q, k, v))

    # run attention
    s_mask = mask if kv is None else kv_mask  # Who is sending, x or kv
    mask = merge_masks(s_mask, attn_mask, q.shape)
    dropout = self.dropout if self.training else 0.0
    a_out = torch_attn(q, k, v, mask, dropout, self.attn_type)

    # recombine heads
    a_out = a_out.transpose(1, 2).contiguous().view(B, S, D)

    # mix with final linear layer
    return self.out_proj(a_out)

salt.models.transformer_v2.GLU #

Bases: torch.nn.Module

Dense update with gated linear unit.

See 2002.05202.

Parameters:

Name Type Description Default
embed_dim int

Dimension of the input and output.

required
hidden_dim int | None

Dimension of the hidden layer. If None, defaults to embed_dim * 2.

None
activation str

Activation function.

'SiLU'
dropout float

Dropout rate.

0.0
bias bool

Whether to include bias in the linear layers.

True
gated bool

Whether to gate the output of the hidden layer.

False
Source code in salt/models/transformer_v2.py
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
def __init__(
    self,
    embed_dim: int,
    hidden_dim: int | None = None,
    activation: str = "SiLU",
    dropout: float = 0.0,
    bias: bool = True,
    gated: bool = False,
):
    """Dense update with gated linear unit.

    See [2002.05202](https://arxiv.org/abs/2002.05202).

    Parameters
    ----------
    embed_dim : int
        Dimension of the input and output.
    hidden_dim : int | None, optional
        Dimension of the hidden layer. If None, defaults to embed_dim * 2.
    activation : str, optional
        Activation function.
    dropout : float, optional
        Dropout rate.
    bias : bool, optional
        Whether to include bias in the linear layers.
    gated : bool, optional
        Whether to gate the output of the hidden layer.
    """
    super().__init__()

    if hidden_dim is None:
        hidden_dim = embed_dim * 2

    self.gated = gated
    self.embed_dim = embed_dim
    self.in_proj = nn.Linear(embed_dim, hidden_dim + hidden_dim * gated, bias=bias)
    self.out_proj = nn.Linear(hidden_dim, embed_dim, bias=bias)
    self.drop = nn.Dropout(dropout)
    self.activation = getattr(nn, activation)()

salt.models.transformer_v2.EncoderLayer #

Bases: torch.nn.Module

Encoder layer consisting of a self-attention and a feed-forward layer.

Parameters:

Name Type Description Default
embed_dim int

Dimension of the embeddings at each layer.

required
norm str

Normalization style, by default "LayerNorm".

'LayerNorm'
drop_path float

Drop path rate, by default 0.0.

0.0
ls_init float | None

Initial value for the layerscale, by default 1e-3.

None
dense_kwargs dict | None

Keyword arguments for salt.models.transformer_v2.GLU.

None
attn_kwargs dict | None

Keyword arguments for salt.models.transformer_v2.Attention.

None
Source code in salt/models/transformer_v2.py
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
def __init__(
    self,
    embed_dim: int,
    norm: str = "LayerNorm",
    ls_init: float | None = None,
    drop_path: float = 0.0,
    dense_kwargs: dict | None = None,
    attn_kwargs: dict | None = None,
) -> None:
    """Encoder layer consisting of a self-attention and a feed-forward layer.

    Parameters
    ----------
    embed_dim : int
        Dimension of the embeddings at each layer.
    norm : str, optional
        Normalization style, by default "LayerNorm".
    drop_path : float, optional
        Drop path rate, by default 0.0.
    ls_init : float | None, optional
        Initial value for the layerscale, by default 1e-3.
    dense_kwargs : dict | None, optional
        Keyword arguments for [salt.models.transformer_v2.GLU][salt.models.transformer_v2.GLU].
    attn_kwargs : dict | None, optional
        Keyword arguments for
        [salt.models.transformer_v2.Attention][salt.models.transformer_v2.Attention].
    """
    super().__init__()

    # Safe defaults
    if attn_kwargs is None:
        attn_kwargs = {}
    if dense_kwargs is None:
        dense_kwargs = {}

    # Attributes
    self.embed_dim = embed_dim

    # Submodules
    residual = partial(PreNormResidual, norm=norm, ls_init=ls_init, drop_path=drop_path)
    self.attn = residual(Attention(embed_dim, **attn_kwargs))
    self.dense = residual(GLU(embed_dim, **dense_kwargs))

salt.models.transformer_v2.TransformerV2 #

Bases: torch.nn.Module

Transformer model consisting of a stack of Transformer encoder layers.

Parameters:

Name Type Description Default
num_layers int

Number of layers.

required
embed_dim int

Dimension of the embeddings at each layer.

required
out_dim int | None

Optionally project the output to a different dimension.

None
norm str

Normalization style, by default "LayerNorm".

'LayerNorm'
attn_type str

The backend for the attention mechanism, by default "torch-flash". Provided here because the varlen backend requires pre/post processing.

'torch-math'
do_final_norm bool

Whether to apply a final normalization layer, by default True.

True
num_registers int

The number of registers to add to the END of the input sequence. Registers are randomly initialised tokens of the same dimension as any other inputs after initialiser networks. See 2309.16588.

1
drop_registers bool

If to drop the registers from the outputs

False
kwargs dict

Keyword arguments for [salt.models.transformer_v2.EncoderLayer].

{}
Source code in salt/models/transformer_v2.py
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
def __init__(
    self,
    num_layers: int,
    embed_dim: int,
    out_dim: int | None = None,
    norm: str = "LayerNorm",
    attn_type: str = "torch-math",
    do_final_norm: bool = True,
    num_registers: int = 1,
    drop_registers: bool = False,
    **kwargs,
) -> None:
    """Transformer model consisting of a stack of Transformer encoder layers.

    Parameters
    ----------
    num_layers : int
        Number of layers.
    embed_dim : int
        Dimension of the embeddings at each layer.
    out_dim : int | None, optional
        Optionally project the output to a different dimension.
    norm : str, optional
        Normalization style, by default "LayerNorm".
    attn_type : str, optional
        The backend for the attention mechanism, by default "torch-flash".
        Provided here because the varlen backend requires pre/post processing.
    do_final_norm : bool, optional
        Whether to apply a final normalization layer, by default True.
    num_registers : int, optional
        The number of registers to add to the END of the input sequence.
        Registers are randomly initialised tokens of the same dimension as
        any other inputs after initialiser networks. See 2309.16588.
    drop_registers : bool, optional
        If to drop the registers from the outputs
    kwargs : dict
        Keyword arguments for [salt.models.transformer_v2.EncoderLayer].
    """
    super().__init__()

    # Check the inputs
    if num_registers < 1:
        raise ValueError(
            "Some jets have no tracks, which causes NaNs in the attention scores. ",
            "To avoid this, set num_registers to at least 1",
        )

    # Attributes
    self.num_layers = num_layers
    self.embed_dim = embed_dim
    self.out_dim = out_dim or embed_dim
    self.do_final_norm = do_final_norm
    self.do_out_proj = out_dim is not None
    self.attn_type = attn_type
    self.num_registers = num_registers
    self.drop_registers = drop_registers

    # Submodules
    self.layers = torch.nn.ModuleList([
        EncoderLayer(embed_dim=embed_dim, norm=norm, **kwargs) for _ in range(num_layers)
    ])
    self.attn_type = self.set_backend(attn_type)

    # Optional submodules
    if self.do_out_proj:
        self.out_proj = nn.Linear(self.embed_dim, out_dim)
    if self.do_final_norm:
        self.out_norm = getattr(layernorms, norm)(self.out_dim)
    if self.num_registers:
        self.registers = nn.Parameter(
            torch.normal(torch.zeros((self.num_registers, self.embed_dim)), std=1e-4)
        )
        self.register_buffer("register_mask", torch.zeros(num_registers, dtype=torch.bool))
    self.featurewise = nn.ModuleList()

Last update: May 9, 2024
Created: January 25, 2024