Transformer#
salt.models.transformer_v2.Attention
#
Bases: torch.nn.Module
Multihead attention module with optional differential attention and norms.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embed_dim
|
int
|
Input (and output) embedding dimension. |
required |
num_heads
|
int
|
Number of attention heads. The default is |
1
|
attn_type
|
str
|
Backend kernel to use. One of |
'torch-meff'
|
dropout
|
float
|
Dropout rate applied in attention. The default is |
0.0
|
bias
|
bool
|
Whether to include bias terms in projections. The default is |
True
|
diff_attention
|
bool
|
Enable differential attention (splits heads in two branches). The default is |
False
|
depth
|
int
|
Layer depth index (used to set differential attention weights). The default is |
1
|
do_qk_norm
|
bool
|
Whether to apply RMSNorm to Q and K per head. The default is |
False
|
do_v_norm
|
bool
|
Whether to apply RMSNorm to V per head. The default is |
False
|
Source code in salt/models/transformer_v2.py
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 |
|
forward
#
Attention forward pass, dispatching to the appropriate backend.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
torch.Tensor
|
Query input of shape |
required |
kv
|
torch.Tensor | None
|
Optional key/value input of shape |
None
|
mask
|
torch.BoolTensor | None
|
Padding mask for |
None
|
kv_mask
|
torch.BoolTensor | None
|
Padding mask for |
None
|
attn_mask
|
torch.BoolTensor | None
|
Attention mask of shape |
None
|
culens
|
torch.Tensor | None
|
Cumulative lengths for varlen flash. Required for |
None
|
maxlen
|
int | None
|
Maximum sequence length. Required for |
None
|
Returns:
Type | Description |
---|---|
torch.Tensor
|
Output of shape |
Source code in salt/models/transformer_v2.py
531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 |
|
salt.models.transformer_v2.GLU
#
Bases: torch.nn.Module
Dense update with a (gated) linear unit.
See https://arxiv.org/abs/2002.05202.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embed_dim
|
int
|
Input/output embedding dimension. |
required |
hidden_dim
|
int | None
|
Hidden dimension. If |
None
|
activation
|
str
|
Name of the activation class in |
'SiLU'
|
dropout
|
float
|
Dropout probability. The default is |
0.0
|
bias
|
bool
|
Whether to include bias terms. The default is |
True
|
gated
|
bool
|
If |
False
|
Source code in salt/models/transformer_v2.py
601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 |
|
salt.models.transformer_v2.EncoderLayer
#
Bases: torch.nn.Module
Transformer encoder layer: self-attention + feed-forward.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
embed_dim
|
int
|
Embedding dimension. |
required |
norm
|
str
|
Normalization style (class name from :mod: |
'LayerNorm'
|
ls_init
|
float | None
|
Initial LayerScale value. If |
None
|
drop_path
|
float
|
Drop-path rate. The default is |
0.0
|
depth
|
int
|
Layer depth index, used for differential attention weighting. The default is |
1
|
dense_kwargs
|
dict | None
|
Keyword args for :class: |
None
|
attn_kwargs
|
dict | None
|
Keyword args for :class: |
None
|
norm_type
|
str
|
One of |
'pre'
|
Source code in salt/models/transformer_v2.py
792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 |
|
salt.models.transformer_v2.TransformerV2
#
Bases: torch.nn.Module
Transformer encoder stack with optional registers and output projection.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_layers
|
int
|
Number of encoder layers. |
required |
embed_dim
|
int
|
Embedding dimension. |
required |
out_dim
|
int | None
|
Optional output projection dimension. If |
None
|
norm
|
str
|
Normalization style (class name from :mod: |
'LayerNorm'
|
attn_type
|
str
|
Attention backend, one of |
'torch-math'
|
do_final_norm
|
bool
|
Whether to apply a final normalization layer. The default is |
True
|
num_registers
|
int
|
Number of learned register tokens appended to the end of the sequence. The default is |
1
|
drop_registers
|
bool
|
If |
False
|
**kwargs
|
typing.Any
|
Extra keyword arguments forwarded to :class: |
{}
|
Raises:
Type | Description |
---|---|
ValueError
|
If |
Source code in salt/models/transformer_v2.py
975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 |
|
Created: January 25, 2024