Create chitchat bot by fine-tuning DialoGPT on The Simpsons scripts.
 
the-simpsons.png
(c) 20th Century Fox Television

Setup

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset, concatenate_datasets, DatasetDict

from fastai.text.all import *
from fasthugs.learner import TransLearner
from fasthugs.data import *
model_name = "arampacha/DialoGPT-medium-simpsons"
# data
bs = 4
val_bs = bs*4
eff_bs = 128
# training
lr = 3e-5

Data preprocessing

The data is obtained from this kaggle dataset.

filename = "simpsons_script_lines.csv"
df = pd.read_csv(filename, index_col='id')
df = df[df.spoken_words.notna()]
df.sort_index(inplace=True)
df.shape
/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py:2718: DtypeWarning: Columns (4,5,6) have mixed types.Specify dtype option on import or set low_memory=False.
  interactivity=interactivity, compiler=compiler, result=result)
(132112, 12)
ids = []
for id, x in zip(df.index, df.word_count):
    try:
        int(x)
    except:
        ids.append(id)

Remove the lines where word_count is not convertible to integer.

df.drop(index=ids, inplace=True)
df['word_count'] = df.word_count.astype(int)
df.head(5)
episode_id number raw_text timestamp_in_ms speaking_line character_id location_id raw_character_text raw_location_text spoken_words normalized_text word_count
id
3 1 2 Marge Simpson: Ooo, careful, Homer. 8000 true 1 2.0 Marge Simpson Car Ooo, careful, Homer. ooo careful homer 3
4 1 3 Homer Simpson: There's no time to be careful. 10000 true 2 2.0 Homer Simpson Car There's no time to be careful. theres no time to be careful 6
5 1 4 Homer Simpson: We're late. 10000 true 2 2.0 Homer Simpson Car We're late. were late 2
8 1 7 Marge Simpson: (HUSHED VOICE) Sorry, Excuse us. Pardon me... 24000 true 1 4.0 Marge Simpson Auditorium Sorry, Excuse us. Pardon me... sorry excuse us pardon me 5
9 1 8 Homer Simpson: (SIMULTANEOUSLY) Hey, Norman. How's it going? So you got dragged down here, too... heh, heh. How ya doing, Fred? Excuse me, Fred. 26000 true 2 4.0 Homer Simpson Auditorium Hey, Norman. How's it going? So you got dragged down here, too... heh, heh. How ya doing, Fred? Excuse me, Fred. hey norman hows it going so you got dragged down here too heh heh how ya doing fred excuse me fred 21
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

Prepairing dialog data

from tqdm.auto import tqdm
max_context = 100
min_context = 5


res = []
e = -1
loc = -1

for _, row in tqdm(df.iterrows(), total=df.shape[0]):
    prev_e, e = e, row['episode_id']
    prev_loc, loc = loc, row['location_id']
    
    if (prev_e != e) or (prev_loc != loc):
        context = []
        total_context_length = 0
        context_lens = []
    line = row['spoken_words'] + tokenizer.eos_token
    
    if row.word_count > max_context//2:
        continue
    if total_context_length >= min_context:
        res.append({'responce':line, 'context':''.join(l for l in context), 'context_length':total_context_length, 'episode':row['episode_id']})

    context.append(line)
    context_lens.append(row.word_count)
    total_context_length += row.word_count
    to_remove = 0
    while total_context_length > max_context:
        total_context_length -= context_lens[to_remove]
        to_remove += 1

    context = context[to_remove:]

dialog_df = pd.DataFrame(res)

dialog_df.shape
(113900, 4)
dialog_df['line'] = dialog_df.context + dialog_df.responce
def tokenize(batch):
    return tokenizer(batch['line'], return_attention_mask=True, verbose=False, return_length=True)
from datasets import DatasetDict
dialog_df[dialog_df.episode <  550].to_csv('simpsons_dialog_train.csv')
dialog_df[dialog_df.episode >= 550].to_csv('simpsons_dialog_valid.csv')
ds = DatasetDict.from_csv({'train':'simpsons_dialog_train.csv', 'validation':'simpsons_dialog_valid.csv'})

Tokenize the lines in dataset:

ds = ds.map(tokenize, batched=True, batch_size=100, remove_columns=ds['train'].column_names, num_proc=2)

And remove excesively long samples:

ds = ds.filter(lambda x: x['length'] < 300)
len(ds['train']), len(ds['validation'])
(108871, 3247)

Training

train_idx, valid_idx = get_splits(ds)
train_ds = concatenate_datasets([ds['train'], ds['validation']])

The batching of data samples I want to use is somewhat different from batching used for regular causal language modeling. The samples are padded with eos_token and the replics in dialog are separated with the same token. I want to ignore padding when computing loss. The targets corresponding to padding will be set to -100. This can be easily done using attention_mask.

from dataclasses import dataclass
from transformers import PreTrainedTokenizerBase, BatchEncoding
from typing import List, Dict, Union
@dataclass
class DataCollatorForDialog:
    """
    Data collator used for dialog modeling. Inputs are dynamically padded to the maximum length of a batch if they
    are not all of the same length. The labels are constructed according to attention mask setting `label=-100` 
    where `attention_mask == 0`. 

    Args:
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data.
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
    """

    tokenizer: PreTrainedTokenizerBase
    pad_to_multiple_of: Optional[int] = None

    def __call__(
        self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
        
        labels = batch["input_ids"].clone()
        labels = torch.where(batch["attention_mask"].bool(), batch["input_ids"].clone(), torch.tensor(-100))
        batch["labels"] = labels
        return batch

To speed up training samples are grouped by length.

dblock = DataBlock(blocks=[TransformersLMBlock(tokenizer=tokenizer,
                                               masking_func=DataCollatorForDialog(tokenizer),
                                               group_by_len=True,
                                               skip_special_tokens=True)],
                   splitter=IndexSplitter(valid_idx))
train_lens = ds['train']['length']
valid_lens = ds['validation']['length']
dl_kwargs = [{'res':train_lens},{'val_res':valid_lens}]
dls = dblock.dataloaders(train_ds, bs=bs, val_bs=val_bs, num_workers=2, dl_kwargs=dl_kwargs)
dls.show_batch()
text
0 Mr. Simpson, I dread the day when a hundred thousand dollars isn't worth groveling for.Get outta here.You just made yourself a very powerful enemy, old man.Here's the deal, Grampa. A guy, I think was an explorer, left this in the bar one night. It may be a map to ancient treasure, or directions to some guy's house, but to find out, we'll need money, we'll need provisions, and a two man diving bell.It's pretty stupid, but so far you're the front runner.It's a special isolation chamber. The subject pulls levers to receive food and warmth. The floor can become electrified and showers of icy water randomly fall on the subject. I call it the Monroe Box.Huh, uh. Well, it sounds interesting.Huh uh.How much will it cost to build?Oh, that's the beauty part, it's already built. I need the money to buy a baby to raise in the box until the
1 The noodles? What noodles?The noozle on the end of the hooze! Ach!Miss Simpson, do you find something funny about the word "tromboner?"No, sir. I was laughing at something outside.She was looking at Nelson!Lisa likes Nelson!She does not!Milhouse likes Lisa!He does not!Janey likes Milhouse!She does not!Uter likes Milhouse!Nobody likes Milhouse! Lisa, you've got detention!Oh, how does Bart do this every week?Hey, Brainiac, since when do you get detention?It's your fault! I accidentally laughed at your immature prank.Haw. Yeah, the best part was when he got wet! Hey, you're doin' that the stupid way.If you use that deal with the five chalks, you'll get done faster.Thanks, but I prefer the honest way.Whatever. Smell ya later!Wow, that was a good idea. And I can't believe it came from Nelson.He's not like anybody I've ever met. He's like a riddle wrapped in an enigma wrapped in a vest. He sure is ugly, though. So
2 Yes, I finked on Homer. But you know, he deserved it. Never have I seen such abuse of the "take a penny, leave a penny" tray.The tax men were merciless.Hey, they can't take our house. My pot-bellied pig is in there.Ohhhh, Mister. Porky.Inevitably, the behind-the-scenes turmoil took its toll on their TV series.Annd action!Hold on! Cut!Bart, if it's not too much trouble...Fine! I'll do "Teen Wolf III." I've got fair-weather friends to feed.Dad, I want to go to bed. Aren't there child labor laws?Who told you about those laws? Was it Marge?Hey, you've been riding me all day. Why don't you poop in your hat?Are you going to need us tonight?I have ballet tickets. Not that they'll do much good now.With the family in disarray, episodes increasingly resorted to gimmicky premises and nonsensical plots.I'm an imposter. That man is the real Seymour Skinner.Trendy guest stars were shamelessly trotted out to
3 Maybe you should let Dad read your book before you submit it to publishers.I suppose I better. Your father's a very private person.Marge! We're out of bath towels.Ooh, ice cream truck!HERE IN MY CAR / I AM HOSING OFF BLOOD / SOME OF IT'S MINE / BUT MOST OF IT'S NOT / HERE'S MARGE...Homie, I finished my novel.Ooh, typed!It's really important that you read it and tell me what you think.No problem.Two hundred and eighty-six pages!It's double-spaced.Woo hoo! I'm half-way through!All right, "Chapter One." Hm, that makes sense. "There once was a girl from Nantucket..." Good, good... "Her name was Temperance Barrows and her heart was heavy with feeling. She..."No! Gotta read Marge's book. Can't get distracted. Hm... "distracted," that's a funny word. Does anyone ever get "tracted?" Let me call the suicide hotline and ask them.Well?Well what?
model = AutoModelForCausalLM.from_pretrained(model_name)
learn = TransLearner(dls, model, loss_func=noop, metrics=perplexity)
learn.validate()
(#2) [3.3612000942230225,28.823760986328125]
cbs = [GradientAccumulation(eff_bs)] if eff_bs != bs else [] 
learn.fit_one_cycle(1, 1e-5, cbs=cbs)
epoch train_loss valid_loss perplexity time
0 3.181646 3.345813 28.383650 2:46:02
model = learn.model
model.cpu()
for step in range(5):
    new_user_input_ids = tokenizer.encode(input(">> User: ") + tokenizer.eos_token, return_tensors='pt')

    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

    chat_history_ids = model.generate(bot_input_ids, 
                                      max_length=1000, 
                                      pad_token_id=tokenizer.eos_token_id,
                                      do_sample=True,
                                      top_p=0.9)

    print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))
>> User: Hi Homer. What's up?
DialoGPT: Uh, uh... nothing.
>> User: Are you hiding something. This sounds suspecious
DialoGPT: Uh, uh, you want me to give you some of those pills?
>> User: Did you take some pills? Which pills?
DialoGPT: Uh, I got the whole package.
>> User: I think you need to go to hospital!
DialoGPT: I've been a little sick lately.
>> User: Ok, let's call ambulance!
DialoGPT: Okay, but I'm a little worried about you. Do you know you're sick?