diff --git a/torch_vggish.py b/torch_vggish.py index 440e21b..deae00e 100644 --- a/torch_vggish.py +++ b/torch_vggish.py @@ -62,15 +62,12 @@ class Vggish(NNOperator): outs = features.to("cpu") return [AudioOutput(outs.detach().numpy())] - def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None): - if audio.dtype == numpy.int32: - samples = audio / 2147483648.0 - elif audio.dtype == numpy.int16: - samples = audio / 32768.0 + def preprocess(self, audio: numpy.ndarray, sr: int = None): + ii = numpy.iinfo(audio.dtype) + samples = 2 * audio / (ii.max - ii.min + 1) return vggish_input.waveform_to_examples(samples, sr, return_tensor=True) - # if __name__ == '__main__': # encoder = Vggish() #