This repository provides codes and checkpoints for extracting audio and text representations using BLAT (Bootstrapping Language-Audio pre-training based on Tag-guided synthetic data) models.
First install the missing dependencies: pip install -r requirements
. Then download the pre-trained weights:
$ wget https://zenodo.org/record/8192397/files/blat_cnn14_bertm.pth -O checkpoints/blat_cnn14_bertm/model.pth
Refer to inference.py
for the usage:
from inference import load_blat, encode_audio, encode_text
import numpy as np
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt_dir = "./checkpoints/blat_cnn14_bertm"
model, text_tokenizer, max_length = load_blat(ckpt_dir, device)
audio = "./example.wav"
text = ['a dog barks', 'a man is speaking', 'birds are chirping']
with torch.no_grad():
audio_emb = encode_audio(model, audio, device)
text_emb = encode_text(model, text_tokenizer, text, device, max_length)
sim = np.matmul(audio_emb, text_emb.T)
print(sim) # [[0.56612206 0.18251741 0.15569025]]
If you find the model useful, please cite this paper:
@inproceedings{xu2023blat,
title={Blat: Bootstrapping language-audio pre-training based on audioset tag-guided synthetic data},
author={Xu, Xuenan and Zhang, Zhiling and Zhou, Zelin and Zhang, Pingyue and Xie, Zeyu and Wu, Mengyue and Zhu, Kenny Q},
booktitle={Proceedings of the 31st ACM International Conference on Multimedia},
pages={2756--2764},
year={2023}
}