Transformer#
salt.models.transformer.Attention
#
Bases: torch.nn.Module
Multihead attention module with optional differential attention and norms.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embed_dim
|
int
|
Input (and output) embedding dimension. |
required |
num_heads
|
int
|
Number of attention heads. The default is |
1
|
attn_type
|
str
|
Backend kernel to use. One of |
'torch-meff'
|
dropout
|
float
|
Dropout rate applied in attention. The default is |
0.0
|
bias
|
bool
|
Whether to include bias terms in projections. The default is |
True
|
do_qk_norm
|
bool
|
Whether to apply RMSNorm to Q and K per head. The default is |
False
|
do_v_norm
|
bool
|
Whether to apply RMSNorm to V per head. The default is |
False
|
mup
|
bool
|
Whether to use the muP parametrisation. The default is |
False
|
Source code in salt/models/attention.py
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 | |
forward
#
Attention forward pass, dispatching to the appropriate backend.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
torch.Tensor
|
Query input of shape |
required |
kv
|
torch.Tensor | None
|
Optional key/value input of shape |
None
|
mask
|
torch.BoolTensor | None
|
Padding mask for |
None
|
kv_mask
|
torch.BoolTensor | None
|
Padding mask for |
None
|
attn_mask
|
torch.BoolTensor | None
|
Attention mask of shape |
None
|
culens
|
torch.Tensor | None
|
Cumulative lengths for varlen flash. Required for |
None
|
maxlen
|
int | None
|
Maximum sequence length. Required for |
None
|
Returns:
| Type | Description |
|---|---|
torch.Tensor
|
Output of shape |
Source code in salt/models/attention.py
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 | |
salt.models.transformer.GLU
#
Bases: torch.nn.Module
Dense update with a (gated) linear unit.
See https://arxiv.org/abs/2002.05202.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embed_dim
|
int
|
Input/output embedding dimension. |
required |
hidden_dim
|
int | None
|
Hidden dimension. If |
None
|
activation
|
str
|
Name of the activation class in |
'SiLU'
|
dropout
|
float
|
Dropout probability. The default is |
0.0
|
bias
|
bool
|
Whether to include bias terms. The default is |
True
|
gated
|
bool
|
If |
False
|
mup
|
bool
|
Whether to use μP parameterization. The default is |
False
|
Source code in salt/models/transformer.py
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | |
salt.models.transformer.EncoderLayer
#
Bases: torch.nn.Module
Transformer encoder layer: self-attention + feed-forward.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embed_dim
|
int
|
Embedding dimension. |
required |
norm
|
str
|
Normalization style (class name from :mod: |
'LayerNorm'
|
ls_init
|
float | None
|
Initial LayerScale value. If |
None
|
drop_path
|
float
|
Drop-path rate. The default is |
0.0
|
depth
|
int
|
Layer depth index, used for differential attention weighting. The default is |
1
|
dense_kwargs
|
dict | None
|
Keyword args for :class: |
None
|
attn_kwargs
|
dict | None
|
Keyword args for :class: |
None
|
norm_type
|
str
|
One of |
'pre'
|
edge_embed_dim
|
int
|
Model embedding dimension for edge features. The default is |
0
|
update_edges
|
bool
|
If |
False
|
mup
|
bool
|
Whether to use μP parameterization. The default is |
False
|
Source code in salt/models/transformer.py
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 | |
salt.models.transformer.Transformer
#
Bases: torch.nn.Module
Transformer encoder stack with optional registers and output projection.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_layers
|
int
|
Number of encoder layers. |
required |
embed_dim
|
int
|
Embedding dimension. |
required |
out_dim
|
int | None
|
Optional output projection dimension. If |
None
|
norm
|
str
|
Normalization style (class name from :mod: |
'LayerNorm'
|
attn_type
|
str
|
Attention backend, one of |
'torch-math'
|
do_final_norm
|
bool
|
Whether to apply a final normalization layer. The default is |
True
|
num_registers
|
int
|
Number of learned register tokens appended to the end of the sequence. The default is |
1
|
drop_registers
|
bool
|
If |
False
|
edge_embed_dim
|
int
|
Model embedding dimension for edge features. The default is |
0
|
update_edges
|
bool
|
If |
False
|
mup
|
bool
|
Whether to use μP parameterization. The default is |
False
|
**kwargs
|
typing.Any
|
Extra keyword arguments forwarded to :class: |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in salt/models/transformer.py
540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 | |
Created: January 25, 2024