diff --git a/__init__.py b/__init__.py index 37f5bd7..0a4a066 100644 --- a/__init__.py +++ b/__init__.py @@ -11,3 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .vggish import Vggish + + +def vggish(weights_path: str = None, framework: str = 'pytorch'): + return Vggish(weights_path, framework) diff --git a/pytorch/mel_features.py b/mel_features.py similarity index 100% rename from pytorch/mel_features.py rename to mel_features.py diff --git a/pytorch/.gitattributes b/pytorch/.gitattributes deleted file mode 100644 index 05ef9ad..0000000 --- a/pytorch/.gitattributes +++ /dev/null @@ -1 +0,0 @@ -vggish.pth filter=lfs diff=lfs merge=lfs -text diff --git a/pytorch/__init__.py b/pytorch/__init__.py deleted file mode 100644 index b661573..0000000 --- a/pytorch/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2021 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -# For requirements. -try: - import timm -except ModuleNotFoundError: - os.system('pip install timm') - -from timm.data import resolve_data_config -from timm.data.transforms_factory import create_transform \ No newline at end of file diff --git a/pytorch/model.py b/pytorch/model.py deleted file mode 100644 index 7c8cdf6..0000000 --- a/pytorch/model.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2021 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from torch import nn -import torch -import sys -from pathlib import Path - -from towhee.models.vggish.torch_vggish import VGG - -sys.path.append(str(Path(__file__).parent)) - -import vggish_input - -class Model(nn.Module): - """ - PyTorch model class - """ - def __init__(self, weights_path: str=None): - super().__init__() - self._model = VGG() - if not weights_path: - path = str(Path(__file__).parent) - weights_path = path + '/vggish.pth' - state_dict = torch.load(weights_path, map_location=torch.device('cpu')) - self._model.load_state_dict(state_dict) - self._model.eval() - - def forward(self, x): - return self._model(x) - - def preprocess(self, audio_path: str): - audio_tensors = vggish_input.wavfile_to_examples(audio_path) - return audio_tensors - - def train(self): - """ - For training model - """ - pass diff --git a/vggish.bak b/vggish.bak new file mode 100644 index 0000000..69c3e3f --- /dev/null +++ b/vggish.bak @@ -0,0 +1,84 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import warnings + +import os +import sys +import numpy +from pathlib import Path +from typing import Union + +import torch + +from towhee.operator.base import NNOperator +from towhee.models.vggish.torch_vggish import VGG +from towhee import register + +sys.path.append(str(Path(__file__).parent)) +import vggish_input + +warnings.filterwarnings('ignore') +log = logging.getLogger() + + +@register(output_schema=['vec']) +class Vggish(NNOperator): + """ + """ + + def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None: + super().__init__(framework=framework) + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = VGG() + if not weights_path: + path = str(Path(__file__).parent) + weights_path = os.path.join(path, 'vggish.pth') + state_dict = torch.load(weights_path, map_location=torch.device('cpu')) + self.model.load_state_dict(state_dict) + self.model.eval() + self.model.to(self.device) + + def __call__(self, audio: Union[str, numpy.ndarray], sr: int = None) -> numpy.ndarray: + audio_tensors = self.preprocess(audio, sr).to(self.device) + features = self.model(audio_tensors) + outs = features.to("cpu") + return outs.detach().numpy() + + def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None): + if isinstance(audio, str): + audio_tensors = vggish_input.wavfile_to_examples(audio) + elif isinstance(audio, numpy.ndarray): + try: + audio = audio.transpose() + audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True) + except Exception as e: + log.error("Fail to load audio data.") + raise e + else: + log.error(f"Invalid input audio: {type(audio)}") + return audio_tensors + + +# if __name__ == '__main__': +# encoder = Vggish() +# +# # audio_path = '/path/to/audio' +# # vec = encoder(audio_path) +# +# audio_data = numpy.zeros((2, 441344)) +# sample_rate = 44100 +# vec = encoder(audio_data, sample_rate) +# print(vec) diff --git a/pytorch/vggish.pth b/vggish.pth similarity index 100% rename from pytorch/vggish.pth rename to vggish.pth diff --git a/vggish.py b/vggish.py new file mode 100644 index 0000000..440e21b --- /dev/null +++ b/vggish.py @@ -0,0 +1,83 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import warnings + +import os +import sys +import numpy +from pathlib import Path +from typing import Union, List, NamedTuple + +import torch + +from towhee.operator.base import NNOperator +from towhee.models.vggish.torch_vggish import VGG +from towhee import register + +sys.path.append(str(Path(__file__).parent)) +import vggish_input + +warnings.filterwarnings('ignore') +log = logging.getLogger() + + +AudioOutput = NamedTuple('AudioOutput', [('vec', 'ndarray')]) + + +class Vggish(NNOperator): + """ + """ + + def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None: + super().__init__(framework=framework) + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = VGG() + if not weights_path: + path = str(Path(__file__).parent) + weights_path = os.path.join(path, 'vggish.pth') + state_dict = torch.load(weights_path, map_location=torch.device('cpu')) + self.model.load_state_dict(state_dict) + self.model.eval() + self.model.to(self.device) + + def __call__(self, datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int')])]) -> numpy.ndarray: + audios = numpy.stack([item.audio for item in datas]) + sr = datas[0].sample_rate + audio_array = numpy.reshape(audios, (-1, 2)) + audio_tensors = self.preprocess(audio_array, sr).to(self.device) + features = self.model(audio_tensors) + outs = features.to("cpu") + return [AudioOutput(outs.detach().numpy())] + + def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None): + if audio.dtype == numpy.int32: + samples = audio / 2147483648.0 + elif audio.dtype == numpy.int16: + samples = audio / 32768.0 + return vggish_input.waveform_to_examples(samples, sr, return_tensor=True) + + + +# if __name__ == '__main__': +# encoder = Vggish() +# +# # audio_path = '/path/to/audio' +# # vec = encoder(audio_path) +# +# audio_data = numpy.zeros((2, 441344)) +# sample_rate = 44100 +# vec = encoder(audio_data, sample_rate) +# print(vec) diff --git a/pytorch/vggish_input.py b/vggish_input.py similarity index 92% rename from pytorch/vggish_input.py rename to vggish_input.py index 256f0d1..856a406 100644 --- a/pytorch/vggish_input.py +++ b/vggish_input.py @@ -17,14 +17,13 @@ # Modification: Return torch tensors rather than numpy arrays import torch - import numpy as np import resampy import mel_features import vggish_params -import soundfile as sf +import torchaudio def waveform_to_examples(data, sample_rate, return_tensor=True): @@ -92,7 +91,6 @@ def wavfile_to_examples(wav_file, return_tensor=True): Returns: See waveform_to_examples. """ - wav_data, sr = sf.read(wav_file, dtype='int16') - assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype - samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] - return waveform_to_examples(samples, sr, return_tensor) + data, sr = torchaudio.load(wav_file) + wav_data = data.detach().numpy().transpose() + return waveform_to_examples(wav_data, sr, return_tensor) diff --git a/pytorch/vggish_params.py b/vggish_params.py similarity index 100% rename from pytorch/vggish_params.py rename to vggish_params.py