Skip to content

Commit

Permalink
FIX: get finetuning pipeline working
Browse files Browse the repository at this point in the history
  • Loading branch information
femke-sintef committed Apr 10, 2024
1 parent 0f5f092 commit 8440511
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 58 deletions.
19 changes: 12 additions & 7 deletions dcase_fine_tune/FTBeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ def __init__(
self._build_model()

self.train_acc = Accuracy(
task="multiclass", num_classes=self.num_target_classes
task="multiclass",
num_classes=self.num_target_classes
)
self.valid_acc = Accuracy(
task="multiclass", num_classes=self.num_target_classes
task="multiclass",
num_classes=self.num_target_classes
)
self.save_hyperparameters()

Expand All @@ -62,7 +64,9 @@ def _build_model(self):

# 2. Classifier
print(f"Classifier has {self.num_target_classes} output neurons")
self.fc = nn.Linear(self.cfg.encoder_embed_dim, self.num_target_classes)
self.beats.predictor_dropout = nn.Dropout(self.cfg.predictor_dropout)
self.beats.predictor = nn.Linear(self.cfg.encoder_embed_dim, self.cfg.predictor_class)
# self.fc = nn.Linear(self.cfg.encoder_embed_dim, self.num_target_classes)

def extract_features(self, x, padding_mask=None):
if padding_mask != None:
Expand All @@ -81,10 +85,10 @@ def forward(self, x, padding_mask=None):
x, _ = self.beats.extract_features(x)

# Get the logits
x = self.fc(x)
# x = self.fc(x)

# Mean pool the second dimension (these are the tokens)
x = x.mean(dim=1)
# x = x.mean(dim=1)

return x

Expand Down Expand Up @@ -120,14 +124,15 @@ def validation_step(self, batch, batch_idx):
def configure_optimizers(self):
if self.ft_entire_network:
optimizer = optim.AdamW(
[{"params": self.beats.parameters()}, {"params": self.fc.parameters()}],
self.beats.parameters(),
# [{"params": self.beats.parameters()}, {"params": self.fc.parameters()}],
lr=self.lr,
betas=(0.9, 0.98),
weight_decay=0.01,
)
else:
optimizer = optim.AdamW(
self.fc.parameters(), lr=self.lr, betas=(0.9, 0.98), weight_decay=0.01
self.beats.predictor.parameters(), lr=self.lr, betas=(0.9, 0.98), weight_decay=0.01
)

return optimizer
105 changes: 56 additions & 49 deletions dcase_fine_tune/FTDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,22 @@ def __init__(
self.test_size = test_size
self.min_sample_per_category = min_sample_per_category

self.setup()
self.divide_train_val()

def setup(self, stage=None):
# load data
# encode labels
self.data_frame["category"] = LabelEncoder().fit_transform(
self.data_frame["category"]
)
# remove classes with too few samples
removals = self.data_frame['category'].value_counts().reset_index()
removals = removals[removals['category'] < min_sample_per_category]['index'].values
self.data_frame.drop(self.data_frame[self.data_frame["category"].isin(removals)].index, inplace=True)
self.data_frame.reset_index(inplace=True)

# init dataset and divide into train&val
self.complete_dataset = TrainAudioDatasetDCASE(data_frame=self.data_frame)
self.divide_train_val()




def divide_train_val(self):
# Separate into training and validation set
Expand Down Expand Up @@ -144,50 +151,50 @@ def collate_fn(self, input_data):
return (all_images, all_labels)


class AudioDatasetDCASE(Dataset):
def __init__(
self,
data_frame,
label_dict=None,
):
self.data_frame = data_frame
self.label_encoder = LabelEncoder()
if label_dict is not None:
self.label_encoder.fit(list(label_dict.keys()))
self.label_dict = label_dict
else:
self.label_encoder.fit(self.data_frame["category"])
self.label_dict = dict(
zip(
self.label_encoder.classes_,
self.label_encoder.transform(self.label_encoder.classes_),
)
)

def __len__(self):
return len(self.data_frame)

def get_labels(self):
labels = []

for i in range(0, len(self.data_frame)):
label = self.data_frame.iloc[i]["category"]
label = self.label_encoder.transform([label])[0]
labels.append(label)

return labels

def __getitem__(self, idx):
input_feature = torch.Tensor(self.data_frame.iloc[idx]["feature"])
label = self.data_frame.iloc[idx]["category"]

# Encode label as integer
label = self.label_encoder.transform([label])[0]

return input_feature, label

def get_label_dict(self):
return self.label_dict
# class AudioDatasetDCASE(Dataset):
# def __init__(
# self,
# data_frame,
# label_dict=None,
# ):
# self.data_frame = data_frame
# self.label_encoder = LabelEncoder()
# if label_dict is not None:
# self.label_encoder.fit(list(label_dict.keys()))
# self.label_dict = label_dict
# else:
# self.label_encoder.fit(self.data_frame["category"])
# self.label_dict = dict(
# zip(
# self.label_encoder.classes_,
# self.label_encoder.transform(self.label_encoder.classes_),
# )
# )

# def __len__(self):
# return len(self.data_frame)

# def get_labels(self):
# labels = []

# for i in range(0, len(self.data_frame)):
# label = self.data_frame.iloc[i]["category"]
# label = self.label_encoder.transform([label])[0]
# labels.append(label)

# return labels

# def __getitem__(self, idx):
# input_feature = torch.Tensor(self.data_frame.iloc[idx]["feature"])
# label = self.data_frame.iloc[idx]["category"]

# # Encode label as integer
# label = self.label_encoder.transform([label])[0]

# return input_feature, label

# def get_label_dict(self):
# return self.label_dict


class predictLoader:
Expand Down
11 changes: 9 additions & 2 deletions dcase_fine_tune/FTtrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import hashlib
import pandas as pd
import json

from datetime import datetime
import pytorch_lightning as pl

from dcase_fine_tune.FTBeats import BEATsTransferLearningModel
Expand Down Expand Up @@ -33,7 +33,14 @@ def train_model(
pl.callbacks.LearningRateMonitor(logging_interval="step"),
pl.callbacks.EarlyStopping(
monitor="train_loss", mode="min", patience=patience
),
),
pl.callbacks.ModelCheckpoint(
os.path.join("lightning_logs", "finetuning","{date:%Y%m%d_%H%M%S}".format(date=datetime.now())),
monitor="val_loss",
mode="min",
save_top_k=1,
verbose=True,
)
],
default_root_dir=root_dir,
enable_checkpointing=True,
Expand Down

0 comments on commit 8440511

Please sign in to comment.