Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dpo training (Do not merge) #63

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mttl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,7 @@ def _set_defaults(self):
self.soft_prompt_learn_kv: bool = False
self.prompt_placement: str = "prefix"
self.add_routing_token: bool = False

# rl training
self.rl_training = "dpo"
self.beta = 0.5
9 changes: 9 additions & 0 deletions mttl/datamodule/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,10 @@ def get_datamodule(args, for_generation=False, dataset_override=None):
WinograndeMultiChoiceDataModule,
)

from mttl.datamodule.ultrafeedback_data_module import (
UltrafeedbackSFTmodule,
)

# refactor all the common arguments below into a dict common kwargs
dataset = args.dataset if not dataset_override else dataset_override

Expand Down Expand Up @@ -737,6 +741,11 @@ def get_datamodule(args, for_generation=False, dataset_override=None):
augment_few_shot=args.augment_few_shot,
)
dm = FlatMultiTaskModule(config, for_generation=for_generation)
elif "ultrachat" in dataset:
config = DatasetConfig(
**common_kwargs,
)
dm = UltrafeedbackSFTmodule(config, for_generation=for_generation)
elif "mmlu" in dataset:
config = MMLUDataConfig(
**common_kwargs,
Expand Down
110 changes: 110 additions & 0 deletions mttl/datamodule/preference_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from dataclasses import dataclass

import torch

from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule
from mttl.models.library.expert_library import DatasetLibrary


@dataclass
class DataCollatorForDPO(DefaultCollator):
def __call__(self, batch):
prompts = ["Instruct: " + item["prompt"] + "\n" for item in batch]
chosen_responses = ["Output: " + item["chosen"] for item in batch]
rejected_responses = ["Output: " + item["rejected"] for item in batch]

prompt_ids = self.tokenizer.batch_encode_plus(
prompts,
padding=True,
return_tensors="pt",
max_length=self.max_input_length,
truncation=True,
)["input_ids"]

prefered_tokenize = self.tokenizer.batch_encode_plus(
chosen_responses,
padding=True,
return_tensors="pt",
max_length=self.max_input_length,
truncation=True,
)
prefered_ids = prefered_tokenize["input_ids"]

disprefered_tokenize = self.tokenizer.batch_encode_plus(
rejected_responses,
padding=True,
return_tensors="pt",
max_length=self.max_input_length,
truncation=True,
)
disprefered_ids = disprefered_tokenize["input_ids"]

prompt_prefered_ids = torch.cat([prompt_ids, prefered_ids], dim=-1)
prompt_disprefered_ids = torch.cat([prompt_ids, disprefered_ids], dim=-1)

prompt_prefered_mask = torch.cat(
[torch.ones_like(prompt_ids), torch.zeros_like(prefered_ids)], dim=-1
)
# compute the each length of the prefered
prefered_y_len = prefered_tokenize["attention_mask"].sum(dim=1)
disprefered_y_len = disprefered_tokenize["attention_mask"].sum(dim=1)

prompt_disprefered_mask = torch.cat(
[torch.ones_like(prompt_ids), torch.zeros_like(disprefered_ids)], dim=-1
)

return {
"prompt_prefered_ids": prompt_prefered_ids,
"prompt_disprefered_ids": prompt_disprefered_ids,
"prompt_prefered_mask": prompt_prefered_mask,
"prompt_disprefered_mask": prompt_disprefered_mask,
"prefered_y_len": prefered_y_len,
"disprefered_y_len": disprefered_y_len,
}


@dataclass
class Preferencemodule(DefaultDataModule):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def setup_dataset(self):
train_dataset = DatasetLibrary.pull_dataset_with_retry(
"jondurbin/truthy-dpo-v0.1"
)["train"]

self.train_dataset, self.dev_dataset = self.create_train_valid_split(
train_dataset, 0.1
)
self.test_dataset = self.dev_dataset

self.print_infos()

@property
def collate_fn(self):
return DataCollatorForDPO(
tokenizer=self.tokenizer,
padding="longest",
max_input_length=self.config.max_input_length,
max_output_length=self.config.max_output_length,
return_tensors="pt",
model_family=self.config.model_family,
for_generation=self.for_generation,
)


if __name__ == "__main__":
config = DatasetConfig(model="microsoft/phi-2")
datamodule = Preferencemodule(config)
train_dataloader = datamodule.train_dataloader()
val_dataloder = datamodule.val_dataloader()
for batch in val_dataloder:
prompt_prefered_mask = batch["prompt_prefered_mask"]
prompt_disprefered_mask = batch["prompt_disprefered_mask"]

# get the length of the response
prefered_y_len = batch["prefered_y_len"]
disprefered_y_len = batch["disprefered_y_len"]
print(prefered_y_len, disprefered_y_len)
breakpoint()
199 changes: 199 additions & 0 deletions mttl/datamodule/ultrafeedback_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
from dataclasses import dataclass
from typing import Any
import torch

from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule
from mttl.models.library.expert_library import DatasetLibrary


def is_openai_format(messages: Any) -> bool:
"""
Check if the input messages are in OpenAI format.
Args:
messages (`Any`):
Messages to check.
Returns:
`bool`: Whether the messages are in OpenAI format.
"""
if isinstance(messages, list) and all(
isinstance(message, dict) for message in messages
):
return all("role" in message and "content" in message for message in messages)
return False


@dataclass
class UltrafeedbackDPOCollator(DefaultCollator):
def __call__(self, batch):

# For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue
# We therefore need to extract the N-1 turns to form the prompt
prompts = []
chosen_responses = []
rejected_responses = []
for example in batch:
if "prompt" in example and is_openai_format(example["prompt"]):
prompt_messages = example["prompt"]
chosen_messages = example["chosen"]
rejected_messages = example["rejected"]
else:
prompt_messages = example["chosen"][:-1]
# Now we extract the final turn to define chosen/rejected responses
chosen_messages = example["chosen"][-1:]
rejected_messages = example["rejected"][-1:]
prompts.append(
self.tokenizer.apply_chat_template(prompt_messages, tokenize=False)
)
chosen_responses.append(
self.tokenizer.apply_chat_template(chosen_messages, tokenize=False)
)
rejected_responses.append(
self.tokenizer.apply_chat_template(rejected_messages, tokenize=False)
)

prompt_ids = self.tokenizer.batch_encode_plus(
prompts,
padding=True,
return_tensors="pt",
max_length=self.max_input_length,
truncation=True,
)["input_ids"]

prefered_tokenize = self.tokenizer.batch_encode_plus(
chosen_responses,
padding=True,
return_tensors="pt",
max_length=self.max_input_length,
truncation=True,
)
prefered_ids = prefered_tokenize["input_ids"]

disprefered_tokenize = self.tokenizer.batch_encode_plus(
rejected_responses,
padding=True,
return_tensors="pt",
max_length=self.max_input_length,
truncation=True,
)
disprefered_ids = disprefered_tokenize["input_ids"]

prompt_prefered_ids = torch.cat([prompt_ids, prefered_ids], dim=-1)
prompt_disprefered_ids = torch.cat([prompt_ids, disprefered_ids], dim=-1)

prompt_prefered_mask = torch.cat(
[torch.ones_like(prompt_ids), torch.zeros_like(prefered_ids)], dim=-1
)
# compute the each length of the prefered
prefered_y_len = prefered_tokenize["attention_mask"].sum(dim=1)
disprefered_y_len = disprefered_tokenize["attention_mask"].sum(dim=1)

prompt_disprefered_mask = torch.cat(
[torch.ones_like(prompt_ids), torch.zeros_like(disprefered_ids)], dim=-1
)

return {
"prompt_prefered_ids": prompt_prefered_ids,
"prompt_disprefered_ids": prompt_disprefered_ids,
"prompt_prefered_mask": prompt_prefered_mask,
"prompt_disprefered_mask": prompt_disprefered_mask,
"prefered_y_len": prefered_y_len,
"disprefered_y_len": disprefered_y_len,
}


@dataclass
class UltrafeedbackDPOmodule(DefaultDataModule):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def setup_dataset(self):
dataset = DatasetLibrary.pull_dataset_with_retry(
"princeton-nlp/gemma2-ultrafeedback-armorm"
)

# format the ultrafeedback dataset to chatbot format
self.train_dataset = dataset["train"]
self.test_dataset = dataset["test"]
self.dev_dataset = self.test_dataset

self.print_infos()

@property
def collate_fn(self):
return UltrafeedbackDPOCollator(
tokenizer=self.tokenizer,
padding="longest",
max_input_length=self.config.max_input_length,
max_output_length=self.config.max_output_length,
return_tensors="pt",
model_family=self.config.model_family,
for_generation=self.for_generation,
)


@dataclass
class UltrafeedbackSFTCollator(DefaultCollator):
def __call__(self, batch):

# For SFT, the inputs are triples of (prompt, message), where `chosen` and `rejected` are the final turn of a dialogue
# We therefore need to extract the N-1 turns to form the prompt
prompts = []
messages = []
for example in batch:
prompt_messages = example["prompt"]
chosen_messages = example["messages"]
prompts.append(prompt_messages)
messages.append(
self.tokenizer.apply_chat_template(chosen_messages, tokenize=False)
)

return {
"sources_texts": prompts,
"labels_texts": messages,
}


@dataclass
class UltrafeedbackSFTmodule(DefaultDataModule):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def setup_dataset(self):
dataset = DatasetLibrary.pull_dataset_with_retry("HuggingFaceH4/ultrachat_200k")

# format the ultrafeedback dataset to chatbot format
self.train_dataset = dataset["train_sft"]
self.test_dataset = dataset["test_sft"]
self.dev_dataset = self.test_dataset

self.print_infos()

@property
def collate_fn(self):
return UltrafeedbackSFTCollator(
tokenizer=self.tokenizer,
padding="longest",
max_input_length=self.config.max_input_length,
max_output_length=self.config.max_output_length,
return_tensors="pt",
model_family=self.config.model_family,
for_generation=self.for_generation,
)


if __name__ == "__main__":
config = DatasetConfig(model="microsoft/Phi-3-mini-4k-instruct")
datamodule = UltrafeedbackSFTmodule(config)
train_dataloader = datamodule.train_dataloader()
val_dataloder = datamodule.val_dataloader()
for batch in val_dataloder:
# prompt_prefered_mask = batch["prompt_prefered_mask"]
# prompt_disprefered_mask = batch["prompt_disprefered_mask"]

# get the length of the response
# prefered_y_len = batch["prefered_y_len"]
# disprefered_y_len = batch["disprefered_y_len"]
print(batch)
breakpoint()
Loading