Image-audio ZSL training#

Here’s a code example of training our image-audio siamese network.

from zsl.model import ImageAudioSiameseNetwork
from zsl.loss import TripletLoss
from zsl.dataset import ImageAudioDataset
from zsl.data_prep import prepare_zsl_split_img_audio
from zsl.transforms import get_transforms
from zsl.model_manager import ModelManager
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
(
    seen_img_classes,
    seen_img_path,
    seen_img_label,
    seen_audio_X_train, 
    seen_audio_y_train, 
    seen_audio_X_test, 
    seen_audio_y_test,
    unseen_img_classes,
    unseen_img_path,
    unseen_img_label,
    unseen_audio_X_train, 
    unseen_audio_y_train, 
    unseen_audio_X_test, 
    unseen_audio_y_test,
) = prepare_zsl_split_img_audio()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
siamese_network = ImageAudioSiameseNetwork().to(device)
triplet_loss = TripletLoss(0.5).to(device)
img_transforms, mel_transform = get_transforms()
seen_img_audio_dataset_tr = ImageAudioDataset(
    audio_path_list = seen_audio_X_train,
    audio_label_list = seen_audio_y_train,
    img_path_list = seen_img_path,
    img_label_list = seen_img_label,
    img_class_list = seen_img_classes,
    audio_transform = mel_transform,
    img_transform = img_transforms['train']
)
seen_img_audio_dataset_ts = ImageAudioDataset(
    audio_path_list = seen_audio_X_test,
    audio_label_list = seen_audio_y_test,
    img_path_list = seen_img_path,
    img_label_list = seen_img_label,
    img_class_list = seen_img_classes,
    audio_transform = mel_transform,
    img_transform = img_transforms['test']
)
seen_img_audio_dataloaders = {}
seen_img_audio_dataloaders['train'] = DataLoader(
    seen_img_audio_dataset_tr,
    batch_size=16, 
    num_workers=8,
    shuffle=True
)
seen_img_audio_dataloaders['test'] = DataLoader(
    seen_img_audio_dataset_ts,
    batch_size=16, 
    num_workers=8,
    shuffle=False
)

We include the hyperparameters of one of working training strategies.

optimizer_siamese = optim.Adam(siamese_network.parameters(), lr=0.01)
exp_lr_scheduler = ReduceLROnPlateau(
    optimizer_siamese, 
    mode='min', 
    factor=0.2, 
    patience=5, 
    verbose=True
)
curr_model = ModelManager(
    siamese_network, 
    triplet_loss, 
    optimizer_siamese, 
    exp_lr_scheduler, 
    device
)
curr_model.train_model(
    exp='img_audio', 
    dataloaders=seen_img_audio_dataloaders, 
    validation='random', 
    num_epochs=200
)