Skip to content

Commit

Permalink
WIP: plot STFT from given audio
Browse files Browse the repository at this point in the history
related to: #5
  • Loading branch information
anthonio9 committed May 26, 2024
1 parent 6a335fd commit ce63670
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 11 deletions.
1 change: 1 addition & 0 deletions penn/plot/to_latex/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .core import *
from . import mplt

79 changes: 68 additions & 11 deletions penn/plot/to_latex/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import penn


import torchaudio
import torchutil
import jams
import plotly.express as px
import plotly.graph_objects as go
Expand Down Expand Up @@ -35,23 +35,82 @@ def from_audio(

# Concatenate results
logits = torch.cat(logits)
pitch = None
times = None

return logits, times
with torchutil.time.context('decode'):
# pitch is in Hz
predicted, pitch, periodicity = penn.postprocess(logits)
pitch = pitch.detach().numpy()[0, ...]
pitch = np.split(pitch, pitch.shape[0])
times = penn.HOPSIZE_SECONDS * np.arange(pitch[0].shape[-1])

return pitch, times


def get_ground_truth(ground_truth_file):
assert isfile(ground_truth_file)

jams_track = jams.load(str(ground_truth_file))
pitch_dict = penn.plot.raw_data.extract_pitch(jams_track)
notes_dict = penn.data.preprocess.jams_to_notes(jams_track)
pitch_dict = penn.data.preprocess.notes_dict_to_pitch_dict(notes_dict)
return pitch_dict


def plot_over_gt(stft, logits, gt):
pass
def plot_over_gt_with_plotly(audio, sr, pred_freq, pred_times, gt, return_fig=False):
stft, freqs, times = penn.plot.raw_data.extract_spectrogram(audio,
sr=sr,
window_length=2048*4,
hop_length=penn.data.preprocess.GSET_HOPSIZE)

fig = px.imshow(
stft,
color_continuous_scale="aggrnyl",
x=times,
y=freqs,
aspect='auto',
origin='lower')

max_pitch = []
min_pitch = []
#
# for no_slice, pitch_slice in gt.items():
# fig = fig.add_trace(go.Scatter(
# name=f"String {no_slice}",
# x = pitch_slice["times"],
# y = pitch_slice["frequency"],
# mode="markers",
# marker=dict (size=5)))
#
# if pitch_slice["frequency"].size > 0:
# max_pitch.append(pitch_slice["frequency"].max())
# min_pitch.append(pitch_slice["frequency"].min())
#
for no_slice, pitch_slice in enumerate(pred_freq):
fig = fig.add_trace(go.Scatter(
name=f"String {no_slice}",
x = pred_times,
y = pitch_slice.reshape(-1),
mode="markers",
marker=dict (size=5)))

if pitch_slice.size > 0:
max_pitch.append(pitch_slice.max())
min_pitch.append(pitch_slice.min())

ymax = max(max_pitch)
ymin = min(min_pitch)

offset = (ymax - ymin) * 0.1
ymax += offset
ymin -= offset

fig.update_yaxes(range=[ymin, ymax], autorange=False)
if return_fig:
return fig

fig.show()



def from_file_to_file(audio_file, ground_truth_file, checkpoint, output_file=None, gpu=None):
# Load audio
Expand All @@ -69,11 +128,9 @@ def from_file_to_file(audio_file, ground_truth_file, checkpoint, output_file=Non
# get the stft of the audio
audio, sr = torchaudio.load(audio_file)
audio = audio.cpu().numpy()
stft, freqs, times = penn.plot.raw_data.extract_spectrogram(audio,
sr=sr,
window_length=2048*4,
hop_length=penn.data.preprocess.GSET_HOPSIZE)

# now that we have both ground truth, STFT and the preditcted pitch, plot all with matplotlib and plotly
# well, do we have predicted pitch?
breakpoint()
# plot_over_gt_with_plotly(audio, sr, pred_freq, pred_times, gt)

penn.plot.to_latex.mplt.plot_with_matplotlib(audio, sr, pred_freq, pred_times, gt)
53 changes: 53 additions & 0 deletions penn/plot/to_latex/mplt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import penn

import matplotlib.pyplot as plt
import numpy as np


def plot_stft(axis : plt.Axes,
audio,
sr=penn.SAMPLE_RATE,
window_length=2048*4,
hop_length=penn.data.preprocess.GSET_HOPSIZE):
"""
Add a plot of STFT to given audio.
Parameters:
axis - matplotlib pyplot figure axis to have the STFT plot
audio - source data
sr - sampling rate
window_length - length of the moving STFT window
hop_length - hop step of the moving window in samples
"""

stft, freqs, times = penn.plot.raw_data.extract_spectrogram(audio,
sr=sr,
window_length=window_length,
hop_length=hop_length)

axis.pcolormesh(times, freqs, np.abs(stft), )
axis.set_ylim([50, 300])
axis.set_xlim([times[0], times[-1]])

# take inspiration from this post: https://dsp.stackexchange.com/a/70136


def plot_with_matplotlib(audio, sr=penn.SAMPLE_RATE, pitch_pred=None, pred_times=None, ground_truth=None, periodicity=None, threshold=None):
"""
Plot stft to the given audio. Optionally put raw pitch data
or even thresholded periodicity data on top of it.
"""

# Create plot
figure, axis = plt.subplots(figsize=(7, 3))

# Make pretty
axis.spines['top'].set_visible(False)
axis.spines['right'].set_visible(False)
axis.spines['bottom'].set_visible(False)
axis.spines['left'].set_visible(False)

plot_stft(axis, audio, sr)

# figure.show()
plt.show()

0 comments on commit ce63670

Please sign in to comment.