Skip to content

Commit

Permalink
Reorganize slightly
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Jan 14, 2024
1 parent 76ee979 commit 206102b
Showing 1 changed file with 34 additions and 34 deletions.
68 changes: 34 additions & 34 deletions mergekit/io/tensor_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,45 +61,13 @@ def save_tensor(self, name: str, tensor: torch.Tensor, clone: bool = False):
self.current_shard[name] = tensor
self.current_shard_size += tensor_size

def get_name_components(self):
if self.safe_serialization:
return "model", "safetensors"
return "pytorch_model", "bin"

def _save_st(self, shard_path: str):
def _do_save():
safetensors.torch.save_file(
self.current_shard,
shard_path,
metadata={"format": "pt"},
)

try:
_do_save()
except RuntimeError as e:
if (
len(e.args) > 0
and isinstance(e.args[0], str)
and "share memory" in e.args[0]
):
logging.warning(
"Your model has duplicated tensors but the --clone-tensors "
"flag is not set."
)
self.current_shard = {
key: self.current_shard[key].clone() for key in self.current_shard
}
_do_save()
else:
raise

def flush_current_shard(self):
if not self.current_shard:
return

logging.info(f"writing shard #{self.shards_written+1} to disk")

prefix, extension = self.get_name_components()
prefix, extension = self._get_name_components()
shard_name = f"{prefix}-{self.shards_written+1}.{extension}"
for key in self.current_shard:
self.weight_map[key] = shard_name
Expand All @@ -117,7 +85,7 @@ def flush_current_shard(self):
def finalize(self):
self.flush_current_shard()

prefix, extension = self.get_name_components()
prefix, extension = self._get_name_components()

# standardize shard names to hf format
total_shards = self.shards_written
Expand Down Expand Up @@ -148,3 +116,35 @@ def finalize(self):
},
file,
)

def _get_name_components(self):
if self.safe_serialization:
return "model", "safetensors"
return "pytorch_model", "bin"

def _save_st(self, shard_path: str):
def _do_save():
safetensors.torch.save_file(
self.current_shard,
shard_path,
metadata={"format": "pt"},
)

try:
_do_save()
except RuntimeError as e:
if (
len(e.args) > 0
and isinstance(e.args[0], str)
and "share memory" in e.args[0]
):
logging.warning(
"Your model has duplicated tensors but the --clone-tensors "
"flag is not set."
)
self.current_shard = {
key: self.current_shard[key].clone() for key in self.current_shard
}
_do_save()
else:
raise

0 comments on commit 206102b

Please sign in to comment.