logo
Browse Source

Debug for mono input & support all dtypes

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
8adefb352a
  1. 41
      torch_vggish.py
  2. 8
      vggish_input.py

41
torch_vggish.py

@ -53,16 +53,41 @@ class Vggish(NNOperator):
self.model.eval() self.model.eval()
self.model.to(self.device) self.model.to(self.device)
def __call__(self, datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int')])]) -> numpy.ndarray:
audios = numpy.hstack([item.audio for item in datas])
def __call__(self,
datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int'), ('layout', 'str')])]):
audios = [item.audio for item in datas]
sr = datas[0].sample_rate sr = datas[0].sample_rate
audio_array = numpy.reshape(audios, (-1, 1))
audio_tensors = self.preprocess(audio_array, sr).to(self.device)
layout = datas[0].layout
audio_tensors = self.preprocess(audios, sr, layout).to(self.device)
features = self.model(audio_tensors) features = self.model(audio_tensors)
outs = features.to("cpu") outs = features.to("cpu")
return [AudioOutput(outs.detach().numpy())] return [AudioOutput(outs.detach().numpy())]
def preprocess(self, audio: numpy.ndarray, sr: int = None):
ii = numpy.iinfo(audio.dtype)
samples = 2 * audio / (ii.max - ii.min + 1)
return vggish_input.waveform_to_examples(samples, sr, return_tensor=True)
def preprocess(self, frames: List[numpy.ndarray], sr, layout):
audio = numpy.hstack(frames)
if layout == 'stereo':
audio = audio.reshape(-1, 2)
audio = self.int2float(audio)
try:
audio = audio.transpose()
audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True)
return audio_tensors
except Exception as e:
log.error("Fail to load audio data.")
raise e
def int2float(self, wav: numpy.ndarray, dtype: str = 'float64'):
"""
Convert audio data from int to float.
The input dtype must be integers.
The output dtype is controlled by the parameter `dtype`, defaults to 'float64'.
The code is inspired by https://github.com/mgeier/python-audio/blob/master/audio-files/utility.py
"""
assert wav.dtype.kind in 'iu'
dtype = numpy.dtype(dtype)
assert dtype.kind == 'f'
ii = numpy.iinfo(wav.dtype)
abs_max = 2 ** (ii.bits - 1)
offset = ii.min + abs_max
return (wav.astype(dtype) - offset) / abs_max

8
vggish_input.py

@ -28,8 +28,8 @@ def waveform_to_examples(data, sample_rate, return_tensor=True):
"""Converts audio waveform into an array of examples for VGGish. """Converts audio waveform into an array of examples for VGGish.
Args: Args:
data: np.array of either one dimension (mono) or two dimensions
(multi-channel, with the outer dimension representing channels).
data:
np.array of 2 dimension, second of which is number of channels.
Each sample is generally expected to lie in the range [-1.0, +1.0], Each sample is generally expected to lie in the range [-1.0, +1.0],
although this is not required. although this is not required.
sample_rate: Sample rate of data. sample_rate: Sample rate of data.
@ -43,8 +43,10 @@ def waveform_to_examples(data, sample_rate, return_tensor=True):
""" """
# Convert to mono. # Convert to mono.
if len(data.shape) > 1:
if data.shape[1] > 1:
data = np.mean(data, axis=1) data = np.mean(data, axis=1)
else:
data = data.squeeze(1)
# Resample to the rate assumed by VGGish. # Resample to the rate assumed by VGGish.
if sample_rate != vggish_params.SAMPLE_RATE: if sample_rate != vggish_params.SAMPLE_RATE:
data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)

Loading…
Cancel
Save