Skip to content

Commit

Permalink
[ADD] F1: 0.48
Browse files Browse the repository at this point in the history
  • Loading branch information
BenCretois committed Mar 18, 2024
1 parent 4c058e4 commit df52251
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 20 deletions.
12 changes: 12 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for more information:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
# https://containers.dev/guide/dependabot

version: 2
updates:
- package-ecosystem: "devcontainers"
directory: "/"
schedule:
interval: weekly
12 changes: 7 additions & 5 deletions CONFIG.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
# PARAMETERS FOR DATA PROCESSING #
##################################
data:
n_task_train: 100
n_task_val: 100
n_task_train: 500
n_task_val: 500
target_fs: 16000 # used in preprocessing
resample: true # used in preprocessing
denoise: False # used in preprocessing
resample: True # used in preprocessing
denoise: True # used in preprocessing
normalize: true # used in preprocessing
frame_length: 25.0 # used in preprocessing
tensor_length: 128 # used in preprocessing
n_shot: 3 # number of images PER CLASS in the support set
n_query: 2 # number of images PER CLASS in the query set
n_way: 20
overlap: 0.5 # used in preprocessing
n_subsample: 1
num_mel_bins: 128 # used in preprocessing
Expand All @@ -41,8 +42,9 @@ model:
lr: 1.0e-05
model_type: beats # beats, pann or baseline
state: train # train or validate - for which model should be loaded
model_path: None
model_path: /data/DCASE/models/BEATs/BEATs_iter3_plus_AS2M.pt
specaugment_params: null
n_way: 20
# specaugment_params:
# application_ratio: 1.0
# time_mask: 40
Expand Down
6 changes: 3 additions & 3 deletions CONFIG_PREDICT.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ model:
distance: euclidean # other option is mahalanobis
lr: 1.0e-05
model_type: beats # beats, pann or baseline
state: validate # train or validate - for which model should be loaded
state: train # train or validate - for which model should be loaded
model_path: None
specaugment_params: null
# specaugment_params:
Expand All @@ -54,11 +54,11 @@ model:
# PARAMETERS FOR MODEL PREDICTION #
###################################
predict:
wav_save: True
wav_save: False
overwrite: True
n_self_detected_supports: 0
tolerance: 0
filter_by_p_values: True # Whether we filter outliers by their pvalues
filter_by_p_values: False # Whether we filter outliers by their pvalues
n_subsample: 1 # Whether each segment should be subsampled
self_detect_support: False # Whether to use the self-training loop

Expand Down
2 changes: 1 addition & 1 deletion callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pytorch_lightning.callbacks.finetuning import BaseFinetuning

class MilestonesFinetuning(BaseFinetuning):
def __init__(self, milestones: int = 10):
def __init__(self, milestones: int = 1):
super().__init__()
self.unfreeze_at_epoch = milestones

Expand Down
2 changes: 1 addition & 1 deletion datamodules/DCASEDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
set_type: str = "Training_Set",
n_shot: int = 5,
n_query: int = 10,
n_way: int = 5,
n_way: int = 20,
n_subsample: int = 1,
overlap: float = 0.5,
num_mel_bins: int = 128,
Expand Down
8 changes: 4 additions & 4 deletions dcase_fine_tune/CONFIG.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ data:
overlap: 0.5 # used in preprocessing
num_mel_bins: 128 # used in preprocessing
max_segment_length: 1.0 # used in preprocessing
status: validate # used in preprocessing, train or validate or evaluate
set_type: "Validation_Set"
status: train # used in preprocessing, train or validate or evaluate
set_type: "Training_Set"


#################################
Expand All @@ -29,12 +29,12 @@ data:

trainer:
max_epochs: 10000
default_root_dir: /data/lightning_logs/BEATs
default_root_dir: /data/lightning_logs/baseline
accelerator: gpu
gpus: 1
batch_size: 64
num_workers: 4
patience: 20
patience: 10
min_sample_per_category: 10
test_size: 0.2

Expand Down
8 changes: 6 additions & 2 deletions dcase_fine_tune/FTDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np

from torch.utils.data import WeightedRandomSampler
from torchsampler import ImbalancedDatasetSampler

class TrainAudioDatasetDCASE(Dataset):
def __init__(
Expand Down Expand Up @@ -80,12 +81,15 @@ def divide_train_val(self):
samples_weight = np.array([weight[t] for t in data_frame_train["category"]])
samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
self.sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
self.sampler = WeightedRandomSampler(samples_weight, len(samples_weight)*10)

# Make the validation set
data_frame_validation = self.data_frame.loc[validation_indices]
data_frame_validation.reset_index(drop=True, inplace=True)

#print(data_frame_train["category"].value_counts())
#print(data_frame_validation["category"].value_counts())

# generate subset based on indices
self.train_set = TrainAudioDatasetDCASE(
data_frame=data_frame_train,
Expand All @@ -100,7 +104,7 @@ def train_dataloader(self):
num_workers=self.num_workers,
pin_memory=False,
collate_fn=self.collate_fn,
sampler=self.sampler
sampler=ImbalancedDatasetSampler(self.train_set) #self.sampler
)
return train_loader

Expand Down
2 changes: 1 addition & 1 deletion prototypicalbeats/prototraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class ProtoBEATsModel(pl.LightningModule):
def __init__(
self,
n_way: int = 5,
n_way: int = 20,
milestones: int = 5,
lr: float = 1e-5,
lr_scheduler_gamma: float = 1e-1,
Expand Down
3 changes: 2 additions & 1 deletion shell_scripts/train_beats.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/bin/bash

BASE_FOLDER=$1
#BASE_FOLDER=$1
BASE_FOLDER=/home/benjamin.cretois/data/DCASE
CONFIG_PATH="/app/CONFIG.yaml"

cd ..
Expand Down
6 changes: 4 additions & 2 deletions shell_scripts/validate_beats.sh
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#!/bin/bash

BASE_DIR=$1
#BASE_DIR=$1
BASE_DIR=/home/benjamin.cretois/data/DCASE #/data/Prosjekter3/823001_19_metodesats_analyse_23_36_cretois

cd ..

docker run -v $BASE_DIR:/data -v $PWD:/app \
--gpus all \
--shm-size=10gb \
beats \
poetry run python /app/evaluate/evaluateDCASE.py \
'model.model_type="beats"' \
'model.state="train"' \
'model.model_path="/data/models/BEATs/BEATs_iter3_plus_AS2M.pt"'
'model.model_path="/data/models/BEATs/BEATs_iter3_plus_AS2M.pt"'

0 comments on commit df52251

Please sign in to comment.