Generic model for testing new building blocks. WIP...
bs = 4
sl = 128
d = 64
x = torch.randn(bs, sl, d)
m = XEncoderBlock(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
x = torch.randn(bs, sl, d)
m = XEncoder(d)
out = m(x)
assert (out.size() == (bs, sl, d))
out.shape
m = XEncoder(d, residual_type='rezero')
out = m(x)
assert (out.size() == (bs, sl, d))
assert (out == x).all()
x = torch.randn(bs, sl, d)
context = torch.randn(bs, sl, d)
m = XDecoder(d)
out = m(x, context)
assert (out.size() == (bs, sl, d))
out.shape
bs = 4
sl = 128
d = 64
vocab_sz = 256
x = torch.randint(vocab_sz, (bs, sl))
model = XTransformerLM(vocab_sz, d, n_layers=2, causal=False)
out = model(x)
assert (out.size() == (bs, sl, vocab_sz))
out.shape
#add tests for various configs here
Warning: Not implemented yet
model = XTransformerLM.from_config(XConfig(n_layers=2, residual_type='rezero'))
model