ALUM
Adversarial training for large neural language models as presented in https://arxiv.org/abs/2004.08994.
Algorithm
Input: $T$: total number of iterations, $\mathcal X$: the dataset, $f(x; \theta)$: model parameterized by $\theta$, $\sigma^2$: variance for random initialization of perturbation $\delta$, $\epsilon$: perturbation bound, $K$: number of iterations for updating $\delta$, $\eta$: lr for updating $\delta$, $\tau$: global learning rate, $\alpha$: adversarial loss weight, $\Pi$: projection operation.
01: for $t = 1,...,T$ do
02: $\quad$ for $(x,y) \in \mathcal X$ do
03: $\quad \quad$ $\delta \sim \mathcal{N} (0, \sigma^2 I)$
04: $\quad \quad$ for $m = 1,...,K$ do
05: $\quad \quad \quad$ $g_{adv} \leftarrow \Delta_\delta l(f(x;\theta), f(x+\delta; \theta)) $
06: $\quad \quad \quad$ $\delta \leftarrow \Pi_{\|\delta\|_\infty \le \epsilon}(\delta + \eta g_{adv})$
07: $\quad \quad$ end for
08: $\quad \quad$ $g_\theta \leftarrow \Delta_\theta l(f(x;\theta), y) + \alpha \Delta_\theta l(f(x;\theta), f(x+\delta;\theta))$
09: $\quad \quad$ $\theta \leftarrow \theta - \tau g_\theta$
10: $\quad$ end for
11: end for
Output: $\theta$
model = nn.Sequential(
nn.Linear(1,10, bias=False),
nn.Linear(10,1, bias=False)
)
learn = synth_learner(model=model, cbs=ALUMCallback(model[0]))
learn.fit(2, 1e-3)
Algorithm
Notation:
$ g_i(\tilde{x_i}, \bar{\theta_i}) = \frac{1}{|\mathcal{B}|}\sum_{x_i \in \mathcal{B}} {\{\nabla_x \ell_s (\mathcal{f}(x_i; \bar{\theta}_s), \mathcal{f}(\tilde{x_i}; \bar{\theta}_s))} $;
$AdamUpdate_{\mathcal B}$ - ADAM update for optimizing $\theta_{t+1} = argmin_\theta \mathcal F(\theta) + \mu \mathcal{D}_{Breg}(\theta, \tilde{\theta}_t)$;
$\Pi_{\mathcal A}$ - prjection to $\mathcal A$
Input: $T$: total number of iterations, $\mathcal X$: the dataset, $\theta_0$: pre-trained model parameters, $S$: total number of iterations for Bregman proximal point method, $\sigma^2$: variance for random initialization of perturbation, $T_{\bar{x}}$number of iterations for updating $\tilde{x_i}$, $\eta$: lr for updating $\tilde{x_i}$, $\beta$: momentum parameter.
01: $\tilde{\theta_1} \leftarrow \theta_0$
02: for $t = 1,...,T$ do
03: $\quad$ $\bar{\theta}_1 \leftarrow \theta_{t-1}$
04: $\quad$ for $s = 1,...,S$ do
05: $\quad \quad$ Sample $\mathcal{B}$ from $\mathcal X$
06: $\quad \quad$ $\tilde{x_i} \leftarrow x_i + \nu_i$ where $\nu_i ~ \mathcal{N} (0, \sigma^2)$
07: $\quad \quad$ for $m = 1,...,T_\bar{x}$ do
08: $\quad \quad \quad$ $\tilde{g_i} \leftarrow \frac{g_i(\tilde{x_i},\bar{\theta_s})}{\|g_i(\tilde{x_i},\bar{\theta_s})\|_\infty} $
09: $\quad \quad \quad$ $\tilde{x_i} \leftarrow \Pi_{\|\tilde{x_i}-x\|_\infty \le \epsilon}(\tilde{x_i} + \eta \tilde{g_i})$
10: $\quad \quad$ end for
11: $\quad \quad$ $\bar{\theta}_{s+1} \leftarrow AdamUpdate_\mathcal{B} (\bar{\theta}_s)$
12: $\quad$ end for
13: $\quad$ $\theta_t \leftarrow \bar{\theta}_{S}$
14: $\quad$ $\tilde{\theta}_{t+1} \leftarrow (1-\beta) \bar{\theta}_{S} + \beta \tilde{\theta}_t$
15: end for
Output: $\theta_T$
model = nn.Sequential(
nn.Linear(1,10, bias=False),
nn.Linear(10,1, bias=False)
)
learn = synth_learner(model=model, cbs=SMARTCallback(model[0]))
learn.fit(2)