diff --git a/torch_vggish.py b/torch_vggish.py index 6c2bbf0..007745e 100644 --- a/torch_vggish.py +++ b/torch_vggish.py @@ -53,16 +53,41 @@ class Vggish(NNOperator): self.model.eval() self.model.to(self.device) - def __call__(self, datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int')])]) -> numpy.ndarray: - audios = numpy.hstack([item.audio for item in datas]) + def __call__(self, + datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int'), ('layout', 'str')])]): + audios = [item.audio for item in datas] sr = datas[0].sample_rate - audio_array = numpy.reshape(audios, (-1, 1)) - audio_tensors = self.preprocess(audio_array, sr).to(self.device) + layout = datas[0].layout + audio_tensors = self.preprocess(audios, sr, layout).to(self.device) features = self.model(audio_tensors) outs = features.to("cpu") return [AudioOutput(outs.detach().numpy())] - def preprocess(self, audio: numpy.ndarray, sr: int = None): - ii = numpy.iinfo(audio.dtype) - samples = 2 * audio / (ii.max - ii.min + 1) - return vggish_input.waveform_to_examples(samples, sr, return_tensor=True) + def preprocess(self, frames: List[numpy.ndarray], sr, layout): + audio = numpy.hstack(frames) + if layout == 'stereo': + audio = audio.reshape(-1, 2) + audio = self.int2float(audio) + try: + audio = audio.transpose() + audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True) + return audio_tensors + except Exception as e: + log.error("Fail to load audio data.") + raise e + + def int2float(self, wav: numpy.ndarray, dtype: str = 'float64'): + """ + Convert audio data from int to float. + The input dtype must be integers. + The output dtype is controlled by the parameter `dtype`, defaults to 'float64'. + The code is inspired by https://github.com/mgeier/python-audio/blob/master/audio-files/utility.py + """ + assert wav.dtype.kind in 'iu' + dtype = numpy.dtype(dtype) + assert dtype.kind == 'f' + + ii = numpy.iinfo(wav.dtype) + abs_max = 2 ** (ii.bits - 1) + offset = ii.min + abs_max + return (wav.astype(dtype) - offset) / abs_max diff --git a/vggish_input.py b/vggish_input.py index 4297ba8..7e8cfd3 100644 --- a/vggish_input.py +++ b/vggish_input.py @@ -28,8 +28,8 @@ def waveform_to_examples(data, sample_rate, return_tensor=True): """Converts audio waveform into an array of examples for VGGish. Args: - data: np.array of either one dimension (mono) or two dimensions - (multi-channel, with the outer dimension representing channels). + data: + np.array of 2 dimension, second of which is number of channels. Each sample is generally expected to lie in the range [-1.0, +1.0], although this is not required. sample_rate: Sample rate of data. @@ -43,8 +43,10 @@ def waveform_to_examples(data, sample_rate, return_tensor=True): """ # Convert to mono. - if len(data.shape) > 1: + if data.shape[1] > 1: data = np.mean(data, axis=1) + else: + data = data.squeeze(1) # Resample to the rate assumed by VGGish. if sample_rate != vggish_params.SAMPLE_RATE: data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)