Skip to content

Commit

Permalink
fix(modeling_base): loading sharded safetensors
Browse files Browse the repository at this point in the history
  • Loading branch information
maxreciprocate committed Dec 4, 2023
1 parent 5862d01 commit 9879889
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 36 deletions.
2 changes: 1 addition & 1 deletion tests/test_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
import transformers
from peft import get_peft_config, get_peft_model
from peft.utils.config import PeftType, TaskType
from peft.utils import PeftType, TaskType
from transformers import AutoConfig, AutoModelForCausalLM

from trlx.data.configs import TokenizerConfig
Expand Down
71 changes: 36 additions & 35 deletions trlx/models/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
import torch.nn as nn
import transformers
from huggingface_hub import hf_hub_download
import huggingface_hub

import trlx.utils.logging as logging
from trlx.utils import is_peft_available
Expand Down Expand Up @@ -155,8 +155,10 @@ def from_pretrained( # noqa: max-complexity
call (e.g. `transformers.AutoModelForCausalLM.from_pretrained`) and the specific
instance of the wrapped model.
NOTE: You must pass in arguments specific to the wrapped model as keyword arguments.
"""

if kwargs is not None:
peft_from_pretrained_kwargs = kwargs.pop("peft_from_pretrained_kwargs", {})
peft_int8_kwargs = kwargs.pop("peft_int8_kwargs", {})
Expand Down Expand Up @@ -273,42 +275,41 @@ def from_pretrained( # noqa: max-complexity
model = cls(base_model, **wrapped_model_kwargs)

if isinstance(pretrained_model_name_or_path, str):
filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json")
is_sharded = False

if not os.path.exists(filename):
if not os.path.exists(pretrained_model_name_or_path):
try:
filename = hf_hub_download(pretrained_model_name_or_path, "pytorch_model.bin", revision=revision)
# Sharded
except Exception:
if os.path.exists(sharded_index_filename):
index_file_name = sharded_index_filename
else:
index_file_name = hf_hub_download(
pretrained_model_name_or_path,
"pytorch_model.bin.index.json",
revision=revision,
)
with open(index_file_name, "r") as f:
index = json.load(f)

# Load all weights from the shards
files_to_download = set(index["weight_map"].values())
is_sharded = True

if is_sharded:
# Merge each shard into a state dict
# TODO: Optimize this to avoid wasting RAM
state_dict = {}
for shard_file in files_to_download:
filename = os.path.join(pretrained_model_name_or_path, shard_file)
# Download if shard file doesn't exist locally
if not os.path.exists(filename):
filename = hf_hub_download(pretrained_model_name_or_path, shard_file, revision=revision)
state_dict.update(torch.load(filename, map_location="cpu"))
pretrained_model_name_or_path = huggingface_hub.snapshot_download(pretrained_model_name_or_path, revision=revision)
except huggingface_hub.utils._errors.RepositoryNotFoundError:
raise ValueError("Invalid `pretrained_model_name_or_path`. It should be a local path or a repository name.")

sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json")
sharded_safetensors_index_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json")

if os.path.exists(sharded_index_filename):
with open(shared_index_filename, "r") as f:
index = json.load(f)
shards = set(index["weight_map"].values())
elif os.path.exists(sharded_safetensors_index_filename):
with open(sharded_safetensors_index_filename, "r") as f:
index = json.load(f)
shards = set(index["weight_map"].values())
else:
state_dict = torch.load(filename, map_location="cpu")
shard_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
shard_safetensors_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors")
if os.path.exists(shard_filename):
shards = [shard_filename]
elif os.path.exists(shard_safetensors_filename):
shards = [shard_safetensors_filename]

state_dict = {}
for shard in shards:
if shard.endswith(".safetensors"):
import safetensors
with safetensors.safe_open(os.path.join(pretrained_model_name_or_path, shard), framework="pt") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
else:
state_dict.update(torch.load(shard, map_location="cpu"))

else:
state_dict = pretrained_model_name_or_path.state_dict()

Expand Down

0 comments on commit 9879889

Please sign in to comment.