Skip to content

Commit

Permalink
small fixes in SETS
Browse files Browse the repository at this point in the history
  • Loading branch information
JHoelli committed Dec 15, 2023
1 parent 3f83b38 commit f683895
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 51 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import ContractedST, sets, utils
__all__ = ["ContractedST", "sets", "utils"]
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def sets_explain(
X_train_knn = np.swapaxes(X_train_knn, 1, 2)
knns[c].fit(X_train_knn)

orig_c = int(np.argmax(model(to_tff(instance_x))))
orig_c = int(np.argmax(model.predict(to_tff(instance_x))))
if len(target) > 1:
target.remove(orig_c)
for target_c in target:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def __init__(
max_shapelets_to_store_per_class=30,
remove_self_similar=remove_self_similar,
random_state=self.random_state,
remove_self_similar=True,
)
# Fit multivaraite transformer
st_transformer = MultivariateTransformer(shapelet_transform)
Expand Down
4 changes: 2 additions & 2 deletions TSInterpret/InterpretabilityModels/counterfactual/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import CF, COMTE, NativeGuideCF, TSEvo, TSEvoCF, COMTECF
from . import CF, COMTE, NativeGuideCF, TSEvo, TSEvoCF, COMTECF, SETSCF

__all__ = ["CF", "COMTE", "NativeGuideCF", "TSEvoCF", "TSEvo", "COMTECF"]
__all__ = ["CF", "COMTE", "NativeGuideCF", "TSEvoCF", "TSEvo", "COMTECF","SETSCF"]
2 changes: 1 addition & 1 deletion TSInterpret/__version__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
VERSION = (0, 4, 1)
VERSION = (0, 4, 2)
__version__ = ".".join(map(str, VERSION)) # noqa: F401
124 changes: 78 additions & 46 deletions docs/Notebooks/Sets_tensorflow.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"pymop",
"deap",
"wheel",
"sktime"
]

dev_packages = base_packages + [
Expand Down

0 comments on commit f683895

Please sign in to comment.