Image-audio ZSL training
Image-audio ZSL training#
Here’s a code example of training our image-audio siamese network.
from zsl.model import ImageAudioSiameseNetwork
from zsl.loss import TripletLoss
from zsl.dataset import ImageAudioDataset
from zsl.data_prep import prepare_zsl_split_img_audio
from zsl.transforms import get_transforms
from zsl.model_manager import ModelManager
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
(
seen_img_classes,
seen_img_path,
seen_img_label,
seen_audio_X_train,
seen_audio_y_train,
seen_audio_X_test,
seen_audio_y_test,
unseen_img_classes,
unseen_img_path,
unseen_img_label,
unseen_audio_X_train,
unseen_audio_y_train,
unseen_audio_X_test,
unseen_audio_y_test,
) = prepare_zsl_split_img_audio()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
siamese_network = ImageAudioSiameseNetwork().to(device)
triplet_loss = TripletLoss(0.5).to(device)
img_transforms, mel_transform = get_transforms()
seen_img_audio_dataset_tr = ImageAudioDataset(
audio_path_list = seen_audio_X_train,
audio_label_list = seen_audio_y_train,
img_path_list = seen_img_path,
img_label_list = seen_img_label,
img_class_list = seen_img_classes,
audio_transform = mel_transform,
img_transform = img_transforms['train']
)
seen_img_audio_dataset_ts = ImageAudioDataset(
audio_path_list = seen_audio_X_test,
audio_label_list = seen_audio_y_test,
img_path_list = seen_img_path,
img_label_list = seen_img_label,
img_class_list = seen_img_classes,
audio_transform = mel_transform,
img_transform = img_transforms['test']
)
seen_img_audio_dataloaders = {}
seen_img_audio_dataloaders['train'] = DataLoader(
seen_img_audio_dataset_tr,
batch_size=16,
num_workers=8,
shuffle=True
)
seen_img_audio_dataloaders['test'] = DataLoader(
seen_img_audio_dataset_ts,
batch_size=16,
num_workers=8,
shuffle=False
)
We include the hyperparameters of one of working training strategies.
optimizer_siamese = optim.Adam(siamese_network.parameters(), lr=0.01)
exp_lr_scheduler = ReduceLROnPlateau(
optimizer_siamese,
mode='min',
factor=0.2,
patience=5,
verbose=True
)
curr_model = ModelManager(
siamese_network,
triplet_loss,
optimizer_siamese,
exp_lr_scheduler,
device
)
curr_model.train_model(
exp='img_audio',
dataloaders=seen_img_audio_dataloaders,
validation='random',
num_epochs=200
)