Skip to content

Commit

Permalink
FEATURE: new hp eval, and forcing beats into eval mode
Browse files Browse the repository at this point in the history
  • Loading branch information
femke-sintef committed Apr 9, 2024
1 parent d323ccc commit 0f5f092
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 162 deletions.
6 changes: 5 additions & 1 deletion BEATs/BEATs.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def extract_features(
padding_mask: Optional[torch.Tensor] = None,
fbank_mean: float = 15.41663,
fbank_std: float = 6.55582,
skip_dropout = False
):
# start NOTE FBG: changed input to preprocessed
# fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
Expand All @@ -200,8 +201,11 @@ def extract_features(

if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
if not skip_dropout:
x = self.dropout_input(features)
else:
x = features

x = self.dropout_input(features)

x, layer_results = self.encoder(
x,
Expand Down
12 changes: 8 additions & 4 deletions CONFIG_PREDICT.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,18 @@ model:
# PARAMETERS FOR MODEL PREDICTION #
###################################
predict:
wav_save: False
wav_save: True
overwrite: True
n_self_detected_supports: 0
tolerance: 0
tolerance: 1
n_subsample: 1 # Whether each segment should be subsampled
self_detect_support: False # Whether to use the self-training loop
filter_by_p_value: False # Whether we filter outliers by their pvalues
self_detect_support: True # Whether to use the self-training loop
filter_by_p_value: True # Whether we filter outliers by their pvalues
threshold_p_value: 0.1
self_detect_threshold_p_value: 0.95
occurence_threshold: 1 # min number of consequetive frames for a postive to be included
distribution: ecdf # name of distribution, ecdf or norm
repetitions: 1 # number of times to feed Beats same feature (no longer useful)

plot:
tsne: True
Expand Down
51 changes: 38 additions & 13 deletions evaluate/_utils_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from tqdm import tqdm
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score
from statsmodels.distributions.empirical_distribution import ECDF
from scipy.stats import norm

from prototypicalbeats.prototraining import ProtoBEATsModel
from datamodules.TestDCASEDataModule import DCASEDataModule, AudioDatasetDCASE
Expand All @@ -23,7 +25,7 @@ def to_dataframe(features, labels):

def get_proto_coordinates(model, model_type, support_data, support_labels, n_way):
if model_type == "beats":
z_supports, _ = model.get_embeddings(support_data, padding_mask=None)
z_supports, _ = model.get_embeddings(support_data, padding_mask=None, skip_dropout=True)
else:
z_supports = model.get_embeddings(support_data, padding_mask=None)

Expand Down Expand Up @@ -71,13 +73,16 @@ def compute_scores(predicted_labels, gt_labels):
return acc, recall, precision, f1score


def merge_preds(df, tolerence, tensor_length, frame_shift):
def merge_preds(df, tolerence, tensor_length, frame_shift, occurence_threshold):
df["group"] = (
df["Starttime"]
> (
df["Endtime"] + tolerence * tensor_length * frame_shift / 1000 + 0.00001
).shift()
).cumsum()
ids, occurence = np.unique(df["group"], return_counts=True)
ids_too_short_segments = ids[occurence<occurence_threshold]
df.drop(df[df.group.isin(ids_too_short_segments)].index, inplace=True)
result = df.groupby("group").agg({"Starttime": "min", "Endtime": "max"})
return result

Expand Down Expand Up @@ -167,6 +172,7 @@ def predict_labels_query(
frame_shift,
overlap,
pos_index,
repetitions=1
):
"""
- l_segment to know the length of the segment
Expand All @@ -189,10 +195,27 @@ def predict_labels_query(
feature, label = data
feature = feature.to("cuda")

if model_type == "beats":
q_embedding, _ = model.get_embeddings(feature, padding_mask=None)
else:
q_embedding = model.get_embeddings(feature, padding_mask=None)
l_dists = torch.empty(size=(repetitions,2))
l_classification_scores = torch.empty(size=(repetitions,2))
for rep_i in range(repetitions):
if model_type == "beats":
q_embedding, _ = model.get_embeddings(feature, padding_mask=None, skip_dropout=True)
else:
q_embedding = model.get_embeddings(feature, padding_mask=None)

# Get the scores:
classification_scores, dists = calculate_distance(
model_type, q_embedding, prototypes
)

if model_type != "beats":
dists = dists.squeeze()
classification_scores = classification_scores.squeeze()

l_dists[rep_i] = dists
l_classification_scores[rep_i] = classification_scores
dists = torch.mean(l_dists, dim=0)
classification_scores=torch.mean(l_classification_scores, dim=0)

# Calculate beginTime and endTime for each segment
# We multiply by 1000 to get the time in seconds
Expand All @@ -203,14 +226,7 @@ def predict_labels_query(
begin = i * tensor_length * frame_shift * overlap / 1000
end = begin + tensor_length * frame_shift / 1000

# Get the scores:
classification_scores, dists = calculate_distance(
model_type, q_embedding, prototypes
)

if model_type != "beats":
dists = dists.squeeze()
classification_scores = classification_scores.squeeze()

# Get the labels (either POS or NEG):
predicted_labels = torch.max(classification_scores, 0)[
Expand Down Expand Up @@ -246,3 +262,12 @@ def filter_outliers_by_p_values(Y, p_values, target_class=1, upper_threshold=0.0
Y[outlier_indices] = 0

return Y

def obtain_cdf(d_supports_to_POS_prototypes, distribution_name):
if distribution_name == "ecdf":
cdf = ECDF(d_supports_to_POS_prototypes)
elif distribution_name == "norm":
cdf = norm(loc=np.mean(d_supports_to_POS_prototypes), scale=np.std(d_supports_to_POS_prototypes)).cdf
else:
raise
return cdf
Loading

0 comments on commit 0f5f092

Please sign in to comment.