logo
Browse Source

Debug for mono input & support all dtypes

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
7b9825108e
  1. 8
      torch_vggish.py
  2. 8
      vggish_input.py

8
torch_vggish.py

@ -64,12 +64,14 @@ class Vggish(NNOperator):
return [AudioOutput(outs.detach().numpy())]
def preprocess(self, frames: List[numpy.ndarray], sr, layout):
audio = numpy.hstack(frames)
if layout == 'stereo':
audio = audio.reshape(-1, 2)
frames = [frame.reshape(-1, 2) for frame in frames]
audio = numpy.vstack(frames)
else:
audio = numpy.hstack(frames)
audio = audio.transpose()
audio = self.int2float(audio)
try:
audio = audio.transpose()
audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True)
return audio_tensors
except Exception as e:

8
vggish_input.py

@ -28,8 +28,8 @@ def waveform_to_examples(data, sample_rate, return_tensor=True):
"""Converts audio waveform into an array of examples for VGGish.
Args:
data:
np.array of 2 dimension, second of which is number of channels.
data: np.array of either one dimension (mono) or two dimensions
(multi-channel, with the outer dimension representing channels).
Each sample is generally expected to lie in the range [-1.0, +1.0],
although this is not required.
sample_rate: Sample rate of data.
@ -43,10 +43,8 @@ def waveform_to_examples(data, sample_rate, return_tensor=True):
"""
# Convert to mono.
if data.shape[1] > 1:
if len(data.shape) > 1:
data = np.mean(data, axis=1)
else:
data = data.squeeze(1)
# Resample to the rate assumed by VGGish.
if sample_rate != vggish_params.SAMPLE_RATE:
data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)

Loading…
Cancel
Save