Create chitchat bot by fine-tuning DialoGPT on The Simpsons scripts.
(c) 20th Century Fox Television


import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset, concatenate_datasets, DatasetDict

from fastai.text.all import *
from fasthugs.learner import TransLearner
from import *
model_name = "arampacha/DialoGPT-medium-simpsons"
# data
bs = 4
val_bs = bs*4
eff_bs = 128
# training
lr = 3e-5

Data preprocessing

The data is obtained from this kaggle dataset.

filename = "simpsons_script_lines.csv"
df = pd.read_csv(filename, index_col='id')
df = df[df.spoken_words.notna()]
(132112, 12)
ids = []
for id, x in zip(df.index, df.word_count):

Remove the lines where word_count is not convertible to integer.

df.drop(index=ids, inplace=True)
df['word_count'] = df.word_count.astype(int)
episode_id number raw_text timestamp_in_ms speaking_line character_id location_id raw_character_text raw_location_text spoken_words normalized_text word_count
3 1 2 Marge Simpson: Ooo, careful, Homer. 8000 true 1 2.0 Marge Simpson Car Ooo, careful, Homer. ooo careful homer 3
4 1 3 Homer Simpson: There's no time to be careful. 10000 true 2 2.0 Homer Simpson Car There's no time to be careful. theres no time to be careful 6
5 1 4 Homer Simpson: We're late. 10000 true 2 2.0 Homer Simpson Car We're late. were late 2
8 1 7 Marge Simpson: (HUSHED VOICE) Sorry, Excuse us. Pardon me... 24000 true 1 4.0 Marge Simpson Auditorium Sorry, Excuse us. Pardon me... sorry excuse us pardon me 5
9 1 8 Homer Simpson: (SIMULTANEOUSLY) Hey, Norman. How's it going? So you got dragged down here, too... heh, heh. How ya doing, Fred? Excuse me, Fred. 26000 true 2 4.0 Homer Simpson Auditorium Hey, Norman. How's it going? So you got dragged down here, too... heh, heh. How ya doing, Fred? Excuse me, Fred. hey norman hows it going so you got dragged down here too heh heh how ya doing fred excuse me fred 21
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

Prepairing dialog data

from import tqdm
max_context = 100
min_context = 5

res = []
e = -1
loc = -1

for _, row in tqdm(df.iterrows(), total=df.shape[0]):
    prev_e, e = e, row['episode_id']
    prev_loc, loc = loc, row['location_id']
    if (prev_e != e) or (prev_loc != loc):
        context = []
        total_context_length = 0
        context_lens = []
    line = row['spoken_words'] + tokenizer.eos_token
    if row.word_count > max_context//2:
    if total_context_length >= min_context:
        res.append({'responce':line, 'context':''.join(l for l in context), 'context_length':total_context_length, 'episode':row['episode_id']})

    total_context_length += row.word_count
    to_remove = 0
    while total_context_length > max_context:
        total_context_length -= context_lens[to_remove]
        to_remove += 1

    context = context[to_remove:]

dialog_df = pd.DataFrame(res)

(113900, 4)
dialog_df['line'] = dialog_df.context + dialog_df.responce
def tokenize(batch):
    return tokenizer(batch['line'], return_attention_mask=True, verbose=False, return_length=True)
from datasets import DatasetDict
dialog_df[dialog_df.episode <  550].to_csv('simpsons_dialog_train.csv')
dialog_df[dialog_df.episode >= 550].to_csv('simpsons_dialog_valid.csv')
ds = DatasetDict.from_csv({'train':'simpsons_dialog_train.csv', 'validation':'simpsons_dialog_valid.csv'})

Tokenize the lines in dataset:

ds =, batched=True, batch_size=100, remove_columns=ds['train'].column_names, num_proc=2)

And remove excesively long samples:

ds = ds.filter(lambda x: x['length'] < 300)
len(ds['train']), len(ds['validation'])
(108871, 3247)


train_idx, valid_idx = get_splits(ds)
train_ds = concatenate_datasets([ds['train'], ds['validation']])

The batching of data samples I want to use is somewhat different from batching used for regular causal language modeling. The samples are padded with eos_token and the replics in dialog are separated with the same token. I want to ignore padding when computing loss. The targets corresponding to padding will be set to -100. This can be easily done using attention_mask.

from dataclasses import dataclass
from transformers import PreTrainedTokenizerBase, BatchEncoding
from typing import List, Dict, Union
class DataCollatorForDialog:
    Data collator used for dialog modeling. Inputs are dynamically padded to the maximum length of a batch if they
    are not all of the same length. The labels are constructed according to attention mask setting `label=-100` 
    where `attention_mask == 0`. 

        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data.
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.

    tokenizer: PreTrainedTokenizerBase
    pad_to_multiple_of: Optional[int] = None

    def __call__(
        self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
        labels = batch["input_ids"].clone()
        labels = torch.where(batch["attention_mask"].bool(), batch["input_ids"].clone(), torch.tensor(-100))
        batch["labels"] = labels
        return batch

To speed up training samples are grouped by length.

dblock = DataBlock(blocks=[TransformersLMBlock(tokenizer=tokenizer,
train_lens = ds['train']['length']
valid_lens = ds['validation']['length']
dl_kwargs = [{'res':train_lens},{'val_res':valid_lens}]
dls = dblock.dataloaders(train_ds, bs=bs, val_bs=val_bs, num_workers=2, dl_kwargs=dl_kwargs)
model = AutoModelForCausalLM.from_pretrained(model_name)
learn = TransLearner(dls, model, loss_func=noop, metrics=perplexity)
(#2) [3.3612000942230225,28.823760986328125]
cbs = [GradientAccumulation(eff_bs)] if eff_bs != bs else [] 
learn.fit_one_cycle(1, 1e-5, cbs=cbs)
epoch train_loss valid_loss perplexity time
0 3.181646 3.345813 28.383650 2:46:02
model = learn.model
for step in range(5):
    new_user_input_ids = tokenizer.encode(input(">> User: ") + tokenizer.eos_token, return_tensors='pt')

    bot_input_ids =[chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

    chat_history_ids = model.generate(bot_input_ids, 

    print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))
>> User: Hi Homer. What's up?
DialoGPT: Uh, uh... nothing.
>> User: Are you hiding something. This sounds suspecious
DialoGPT: Uh, uh, you want me to give you some of those pills?
>> User: Did you take some pills? Which pills?
DialoGPT: Uh, I got the whole package.
>> User: I think you need to go to hospital!
DialoGPT: I've been a little sick lately.
>> User: Ok, let's call ambulance!
DialoGPT: Okay, but I'm a little worried about you. Do you know you're sick?