Create chitchat bot by fine-tuning DialoGPT on The Simpsons scripts.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset, concatenate_datasets, DatasetDict
from fastai.text.all import *
from fasthugs.learner import TransLearner
from fasthugs.data import *
model_name = "arampacha/DialoGPT-medium-simpsons"
# data
bs = 4
val_bs = bs*4
eff_bs = 128
# training
lr = 3e-5
The data is obtained from this kaggle dataset.
filename = "simpsons_script_lines.csv"
df = pd.read_csv(filename, index_col='id')
df = df[df.spoken_words.notna()]
df.sort_index(inplace=True)
df.shape
ids = []
for id, x in zip(df.index, df.word_count):
try:
int(x)
except:
ids.append(id)
Remove the lines where word_count
is not convertible to integer.
df.drop(index=ids, inplace=True)
df['word_count'] = df.word_count.astype(int)
df.head(5)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
from tqdm.auto import tqdm
max_context = 100
min_context = 5
res = []
e = -1
loc = -1
for _, row in tqdm(df.iterrows(), total=df.shape[0]):
prev_e, e = e, row['episode_id']
prev_loc, loc = loc, row['location_id']
if (prev_e != e) or (prev_loc != loc):
context = []
total_context_length = 0
context_lens = []
line = row['spoken_words'] + tokenizer.eos_token
if row.word_count > max_context//2:
continue
if total_context_length >= min_context:
res.append({'responce':line, 'context':''.join(l for l in context), 'context_length':total_context_length, 'episode':row['episode_id']})
context.append(line)
context_lens.append(row.word_count)
total_context_length += row.word_count
to_remove = 0
while total_context_length > max_context:
total_context_length -= context_lens[to_remove]
to_remove += 1
context = context[to_remove:]
dialog_df = pd.DataFrame(res)
dialog_df.shape
dialog_df['line'] = dialog_df.context + dialog_df.responce
def tokenize(batch):
return tokenizer(batch['line'], return_attention_mask=True, verbose=False, return_length=True)
from datasets import DatasetDict
dialog_df[dialog_df.episode < 550].to_csv('simpsons_dialog_train.csv')
dialog_df[dialog_df.episode >= 550].to_csv('simpsons_dialog_valid.csv')
ds = DatasetDict.from_csv({'train':'simpsons_dialog_train.csv', 'validation':'simpsons_dialog_valid.csv'})
Tokenize the lines in dataset:
ds = ds.map(tokenize, batched=True, batch_size=100, remove_columns=ds['train'].column_names, num_proc=2)
And remove excesively long samples:
ds = ds.filter(lambda x: x['length'] < 300)
len(ds['train']), len(ds['validation'])
train_idx, valid_idx = get_splits(ds)
train_ds = concatenate_datasets([ds['train'], ds['validation']])
The batching of data samples I want to use is somewhat different from batching used for regular causal language modeling. The samples are padded with eos_token
and the replics in dialog are separated with the same token. I want to ignore padding when computing loss. The targets corresponding to padding will be set to -100. This can be easily done using attention_mask
.
from dataclasses import dataclass
from transformers import PreTrainedTokenizerBase, BatchEncoding
from typing import List, Dict, Union
@dataclass
class DataCollatorForDialog:
"""
Data collator used for dialog modeling. Inputs are dynamically padded to the maximum length of a batch if they
are not all of the same length. The labels are constructed according to attention mask setting `label=-100`
where `attention_mask == 0`.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
"""
tokenizer: PreTrainedTokenizerBase
pad_to_multiple_of: Optional[int] = None
def __call__(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
labels = batch["input_ids"].clone()
labels = torch.where(batch["attention_mask"].bool(), batch["input_ids"].clone(), torch.tensor(-100))
batch["labels"] = labels
return batch
To speed up training samples are grouped by length.
dblock = DataBlock(blocks=[TransformersLMBlock(tokenizer=tokenizer,
masking_func=DataCollatorForDialog(tokenizer),
group_by_len=True,
skip_special_tokens=True)],
splitter=IndexSplitter(valid_idx))
train_lens = ds['train']['length']
valid_lens = ds['validation']['length']
dl_kwargs = [{'res':train_lens},{'val_res':valid_lens}]
dls = dblock.dataloaders(train_ds, bs=bs, val_bs=val_bs, num_workers=2, dl_kwargs=dl_kwargs)
dls.show_batch()
model = AutoModelForCausalLM.from_pretrained(model_name)
learn = TransLearner(dls, model, loss_func=noop, metrics=perplexity)
learn.validate()
cbs = [GradientAccumulation(eff_bs)] if eff_bs != bs else []
learn.fit_one_cycle(1, 1e-5, cbs=cbs)
model = learn.model
model.cpu()
for step in range(5):
new_user_input_ids = tokenizer.encode(input(">> User: ") + tokenizer.eos_token, return_tensors='pt')
bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
chat_history_ids = model.generate(bot_input_ids,
max_length=1000,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
top_p=0.9)
print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))