Skip to content

Commit

Permalink
WIP: fix inter module errors in plot.to_latex
Browse files Browse the repository at this point in the history
related to: #5
  • Loading branch information
anthonio9 committed May 8, 2024
1 parent 06a78c0 commit 6a335fd
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
12 changes: 6 additions & 6 deletions penn/plot/to_latex/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ def parse_args():
type=Path,
help='The audio file to plot the logits of')
parser.add_argument(
'--output_file',
'--ground_truth_file',
type=Path,
help='The jpg file to save the plot')
help='The ground truth file')
parser.add_argument(
'--checkpoint',
type=Path,
help='The checkpoint file to use for inference')
parser.add_argument(
'--output_file',
type=Path,
help='The jpg file to save the plot')
parser.add_argument(
'--gpu',
type=int,
help='The index of the GPU to use for inference')
parser.add_argument(
'--iters',
type=int,
help='Number of dummy iterations on the loader before extracting the data')
return parser.parse_known_args()[0]


Expand Down
25 changes: 13 additions & 12 deletions penn/plot/to_latex/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

import penn


import torchaudio
import jams
import plotly.express as px
import plotly.graph_objects as go

Expand Down Expand Up @@ -32,14 +35,16 @@ def from_audio(

# Concatenate results
logits = torch.cat(logits)
times = None

return logits
return logits, times


def get_ground_truth(ground_truth_file):
assert isfile(ground_truth_file)

pitch_dict = extract_pitch(jams_track)
jams_track = jams.load(str(ground_truth_file))
pitch_dict = penn.plot.raw_data.extract_pitch(jams_track)
return pitch_dict


Expand All @@ -55,24 +60,20 @@ def from_file_to_file(audio_file, ground_truth_file, checkpoint, output_file=Non
if checkpoint is None:
return

checkpoint = torch.load(checkpoint, map_location='cpu')

# Initialize model
model = penn.Model()

# Load from disk
model.load_state_dict(checkpoint['model'])

# get logits
logits = from_audio(audio, penn.SAMPLE_RATE, checkpoint, gpu)
pred_freq, pred_times = from_audio(audio, penn.SAMPLE_RATE, checkpoint, gpu)

# get the ground truth
gt = get_ground_truth(ground_truth_file)

# get the stft of the audio
audio, sr = torchaudio.load(audio_file)
audio = audio.cpu().numpy()
stft, freqs, times = extract_spectrogram(audio,
stft, freqs, times = penn.plot.raw_data.extract_spectrogram(audio,
sr=sr,
window_length=2048*4,
hop_length=penn.data.preprocess.GSET_HOPSIZE)

# now that we have both ground truth, STFT and the preditcted pitch, plot all with matplotlib and plotly
# well, do we have predicted pitch?
breakpoint()

0 comments on commit 6a335fd

Please sign in to comment.