Skip to content

Commit

Permalink
FEATURE: finetuning and fp detection
Browse files Browse the repository at this point in the history
  • Loading branch information
femke-sintef committed Apr 19, 2024
1 parent e457e6e commit 897dc42
Show file tree
Hide file tree
Showing 12 changed files with 256 additions and 270 deletions.
84 changes: 55 additions & 29 deletions data_utils/DCASEfewshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

PLOT = False
PLOT_TOO_SHORT_SAMPLES = False
PLOT_SUPPORT = False
PLOT_SUPPORT = True


def normalize_mono(samples):
Expand All @@ -42,14 +42,24 @@ def normalize_mono(samples):


def denoise_signal(samples, sr):
denoised_signal_samples = nr.reduce_noise(
y=np.squeeze(samples),
sr=sr,
prop_decrease=0.95,
stationary=True,
time_mask_smooth_ms=25,
freq_mask_smooth_hz=1000,
)
try:
denoised_signal_samples = nr.reduce_noise(
y=np.squeeze(samples),
sr=sr,
prop_decrease=0.95,
stationary=True,
time_mask_smooth_ms=25,
freq_mask_smooth_hz=1000,
)
except:
denoised_signal_samples = nr.reduce_noise(
y=np.squeeze(samples),
sr=sr,
prop_decrease=0.95,
stationary=True,
time_mask_smooth_ms=43,
freq_mask_smooth_hz=1000,
)
return denoised_signal_samples


Expand Down Expand Up @@ -84,7 +94,6 @@ def preprocess(

def prepare_training_val_data(
status,
set_type,
overwrite,
tensor_length=128,
frame_length=25.0,
Expand Down Expand Up @@ -250,8 +259,19 @@ def preprocess_df(df):
)
break

# Select 'set_type' depending on chosen status

if status == "train":
set_type = "Training_Set"

elif status == "validate":
set_type = "Validation_Set"

else:
set_type = "Evaluation_Set"

# Root directory of data to be processed
root_dir = "/data/DCASE/Development_Set"
root_dir = "/data/DCASE/Development_Set_24"

# Create directories for saving
my_hash_dict = {
Expand All @@ -265,12 +285,18 @@ def preprocess_df(df):
"num_mel_bins": num_mel_bins,
"max_segment_length": max_segment_length,
}
if "24" in root_dir:
my_hash_dict["24_data"] = True
if resample:
my_hash_dict["target_fs"] = target_fs
hash_dir_name = hashlib.sha1(
json.dumps(my_hash_dict, sort_keys=True).encode()
).hexdigest()
target_path = os.path.join("/data/DCASEfewshot", status, hash_dir_name)
if os.path.exists(os.path.join(target_path, "audio")):
print("Warning, audio path already exists.")
if os.path.exists(os.path.join(target_path, "plots")):
print("Warning, plots path already exists.")
if overwrite:
if os.path.exists(target_path):
shutil.rmtree(os.path.join(target_path, "audio"))
Expand Down Expand Up @@ -308,18 +334,20 @@ def preprocess_df(df):
split_list = file.split("/")
glob_cls_name = split_list[split_list.index(set_type) + 1]
file_name = split_list[split_list.index(set_type) + 2]
df = pd.read_csv(file, header=0, index_col=False)
df_all = pd.read_csv(file, header=0, index_col=False)

# read audio file into y
audio_path = file.replace("csv", "wav")
print("Processing file name {}".format(audio_path))
y, fs = librosa.load(audio_path, sr=None, mono=True)
if not resample: # or my_hash_dict["target_fs"] > fs:
if not resample or my_hash_dict["target_fs"] > fs:
target_fs = fs
#else:
# target_fs = my_hash_dict["target_fs"]
df = df[(df == "POS").any(axis=1)]
else:
target_fs = my_hash_dict["target_fs"]
df = df_all[(df_all == "POS").any(axis=1)]
df = df.reset_index()
df_UNK = df_all[(df_all == "UNK").any(axis=1)]
df_UNK = df_UNK.reset_index()

# For csv files with a column name Call, pick up the global class name
if "CALL" in df.columns:
Expand Down Expand Up @@ -404,6 +432,9 @@ def preprocess_df(df):
interval_array = pd.arrays.IntervalArray.from_arrays(
df["Starttime"].values, df["Endtime"].values
)
interval_array_UNK = pd.arrays.IntervalArray.from_arrays(
df_UNK["Starttime"].values, df_UNK["Endtime"].values
)
segment_end_ind = 0
while segment_end_ind < data.shape[1]:
# add feature
Expand All @@ -430,9 +461,14 @@ def preprocess_df(df):
segment_start_ind * frame_shift / 1000,
segment_end_ind * frame_shift / 1000,
)
is_included = np.any(interval_array.overlaps(segment_interval))

# add label
label = "POS" if is_included else "NEG"
if np.any(interval_array.overlaps(segment_interval)):
label = "POS"
elif np.any(interval_array_UNK.overlaps(segment_interval)):
label = "UNK"
else:
label = "NEG"
labels.append(label)
segment_ind += 1
if PLOT:
Expand Down Expand Up @@ -474,6 +510,7 @@ def preprocess_df(df):
# CREATE SUPPORT SETS
# reduce df to 5 lines and class list
df = df.head(5)
assert np.all(df["Q"].values == "POS")
neg_starttimes = []
neg_endtimes = []
last_pos_end_time = 0.0
Expand Down Expand Up @@ -591,19 +628,8 @@ def preprocess_df(df):
+ "'."
)

# Select 'set_type' depending on chosen status
if cfg["data"]["status"] == "train":
cfg["data"]["set_type"] = "Training_Set"

elif cfg["data"]["status"] == "validate":
cfg["data"]["set_type"] = "Validation_Set"

else:
cfg["data"]["set_type"] = "Evaluation_Set"

prepare_training_val_data(
cfg["data"]["status"],
cfg["data"]["set_type"],
cli_args.overwrite,
cfg["data"]["tensor_length"],
cfg["data"]["frame_length"],
Expand Down
4 changes: 3 additions & 1 deletion datamodules/DCASEDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def setup(self, stage=None):
"num_mel_bins": self.num_mel_bins,
"max_segment_length": self.max_segment_length,
}

my_hash_dict["24_data"] = True
if self.resample:
my_hash_dict["target_fs"] = self.target_fs
hash_dir_name = hashlib.sha1(
Expand All @@ -154,7 +156,7 @@ def setup(self, stage=None):
data_frame = pd.DataFrame({"feature": list_input_features, "category": labels})

complete_dataset = AudioDatasetDCASE(
data_frame=data_frame,
data_frame=data_frame,
)
# Separate into training and validation set
train_indices, validation_indices, _, _ = train_test_split(
Expand Down
2 changes: 1 addition & 1 deletion datamodules/TestDCASEDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
def setup(self, stage=None):
# load data
self.complete_dataset = AudioDatasetDCASE(
data_frame=self.data_frame,
data_frame=self.data_frame
)

def train_dataloader(self):
Expand Down
82 changes: 41 additions & 41 deletions dcase_fine_tune/FTDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,46 +214,46 @@ def collate_fn(self, input_data):
# return self.label_dict


class predictLoader:
def __init__(
self, data_frame=pd.DataFrame, batch_size=1, num_workers=4, tensor_length=128
):
self.data_frame = data_frame
self.batch_size = batch_size
self.num_workers = num_workers
self.tensor_length = tensor_length
self.setup()

def setup(self, stage=None):
# load data
self.complete_dataset = AudioDatasetDCASE(
data_frame=self.data_frame,
)

def pred_dataloader(self):
pred_loader = DataLoader(
self.complete_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=False,
shuffle=True,
collate_fn=self.collate_fn,
)
return pred_loader

def collate_fn(self, input_data):
true_class_ids = list({x[1] for x in input_data})
new_input = []
for x in input_data:
if x[0].shape[1] > self.tensor_length:
rand_start = torch.randint(0, x[0].shape[1] - self.tensor_length, (1,))
new_input.append(
(x[0][:, rand_start : rand_start + self.tensor_length], x[1])
)
else:
new_input.append(x)
# class predictLoader:
# def __init__(
# self, data_frame=pd.DataFrame, batch_size=1, num_workers=4, tensor_length=128
# ):
# self.data_frame = data_frame
# self.batch_size = batch_size
# self.num_workers = num_workers
# self.tensor_length = tensor_length
# self.setup()

# def setup(self, stage=None):
# # load data
# self.complete_dataset = AudioDatasetDCASE(
# data_frame=self.data_frame,
# )

# def pred_dataloader(self):
# pred_loader = DataLoader(
# self.complete_dataset,
# batch_size=self.batch_size,
# num_workers=self.num_workers,
# pin_memory=False,
# shuffle=True,
# collate_fn=self.collate_fn,
# )
# return pred_loader

# def collate_fn(self, input_data):
# true_class_ids = list({x[1] for x in input_data})
# new_input = []
# for x in input_data:
# if x[0].shape[1] > self.tensor_length:
# rand_start = torch.randint(0, x[0].shape[1] - self.tensor_length, (1,))
# new_input.append(
# (x[0][:, rand_start : rand_start + self.tensor_length], x[1])
# )
# else:
# new_input.append(x)

all_images = torch.cat([x[0].unsqueeze(0) for x in new_input])
all_labels = torch.tensor([true_class_ids.index(x[1]) for x in input_data])
# all_images = torch.cat([x[0].unsqueeze(0) for x in new_input])
# all_labels = torch.tensor([true_class_ids.index(x[1]) for x in input_data])

return (all_images, all_labels)
# return (all_images, all_labels)
Loading

0 comments on commit 897dc42

Please sign in to comment.