Training a Few-Shot Instrument Classifier#

In this tutorial, we will be training a few-shot learning model for music information retrieval. We will use the PyTorch framework and the PyTorch Lightning library.

In previous chapters, we introduced several essential concepts for training a few-shot learning model. We learned how to create a class-conditional dataset for few-shot learning, using the TinySOL dataset. We also learned how to construct few-shot learning episodes from a class-conditional dataset, using an episode dataset. Finally, we learned how to create a Prototypical Network, given any backbone model architecture.

Now, it’s time to put all these pieces together and train our few-shot model.

We will train a few-shot instrument classifier to solve 5-way, 5-shot classification tasks. This means that the model will be trained to classify 5 different instrument classes at a time, using 5 support examples per class.

Requirements (hidden)

%%capture
!pip install torch
!pip install pytorch-lightning
!pip install numpy
!pip install --no-cache-dir --upgrade music-fsl
import torch
import numpy as np
from torch import nn
import pytorch_lightning as pl
from torchmetrics import Accuracy

from music_fsl.backbone import Backbone
from music_fsl.data import TinySOL, EpisodeDataset
from music_fsl.protonet import PrototypicalNet

Hyperparameters#

We’ll define some hyperparameters below.

sample_rate = 16000 # sample rate of the audio
n_way= 5 # number of classes per episode
n_support = 5 # number of support examples per class
n_query = 20 # number of samples per class to use as query
n_train_episodes = int(50000) # number of episodes to generate for training
n_val_episodes = 100 # number of episodes to generate for validation
num_workers = 10 # number of workers to use for data loading

Data#

Split the dataset into train and test sets#

Since we’re training a few-shot model to generalize to unseen instrument classes, we’ll need to make a class-conditional split of the TinySOL dataset. This means we’ll keep most of the instrument classes in the training set, and leave out a few for the test set.

We’ll use an arbitrary split, as shown below.

TRAIN_INSTRUMENTS = [
    'French Horn', 
    'Violin', 
    'Flute', 
    'Contrabass', 
    'Trombone', 
    'Cello', 
    'Clarinet in Bb', 
    'Oboe',
    'Accordion'
]

TEST_INSTRUMENTS = [
    'Bassoon',
    'Viola',
    'Trumpet in C',
    'Bass Tuba',
    'Alto Saxophone'
]

Load the datasets#

Let’s load the train and test sets, using the class-conditional TinySOL dataset class we implemented in the previous chapter.

# initialize the datasets
train_data = TinySOL(
    instruments=TRAIN_INSTRUMENTS, 
    sample_rate=sample_rate
)

val_data = TinySOL(
    instruments=TEST_INSTRUMENTS, 
    sample_rate=sample_rate
)
INFO: Downloading ['audio', 'annotations'] to /home/hugo/mir_datasets/tinysol
INFO: [audio] downloading TinySOL.tar.gz
INFO: /home/hugo/mir_datasets/tinysol/audio/TinySOL.tar.gz already exists and will not be downloaded. Rerun with force_overwrite=True to delete this file and force the download.
INFO: [annotations] downloading TinySOL_metadata.csv
INFO: /home/hugo/mir_datasets/tinysol/annotation/TinySOL_metadata.csv already exists and will not be downloaded. Rerun with force_overwrite=True to delete this file and force the download.
INFO: Downloading ['audio', 'annotations'] to /home/hugo/mir_datasets/tinysol
INFO: [audio] downloading TinySOL.tar.gz
INFO: /home/hugo/mir_datasets/tinysol/audio/TinySOL.tar.gz already exists and will not be downloaded. Rerun with force_overwrite=True to delete this file and force the download.
INFO: [annotations] downloading TinySOL_metadata.csv
INFO: /home/hugo/mir_datasets/tinysol/annotation/TinySOL_metadata.csv already exists and will not be downloaded. Rerun with force_overwrite=True to delete this file and force the download.

Create the Episode Datasets#

Next, we’ll initialize the episode datasets for the train and test sets.

As we learned in the previous chapter, we can use the EpisodeDataset class to create a few-shot learning episode from a dataset. The EpisodeDataset wraps around the ClassConditionalDataset to retrieve few-shot learning episodes, given the dataset and the number of classes and support examples per class.

# initialize the episode datasets
train_episodes = EpisodeDataset(
    dataset=train_data, 
    n_way=n_way, 
    n_support=n_support,
    n_query=n_query, 
    n_episodes=n_train_episodes
)

val_episodes = EpisodeDataset(
    dataset=val_data, 
    n_way=n_way, 
    n_support=n_support,
    n_query=n_query, 
    n_episodes=n_val_episodes
)

Dataloaders#

We can pass the episode datasets to a PyTorch DataLoader to create a dataloader for the train and test sets. Since our episodes already contained a batch of examples in the support and query sets, we set the batch size to None in the dataloader.

# initialize the dataloaders
from torch.utils.data import DataLoader
train_loader = DataLoader(
    train_episodes, 
    batch_size=None,
    num_workers=num_workers
)

val_loader = DataLoader(
    val_episodes, 
    batch_size=None,
    num_workers=num_workers
)

Build the Prototypical Network#

Let’s instantiate the prototypical network we coded up in the last chapter. As a reminder, the prototypical network will take the support and query sets as input, and will return a set of logits for each query example.

# build models
backbone = Backbone(sample_rate=sample_rate)
protonet = PrototypicalNet(backbone)

protonet
PrototypicalNet(
  (backbone): Backbone(
    (melspec): MelSpectrogram(
      (spectrogram): Spectrogram()
      (mel_scale): MelScale()
    )
    (conv1): ConvBlock(
      (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (gn): GroupNorm(8, 32, eps=1e-05, affine=True)
      (relu): ReLU()
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (conv2): ConvBlock(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (gn): GroupNorm(16, 64, eps=1e-05, affine=True)
      (relu): ReLU()
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (conv3): ConvBlock(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (gn): GroupNorm(32, 128, eps=1e-05, affine=True)
      (relu): ReLU()
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (conv4): ConvBlock(
      (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (gn): GroupNorm(64, 256, eps=1e-05, affine=True)
      (relu): ReLU()
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (conv5): ConvBlock(
      (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), padding=same)
      (gn): GroupNorm(128, 512, eps=1e-05, affine=True)
      (relu): ReLU()
      (maxpool): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
    )
  )
)

Setup a LightningModule for Training#

Next, we will define a PyTorch Lightning LightningModule to train our few-shot learning model. We will name this module FewShotLearner. The LightningModule is a PyTorch Lightning class that provides several useful methods for training, validation, and testing.

Because there is an abundance of fantastic Pytorch Lightning tutorials, we will not go into too much detail about the LightningModule. If you are interested in learning more about PyTorch Lightning, check out the PyTorch Lightning Tutorials.

Setting up the LightningModule#

In this step, we will define the FewShotLearner class, which is a PyTorch Lightning LightningModule. This class will be responsible for training our few-shot learning model. It takes a few arguments in its constructor, including the PrototypicalNet model that we defined earlier and a learning rate for the optimizer. We also define a loss function and some evaluation metrics in the constructor. In this case, we use the cross-entropy loss and accuracy as our loss and metrics, respectively. The LightningModule provides several useful methods for training, validation, and testing, making it a convenient way to train our few-shot learning model.

class FewShotLearner(pl.LightningModule):

    def __init__(self, 
        protonet: nn.Module, 
        learning_rate: float = 1e-3,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.protonet = protonet
        self.learning_rate = learning_rate

        self.loss = nn.CrossEntropyLoss()
        self.metrics = nn.ModuleDict({
            'accuracy': Accuracy()
        })

The Training (And Eval) Step#

In the FewShotLearner class, we’ll define a step method that performs the actual training step. This method takes in a batch of data, the batch index, and a string tag that indicates whether the step is for training, validation, or testing. It unpacks the batch into the support and query sets, then uses the PrototypicalNet to make predictions on the query set. It computes the loss and evaluation metrics, and logs the output dictionary. The training_step, validation_step, and test_step methods simply call the step method with the appropriate tag.

The step method is where the majority of the logic for training the model resides. It is here that we make predictions with the PrototypicalNet, compute the loss and evaluation metrics, and log the output. By defining a separate step method, we can easily reuse this logic for the training, validation, and testing steps.

def step(self, batch, batch_idx, tag: str):
    support, query = batch

    logits = self.protonet(support, query)
    loss = self.loss(logits, query["target"])

    output = {"loss": loss}
    for k, metric in self.metrics.items():
        output[k] = metric(logits, query["target"])

    for k, v in output.items():
        self.log(f"{k}/{tag}", v)
    return output

def training_step(self, batch, batch_idx):
    return self.step(batch, batch_idx, "train")

def validation_step(self, batch, batch_idx):
    return self.step(batch, batch_idx, "val")

def test_step(self, batch, batch_idx):
    return self.step(batch, batch_idx, "test")

Expand the code below to see the full implementation of the FewShotLearner class.

class FewShotLearner(pl.LightningModule):

    def __init__(self, 
        protonet: nn.Module, 
        learning_rate: float = 1e-3,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.protonet = protonet
        self.learning_rate = learning_rate

        self.loss = nn.CrossEntropyLoss()
        self.metrics = nn.ModuleDict({
            'accuracy': Accuracy()
        })

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def step(self, batch, batch_idx, tag: str):
        support, query = batch

        logits = self.protonet(support, query)
        loss = self.loss(logits, query["target"])

        output = {"loss": loss}
        for k, metric in self.metrics.items():
            output[k] = metric(logits, query["target"])

        for k, v in output.items():
            self.log(f"{k}/{tag}", v)
        return output

    def training_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "train")
    
    def validation_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "test")

Putting it all together – Training the Model#

Now that we have defined the FewShotLearner class, we can instantiate it and train the model. We’ll use the Trainer class from PyTorch Lightning to train the model.

learner = FewShotLearner(protonet)
print(learner)
FewShotLearner(
  (protonet): PrototypicalNet(
    (backbone): Backbone(
      (melspec): MelSpectrogram(
        (spectrogram): Spectrogram()
        (mel_scale): MelScale()
      )
      (conv1): ConvBlock(
        (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
        (gn): GroupNorm(8, 32, eps=1e-05, affine=True)
        (relu): ReLU()
        (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (conv2): ConvBlock(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
        (gn): GroupNorm(16, 64, eps=1e-05, affine=True)
        (relu): ReLU()
        (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (conv3): ConvBlock(
        (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
        (gn): GroupNorm(32, 128, eps=1e-05, affine=True)
        (relu): ReLU()
        (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (conv4): ConvBlock(
        (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=same)
        (gn): GroupNorm(64, 256, eps=1e-05, affine=True)
        (relu): ReLU()
        (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (conv5): ConvBlock(
        (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), padding=same)
        (gn): GroupNorm(128, 512, eps=1e-05, affine=True)
        (relu): ReLU()
        (maxpool): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
      )
    )
  )
  (loss): CrossEntropyLoss()
  (metrics): ModuleDict(
    (accuracy): Accuracy()
  )
)
/home/hugo/conda/envs/hugo/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py:262: UserWarning: Attribute 'protonet' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['protonet'])`.
  rank_zero_warn(

The code cell below will train the model for as many episodes are in the training episode dataset. On a GPU, training should take anytime between 20 minutes to an hour.

Note that the Lightning Trainer will automatically log the loss and metrics to Tensorboard. You can view the Tensorboard logs by running the following command in the terminal:

tensorboard --logdir logs/
# set up the trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.profiler import SimpleProfiler

trainer = pl.Trainer(
    gpus=1 if torch.cuda.is_available() else 0,
    max_epochs=1,
    log_every_n_steps=1, 
    val_check_interval=50,
    profiler=SimpleProfiler(
        filename="profile.txt",
    ), 
    logger=TensorBoardLogger(
        save_dir=".",
        name="logs"
    ), 
)

# train!
trainer.fit(learner, train_loader, val_dataloaders=val_loader)

Once your model has finished training, the Trainer will save the model checkpoint to the logs directory. In the final chapter of this coding tutorial, we’ll load our trained model, evaluate it on our evaluation set, and visualize the embedding space of our prototypical network.