diff --git a/dcase_fine_tune/FTBeats.py b/dcase_fine_tune/FTBeats.py index e27c2df..82572c9 100644 --- a/dcase_fine_tune/FTBeats.py +++ b/dcase_fine_tune/FTBeats.py @@ -46,10 +46,12 @@ def __init__( self._build_model() self.train_acc = Accuracy( - task="multiclass", num_classes=self.num_target_classes + task="multiclass", + num_classes=self.num_target_classes ) self.valid_acc = Accuracy( - task="multiclass", num_classes=self.num_target_classes + task="multiclass", + num_classes=self.num_target_classes ) self.save_hyperparameters() @@ -62,7 +64,9 @@ def _build_model(self): # 2. Classifier print(f"Classifier has {self.num_target_classes} output neurons") - self.fc = nn.Linear(self.cfg.encoder_embed_dim, self.num_target_classes) + self.beats.predictor_dropout = nn.Dropout(self.cfg.predictor_dropout) + self.beats.predictor = nn.Linear(self.cfg.encoder_embed_dim, self.cfg.predictor_class) + # self.fc = nn.Linear(self.cfg.encoder_embed_dim, self.num_target_classes) def extract_features(self, x, padding_mask=None): if padding_mask != None: @@ -81,10 +85,10 @@ def forward(self, x, padding_mask=None): x, _ = self.beats.extract_features(x) # Get the logits - x = self.fc(x) + # x = self.fc(x) # Mean pool the second dimension (these are the tokens) - x = x.mean(dim=1) + # x = x.mean(dim=1) return x @@ -120,14 +124,15 @@ def validation_step(self, batch, batch_idx): def configure_optimizers(self): if self.ft_entire_network: optimizer = optim.AdamW( - [{"params": self.beats.parameters()}, {"params": self.fc.parameters()}], + self.beats.parameters(), + # [{"params": self.beats.parameters()}, {"params": self.fc.parameters()}], lr=self.lr, betas=(0.9, 0.98), weight_decay=0.01, ) else: optimizer = optim.AdamW( - self.fc.parameters(), lr=self.lr, betas=(0.9, 0.98), weight_decay=0.01 + self.beats.predictor.parameters(), lr=self.lr, betas=(0.9, 0.98), weight_decay=0.01 ) return optimizer diff --git a/dcase_fine_tune/FTDataModule.py b/dcase_fine_tune/FTDataModule.py index 0f38db6..b96375e 100644 --- a/dcase_fine_tune/FTDataModule.py +++ b/dcase_fine_tune/FTDataModule.py @@ -55,15 +55,22 @@ def __init__( self.test_size = test_size self.min_sample_per_category = min_sample_per_category - self.setup() - self.divide_train_val() - - def setup(self, stage=None): - # load data + # encode labels self.data_frame["category"] = LabelEncoder().fit_transform( self.data_frame["category"] ) + # remove classes with too few samples + removals = self.data_frame['category'].value_counts().reset_index() + removals = removals[removals['category'] < min_sample_per_category]['index'].values + self.data_frame.drop(self.data_frame[self.data_frame["category"].isin(removals)].index, inplace=True) + self.data_frame.reset_index(inplace=True) + + # init dataset and divide into train&val self.complete_dataset = TrainAudioDatasetDCASE(data_frame=self.data_frame) + self.divide_train_val() + + + def divide_train_val(self): # Separate into training and validation set @@ -144,50 +151,50 @@ def collate_fn(self, input_data): return (all_images, all_labels) -class AudioDatasetDCASE(Dataset): - def __init__( - self, - data_frame, - label_dict=None, - ): - self.data_frame = data_frame - self.label_encoder = LabelEncoder() - if label_dict is not None: - self.label_encoder.fit(list(label_dict.keys())) - self.label_dict = label_dict - else: - self.label_encoder.fit(self.data_frame["category"]) - self.label_dict = dict( - zip( - self.label_encoder.classes_, - self.label_encoder.transform(self.label_encoder.classes_), - ) - ) - - def __len__(self): - return len(self.data_frame) - - def get_labels(self): - labels = [] - - for i in range(0, len(self.data_frame)): - label = self.data_frame.iloc[i]["category"] - label = self.label_encoder.transform([label])[0] - labels.append(label) - - return labels - - def __getitem__(self, idx): - input_feature = torch.Tensor(self.data_frame.iloc[idx]["feature"]) - label = self.data_frame.iloc[idx]["category"] - - # Encode label as integer - label = self.label_encoder.transform([label])[0] - - return input_feature, label - - def get_label_dict(self): - return self.label_dict +# class AudioDatasetDCASE(Dataset): +# def __init__( +# self, +# data_frame, +# label_dict=None, +# ): +# self.data_frame = data_frame +# self.label_encoder = LabelEncoder() +# if label_dict is not None: +# self.label_encoder.fit(list(label_dict.keys())) +# self.label_dict = label_dict +# else: +# self.label_encoder.fit(self.data_frame["category"]) +# self.label_dict = dict( +# zip( +# self.label_encoder.classes_, +# self.label_encoder.transform(self.label_encoder.classes_), +# ) +# ) + +# def __len__(self): +# return len(self.data_frame) + +# def get_labels(self): +# labels = [] + +# for i in range(0, len(self.data_frame)): +# label = self.data_frame.iloc[i]["category"] +# label = self.label_encoder.transform([label])[0] +# labels.append(label) + +# return labels + +# def __getitem__(self, idx): +# input_feature = torch.Tensor(self.data_frame.iloc[idx]["feature"]) +# label = self.data_frame.iloc[idx]["category"] + +# # Encode label as integer +# label = self.label_encoder.transform([label])[0] + +# return input_feature, label + +# def get_label_dict(self): +# return self.label_dict class predictLoader: diff --git a/dcase_fine_tune/FTtrain.py b/dcase_fine_tune/FTtrain.py index a705b1b..f6b2bc5 100644 --- a/dcase_fine_tune/FTtrain.py +++ b/dcase_fine_tune/FTtrain.py @@ -3,7 +3,7 @@ import hashlib import pandas as pd import json - +from datetime import datetime import pytorch_lightning as pl from dcase_fine_tune.FTBeats import BEATsTransferLearningModel @@ -33,7 +33,14 @@ def train_model( pl.callbacks.LearningRateMonitor(logging_interval="step"), pl.callbacks.EarlyStopping( monitor="train_loss", mode="min", patience=patience - ), + ), + pl.callbacks.ModelCheckpoint( + os.path.join("lightning_logs", "finetuning","{date:%Y%m%d_%H%M%S}".format(date=datetime.now())), + monitor="val_loss", + mode="min", + save_top_k=1, + verbose=True, + ) ], default_root_dir=root_dir, enable_checkpointing=True,