Skip to content

Commit

Permalink
re-download DAT models if filesize < 200 bytes
Browse files Browse the repository at this point in the history
probably an LFS pointer
  • Loading branch information
w-e-w committed Jul 31, 2024
1 parent 2f74cd6 commit 11b5249
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
10 changes: 10 additions & 0 deletions modules/dat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ def load_model(self, path):
model_dir=self.model_download_path,
hash_prefix=scaler.sha256,
)

if os.path.getsize(scaler.local_data_path) < 200:
# Re-download if the file is too small, probably an LFS pointer
scaler.local_data_path = modelloader.load_file_from_url(
scaler.data_path,
model_dir=self.model_download_path,
hash_prefix=scaler.sha256,
re_download=True,
)

if not os.path.exists(scaler.local_data_path):
raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
return scaler
Expand Down
9 changes: 7 additions & 2 deletions modules/modelloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,22 @@ def load_file_from_url(
progress: bool = True,
file_name: str | None = None,
hash_prefix: str | None = None,
re_download: bool = False,
) -> str:
"""Download a file from `url` into `model_dir`, using the file present if possible.
Returns the path to the downloaded file.
file_name: if specified, it will be used as the filename, otherwise the filename will be extracted from the url.
hash_prefix: is provided, the hash of the downloaded file will be checked against.
re_download: re-download the file even if it already exists.
"""
os.makedirs(model_dir, exist_ok=True)
if not file_name:
parts = urlparse(url)
file_name = os.path.basename(parts.path)
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
if not os.path.exists(cached_file):
if re_download or not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
from torch.hub import download_url_to_file
download_url_to_file(url, cached_file, progress=progress, hash_prefix=hash_prefix)
Expand Down

0 comments on commit 11b5249

Please sign in to comment.