from datasets import load_dataset
import pandas as pd
from rich import print
# Loadthe dataset
= load_dataset("strickvl/isafpressreleases", split="train")
train_dataset = load_dataset("strickvl/isafpressreleases", split="test")
test_dataset
# Convert the dataset to a pandas DataFrame
= pd.DataFrame(train_dataset)
train_df = pd.DataFrame(test_dataset) test_df
We previously looked into how well the top LLMs could do when given press releases and asked to extract structured data from them. I was glad that this clearly wasn’t a task they struggled with, but it was by no means a simple task for them and some basic evaluations that I performed showed that there was room for improvement.
Since writing that post I also heard from readers to say that perhaps I wasn’t using the OpenAI API in a way that would get the best results. In particular, function calling would give a better accuracy over the raw prompting that I was using. I’ll probably return to that in a separate post when I compare how well we’re doing with finetuning.
As a quick reminder, we’re hoping to create something that will allow us to go from an unstructured text (a press release, in our case) to a structured output that accurately extracts certain pieces of metadata from the text. Please give the first post in the series a read if you want more of the context of what we’re doing.
This blogpost will be about my first finetune(s) of some models and I’ll showcase how I got the data ready and then some observations about the finetuning process in general.
Preparing data for finetuning
In my previous posts I’ve already showed how I converted my dataset into Pydantic models. This helps ensure a uniformity to the data I’ll be using in my finetuning. We’ll actually want to convert the labelled data into JSON strings from the Pydantic models, but as an interim datastructure Pydantic is useful for the validation.
Loading the datasets
train_dataset
Dataset({
features: ['name', 'eventrefnumber', 'text', 'StartDate', 'eventtype', 'province', 'citydistrict', 'village', 'targetgroup', 'commander', 'position', 'minkilled', 'mincaptured', 'capturedcharacterisation', 'killedcharacterisation', 'killq', 'captureq', 'killcaptureraid', 'airstrike', 'noshotsfired', 'dataprocessed', 'flagged', 'glossarymeta', 'minleaderskilled', 'minfacilitatorskilled', 'minleaderscaptured', 'minfacilitatorscaptured', 'leaderq'],
num_rows: 4098
})
test_dataset
Dataset({
features: ['name', 'eventrefnumber', 'text', 'StartDate', 'eventtype', 'province', 'citydistrict', 'village', 'targetgroup', 'commander', 'position', 'minkilled', 'mincaptured', 'capturedcharacterisation', 'killedcharacterisation', 'killq', 'captureq', 'killcaptureraid', 'airstrike', 'noshotsfired', 'dataprocessed', 'flagged', 'glossarymeta', 'minleaderskilled', 'minfacilitatorskilled', 'minleaderscaptured', 'minfacilitatorscaptured', 'leaderq'],
num_rows: 724
})
We have 4098 training examples and 724 test examples. This seems a good split to me. I experimented a bit with the exact split and found that 15% seemed like a good compromise. We want enough data to get a good evaluation, but we also want to give our model enough examples to learn. In the course people were frequently talking about somewhere in the order of mid hundreds to low thousands as being the sweet spot, so I hope I’m firmly in that range.
It’s also worth reflecting that I’m lucky that I have such a large clean dataset to work with. In a later project I’d like to try working with much less and slowly building up something more complex since that’s a skill in and of itself.
Setting up our Pydantic models with validation
There’s a decent amount of code in the next cell, and definitely read the previous posts to understand what all the pieces are about, but in a nutshell we’re setting ourselves up to extract structured data from the text. This Pydantic model is what will hold the data we’re interested in.
from enum import Enum
from typing import Set, Annotated, Optional
from pydantic import BaseModel, Field, validator, ValidationInfo
from datetime import date
class EventType(str, Enum):
= "airstrike"
airstrike = "detention"
detention = "captureandkill"
captureandkill = "insurgentskilled"
insurgentskilled = "exchangeoffire"
exchangeoffire = "civiliancasualty"
civiliancasualty
class Province(str, Enum):
= "badakhshan"
badakhshan = "badghis"
badghis = "baghlan"
baghlan = "balkh"
balkh = "bamyan"
bamyan = "day_kundi"
day_kundi = "farah"
farah = "faryab"
faryab = "ghazni"
ghazni = "ghor"
ghor = "helmand"
helmand = "herat"
herat = "jowzjan"
jowzjan = "kabul"
kabul = "kandahar"
kandahar = "kapisa"
kapisa = "khost"
khost = "kunar"
kunar = "kunduz"
kunduz = "laghman"
laghman = "logar"
logar = "nangarhar"
nangarhar = "nimroz"
nimroz = "nuristan"
nuristan = "paktya"
paktya = "paktika"
paktika = "panjshir"
panjshir = "parwan"
parwan = "samangan"
samangan = "sar_e_pul"
sar_e_pul = "takhar"
takhar = "uruzgan"
uruzgan = "wardak"
wardak = "zabul"
zabul
class TargetGroup(str, Enum):
= "taliban"
taliban = "haqqani"
haqqani = "criminals"
criminals = "aq"
aq = "hig"
hig = "let"
let = "imu"
imu = "judq"
judq = "iju"
iju = "hik"
hik = "ttp"
ttp = "other"
other
def validate_event_type(value: str):
= [
valid_values "airstrike",
"detention",
"captureandkill",
"insurgentskilled",
"exchangeoffire",
"civiliancasualty",
]if value.lower() not in valid_values:
return "other"
return value.lower()
def validate_province(value: str):
= [
valid_values "badakhshan",
"badghis",
"baghlan",
"balkh",
"bamyan",
"day_kundi",
"farah",
"faryab",
"ghazni",
"ghor",
"helmand",
"herat",
"jowzjan",
"kabul",
"kandahar",
"kapisa",
"khost",
"kunar",
"kunduz",
"laghman",
"logar",
"nangarhar",
"nimroz",
"nuristan",
"paktya",
"paktika",
"panjshir",
"parwan",
"samangan",
"sar_e_pul",
"takhar",
"uruzgan",
"wardak",
"zabul",
]if value.lower() not in valid_values:
return "other"
return value.lower()
def validate_target_group(value: str):
= [
valid_values "taliban",
"haqqani",
"criminals",
"aq",
"hig",
"let",
"imu",
"judq",
"iju",
"hik",
"ttp",
"other",
]if value.lower() not in valid_values:
return "other"
return value.lower()
class IsafEvent(BaseModel):
str = Field(
name: ="A title or name for the event which summarises the event as a headline"
description
)str] = Field(description="The full text of the press release")
text: Optional[= Field(
start_date: date ="The start date of the event in YYYY-MM-DD format"
description
)str, Field(validator=validate_event_type)]] = Field(
event_type: Set[Annotated[="The event type. Can be multiple types."
description
)str, Field(validator=validate_province)]] = Field(
province: Set[Annotated[="The province in which the event occurred. Can be multiple provinces."
description
)str, Field(validator=validate_target_group)]] = Field(
target_group: Set[Annotated[="The group that was targetted during the event. Can be multiple groups."
description
)int = Field(
min_killed: ="The minimum number of people killed during the event"
description
)int = Field(
min_captured: ="The minimum number of people captured during the event"
description
)bool = Field(
killq: ="Whether someone was killed or not during the event"
description
)bool = Field(
captureq: ="Whether someone was captured or not during the event"
description
)bool = Field(
killcaptureraid: ="Whether the event was a so-called 'kill-capture raid'."
description
)bool = Field(
airstrike: ="Whether an airstrike was used during the event"
description
)bool = Field(
noshotsfired: ="Whether no shots were fired during the event"
description
)int = Field(
min_leaders_killed: ="The minimum number of leaders killed during the event"
description
)int = Field(
min_leaders_captured: ="The minimum number of leaders captured during the event"
description
)
class Config:
= True arbitrary_types_allowed
Here’s what a couple of examples of our training data looks like as Pydantic models when we pass them in:
from typing import List
= []
events: List[IsafEvent]
for i, row in list(train_df.iterrows()):
= set(
event_types for eventtype in row["eventtype"].split(",")
eventtype.strip().lower()
)= set(province.strip().lower() for province in row["province"].split(","))
provinces = set(
target_groups for target_group in row["targetgroup"].split(",")
target_group.strip().lower()
)
events.append(
IsafEvent(=row["name"],
name=row["text"],
text=row["StartDate"].to_pydatetime().date(),
start_date=event_types,
event_type=provinces,
province=target_groups,
target_group=int(row["minkilled"]),
min_killed=int(row["mincaptured"]),
min_captured=row["killq"] == "true",
killq=row["captureq"] == "true",
captureq=row["killcaptureraid"] == "true",
killcaptureraid=row["airstrike"] == "true",
airstrike=row["noshotsfired"] == "true",
noshotsfired=int(row["minleaderskilled"]),
min_leaders_killed=int(row["minleaderscaptured"]),
min_leaders_captured
)
)
print(events[:2])
[ IsafEvent( name='Several insurgents killed in Helmand', text='ISAF Joint Command Evening Operational Update Feb. 19, 2011\nISAF Joint Command - Afghanistan\u20282011-02-S-143\u2028For Immediate Release \u2028\u2028KABUL, Afghanistan (Feb. 19)\u2028\u2028ISAF service members at a compound in Sangin district, Helmand province observed numerous insurgents north and south of their position talking on radios today. After gaining positive identification of the insurgent positions, the coalition troops engaged, killing several insurgents. Later, the ISAF troops observed more insurgents positioning in the area with weapons. After positive identification, coalition forces continued firing on the various insurgent positions, resulting in several more insurgents being killed.', start_date=datetime.date(2011, 2, 18), event_type={'insurgentskilled'}, province={'helmand'}, target_group={''}, min_killed=6, min_captured=0, killq=True, captureq=False, killcaptureraid=False, airstrike=False, noshotsfired=False, min_leaders_killed=0, min_leaders_captured=0 ), IsafEvent( name='Force Continues Targeting Haqqani Leadership', text='Force Continues Targeting Haqqani Leadership\nISAF Joint Command - Afghanistan\u20282010-09-CA-211 For Immediate Release\u2028Download PDF \u2028\u2028\u2028\xa0KABUL, Afghanistan (Sept. 20) - An Afghan and coalition security force detained two insurgents, including a Haqqani Network sub-commander operating in Khost province, Sunday. \u2028\u2028The commander coordinated and conducted attacks on coalition forces operating in the province and was formerly active in Kabul. \u2028\u2028Intelligence reports led the security force to a compound northwest of Khost City to search for the commander. Afghan forces called for all occupants to exit the buildings peacefully and then the combined force cleared and secured the compound. During the clearance, an armed individual came out of an adjacent building toward the security force. The forced engaged the individual and killed him. \u2028\u2028After the area was secure, the security force questioned the residents at the scene and detained the commander and one of his associates. The security force also found multiple automatic weapons, magazines and grenades at the scene. \u2028\u2028The assault force protected the women and children throughout the search.', start_date=datetime.date(2010, 9, 19), event_type={'captureandkill'}, province={'khost'}, target_group={'haqqani'}, min_killed=1, min_captured=2, killq=True, captureq=True, killcaptureraid=True, airstrike=False, noshotsfired=False, min_leaders_killed=0, min_leaders_captured=0 ) ]
So this is data that we’ve already labelled. You can see the text that we’ll provide as input to our model, and then you can see the various fields that we’re hoping our model can learn to extract. As a JSON string, the prediction that we’re hoping our model will output would look like this:
= events[0].model_dump_json(exclude={"text"})
json_str print(json_str)
{"name":"Several insurgents killed in Helmand","start_date":"2011-02-18","event_type":["insurgentskilled"],"province":["helmand"],"target_group":[""],"mi n_killed":6,"min_captured":0,"killq":true,"captureq":false,"killcaptureraid":false,"airstrike":false,"noshotsfired" :false,"min_leaders_killed":0,"min_leaders_captured":0}
If you wish to view more examples of the data, I created an interim dataset which I uploaded to the Hugging Face Hub, but it’s not completely in the required form for finetuning so I’ll just link to it here and you can explore it to see the pairings of input and output if you’re interested.
Writing our data as JSONL
axolotl
likes its training data as a JSONL file, so that’s what we’ll write to disk to make training possible. We’ll write two different files, one for training and another for evaluation. axolotl
actually handles making a train-test split for us, so we’ll actually use the test set as a true held out evaluation set for use later on.
The data needs to be formatted in a certain way for our model to learn to output in JSON format. I’ll show two different ways of doing this below, since I found the process a bit confusing and the first time I did it in a way that technically works but might not be the best.
Writing Template-Free JSON
Another option available to us, especially if we are finetuning a base LLM (as opposed to one that has been instruction-tuned), is to write our data in a different format. Hamel’s written a guide for this on his blog and that has also been absorbed into the official axolotl
documentation, so read the blog if you want more information.
The basic idea is that instead of following a format like the one above, we can essentially create our own that’s custom to our own needs. You want this kind of freedom because to follow one of the standard templates is sometimes to shoot yourself in the food with artifacts of those templates that you don’t need in your output.
The key is to specify train_on_inputs
as false in our axolotl
config which will allow us to mask certain segments of our input data. This means that our model won’t learn the inputs but only the outputs (which we’ll specify).
All that we have to do is set up the JSONL output in a way that makes sense for our use case:
= (
template_free_train_target_file_path "../data/templatefree_isaf_press_releases_ft_train.jsonl"
)= (
template_free_test_target_file_path "../data/templatefree_isaf_press_releases_ft_test.jsonl"
)
def write_data_to_jsonl(df: pd.DataFrame, target_file_path: str) -> None:
= []
events: List[IsafEvent]
for i, row in list(df.iterrows()):
= set(
event_types for eventtype in row["eventtype"].split(",")
eventtype.strip().lower()
)= set(
provinces for province in row["province"].split(",")
province.strip().lower()
)= set(
target_groups
target_group.strip().lower()for target_group in row["targetgroup"].split(",")
)
events.append(
IsafEvent(=row["name"],
name=row["text"],
text=row["StartDate"].to_pydatetime().date(),
start_date=event_types,
event_type=provinces,
province=target_groups,
target_group=int(row["minkilled"]),
min_killed=int(row["mincaptured"]),
min_captured=row["killq"] == "true",
killq=row["captureq"] == "true",
captureq=row["killcaptureraid"] == "true",
killcaptureraid=row["airstrike"] == "true",
airstrike=row["noshotsfired"] == "true",
noshotsfired=int(row["minleaderskilled"]),
min_leaders_killed=int(row["minleaderscaptured"]),
min_leaders_captured
)
)
= [
processed_data
{"segments": [
{"label": False,
"text": "<s>You are an expert at identifying events in a press release. You are precise and always make sure you are correct, drawing inference from the text of the press release. event_types = ['airstrike', 'detention', 'captureandkill', 'insurgentskilled', 'exchangeoffire', 'civiliancasualty'], provinces = ['badakhshan', 'badghis', 'baghlan', 'balkh', 'bamyan', 'day_kundi', 'farah', 'faryab', 'ghazni', 'ghor', 'helmand', 'herat', 'jowzjan', 'kabul', 'kandahar', 'kapisa', 'khost', 'kunar', 'kunduz', 'laghman', 'logar', 'nangarhar', 'nimroz', 'nuristan', 'paktya', 'paktika', 'panjshir', 'parwan', 'samangan', 'sar_e_pul', 'takhar', 'uruzgan', 'wardak', 'zabul'], target_groups = ['taliban', 'haqqani', 'criminals', 'aq', 'hig', 'let', 'imu', 'judq', 'iju', 'hik', 'ttp', 'other']",
},"label": False, "text": f"PRESS RELEASE TEXT: {event.text}"},
{
{"label": True,
"text": f"{event.model_dump_json(exclude={'text'})}</s>",
},
]
}for event in events
]
# Write the processed data to a JSONL file
=True)
os.makedirs(os.path.dirname(target_file_path), exist_okwith open(target_file_path, "w") as f:
for item in processed_data:
+ "\n")
f.write(json.dumps(item)
write_data_to_jsonl(train_df, template_free_train_target_file_path) write_data_to_jsonl(test_df, template_free_test_target_file_path)
And you can now see the difference in the format of the JSONL dataset we’ve constructured:
with open(template_free_train_target_file_path, "r") as f:
print(f.readline())
{"segments": [{"label": false, "text": "<s>You are an expert at identifying events in a press release. You are precise and always make sure you are correct, drawing inference from the text of the press release. event_types = ['airstrike', 'detention', 'captureandkill', 'insurgentskilled', 'exchangeoffire', 'civiliancasualty'], provinces = ['badakhshan', 'badghis', 'baghlan', 'balkh', 'bamyan', 'day_kundi', 'farah', 'faryab', 'ghazni', 'ghor', 'helmand', 'herat', 'jowzjan', 'kabul', 'kandahar', 'kapisa', 'khost', 'kunar', 'kunduz', 'laghman', 'logar', 'nangarhar', 'nimroz', 'nuristan', 'paktya', 'paktika', 'panjshir', 'parwan', 'samangan', 'sar_e_pul', 'takhar', 'uruzgan', 'wardak', 'zabul'], target_groups = ['taliban', 'haqqani', 'criminals', 'aq', 'hig', 'let', 'imu', 'judq', 'iju', 'hik', 'ttp', 'other']"}, {"label": false, "text": "PRESS RELEASE TEXT: ISAF Joint Command Evening Operational Update Feb. 19, 2011\nISAF Joint Command - Afghanistan\u20282011-02-S-143\u2028For Immediate Release \u2028\u2028KABUL, Afghanistan (Feb. 19)\u2028\u2028ISAF service members at a compound in Sangin district, Helmand province observed numerous insurgents north and south of their position talking on radios today. After gaining positive identification of the insurgent positions, the coalition troops engaged, killing several insurgents. Later, the ISAF troops observed more insurgents positioning in the area with weapons. After positive identification, coalition forces continued firing on the various insurgent positions, resulting in several more insurgents being killed."}, {"label": true, "text": "{\"name\":\"Several insurgents killed in Helmand\",\"start_date\":\"2011-02-18\",\"event_type\":[\"insurgentskilled\"],\"province\":[\"helmand\"],\"target_g roup\":[\"\"],\"min_killed\":6,\"min_captured\":0,\"killq\":true,\"captureq\":false,\"killcaptureraid\":false,\"air strike\":false,\"noshotsfired\":false,\"min_leaders_killed\":0,\"min_leaders_captured\":0}</s>"}]}
Finetuning our model
With our datasets ready, finetuning our model is a simple matter of running the following two commands:
# preprocess the data ahead of training
CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess configs/tiny-llama/lora.yml
# train the model
accelerate launch -m axolotl.cli.train configs/tiny-llama/lora.yml
axolotl
handles everything else. You’ll note that we’re using a pre-prepared config file which is quite long but I’m basically using a default config with only a few changes. We can use bat
to view the config file:
!bat /home/strickvl/coding/isafpr_finetune/configs/tiny-llama/lora.yml
───────┬────────────────────────────────────────────────────────────────────────
│ File: /home/strickvl/coding/isafpr_finetune/configs/tiny-llama/lora.yml
───────┼────────────────────────────────────────────────────────────────────────
1 │ base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
2 │ model_type: LlamaForCausalLM
3 │ tokenizer_type: LlamaTokenizer
4 │
5 │ load_in_8bit: false
6 + │ # I'm training on 4090 GPUs
7 + │ # so I'm using 4-bit precision to save on memory
8 │ load_in_4bit: true
9 │ strict: false
10 │
11 │ data_seed: 42
12 │ seed: 42
13 │
14 │ datasets:
15 ~ │ - path: data/templatefree_isaf_press_releases_ft_train.jsonl
16 ~ │ type: input_output
17 │ dataset_prepared_path:
18 ~ │ val_set_size: 0.1
19 │ output_dir: ./outputs/tiny-llama/lora-out
20 │ hub_model_id: strickvl/isafpr-tiny-llama-lora
21 │
22 │ sequence_len: 4096
23 │ sample_packing: true
24 │ eval_sample_packing: false
25 │ pad_to_sequence_len: true
26 │
27 │ adapter: lora
28 │ lora_model_dir:
29 │ lora_r: 32
30 │ lora_alpha: 16
31 │ lora_dropout: 0.05
32 │ lora_target_linear: true
33 │ lora_fan_in_fan_out:
34 │
35 │ wandb_project: isaf_pr_ft
36 │ wandb_entity: strickvl
37 │ wandb_watch:
38 │ wandb_name:
39 │ wandb_log_model:
40 │
41 │ gradient_accumulation_steps: 4
42 │ micro_batch_size: 2
43 │ num_epochs: 4
44 │ optimizer: adamw_bnb_8bit
45 │ lr_scheduler: cosine
46 │ learning_rate: 0.0002
47 │
48 │ train_on_inputs: false
49 │ group_by_length: false
50 │ bf16: auto
51 │ fp16:
52 │ tf32: false
53 │
54 │ gradient_checkpointing: true
55 │ early_stopping_patience:
56 │ resume_from_checkpoint:
57 │ local_rank:
58 │ logging_steps: 1
59 │ xformers_attention:
60 │ flash_attention: true
61 │
62 │ warmup_steps: 10
63 │ evals_per_epoch: 4
64 │ saves_per_epoch: 1
65 │ debug:
66 │ deepspeed:
67 │ weight_decay: 0.0
68 │ fsdp:
69 │ fsdp_config:
70 │ special_tokens:
───────┴────────────────────────────────────────────────────────────────────────
You can check out some of the trainings I did with the following links:
Name | Config | Model | Wandb Report |
---|---|---|---|
Tiny-Llama (Template Free) | link | link | link |
Tiny-Llama (ShareGPT) | link | link | link |
Llama-3 (Template Free) | link | link | link |
Mistral (Template Free) | link | link | link |
Now that we have 4 models, we can try some out to see how they fare with some data they haven’t yet seen (from the test set). I used the code from Hamel’s Sanity Check notebook to generate some predictions and evaluate them:
from typing import Union
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
def prompt(press_release):
return f"""You are an expert at identifying events in a press release. You are precise and always make sure you are correct, drawing inference from the text of the press release. event_types = ['airstrike', 'detention', 'captureandkill', 'insurgentskilled', 'exchangeoffire', 'civiliancasualty'], provinces = ['badakhshan', 'badghis', 'baghlan', 'balkh', 'bamyan', 'day_kundi', 'farah', 'faryab', 'ghazni', 'ghor', 'helmand', 'herat', 'jowzjan', 'kabul', 'kandahar', 'kapisa', 'khost', 'kunar', 'kunduz', 'laghman', 'logar', 'nangarhar', 'nimroz', 'nuristan', 'paktya', 'paktika', 'panjshir', 'parwan', 'samangan', 'sar_e_pul', 'takhar', 'uruzgan', 'wardak', 'zabul'], target_groups = ['taliban', 'haqqani', 'criminals', 'aq', 'hig', 'let', 'imu', 'judq', 'iju', 'hik', 'ttp', 'other']
### Instruction:
PRESS RELEASE TEXT: "{press_release}"
### Response:
"""
def prompt_tok(
model: AutoPeftModelForCausalLM,
tokenizer: AutoTokenizer,str,
press_release: bool = False,
return_ids: -> Union[str, torch.Tensor]:
) = prompt(press_release)
_p = tokenizer(_p, return_tensors="pt", truncation=True).input_ids.cuda()
input_ids = model.generate(input_ids=input_ids, max_new_tokens=5000, do_sample=False)
out_ids = out_ids.detach().cpu().numpy()
ids if return_ids:
return out_ids
return tokenizer.batch_decode(ids, skip_special_tokens=True)[0][len(_p) :]
from rich import print
import json
= "strickvl/isafpr-tiny-llama-lora-templatefree"
tinyllama_templatefree_model_id = AutoPeftModelForCausalLM.from_pretrained(tinyllama_templatefree_model_id).cuda()
model = AutoTokenizer.from_pretrained(tinyllama_templatefree_model_id)
tokenizer = tokenizer.eos_token
tokenizer.pad_token
= """2011-11-S-011 ISAF Joint Command - Afghanistan For Immediate Release KABUL, Afghanistan (Nov. 7, 2011) — A combined Afghan and coalition security force conducted an operation in search of a Haqqani facilitator in Argo district, Badakshan province. The facilitator coordinates suicide attacks with other insurgent leaders in the area. During the operation, a local national male failed to comply with repeated verbal warnings and displayed hostile intent toward the security force. The security force engaged the individual, resulting in his death. The security force confiscated a shotgun and intelligence linking the local national to the Haqqani network. The security force also detained two suspected insurgents during the operation."""
press_release_sample1
= prompt_tok(model, tokenizer, press_release_sample1)
out print(out)
= json.loads(out)
out_dict print(out_dict)
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
{"name":"2011-11-07-airstrike","start_date":"2011-11-07","event_type":["airstrike"],"province":["badakhshan"],"targ et_group":["haqqani"],"min_killed":1,"min_captured":2,"killq":true,"captureq":true,"killcaptureraid":true,"airstrik e":true,"noshotsfired":false,"min_leaders_killed":0,"min_leaders_captured":0}
{ 'name': '2011-11-07-airstrike', 'start_date': '2011-11-07', 'event_type': ['airstrike'], 'province': ['badakhshan'], 'target_group': ['haqqani'], 'min_killed': 1, 'min_captured': 2, 'killq': True, 'captureq': True, 'killcaptureraid': True, 'airstrike': True, 'noshotsfired': False, 'min_leaders_killed': 0, 'min_leaders_captured': 0 }
The model has certainly learned to output JSON, and it’s done an ok job at parsing the contents of the text, but it has also made errors. It’s said that this was an airstrike whereas no airstrike is mentioned in the text.
This is only a finetune of Tiny-Llama, a much smaller version of a Llama model (with v2 architecture). Let’s maybe check out how our Mistral finetune did in comparison:
= "strickvl/isafpr-mistral-lora-templatefree"
mistral_templatefree_model_id = AutoPeftModelForCausalLM.from_pretrained(
mistral_model
mistral_templatefree_model_id
).cuda()= AutoTokenizer.from_pretrained(mistral_templatefree_model_id)
mistral_tokenizer = mistral_tokenizer.eos_token
mistral_tokenizer.pad_token
= prompt_tok(mistral_model, mistral_tokenizer, press_release_sample1)
mistral_out print(mistral_out)
= json.loads(mistral_out)
mistral_out_dict print(mistral_out_dict)
{
'name': '1',
'start_date': '2011-11-07',
'event_type': ['captureandkill'],
'province': ['badakhshan'],
'target_group': ['haqqani'],
'min_killed': 1,
'min_captured': 2,
'killq': True,
'captureq': True,
'killcaptureraid': True,
'airstrike': False,
'noshotsfired': False,
'min_leaders_killed': 0,
'min_leaders_captured': 0
}
I actually had to cheat a bit to get this output. I was getting an out-of-memory (OOM) error when trying to run the inference locally, so I ran the inference using Modal’s compute platform. You can see the script where I ran the one-off inference here. Note that I had to pass in my Hugging Face write token ahead of running the script since Mistral is a gated model. (So the final command was HF_TOKEN="hf_MY_TOKEN_VALUE_WENT_HERE" modal run notebooks/sanity_check.py
)
The output it produced, however, is pretty spot on. Actually I noticed that it outperformed the original data since it was able to correctly identify the event type as captureandkill
and specify the boolean killcaptureraid
as True
even though in the ground truth dataset it seems I mislabelled the data and stated that the killcaptureraid
value was False
.
One thing you might also have noticed is that the name
attribute was predicted as being 1
. Actually, this is not really a problem. When I was labelling the dataset there were sometimes press releases that I needed to split up into separate reports, so I’d give them a numerical name where these were split up. The name is really just a summary of the event, but there are a non-trivial number of events which have numbers as their name
, so it’s probably not even a useful field to be training on or trying to predict. Rather, if this was really necessary I could train a model to summarise the content as a headline, but for my specific use case I’m not sure it’s even useful to have this information at all.
I tried getting my Llama3 model to make its predictions, but all I got out was an endless stream of JSON content.
I suspect it has something to do with the presence or absence of an <s>
tag, which I was using for EOS or ‘end of stream’. I recall there were some error messages during training around these lines, so potentially I’ll want to look into that. I also ran into the same problem as with the Mistral model, i.e having to run it on Modal, and was only able to make it work by specifying a different datatype when loading the model:
= AutoPeftModelForCausalLM.from_pretrained(
model
model_id,=torch.bfloat16,
torch_dtype="auto",
device_map ).cuda()
Thanks to Chris Levy over on the course forum for suggesting this approach. While this works for getting it to run on Modal, I still am looking for a way to get my Mistral and Llama3 models to run locally, so I’ll probably have to investigate how to optimise the model’s memory usage further.
Next Steps
I’m pretty happy with this set of experiments. It was exciting to see that it’s relatively quick to do experiments with finetuning LLMs: a finetune of Llama3 or Mistral over four epochs only took about 35 minutes on my local machine.
Some obvious next steps for this project are:
- Figure out the model loading issue mentioned at the end of this post: how to load my models locally and what are the tradeoffs of whatever approaches are possible?
- Get more deliberate about adding in some manual evaluations: figure out some examples where I’m deliberately testing some known edge cases and hard-coded outputs. This feels like something that ought to be done sooner than later.
- Run the evaluations I had for GPT-4 in the previous blog using my new model(s). Let’s see how well my finetuned models do in comparison.
- Run training jobs on all the different platforms where we have course credits in order to get a sense for hwo they work. I’ve only tried out Modal so far, and not even for training, just inference.
- Pick one of the base models and try some hyperparameter tuning to see which combination of parameters and config values gives the best performance.
- Think about model deployment for whichever candidate I choose as being the best, then run some benchmarking / tests to see how well it performs and whether we can ever compete with the price point of something like GPT-4 (esp when we factor in the accuracy scores across my evaluations).
Having written these down, the order in which I wrote them seems like a sensible way to keep going. So my next step will be to read up on the model loading a bit and to try out some possible solutions for loading my models locally.
I’ll also add one personal note on the experience so far. I’m really enjoying the experience of being very hands-on. There are some parts of what I did so far that perhaps require a bit of experience to quickly move past some boring work (converting data from one format to the other and so on), but for the most part I’ve found the work of finetuning models to be really accessible to someone without much technical background. Even the conversion of datasets and construction of templates is all relatively straightforward and (beyond needing to have the intuition to know that that’s a thing you have to do) you could accomplish most of it using Claude or GPT-4 without even any technical background at all.
It’s also quite empowering to see all these vistas open up before me, especially the ones enabled by being able to finetune these models on my local machine. I’m really excited about the next experiments and stages of this project to come, in particular how much there is to learn!