Skip to content

Commit

Permalink
[FIX] resample NEG supports
Browse files Browse the repository at this point in the history
  • Loading branch information
BenCretois committed Mar 18, 2024
1 parent 16a48d7 commit 4c058e4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
1 change: 0 additions & 1 deletion evaluate/_utils_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from datamodules.TestDCASEDataModule import DCASEDataModule, AudioDatasetDCASE

import pytorch_lightning as pl
pl.utilities.seed.seed_everything(42, workers=True)

def to_dataframe(features, labels):
# Load the saved array and map the features and labels into a single dataframe
Expand Down
27 changes: 13 additions & 14 deletions evaluate/evaluateDCASE.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from datamodules.TestDCASEDataModule import DCASEDataModule, AudioDatasetDCASE

import pytorch_lightning as pl
pl.utilities.seed.seed_everything(42, workers=True)
pl.utilities.seed.seed_everything(0, workers=True)

from callbacks.callbacks import MilestonesFinetuning

Expand Down Expand Up @@ -61,6 +61,7 @@ def compute(
n_shot=3,
n_query=2,
n_subsample=cfg["data"]["n_subsample"],
n_task_train=cfg["data"]["n_task_train"]
)
label_dic = custom_dcasedatamodule.get_label_dic()
pos_index = label_dic["POS"]
Expand Down Expand Up @@ -162,22 +163,20 @@ def compute(
df_extension_pos["category"] = "POS"

# # Detect NEG samples
# detected_neg_indices = np.where(p_values_pos == 0)[0]
df_neg = df_support[df_support["category"] == "NEG"]
num_pos_samples = len(detected_pos_indices)

# # Randomly sample NEG samples to match the number of POS samples
# num_pos_samples = len(detected_pos_indices)

# if num_pos_samples > 0 and len(detected_neg_indices) > num_pos_samples:
# sampled_neg_indices = np.random.choice(detected_neg_indices, size=num_pos_samples, replace=False)
# else:
# sampled_neg_indices = detected_neg_indices

# df_extension_neg = df_query.iloc[sampled_neg_indices].copy()
# df_extension_neg["category"] = "NEG"
if num_pos_samples > 0 and len(df_neg) > num_pos_samples:
sampled_neg_indices = np.random.choice(range(0, len(df_neg)), size=num_pos_samples, replace=False)
df_extension_neg = df_query.iloc[sampled_neg_indices].copy()
df_extension_neg["category"] = "NEG"
else:
print(df_neg)
df_extension_neg = df_neg

# Append both POS and NEG samples to the support set
# df_support_extended = df_support.append([df_extension_pos, df_extension_neg], ignore_index=True)
df_support_extended = df_support.append([df_extension_pos], ignore_index=True)
df_support_extended = df_support.append([df_extension_pos, df_extension_neg], ignore_index=True)

########################
# RECALCULATE THE ECDF #
########################
Expand Down

0 comments on commit 4c058e4

Please sign in to comment.