diff --git a/README.md b/README.md index 6a7c5fc..8393c4e 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,41 @@ -# torch-vggish +# VGGish Embedding Operator (Pytorch) -This is another test repo \ No newline at end of file +Authors: Jael Gu + +## Overview + +This operator uses reads the waveform of an audio file and then applies VGGish to extract features. The original VGGish model is built on top of Tensorflow.[1] This operator converts VGGish into **Pytorch**. It generates a set of vectors given an input. Each vector represents features of a non-overlapping clip with a fixed length of 0.96s and each clip is composed of 64 mel bands and 96 frames. The model is pre-trained with a large scale of audio dataset [AudioSet](https://research.google.com/audioset). As suggested, this model is suitable to extract features at high level or warm up a larger model. + +## Interface + +```python +__call__(self, filepath: str) +``` + +**Args:** + +- filepath: + - the input audio path + - supported types: str + +**Returns:** + +The Operator returns a tuple Tuple[('embs', numpy.ndarray)] containing following fields: + +- embs: + - embeddings of the audio + - data type: `numpy.ndarray` + - shape: (num_clips,128) + +## Requirements + +You can get the required python package by [requirements.txt](./requirements.txt). + +## How it works + +The `towhee/torch-vggish` Operator implements the function of audio embedding, which can be added to a towhee pipeline. For example, it is the key operator of the pipeline [audio-embedding-vggish](https://hub.towhee.io/towhee/audio-embedding-vggish). + +## Reference + +[1]. https://github.com/tensorflow/models/tree/master/research/audioset/vggish +[2]. https://tfhub.dev/google/vggish/1 diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..37f5bd7 --- /dev/null +++ b/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/pytorch/__init__.py b/pytorch/__init__.py new file mode 100644 index 0000000..b661573 --- /dev/null +++ b/pytorch/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +# For requirements. +try: + import timm +except ModuleNotFoundError: + os.system('pip install timm') + +from timm.data import resolve_data_config +from timm.data.transforms_factory import create_transform \ No newline at end of file diff --git a/pytorch/mel_features.py b/pytorch/mel_features.py new file mode 100644 index 0000000..ac58fb5 --- /dev/null +++ b/pytorch/mel_features.py @@ -0,0 +1,223 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Defines routines to compute mel spectrogram features from audio waveform.""" + +import numpy as np + + +def frame(data, window_length, hop_length): + """Convert array into a sequence of successive possibly overlapping frames. + + An n-dimensional array of shape (num_samples, ...) is converted into an + (n+1)-D array of shape (num_frames, window_length, ...), where each frame + starts hop_length points after the preceding one. + + This is accomplished using stride_tricks, so the original data is not + copied. However, there is no zero-padding, so any incomplete frames at the + end are not included. + + Args: + data: np.array of dimension N >= 1. + window_length: Number of samples in each frame. + hop_length: Advance (in samples) between each window. + + Returns: + (N+1)-D np.array with as many rows as there are complete frames that can be + extracted. + """ + num_samples = data.shape[0] + num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length)) + shape = (num_frames, window_length) + data.shape[1:] + strides = (data.strides[0] * hop_length,) + data.strides + return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) + + +def periodic_hann(window_length): + """Calculate a "periodic" Hann window. + + The classic Hann window is defined as a raised cosine that starts and + ends on zero, and where every value appears twice, except the middle + point for an odd-length window. Matlab calls this a "symmetric" window + and np.hanning() returns it. However, for Fourier analysis, this + actually represents just over one cycle of a period N-1 cosine, and + thus is not compactly expressed on a length-N Fourier basis. Instead, + it's better to use a raised cosine that ends just before the final + zero value - i.e. a complete cycle of a period-N cosine. Matlab + calls this a "periodic" window. This routine calculates it. + + Args: + window_length: The number of points in the returned window. + + Returns: + A 1D np.array containing the periodic hann window. + """ + return 0.5 - (0.5 * np.cos(2 * np.pi / window_length * + np.arange(window_length))) + + +def stft_magnitude(signal, fft_length, + hop_length=None, + window_length=None): + """Calculate the short-time Fourier transform magnitude. + + Args: + signal: 1D np.array of the input time-domain signal. + fft_length: Size of the FFT to apply. + hop_length: Advance (in samples) between each frame passed to FFT. + window_length: Length of each block of samples to pass to FFT. + + Returns: + 2D np.array where each row contains the magnitudes of the fft_length/2+1 + unique values of the FFT for the corresponding frame of input samples. + """ + frames = frame(signal, window_length, hop_length) + # Apply frame window to each frame. We use a periodic Hann (cosine of period + # window_length) instead of the symmetric Hann of np.hanning (period + # window_length-1). + window = periodic_hann(window_length) + windowed_frames = frames * window + return np.abs(np.fft.rfft(windowed_frames, int(fft_length))) + + +# Mel spectrum constants and functions. +_MEL_BREAK_FREQUENCY_HERTZ = 700.0 +_MEL_HIGH_FREQUENCY_Q = 1127.0 + + +def hertz_to_mel(frequencies_hertz): + """Convert frequencies to mel scale using HTK formula. + + Args: + frequencies_hertz: Scalar or np.array of frequencies in hertz. + + Returns: + Object of same size as frequencies_hertz containing corresponding values + on the mel scale. + """ + return _MEL_HIGH_FREQUENCY_Q * np.log( + 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) + + +def spectrogram_to_mel_matrix(num_mel_bins=20, + num_spectrogram_bins=129, + audio_sample_rate=8000, + lower_edge_hertz=125.0, + upper_edge_hertz=3800.0): + """Return a matrix that can post-multiply spectrogram rows to make mel. + + Returns a np.array matrix A that can be used to post-multiply a matrix S of + spectrogram values (STFT magnitudes) arranged as frames x bins to generate a + "mel spectrogram" M of frames x num_mel_bins. M = S A. + + The classic HTK algorithm exploits the complementarity of adjacent mel bands + to multiply each FFT bin by only one mel weight, then add it, with positive + and negative signs, to the two adjacent mel bands to which that bin + contributes. Here, by expressing this operation as a matrix multiply, we go + from num_fft multiplies per frame (plus around 2*num_fft adds) to around + num_fft^2 multiplies and adds. However, because these are all presumably + accomplished in a single call to np.dot(), it's not clear which approach is + faster in Python. The matrix multiplication has the attraction of being more + general and flexible, and much easier to read. + + Args: + num_mel_bins: How many bands in the resulting mel spectrum. This is + the number of columns in the output matrix. + num_spectrogram_bins: How many bins there are in the source spectrogram + data, which is understood to be fft_size/2 + 1, i.e. the spectrogram + only contains the nonredundant FFT bins. + audio_sample_rate: Samples per second of the audio at the input to the + spectrogram. We need this to figure out the actual frequencies for + each spectrogram bin, which dictates how they are mapped into mel. + lower_edge_hertz: Lower bound on the frequencies to be included in the mel + spectrum. This corresponds to the lower edge of the lowest triangular + band. + upper_edge_hertz: The desired top edge of the highest frequency band. + + Returns: + An np.array with shape (num_spectrogram_bins, num_mel_bins). + + Raises: + ValueError: if frequency edges are incorrectly ordered or out of range. + """ + nyquist_hertz = audio_sample_rate / 2. + if lower_edge_hertz < 0.0: + raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz) + if lower_edge_hertz >= upper_edge_hertz: + raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % + (lower_edge_hertz, upper_edge_hertz)) + if upper_edge_hertz > nyquist_hertz: + raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % + (upper_edge_hertz, nyquist_hertz)) + spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins) + spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz) + # The i'th mel band (starting from i=1) has center frequency + # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge + # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in + # the band_edges_mel arrays. + band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz), + hertz_to_mel(upper_edge_hertz), num_mel_bins + 2) + # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins + # of spectrogram values. + mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins)) + for i in range(num_mel_bins): + lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] + # Calculate lower and upper slopes for every spectrogram bin. + # Line segments are linear in the *mel* domain, not hertz. + lower_slope = ((spectrogram_bins_mel - lower_edge_mel) / + (center_mel - lower_edge_mel)) + upper_slope = ((upper_edge_mel - spectrogram_bins_mel) / + (upper_edge_mel - center_mel)) + # .. then intersect them with each other and zero. + mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, + upper_slope)) + # HTK excludes the spectrogram DC bin; make sure it always gets a zero + # coefficient. + mel_weights_matrix[0, :] = 0.0 + return mel_weights_matrix + + +def log_mel_spectrogram(data, + audio_sample_rate=8000, + log_offset=0.0, + window_length_secs=0.025, + hop_length_secs=0.010, + **kwargs): + """Convert waveform to a log magnitude mel-frequency spectrogram. + + Args: + data: 1D np.array of waveform data. + audio_sample_rate: The sampling rate of data. + log_offset: Add this to values when taking log to avoid -Infs. + window_length_secs: Duration of each window to analyze. + hop_length_secs: Advance between successive analysis windows. + **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix. + + Returns: + 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank + magnitudes for successive frames. + """ + window_length_samples = int(round(audio_sample_rate * window_length_secs)) + hop_length_samples = int(round(audio_sample_rate * hop_length_secs)) + fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) + spectrogram = stft_magnitude( + data, + fft_length=fft_length, + hop_length=hop_length_samples, + window_length=window_length_samples) + mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix( + num_spectrogram_bins=spectrogram.shape[1], + audio_sample_rate=audio_sample_rate, **kwargs)) + return np.log(mel_spectrogram + log_offset) diff --git a/pytorch/model.py b/pytorch/model.py new file mode 100644 index 0000000..02bff05 --- /dev/null +++ b/pytorch/model.py @@ -0,0 +1,71 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import numpy as np +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent)) + +import vggish_input + +class Model(nn.Module): + """ + PyTorch model class + """ + def __init__(self): + super().__init__() + self.features = nn.Sequential( + nn.Conv2d(1, 64, 3, 1, 1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 128, 3, 1, 1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + nn.Conv2d(128, 256, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, 1, 1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + nn.Conv2d(256, 512, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, 1, 1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2)) + self.embeddings = nn.Sequential( + nn.Linear(512 * 24, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, 128), + #nn.ReLU(inplace=True) + ) + + def forward(self, x): + x = self.features(x).permute(0, 2, 3, 1).contiguous() + x = x.view(x.size(0), -1) + x = self.embeddings(x) + return x + + def preprocess(self, audio_path: str): + audio_tensors = vggish_input.wavfile_to_examples(audio_path) + return audio_tensors + + def train(self): + """ + For training model + """ + pass diff --git a/pytorch/vggish.pth b/pytorch/vggish.pth new file mode 100644 index 0000000..5e57801 Binary files /dev/null and b/pytorch/vggish.pth differ diff --git a/pytorch/vggish_input.py b/pytorch/vggish_input.py new file mode 100644 index 0000000..256f0d1 --- /dev/null +++ b/pytorch/vggish_input.py @@ -0,0 +1,98 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Compute input examples for VGGish from audio waveform.""" + +# Modification: Return torch tensors rather than numpy arrays +import torch + +import numpy as np +import resampy + +import mel_features +import vggish_params + +import soundfile as sf + + +def waveform_to_examples(data, sample_rate, return_tensor=True): + """Converts audio waveform into an array of examples for VGGish. + + Args: + data: np.array of either one dimension (mono) or two dimensions + (multi-channel, with the outer dimension representing channels). + Each sample is generally expected to lie in the range [-1.0, +1.0], + although this is not required. + sample_rate: Sample rate of data. + return_tensor: Return data as a Pytorch tensor ready for VGGish + + Returns: + 3-D np.array of shape [num_examples, num_frames, num_bands] which represents + a sequence of examples, each of which contains a patch of log mel + spectrogram, covering num_frames frames of audio and num_bands mel frequency + bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. + + """ + # Convert to mono. + if len(data.shape) > 1: + data = np.mean(data, axis=1) + # Resample to the rate assumed by VGGish. + if sample_rate != vggish_params.SAMPLE_RATE: + data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) + + # Compute log mel spectrogram features. + log_mel = mel_features.log_mel_spectrogram( + data, + audio_sample_rate=vggish_params.SAMPLE_RATE, + log_offset=vggish_params.LOG_OFFSET, + window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS, + hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS, + num_mel_bins=vggish_params.NUM_MEL_BINS, + lower_edge_hertz=vggish_params.MEL_MIN_HZ, + upper_edge_hertz=vggish_params.MEL_MAX_HZ) + + # Frame features into examples. + features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS + example_window_length = int(round( + vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) + example_hop_length = int(round( + vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) + log_mel_examples = mel_features.frame( + log_mel, + window_length=example_window_length, + hop_length=example_hop_length) + + if return_tensor: + log_mel_examples = torch.tensor( + log_mel_examples, requires_grad=True)[:, None, :, :].float() + + return log_mel_examples + + +def wavfile_to_examples(wav_file, return_tensor=True): + """Convenience wrapper around waveform_to_examples() for a common WAV format. + + Args: + wav_file: String path to a file, or a file-like object. The file + is assumed to contain WAV audio data with signed 16-bit PCM samples. + torch: Return data as a Pytorch tensor ready for VGGish + + Returns: + See waveform_to_examples. + """ + wav_data, sr = sf.read(wav_file, dtype='int16') + assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype + samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] + return waveform_to_examples(samples, sr, return_tensor) diff --git a/pytorch/vggish_params.py b/pytorch/vggish_params.py new file mode 100644 index 0000000..526784b --- /dev/null +++ b/pytorch/vggish_params.py @@ -0,0 +1,53 @@ +# Copyright 2017 The TensorFlow Authors All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Global parameters for the VGGish model. + +See vggish_slim.py for more information. +""" + +# Architectural constants. +NUM_FRAMES = 96 # Frames in input mel-spectrogram patch. +NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch. +EMBEDDING_SIZE = 128 # Size of embedding layer. + +# Hyperparameters used in feature and example generation. +SAMPLE_RATE = 16000 +STFT_WINDOW_LENGTH_SECONDS = 0.025 +STFT_HOP_LENGTH_SECONDS = 0.010 +NUM_MEL_BINS = NUM_BANDS +MEL_MIN_HZ = 125 +MEL_MAX_HZ = 7500 +LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram. +EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames +EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap. + +# Parameters used for embedding postprocessing. +PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors' +PCA_MEANS_NAME = 'pca_means' +QUANTIZE_MIN_VAL = -2.0 +QUANTIZE_MAX_VAL = +2.0 + +# Hyperparameters used in training. +INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. +LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer. +ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. + +# Names of ops, tensors, and features. +INPUT_OP_NAME = 'vggish/input_features' +INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0' +OUTPUT_OP_NAME = 'vggish/embedding' +OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0' +AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding' diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..bd7fe4b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +torch==1.9.0 +numpy==1.19.5 +soundfile \ No newline at end of file diff --git a/torch_vggish.py b/torch_vggish.py new file mode 100644 index 0000000..388880d --- /dev/null +++ b/torch_vggish.py @@ -0,0 +1,50 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys +import torch +from typing import NamedTuple +from pathlib import Path +import numpy +import os + +from towhee.operator import Operator + +import warnings +warnings.filterwarnings("ignore") + +class TorchVggish(Operator): + """ + """ + + def __init__(self, framework: str = 'pytorch') -> None: + super().__init__() + if framework == 'pytorch': + import importlib.util + path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') + opname = os.path.basename(str(Path(__file__))).split('.')[0] + spec = importlib.util.spec_from_file_location(opname, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + self.model = module.Model() + + path = str(Path(__file__).parent) + self.model.load_state_dict(torch.load(path + '/pytorch/vggish.pth', map_location=torch.device('cpu'))) + + def __call__(self, audio_path: str) -> NamedTuple('Outputs', [('embs', numpy.ndarray)]): + audio_tensors = self.model.preprocess(audio_path) + features = self.model.forward(audio_tensors) + Outputs = NamedTuple('Outputs', [('embs', numpy.ndarray)]) + return Outputs(features.detach().numpy())