diff --git a/torch_vggish.py b/torch_vggish.py index deae00e..671239e 100644 --- a/torch_vggish.py +++ b/torch_vggish.py @@ -56,7 +56,7 @@ class Vggish(NNOperator): def __call__(self, datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int')])]) -> numpy.ndarray: audios = numpy.stack([item.audio for item in datas]) sr = datas[0].sample_rate - audio_array = numpy.reshape(audios, (-1, 2)) + audio_array = numpy.reshape(audios, (-1, 1)) audio_tensors = self.preprocess(audio_array, sr).to(self.device) features = self.model(audio_tensors) outs = features.to("cpu") @@ -66,15 +66,3 @@ class Vggish(NNOperator): ii = numpy.iinfo(audio.dtype) samples = 2 * audio / (ii.max - ii.min + 1) return vggish_input.waveform_to_examples(samples, sr, return_tensor=True) - - -# if __name__ == '__main__': -# encoder = Vggish() -# -# # audio_path = '/path/to/audio' -# # vec = encoder(audio_path) -# -# audio_data = numpy.zeros((2, 441344)) -# sample_rate = 44100 -# vec = encoder(audio_data, sample_rate) -# print(vec)