Word-audio ZSL evaluation#

To evaluate our trained model, we first construct the model and load the saved weights.

1. Load the trained model#

from zsl.model import WordAudioSiameseNetwork
from zsl.loss import TripletLoss
from zsl.dataset import WordAudioDataset
from zsl.data_prep import prepare_zsl_split_word_audio
from zsl.transforms import get_transforms
from zsl.model_manager import ModelManager
from zsl.utils import *
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from collections import defaultdict
from tqdm import tqdm
import numpy as np
import pandas as pd
import pickle
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
siamese_network = WordAudioSiameseNetwork().to(device)
triplet_loss = TripletLoss(0.5).to(device)
_, mel_transform = get_transforms()
optimizer_siamese = optim.Adam(siamese_network.parameters(), lr=0.01)
exp_lr_scheduler = ReduceLROnPlateau(
    optimizer_siamese, 
    mode='min', 
    factor=0.2, 
    patience=5, 
    verbose=True
)
curr_model = ModelManager(
    siamese_network, 
    triplet_loss, 
    optimizer_siamese, 
    exp_lr_scheduler, 
    device
)
curr_model.load_model('./checkpoints/word_audio_221129/w_ep-00195_l-0.0308.pth')
Loading model done: ./checkpoints/word_audio_221129/w_ep-00195_l-0.0308.pth

2. Prepare our data splits#

(
    seen_word_classes,
    seen_audio_X_train, 
    seen_audio_y_train, 
    seen_audio_X_test, 
    seen_audio_y_test,
    unseen_word_classes,
    unseen_audio_X_train, 
    unseen_audio_y_train, 
    unseen_audio_X_test, 
    unseen_audio_y_test,
    inst_word_emb_dict
) = prepare_zsl_split_word_audio()

3. Seen class prediction (supervised learning)#

3-1) Prepare the audio and image data from the seen classes.#

seen_word_audio_dataset_tr = WordAudioDataset(
    audio_path_list = seen_audio_X_train,
    audio_label_list = seen_audio_y_train,
    audio_transform = mel_transform,
    curr_word_classes = seen_word_classes,
    word_emb_dict = inst_word_emb_dict
)
seen_word_audio_dataset_ts = WordAudioDataset(
    audio_path_list = seen_audio_X_test,
    audio_label_list = seen_audio_y_test,
    audio_transform = mel_transform,
    curr_word_classes = seen_word_classes,
    word_emb_dict = inst_word_emb_dict
)
seen_word_audio_dataloaders = {}
seen_word_audio_dataloaders['train'] = DataLoader(
    seen_word_audio_dataset_tr,
    batch_size=16, 
    num_workers=8,
    shuffle=True
)
seen_word_audio_dataloaders['test'] = DataLoader(
    seen_word_audio_dataset_ts,
    batch_size=16, 
    num_workers=8,
    shuffle=True
)

3-2) Compute the audio embeddings using ZSL model#

(
    seen_ts_audio_embs, 
    seen_ts_word_embs, 
    seen_ts_labels, 
    seen_ts_loss
) = curr_model.infer_all(seen_word_audio_dataloaders, 'test')
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 437/437 [00:20<00:00, 21.14it/s]
test Loss: 0.0315

3-3) Get the word embeddings of seen classes on the ZSL embedding space.#

seen_class_to_word_emb = dict()
for _c in list(set(seen_audio_y_train)):
    for i in range(seen_ts_audio_embs.shape[0]):
        if seen_ts_labels[i] == _c:
            seen_class_to_word_emb[_c] = seen_ts_audio_embs[i]
            break
seen_class_to_word_embs = defaultdict(list)
for i in tqdm(range(seen_ts_audio_embs.shape[0])):
    seen_class_to_word_embs[seen_ts_labels[i]].append(seen_ts_word_embs[i])

seen_class_to_mean_word_emb = {}
for _c in seen_class_to_word_embs:
    seen_class_to_mean_word_emb[_c] = seen_class_to_word_embs[_c][0] # same embeddings for all instances
100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 437/437 [00:00<00:00, 455947.97it/s]

3-4) Evaluate the test set audio on seen classes (supervised learning evaluation)#

correct = 0
for idx in range(seen_ts_audio_embs.shape[0]):
    curr_pred_indices, _ = nearest_neighbor_clf_k(np.expand_dims(seen_ts_audio_embs[idx], 0), 
                                                  seen_class_to_word_emb, 
                                                  seen_word_classes,
                                                  1)
    if seen_ts_labels[idx] in curr_pred_indices:
        correct += 1
print('Top-1 acc :', correct / seen_ts_audio_embs.shape[0])
Top-1 acc : 0.8283752860411899
draw_pca(seen_ts_audio_embs, seen_ts_labels, seen_word_classes)
../_images/zsl_eval_word_audio_21_0.png

4. Unseen class prediction#

4-1) Prepare the audio and word data from the unseen classes.#

Note that, unlike many works from the literature, we are not going to use train/test split for unseen labeled audios.

So, we’re going to merge train and test set for unseen classes.

unseen_audio_X = unseen_audio_X_train + unseen_audio_X_test
unseen_audio_y = unseen_audio_y_train + unseen_audio_y_test
unseen_word_audio_dataset = WordAudioDataset(
    audio_path_list = unseen_audio_X,
    audio_label_list = unseen_audio_y,
    audio_transform = mel_transform,
    curr_word_classes = unseen_word_classes,
    word_emb_dict = inst_word_emb_dict
)
unseen_word_audio_dataloaders = {}
unseen_word_audio_dataloaders['all'] = DataLoader(
    unseen_word_audio_dataset,
    batch_size=16, 
    num_workers=8,
    shuffle=False
)

4-2) Compute the audio embeddings using ZSL model#

(
    unseen_audio_embs, 
    unseen_word_embs, 
    unseen_labels, 
    unseen_loss
) = curr_model.infer_all(unseen_word_audio_dataloaders, 'all')
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 738/738 [00:33<00:00, 21.85it/s]
all Loss: 0.4285

4-3) Compute the word embeddings using ZSL model#

unseen_class_to_word_emb = dict()
for _c in list(set(unseen_audio_y)):
    for i in range(unseen_audio_embs.shape[0]):
        if unseen_labels[i] == _c:
            unseen_class_to_word_emb[_c] = unseen_audio_embs[i]
            break
unseen_class_to_word_emb.keys()
dict_keys([0, 1, 2, 3, 4])

4-5) Evaluate the audios on unseen classes (zero-shot evaluation)#

correct = 0
for idx in range(unseen_audio_embs.shape[0]):
    curr_pred_indices, _ = nearest_neighbor_clf_k(np.expand_dims(unseen_audio_embs[idx], 0), 
                                            unseen_class_to_word_emb, 
                                            unseen_word_classes,
                                            1)
    if unseen_labels[idx] in curr_pred_indices:
        correct += 1
print('Top-1 acc :', correct / unseen_audio_embs.shape[0])
Top-1 acc : 0.3807588075880759
draw_pca(unseen_audio_embs, unseen_labels, unseen_word_classes)
../_images/zsl_eval_word_audio_33_0.png

5. Generalized ZSL evaluation#

5-1) Concatenate the seen and unseen classes#

all_word_classes = seen_word_classes + unseen_word_classes

5-2) Relabel the audios from unseen classes#

added_unseen_labels = [_n + 9 for _n in unseen_labels]

5-3) Get the word embedding vector for all classes#

all_class_to_word_embs = {}
for _c in list(seen_class_to_word_emb.keys()):
    all_class_to_word_embs[_c] = seen_class_to_word_emb[_c]

for _c in list(unseen_class_to_word_emb.keys()):
    all_class_to_word_embs[_c+9] = unseen_class_to_word_emb[_c]

5-4) Evaluate the audio on seen + unseen classes (generalized zero-shot evaluation)#

  • audio : whole set from unseen dataset

  • label : seen + unseen classes

all_class_to_word_embs = {}
for _c in list(seen_class_to_word_emb.keys()):
    all_class_to_word_embs[_c] = seen_class_to_word_emb[_c]

for _c in list(unseen_class_to_word_emb.keys()):
    all_class_to_word_embs[_c+9] = unseen_class_to_word_emb[_c]
correct = 0
for idx in range(unseen_audio_embs.shape[0]):
    curr_pred_indices, _ = nearest_neighbor_clf_k(np.expand_dims(unseen_audio_embs[idx], 0), 
                                            all_class_to_word_embs, 
                                            all_word_classes,
                                            1)
    if added_unseen_labels[idx] in curr_pred_indices:
        correct += 1
print('Top-1 acc :', correct / unseen_audio_embs.shape[0])
Top-1 acc : 0.13143631436314362

6. Visualize#

  • test set seen class audio + whole set unseen class audio

  • seen + unseen classes

all_audio_embs = np.concatenate([seen_ts_audio_embs, unseen_audio_embs])
all_labels = seen_ts_labels + added_unseen_labels
draw_pca(all_audio_embs, all_labels, all_word_classes)
../_images/zsl_eval_word_audio_46_0.png