logo
Published on

Train a LlaMA-2 LLM for free to be an email autocomplete assistant using your own email data! Part 1. Introduction and Data Preparation

Authors

Using LLMs like ChatGPT and GPT-4 has gotten a lot of attention this past year. Although methods for prompting via zero-shot, one-shot, few-shot have shown success with a model like GPT-4, OpenAI is beginning to allow opportunities to fine-tune their models on customer data: see their recent blogpost here about fine-tuning GPT-3.5-Turbo.

I would like to train a model to have a more advanced understanding of my email writing style, without needing for me to explicitly state my style preferences. In other words, I don't want to have to prompt GPT-4 with all my preferences like "I prefer to respond to 'are you available' emails with a short one sentence reply": I would like to give a model all emails I have ever replied to and have it implicitly figure out how I handle certain categories of emails. In order to provide the LLM with all emails I have ever replied to, this goes beyond the maximum context length a model can support, so we will need to fine-tune a model instead of prompt engineering.

Although we only deal with email messages, the information here could be applied to a broad range of tasks. This tutorial is an example of how to create a dataset using your own data and how to easily and cheaply train a state-of-the-art model on that custom data.

Overview

This project has two main components.

First, we will create a dataset of emails, where a single item of data contains a message from another author, and my email reply to that email.

Second, we will fine-tune the 7B parameter LlaMA-2-Chat model on this dataset. LlaMA-2-Chat has already been fine-tuned for answering questions and following instructions, so adding our data should only be a small domain shift for the model. Using the Low Rank Adaptation (LoRA) technique and Google Colab with a T4 GPU, we will be able to do this without spending any money!

This will be a multipart series. This first post will focus on the process of obtaining and preparing our custom dataset.

The Data

Machine Learning, it's really mostly about the data. If we have a good dataset, training a model is pretty straightforward. With bad data, it's going to be nearly impossible to get a well performing model.

Meta/CMU/USC/Tel Aviv University published a paper earlier this year (2023) discussing the importance of high quality finetuning data, called LIMA. This paper showed that even on a fine-tuning dataset of only 1000 examples, the standard supervised learning objective using Cross Entropy Loss obtained competitive results, drawing into question whether Reinforcement Learning with Human Feedback (RLHF, used to create ChatGPT) would be necessary if datasets were high quality.

Generally, a high quality dataset contains examples with accurate, complete, and diverse information. Our email dataset must be carefully filtered down so that we remove any items that aren't useful for our task. We will download all emails that we have sent, and create a dataset that will allow us to train a model to answer emails that we might receive.

When you reply to an email, often you draw in information that is not totally predictable: for instance, if my friend Matt asked "Are you free for band practice on Thursday", the LLM that we train is not going to be able to accurately respond "No, I have something else on my calendar that day", because the LLM does not have access to my actual calendar. However, the model we train can be used as an "autocomplete" type system, similar to what is beginning to show up in Outlook/Gmail/iCloud mail clients: if we begin filling out the reply to the email by saying "No, I have", then it's reasonable to expect that our LLM would be able to suggest the rest of the sentence " something else on my calendar".

Gmail

Google makes it pretty convenient to access your data. Follow the steps here in order to download your gmail data. You specifically want to download all gmail items from your "sent" folder.

You can watch this Tony Teaches Tech Video which also walks you through this process.

https://www.youtube.com/watch?v=x26B1_eHP3o

Now here's the code to handle this mbox file. Because of the variability with how messages are formatted, this specific code won't be applicable for every person, but should work as a starting point for developing your code. In my case, I wrote the script with help from the VSCode Debugger, which allowed me to look at the emails that I was parsing to understand what edge cases I needed to handle. One of the edge cases was that gmail was including sent items from my google voice account (SMS messages), which I wanted to exclude from the dataset because my texting behavior is different from my email behavior and not applicable for this project.

This code uses python 3 and the following packages:

pip install dataset pandas
import mailbox # Mbox docs are here https://docs.python.org/3/library/mailbox.html
import pandas as pd
from datasets import Dataset
mboxfilename = '/path/to/the/exported/mboxfile.mbox' # path to your mbox file
my_mbox = mailbox.mbox(mboxfilename)
print(f"This mailbox has {len(my_mbox)} items in it")
df = pd.DataFrame(columns=['subject', 'author', 'date', 'my_message', 'reply_to']) # create the empty datafram that we will add items to.
for message in my_mbox:
    payload = message.get_payload(decode=True) # decode in order to convert it out of base64 format which was how gmail gave it to me.
    if payload is None: # sometimes the sent items are empty messages.
        continue
    to = message['to']
    subject = message['subject']       # Could possibly be None.
    author = message['from']
    # if the subject is None, skip this message because we are only interested in items where we are replying to something, which generally always has something in the subject line, even if it's only 'RE: '
    if subject is None:
        continue
    date = message['date']
    exclusion_list = ['craigslist', 'txt.voice.google.com'] # in my case I found that craiglist and google voice emails were in this mbox file and muddying the dataset
    if any([x in to for x in exclusion_list] if to is not None else [False]) or '<my_email>@gmail.com' not in author or subject[:3] == 'FW:':
        continue
    payload = message.get_payload(decode=True)
    if payload is None:
        continue
    try:
        payload = payload.decode('utf-8')
    except:
        # utf-8 encoding might fail when parsing older emails which may be using the windows encoding.
        payload = payload.decode('windows-1252')
    payload = payload.replace('\r\n', '\n') # Normalize CRLF into LF
    content = payload.split('-----Original Message-----') # some emails used this to separate messages.
    my_message = None
    reply_to = None
    if len(content) == 1: # if len(content) is 1, that means the earlier 'payload.split' didn't find anything to split on
        content = content[0].replace('>',"") # the '>' char is often used as a tab to indicate text from a previous email. Remove it to clean the formatting.
        # find all the occurrences of lines that start with 'On' and end with 'wrote:'
        matching_lines = []
        for line in content.split('\n'):
            line = line.strip()
            if line[:2] == 'On' and line[-6:] == 'wrote:':
                matching_lines.append(line)
        if len(matching_lines) == 1:
            my_message = content.split(matching_lines[0])[0].strip()
            reply_to = content.split(matching_lines[0])[1].strip()
        elif len(matching_lines) > 1:
            # if this case happens, the email chain may have a bunch of replies to replies etc in it. In this case just take the last message that was replied to and ignore all the older replies.
            my_message = content.split(matching_lines[0])[0].strip()
            reply_to = content.split(matching_lines[0])[1].split(matching_lines[1])[0].strip()
        if my_message is None or reply_to is None:
            continue
    else:
        my_message = content[0].strip().replace('>',"")
        reply_to = content[1].strip().replace('>',"")
    df = pd.concat([df, pd.DataFrame([[subject, author, date, my_message, reply_to]], columns=['subject', 'author', 'date', 'my_message', 'reply_to'])])

dataset = Dataset.from_pandas(df)
# create a train/test split
dataset = dataset.train_test_split(test_size=0.2)
dataset.save_to_disk('dataset/email_dataset')

This code creates a Huggingface Dataset and saves it to the dataset/email_dataset folder. In my case, I generated a dataset of 199 items, with 40 items in the test split and 159 items in the train split.

Feel free to reach out to me via email or twitter to discuss or if you have any questions 😊

Come back for Part 2 of this series, where we will use this dataset to fine-tune the Llama-2 model, and see how well it works.