diff --git a/torch_vggish.py b/torch_vggish.py index 007745e..23499cc 100644 --- a/torch_vggish.py +++ b/torch_vggish.py @@ -64,12 +64,14 @@ class Vggish(NNOperator): return [AudioOutput(outs.detach().numpy())] def preprocess(self, frames: List[numpy.ndarray], sr, layout): - audio = numpy.hstack(frames) if layout == 'stereo': - audio = audio.reshape(-1, 2) + frames = [frame.reshape(-1, 2) for frame in frames] + audio = numpy.vstack(frames) + else: + audio = numpy.hstack(frames) + audio = audio.transpose() audio = self.int2float(audio) try: - audio = audio.transpose() audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True) return audio_tensors except Exception as e: diff --git a/vggish_input.py b/vggish_input.py index 7e8cfd3..4297ba8 100644 --- a/vggish_input.py +++ b/vggish_input.py @@ -28,8 +28,8 @@ def waveform_to_examples(data, sample_rate, return_tensor=True): """Converts audio waveform into an array of examples for VGGish. Args: - data: - np.array of 2 dimension, second of which is number of channels. + data: np.array of either one dimension (mono) or two dimensions + (multi-channel, with the outer dimension representing channels). Each sample is generally expected to lie in the range [-1.0, +1.0], although this is not required. sample_rate: Sample rate of data. @@ -43,10 +43,8 @@ def waveform_to_examples(data, sample_rate, return_tensor=True): """ # Convert to mono. - if data.shape[1] > 1: + if len(data.shape) > 1: data = np.mean(data, axis=1) - else: - data = data.squeeze(1) # Resample to the rate assumed by VGGish. if sample_rate != vggish_params.SAMPLE_RATE: data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)