- Published on
Train LlaMA-2 LLM on your own emails, Part 2. Model Training
- Authors
- Name
- Nathan Brake
- @njbrake
Introduction
In part 1 we created a dataset with about 150 of emails that contain my reply to an email that I was sent. Using this small dataset, I will demonstrate how to additionally fine-tune the LlaMA-2 Chat LLM from Meta on this dataset so that the model will help generate emails that have my writing style incorporated in them.
Strategy
Fine-tuning will be done with the newly released LlaMA-2 model from Meta AI (model). We will use the model that was fine-tuned for chat/instruction, which closely matches our goal of instruction fine-tuning. The LlaMA-2 paper is well written and really worth the read (link)
In order to fine-tune this 7B parameter model on low cost GPU hardware (a T4 GPU using google Colab), we use Low Rank Adaptation (LoRA) and Quantized LoRA (QLoRA), which is basically fine-tuning a much smaller weight matrix while quantizing all the 7B parameters to 4bits, and then only training a smaller "low rank" matrix that sits next to each linear layer (MLP and Attention layers). These papers are also excellent and explain it better than I can: LoRA, QLoRA
I made lots of tweaks but drew some inspiration from code that Phil Schmid wrote here: https://www.philschmid.de/instruction-tune-llama-2 I simplified the code a bit and tweaked it specifically to be used on the LlaMA-2 chat fine-tuned model instead of the base LlaMA-2 model.
The Code
Let's get to it!
This code was run in Google Colab using the free T4 GPU. This script is designed to be used with a GPU, and won't work on a CPU (as of 8/28/23, 4bit quantization in the Huggingface (HF) accelerate library requires a GPU) First, install all the packages needed. We install the latest version of transformers on the main branch, but anything v4.33 or above is fine (as of time of writing, transformers v4.32 is the latest, so we needed to install from source.)
pip install git+https://github.com/huggingface/transformers \
git+https://github.com/huggingface/peft \
git+https://github.com/huggingface/accelerate \
bitsandbytes datasets torch --quiet
Next, use the huggingface-cli (installed when installing the transformers library) to log in to your Huggingface account. You need to log in in order to be able to use the LlaMA-2 model. Information about HF user access tokens can be found here
huggingface-cli login --token <your_token_here>
Next, in the python script or jupyter notebook, import the packages that we will be using
from datasets import load_from_disk
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoTokenizer,
set_seed,
default_data_collator,
BitsAndBytesConfig,
Trainer,
TrainingArguments,
)
import torch
import bitsandbytes as bnb
from huggingface_hub import login, HfFolder
from peft import (
get_peft_model,
LoraConfig,
prepare_model_for_kbit_training,
)
Now the fun stuff! We need to load the dataset that we created in part 1 of this series, and convert it into the format used for instruction fine-tuning. If you've looked at instruction fine-tuning of LLMs before, you may have seen a data format that looks something like:
### Instruction:
You're a chatbot, designed to answer questions about food.
### Input:
Does pizza normally have dairy in it?
### Response:
Yes, it does!
That comes from a common format mentioned in the Stanford Alpaca project link. However, the LlaMA-2 model was instruction fine-tuned with a slightly different format that looks like:
<<SYS>>
You're a chatbot, designed to answer questions about food.
<</SYS>>
[INST]
Does pizza normally have dairy in it?
[/INST]
Yes, it does!
So, in our dataset we will use this new formatting style, as to better align with all of the training data that the LlaMA-2 model has already seen during fine-tuning. This will help increase the performance of our model when we only have a small number of items in our dataset to use for our task.
We will use the meta-llama/Llama-2-7b-chat-hf model. Note that in order to get access to this model you must fill out an agreement from Meta, which is linked in the HF model page.
Let's load the dataset, and the LlaMA tokenizer:
path_to_dataset = "/content/email_dataset" # Path to the dataset created in part 1 of this series.
dataset = load_from_disk(f"{path_to_dataset}/train")
test_dataset = load_from_disk(f"{path_to_dataset}/test")
model_id = "meta-llama/Llama-2-7b-chat-hf" # sharded weights
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=True)
Now, set up the code that will turn the dataset into the desired format, and tokenize it. Apply this to the training and test dataset.
template=f"<<SYS>>\n{{instruction}}\n<</SYS>>\n[INST]\n{{context}}\n[/INST]\n{{response}}{tokenizer.eos_token}"
def preprocess(sample):
instruction = f"Respond to the given email using the style of Nate Brake"
context = sample['reply_to']
response = sample['my_message']
sample["text"] = template.format(instruction=instruction, context=context, response=response)
sample['input_ids'] = tokenizer(sample["text"]).input_ids
sample['labels'] = sample["input_ids"].copy()
return sample
dataset = dataset.map(preprocess)
test_dataset = test_dataset.map(preprocess)
For the case of keeping training free and reducing the memory footprint of the model training, remove any samples that are longer than 2k tokens. LlaMA-2 supports a context length of 4k so if you have this running on a bigger GPU than a T4, this isn't necessary.
dataset = dataset.filter(lambda example: len(example['input_ids']) <= 2048)
test_dataset = test_dataset.filter(lambda example: len(example['input_ids']) <= 2048)
# Print total number of samples
print(f"Train samples: {len(dataset)}")
print(f"Test samples: {len(test_dataset)}")
Next up, training the model. Create some basic arguments for easy tweaking later on:
args = {
"epochs": 2,
"per_device_train_batch_size": 1,
"learning_rate": 5e-5, # Default Adam Optimizer helps to compensate so that the exact number picked here doesn't matter as much.
"seed": 42,
"gradient_checkpointing": True,
"bf16": False # T4 doesn't support it, but flip this on if your GPU does
}
set_seed(args['seed']) #for reproducibility
Next, load the model, quantize it to 4-bits, and prepare it for training by freezing all the base model weights and only allowing the LoRA weights to be trainable:
# load model from the hub with a bnb config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
use_cache=False,
device_map="auto",
quantization_config=bnb_config,
use_auth_token=True
)
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=args['gradient_checkpointing'
)
if args['gradient_checkpointing']:
model.gradient_checkpointing_enable()
peft_config = LoraConfig(
r=32,
lora_alpha=16,
target_modules=['v_proj', 'o_proj', 'down_proj', 'gate_proj', 'k_proj', 'up_proj', 'q_proj'], # all the linear weights (MLP and Attention)
lora_dropout=0.1,
bias="none",
task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, peft_config)
Almost there! Now create the Huggingface Trainer, which takes care of the nitty gritty of training for us:
output_dir = "/tmp/llama2-email/"
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=args['per_device_train_batch_size'],
do_eval=True,
bf16=args['bf16'],
learning_rate=args['learning_rate'],
num_train_epochs=args['epochs'],
gradient_checkpointing=args['gradient_checkpointing'],
logging_dir=f"{output_dir}/logs",
logging_strategy="steps",
logging_steps=10,
eval_steps=20
save_strategy="no",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=test_dataset,
data_collator=default_data_collator,
)
trainer.train()
trainer.save_model(output_dir)
With my dataset of 154 training items, training takes about 45 minutes.
After training, lets see how it works.
Load the model:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
'/tmp/llama2/chat-model/',
device_map="auto",
use_auth_token=True,
load_in_4bit=True
) # If you are coming straight from training you could do model = trainer.model
Make an item and tokenize it:
prompt = """<<SYS>>
Respond to the given email using the style of Nate Brake
<</SYS>>
[INST]
Hi Nate,
Sorry for the delay we have been away on vacation. There is a copy of the packet in the main office. If that is not convenient I am going to try and scan the packet in and then I can send it as an attachment.
Mr. Smith
[/INST]
"""
input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
Generate an output email, using greedy decoding. You could do something else described in this great article: decoding strategies.
outputs = model.generate(input_ids=input_ids, max_new_tokens=100, do_sample=False) # Greedy decoding
print(f"Prompt:\n{prompt}\n")
print(f"Generated instruction:\n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}")
Using these settings, this is the output I received back:
Sounds good. I'll check the main office when I get to school tomorrow. Thanks!
Now, if I start to prompt the email output by adding slightly changing the prompt to: (the only change is adding "unfortunately" after [/INST])
prompt = """<<SYS>>
Respond to the given email using the style of Nate Brake
<</SYS>>
[INST]
Hi Nate,
Sorry for the delay we have been away on vacation. There is a copy of the packet in the main office. If that is not convenient I am going to try and scan the packet in and then I can send it as an attachment.
Mr. Smith
[/INST]
Unfortunately,
"""
Now I get the output:
I am not able to access the main office. I am on vacation in Maine. I will try to get a copy of the packet from my mom. I will let you know if I am able to get it.
Nate
This output makes sense since Mr. Smith was a teacher in my high school, and mentions of the main office related to school. And since I grew up in New England, the model generating a paragraph that says I'm in Maine but my mom will try to pick up the packet for me is a reasonable autocomplete. The dataset contains lots of emails sent by me in college, since in college I used a .edu email address and after college most of my emails are sent from a work address.
By changing the dataset to add in emails sent from my work or college email accounts, the model would then begin to predict emails more matching the tone of emails I would send nowadays.
Future work could create a plugin so that this model could be added as an autocomplete application in outlook or some other mail client, or the dataset could be tweaked to integrate additional data, such as time/date of emails. The possibilities are endless, but I think this project helps to showcase just a little bit of what can be done with only a little bit of data!
You could also change the dataset to be essays/blogposts/etc that you have written, and that would allow the model to generate that content in your tone of voice as well.
Got a question? If I missed something or if anything is confusing, feel free to reach out on twitter or email. Happy Hacking!