Skip to content

Commit

Permalink
Merge pull request #3349 from jonahpearl/zarr_folder_suffix
Browse files Browse the repository at this point in the history
Fix zarr folder suffix handling
  • Loading branch information
samuelgarcia committed Aug 29, 2024
2 parents 007b6ef + c992ca6 commit 9238023
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 11 deletions.
5 changes: 2 additions & 3 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .globals import get_global_tmp_folder, is_set_global_tmp_folder
from .core_tools import (
check_json,
clean_zarr_folder_name,
is_dict_extractor,
SIJsonEncoder,
make_paths_relative,
Expand Down Expand Up @@ -1061,9 +1062,7 @@ def save_to_zarr(
print(f"Use zarr_path={zarr_path}")
else:
if storage_options is None:
folder = Path(folder)
if folder.suffix != ".zarr":
folder = folder.parent / f"{folder.stem}.zarr"
folder = clean_zarr_folder_name(folder)
if folder.is_dir() and overwrite:
shutil.rmtree(folder)
zarr_path = folder
Expand Down
7 changes: 7 additions & 0 deletions src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def check_json(dictionary: dict) -> dict:
return json.loads(json_string)


def clean_zarr_folder_name(folder):
folder = Path(folder)
if folder.suffix != ".zarr":
folder = folder.parent / f"{folder.stem}.zarr"
return folder


def add_suffix(file_path, possible_suffix):
file_path = Path(file_path)
if isinstance(possible_suffix, str):
Expand Down
26 changes: 18 additions & 8 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from .base import load_extractor
from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match
from .core_tools import check_json, retrieve_importing_provenance, is_path_remote
from .core_tools import check_json, retrieve_importing_provenance, is_path_remote, clean_zarr_folder_name
from .sorting_tools import generate_unit_ids_for_merge_group, _get_ids_after_merging
from .job_tools import split_job_kwargs
from .numpyextractors import NumpySorting
Expand Down Expand Up @@ -111,6 +111,8 @@ def create_sorting_analyzer(
sparsity off (or give external sparsity) like this.
"""
if format != "memory":
if format == "zarr":
folder = clean_zarr_folder_name(folder)
if Path(folder).is_dir():
if not overwrite:
raise ValueError(f"Folder already exists {folder}! Use overwrite=True to overwrite it.")
Expand Down Expand Up @@ -162,6 +164,8 @@ def load_sorting_analyzer(folder, load_extensions=True, format="auto"):
The loaded SortingAnalyzer
"""
if format == "zarr":
folder = clean_zarr_folder_name(folder)
return SortingAnalyzer.load(folder, load_extensions=load_extensions, format=format)


Expand Down Expand Up @@ -269,6 +273,8 @@ def create(
sorting_analyzer = cls.load_from_binary_folder(folder, recording=recording)
sorting_analyzer.folder = Path(folder)
elif format == "zarr":
assert folder is not None, "For format='zarr' folder must be provided"
folder = clean_zarr_folder_name(folder)
cls.create_zarr(folder, sorting, recording, sparsity, return_scaled, rec_attributes=None)
sorting_analyzer = cls.load_from_zarr(folder, recording=recording)
sorting_analyzer.folder = Path(folder)
Expand Down Expand Up @@ -487,10 +493,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at
import zarr
import numcodecs

folder = Path(folder)
# force zarr sufix
if folder.suffix != ".zarr":
folder = folder.parent / f"{folder.stem}.zarr"
folder = clean_zarr_folder_name(folder)

if folder.is_dir():
raise ValueError(f"Folder already exists {folder}")
Expand Down Expand Up @@ -768,9 +771,7 @@ def _save_or_select_or_merge(

elif format == "zarr":
assert folder is not None, "For format='zarr' folder must be provided"
folder = Path(folder)
if folder.suffix != ".zarr":
folder = folder.parent / f"{folder.stem}.zarr"
folder = clean_zarr_folder_name(folder)
SortingAnalyzer.create_zarr(
folder, sorting_provenance, recording, sparsity, self.return_scaled, self.rec_attributes
)
Expand Down Expand Up @@ -829,6 +830,8 @@ def save_as(self, format="memory", folder=None) -> "SortingAnalyzer":
format : "memory" | "binary_folder" | "zarr", default: "memory"
The new backend format to use
"""
if format == "zarr":
folder = clean_zarr_folder_name(folder)
return self._save_or_select_or_merge(format=format, folder=folder)

def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyzer":
Expand All @@ -854,6 +857,8 @@ def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyz
The newly create sorting_analyzer with the selected units
"""
# TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!!
if format == "zarr":
folder = clean_zarr_folder_name(folder)
return self._save_or_select_or_merge(format=format, folder=folder, unit_ids=unit_ids)

def remove_units(self, remove_unit_ids, format="memory", folder=None) -> "SortingAnalyzer":
Expand All @@ -880,6 +885,8 @@ def remove_units(self, remove_unit_ids, format="memory", folder=None) -> "Sortin
"""
# TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!!
unit_ids = self.unit_ids[~np.isin(self.unit_ids, remove_unit_ids)]
if format == "zarr":
folder = clean_zarr_folder_name(folder)
return self._save_or_select_or_merge(format=format, folder=folder, unit_ids=unit_ids)

def merge_units(
Expand Down Expand Up @@ -938,6 +945,9 @@ def merge_units(
The newly create `SortingAnalyzer` with the selected units
"""

if format == "zarr":
folder = clean_zarr_folder_name(folder)

assert merging_mode in ["soft", "hard"], "Merging mode should be either soft or hard"

if len(merge_unit_groups) == 0:
Expand Down

0 comments on commit 9238023

Please sign in to comment.