Place where things develope before departing to relevant modules
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
m = TransformerEncoderBlockNLN(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
x = torch.randn(bs, sl, d)
m = TransformerEncoderNLN(d, n_layers=2)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
bs = 4
sl = 128
d = 64
vocab_sz = 256
x = torch.randint(vocab_sz, (bs, sl))
model = TransformerLMNLN(vocab_sz, d, n_layers=2, causal=True)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
class FeedForwardFixup(Module):
"""
FeedForward with shifts and scale for FixUp
"""
def __init__(self, d_model:int, d_ff:int=None, dropout:float=0.):
d_ff = default(d_ff, 4 * d_model)
layers = OrderedDict(
[('shift1',Shift()),
('fc1',nn.Linear(d_model, d_ff)),
('shift2',Shift()),
('act',nn.GELU()),
('dropout1',nn.Dropout(dropout)),
('shift3',Shift()),
('fc2',nn.Linear(d_ff, d_model)),
('dropout2',nn.Dropout(dropout)),
('scale',Scale())])
self.net = nn.Sequential(layers)
self._init()
def forward(self, x):
return self.net(x)
def _init(self):
[nn.init.xavier_uniform_(p) for p in self.parameters() if p.dim() > 1]
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
m = TransformerEncoderBlockNLN2(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
x = torch.randn(bs, sl, d)
m = TransformerEncoderNLN2(d, n_layers=2)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
bs = 4
sl = 128
d = 64
vocab_sz = 256
x = torch.randint(vocab_sz, (bs, sl))
model = TransformerLMNLN2(vocab_sz, d, n_layers=2, causal=True)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
m = TransformerEncoderBlockAdmin(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
#...
# config = CharLMConfig(d_model=512, n_layers=6, max_seq_len=512,
# pad_idx=pad_id)
# learn = Learner(dls, TransformerLMAdmin.from_config(config),
# loss_func=CrossEntropyLossFlat(ignore_index=pad_id),
# cbs = [GradientClip(1.0),
# SaveModelCallback(with_opt=True)],
# metrics=[accuracy, perplexity, bpc]).to_fp16()
# learn.add_cb(ActivationStats(modules=res_submodules(learn.model)))
# len(learn.activation_stats.modules)
# learn.fit(1, 1e-3)