Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(modeling): deepspeed checkpoint loading #482

Merged
merged 15 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ class TrainConfig:
rollout_logging_dir: Optional[str] = None
save_best: bool = True
save_optimizer: bool = True
resume_from_checkpoint: Optional[str] = None

tracker: Optional[str] = "wandb"
logging_dir: Optional[str] = None
Expand Down
30 changes: 12 additions & 18 deletions trlx/models/modeling_ilql.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,22 +324,19 @@ def state_dict(self, *args, **kwargs):
Returns the state dictionary of the model. We add the state dictionary of the ilql heads
to the state dictionary of the wrapped model by prepending the key with `ilql_heads.`.
"""
base_model_state_dict = self.base_model.state_dict(*args, **kwargs)
ilql_heads_state_dict = self.ilql_heads.state_dict(*args, **kwargs)
for k, v in ilql_heads_state_dict.items():
base_model_state_dict[f"ilql_heads.{k}"] = v
return base_model_state_dict
return {
**self.base_model.state_dict(*args, **dict(prefix="base_model.", **kwargs)),
**self.ilql_heads.state_dict(*args, **dict(prefix="ilql_heads.", **kwargs)),
}

def post_init(self, state_dict):
"""
We add the state dictionary of the ilql heads to the state dictionary of the wrapped model
by preprending the key with `ilql_heads.`. This function removes the `ilql_heads.` prefix from the
keys of the value head state dictionary.
"""
for k in list(state_dict.keys()):
if "ilql_heads." in k:
state_dict[k.replace("ilql_heads.", "")] = state_dict.pop(k)
self.ilql_heads.load_state_dict(state_dict, strict=False)
trlx_checkpoint = any(k.startswith("base_model.") or k.startswith("ilql_heads.") for k in state_dict)
self.load_state_dict(state_dict, strict=trlx_checkpoint)
jon-tow marked this conversation as resolved.
Show resolved Hide resolved
del state_dict
gc.collect()

Expand Down Expand Up @@ -374,22 +371,19 @@ def state_dict(self, *args, **kwargs):
Returns the state dictionary of the model. We add the state dictionary of the ilql heads
to the state dictionary of the wrapped model by prepending the key with `ilql_heads.`.
"""
base_model_state_dict = self.base_model.state_dict(*args, **kwargs)
ilql_heads_state_dict = self.ilql_heads.state_dict(*args, **kwargs)
for k, v in ilql_heads_state_dict.items():
base_model_state_dict[f"ilql_heads.{k}"] = v
return base_model_state_dict
return {
**self.base_model.state_dict(*args, **dict(prefix="base_model.", **kwargs)),
**self.ilql_heads.state_dict(*args, **dict(prefix="ilql_heads.", **kwargs)),
}

def post_init(self, state_dict):
"""
We add the state dictionary of the ilql heads to the state dictionary of the wrapped model
by preprending the key with `ilql_heads.`. This function removes the `ilql_heads.` prefix from the
keys of the value head state dictionary.
"""
for k in list(state_dict.keys()):
if "ilql_heads." in k:
state_dict[k.replace("ilql_heads.", "")] = state_dict.pop(k)
self.ilql_heads.load_state_dict(state_dict, strict=False)
trlx_checkpoint = any(k.startswith("base_model.") or k.startswith("ilql_heads.") for k in state_dict)
self.load_state_dict(state_dict, strict=trlx_checkpoint)
del state_dict
gc.collect()

Expand Down
90 changes: 71 additions & 19 deletions trlx/models/modeling_ppo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gc
import inspect
import re
from copy import deepcopy
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -311,22 +312,19 @@ def state_dict(self, *args, **kwargs):
Returns the state dictionary of the model. We add the state dictionary of the value head
to the state dictionary of the wrapped model by prepending the key with `v_head.`.
"""
base_model_state_dict = self.base_model.state_dict(*args, **kwargs)
v_head_state_dict = self.v_head.state_dict(*args, **kwargs)
for k, v in v_head_state_dict.items():
base_model_state_dict[f"v_head.{k}"] = v
return base_model_state_dict
return {
**self.base_model.state_dict(*args, **dict(prefix="base_model.", **kwargs)),
**self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs)),
}

def post_init(self, state_dict):
"""
Adds the state dictionary of the value head to the state dictionary of the wrapped model
by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
keys of the value head state dictionary.
"""
for k in list(state_dict.keys()):
if "v_head." in k:
state_dict[k.replace("v_head.", "")] = state_dict.pop(k)
self.v_head.load_state_dict(state_dict, strict=False)
trlx_checkpoint = any(k.startswith("base_model.") or k.startswith("v_head.") for k in state_dict)
self.load_state_dict(state_dict, strict=trlx_checkpoint)
del state_dict
gc.collect() # noqa: E702

Expand All @@ -350,6 +348,8 @@ def __init__(
self.base_model,
num_layers_unfrozen=self.num_layers_unfrozen,
).eval()
else:
self.frozen_head = None

def forward_hydra(
self,
Expand Down Expand Up @@ -393,6 +393,33 @@ def forward_hydra(
return hydra_outputs.logits
return hydra_outputs

def state_dict(self, *args, **kwargs):
# append the state dictionary of the frozen head to the state dictionary of the wrapped model
state_dict = super().state_dict(*args, **kwargs)
if self.frozen_head:
state_dict = {**state_dict, **self.frozen_head.state_dict(*args, **dict(prefix="frozen_head.", **kwargs))}
return state_dict

def post_init(self, state_dict):
trlx_checkpoint = any(k.startswith("base_model.") or k.startswith("v_head.") for k in state_dict)

if self.frozen_head is None:
for k in state_dict:
match = re.search(r"^frozen_head\..+\.(\d+)\.", k)
if match:
self.num_layers_unfrozen = max(self.num_layers_unfrozen, int(match.group(1)) + 1)

config = self.base_model.config
branch_class = hf_get_branch_class(config)
self.frozen_head = branch_class(
self.base_model,
num_layers_unfrozen=self.num_layers_unfrozen,
).eval()

self.load_state_dict(state_dict, strict=trlx_checkpoint)
del state_dict
gc.collect() # noqa: E702


class ModelBranch(transformers.PreTrainedModel):
"""Implements the frozen upper trunk of the pretrained reference model used
Expand Down Expand Up @@ -1046,22 +1073,19 @@ def state_dict(self, *args, **kwargs):
Returns the state dictionary of the model. We add the state dictionary of the value head
to the state dictionary of the wrapped model by prepending the key with `v_head.`.
"""
base_model_state_dict = self.base_model.state_dict(*args, **kwargs)
v_head_state_dict = self.v_head.state_dict(*args, **kwargs)
for k, v in v_head_state_dict.items():
base_model_state_dict[f"v_head.{k}"] = v
return base_model_state_dict
return {
**self.base_model.state_dict(*args, **dict(prefix="base_model.", **kwargs)),
**self.v_head.state_dict(*args, **dict(prefix="v_head.", **kwargs)),
}

def post_init(self, state_dict):
"""
We add the state dictionary of the value head to the state dictionary of the wrapped model
Adds the state dictionary of the value head to the state dictionary of the wrapped model
by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the
keys of the value head state dictionary.
"""
for k in list(state_dict.keys()):
if "v_head." in k:
state_dict[k.replace("v_head.", "")] = state_dict.pop(k)
self.v_head.load_state_dict(state_dict, strict=False)
trlx_checkpoint = any(k.startswith("base_model.") or k.startswith("v_head.") for k in state_dict)
self.load_state_dict(state_dict, strict=trlx_checkpoint)
del state_dict
gc.collect() # noqa: E702

Expand All @@ -1084,6 +1108,8 @@ def __init__(
self.base_model,
num_layers_unfrozen=self.num_layers_unfrozen,
).eval()
else:
self.frozen_head = None

def forward_hydra(
self,
Expand Down Expand Up @@ -1142,6 +1168,32 @@ def forward_hydra(
return hydra_outputs.logits
return hydra_outputs

def state_dict(self, *args, **kwargs):
# append the state dictionary of the frozen head to the state dictionary of the wrapped model
state_dict = super().state_dict(*args, **kwargs)
if self.frozen_head:
state_dict = {**state_dict, **self.frozen_head.state_dict(*args, **dict(prefix="frozen_head.", **kwargs))}
return state_dict

def post_init(self, state_dict):
trlx_checkpoint = any(k.startswith("base_model.") or k.startswith("v_head.") for k in state_dict)

if self.frozen_head is None:
for k in state_dict:
match = re.search(r"^frozen_head\..+\.(\d+)\.", k)
if match:
self.num_layers_unfrozen = max(self.num_layers_unfrozen, int(match.group(1)))

branch_class = T5Branch # TODO: Add support for other model branches
self.frozen_head = branch_class(
self.base_model,
num_layers_unfrozen=self.num_layers_unfrozen,
).eval()

self.load_state_dict(state_dict, strict=trlx_checkpoint)
del state_dict
gc.collect() # noqa: E702


class T5Branch(ModelBranch):
"""Decoder only T5 branch"""
Expand Down
3 changes: 3 additions & 0 deletions trlx/trlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,8 @@ def train( # noqa: C901
)
trainer.add_eval_pipeline(eval_pipeline)

if config.train.resume_from_checkpoint and os.path.exists(config.train.resume_from_checkpoint):
trainer.load(config.train.resume_from_checkpoint)
maxreciprocate marked this conversation as resolved.
Show resolved Hide resolved

trainer.learn()
return trainer