Generic model for testing new building blocks. WIP...

Helpers

wrap_sublayer[source]

wrap_sublayer(sublayer:Module, method:str, d_model)

Wraps a sublayer with skip connection defined by method currently supported:

  • postnorm
  • prenorm
  • admin
  • rezero
  • ...

Bricks

Architecture specific layers, blocks and containers.

Encoder

class XEncoderBlock[source]

XEncoderBlock(d_model:int, n_heads:int=8, attn_module:Module=Attention, ff_module:Module=FeedForward, d_ff:int=None, attn_dropout:float=0.1, ff_dropout:float=0.1, causal:bool=False, attn_bias:bool=False, residual_type:str='postnorm', shared_qk:bool=False, **kwargs) :: Module

Experimental encoder block

bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
m = XEncoderBlock(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
torch.Size([4, 128, 64])

class XEncoder[source]

XEncoder(d_model, n_layers=6, n_heads=8, d_ff=None, attn_module:Module=Attention, ff_module:Module=FeedForward, ff_dropout=0.1, attn_dropout=0.1, attn_bias=False, causal=False, residual_type:str='postnorm', shared_qk:bool=False, final_norm=None, **kwargs) :: Module

Stack of XEncoderBlocks

x = torch.randn(bs, sl, d)
m = XEncoder(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
torch.Size([4, 128, 64])
m = XEncoder(d, residual_type='rezero')
out  = m(x)
assert (out.size() == (bs, sl, d))
assert (out == x).all()

Decoder

class XDecoderBlock[source]

XDecoderBlock(d_model, n_heads=8, d_ff=None, attn_dropout=0.1, ff_dropout=0.1, mask=None, attn_bias=False, residual_type='postnorm') :: Module

Standart transformer decoder block. Consist of self-attention, encoder-decoder attention and positiona feed-forward alyers

class XDecoderBlockV2[source]

XDecoderBlockV2(d_model, n_heads=8, mask=None, d_ff=None, attn_dropout=0.1, ff_dropout=0.1, attn_bias=False, residual_type='postnorm') :: Module

Transformer decoder block using additive attention layer instead of self-attention followed by cross-attention

class XDecoder[source]

XDecoder(d_model, n_layers=6, n_heads=8, d_ff=None, attn_dropout=0.1, ff_dropout=0.1, residual_type='postnorm', comb_attn=False, attn_bias=False, final_norm=None) :: Module

Stack of TransformerDecoder layers

x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl, d)
m = XDecoder(d)
out = m(x, context)
assert (out.size() == (bs, sl, d))
out.shape
torch.Size([4, 128, 64])

Models

Language model

class XTransformerLM[source]

XTransformerLM(vocab_sz:int, d_model:int, n_layers:int=6, n_heads:int=8, d_ff:int=None, attn_module:Module=Attention, ff_module:Module=FeedForward, attn_dropout:float=0.1, ff_dropout:float=0.1, emb_dropout:float=0.1, tie_weights:bool=True, causal:bool=True, pos_enc:str='absolute', max_seq_len:int=512, axial_shape:tuple=None, axial_emb_dims:tuple=None, pad_idx:int=None, residual_type:str='postnorm', attn_bias:bool=False, shared_qk:bool=False) :: Module

Basic Transformer for language modelling

Parameters:

* vocab_sz: int
* d_model: int - inner dimension of the model
* n_layers: int (default: 6)
* n_heads: int (default: 8)
* d_ff: int - inner dimension of the pointwise FeedForward net, if None defaults to 4*d_model
* attn_dropout: float - attention dropout
* ff_dropout: float - feed-forward dropout
* emb_dropout: float - embedding dropout
* causal: bool (default: True) - if True does causal masking automatically
* max_seq_len: int (default: 512)
* tie_weights: bool - if True target embedding weights are used for computation output projection
* residual_type: str - one of {'postnorm', 'prenorm', 'admin', 'rezero'}
* attn_bias: bool - wether to allow biases in attention projection layers
* pad_idx: int - padding token id, required for autogeneration of padding mask
* pos_enc: str from {'absolute', 'fixed', 'axial'} - type of positional encoding to use
* axial_shape: tuple - [optional] should be factors of max_seq_len
* axial_emb_dims: tuple - [optional] axial embedding components, should sum to d_model

Inputs:

* x - input ids, shape [bs, sl]
* mask - optional boolean mask, shape [bs, sl]

Returns:

* logits - target token logits, shape [bs, sl, vocab_sz]
bs = 4
sl = 128
d = 64
vocab_sz = 256
x = torch.randint(vocab_sz, (bs, sl))
model = XTransformerLM(vocab_sz, d, n_layers=2, causal=False)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
out.shape
torch.Size([4, 128, 256])
#add tests for various configs here

Encoder-Decoder model

class XTransformer[source]

XTransformer(enc_vocab_sz, dec_vocab_sz, d_model, n_enc_layers=6, n_dec_layers=6, n_heads=8, d_ff=None, pad_idx=None, tie_weights=True, shared_emb=False, attn_dropout=0.1, ff_dropout=0.1, emb_dropout=0.1, prenorm=False, attn_bias=False, comb_attn=False, pos_enc='absolute', max_seq_len=512, axial_shape=None, axial_emb_dims=None) :: Module

Basic Transformer Encoder-Decoder model Parameters:

* enc_vocab_sz: int - source vocab size
* dec_vocab_sz: int - target vocab size
* d_model: int - inner dimension of the model
* n_enc_layers: int (default: 6)
* n_dec_layers: int (default: 6)
* heads: int (default: 8)
* d_ff: int - inner dimension of the pointwise FeedForward net, if None defaults to 4*d_model
* attn_dropout: float - attention dropout
* ff_dropout: float - feed-forward dropout
* emb_dropout: float - embedding dropout
* max_seq_len: int (default: 512)
* prenorm: bool - whether to use PreNorm or PostNorm
* attn_bias: bool - whether to allow biases in attention projection layers
* pad_idx: int - padding token id, if pad_idx is provided, and no mask/context_mask are
        passed to forward method will be used to generate padding masks
* tie_weights: bool - if True target embedding weights are used for computation output projection
* shared_emb: bool - if True encoder and decoder will use shared embedding layer
* pos_enc: str from {'absolute', 'fixed', 'axial'} - type of positional encoding to use
* axial_shape: tuple - [optional] should be factors of max_seq_len
* axial_emb_dims: tuple - [optional] axial embedding components, should sum to d_model

Inputs:

* src - source input ids, shape [bs, src_sl]
* tgt - target input ids, shape [bs, tgt_sl]
* src_mask - optional boolean source mask, shape [bs, src_sl]
* tgt_mask - optional boolean target mask, shape [bs, tgt_sl]

Returns:

* logits - target token logits, shape [bs, tgt_sl, tgt_vocab_sz]

Config for experiments

class XConfig[source]

XConfig(vocab_sz=256, d_model=512, n_layers=12, n_heads=8, attn_module=Attention, ff_module=FeedForward, d_ff=None, attn_dropout=0.1, ff_dropout=0.1, emb_dropout=0.1, tie_weights=True, causal=True, pos_enc='absolute', max_seq_len=512, axial_shape=None, axial_emb_dims=None, pad_idx=None, attn_bias=False, shared_qk=False, residual_type='postnorm') :: ConfigBase

Config for enwik8 Experiment. See https://arampacha.github.io/reformer_fastai/experiment.enwik8-baseline.html for details

model = XTransformerLM.from_config(XConfig(n_layers=2, residual_type='rezero'))
model
XTransformerLM(
  (emb): TransformerEmbedding(
    (emb): Embedding(256, 512)
    (dropout): Dropout(p=0.1, inplace=False)
    (pos_enc): AbsolutePositionalEmbedding(
      (emb): Embedding(512, 512)
    )
  )
  (encoder): XEncoder(
    (layers): ModuleList(
      (0): XEncoderBlock(
        (attn): ReZero(
          (sublayer): Attention(
            (in_proj): AttnInProjV2(
              (to_q): Linear(in_features=512, out_features=512, bias=False)
              (to_kv): Linear(in_features=512, out_features=1024, bias=False)
            )
            (attn): ScaledDotProdAttention(
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (out_proj): Linear(in_features=512, out_features=512, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (scale): Scale()
        )
        (ff): ReZero(
          (sublayer): FeedForward(
            (net): Sequential(
              (fc1): Linear(in_features=512, out_features=2048, bias=True)
              (act): GELU()
              (drop1): Dropout(p=0.1, inplace=False)
              (fc2): Linear(in_features=2048, out_features=512, bias=True)
              (drop2): Dropout(p=0.1, inplace=False)
            )
          )
          (scale): Scale()
        )
      )
      (1): XEncoderBlock(
        (attn): ReZero(
          (sublayer): Attention(
            (in_proj): AttnInProjV2(
              (to_q): Linear(in_features=512, out_features=512, bias=False)
              (to_kv): Linear(in_features=512, out_features=1024, bias=False)
            )
            (attn): ScaledDotProdAttention(
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (out_proj): Linear(in_features=512, out_features=512, bias=False)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (scale): Scale()
        )
        (ff): ReZero(
          (sublayer): FeedForward(
            (net): Sequential(
              (fc1): Linear(in_features=512, out_features=2048, bias=True)
              (act): GELU()
              (drop1): Dropout(p=0.1, inplace=False)
              (fc2): Linear(in_features=2048, out_features=512, bias=True)
              (drop2): Dropout(p=0.1, inplace=False)
            )
          )
          (scale): Scale()
        )
      )
    )
  )
  (proj): Linear(in_features=512, out_features=256, bias=True)
)