diff --git a/torch_vggish.py b/torch_vggish.py index ad85496..440e21b 100644 --- a/torch_vggish.py +++ b/torch_vggish.py @@ -12,36 +12,72 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import warnings +import os import sys -import torch -from typing import NamedTuple -from pathlib import Path import numpy -import os +from pathlib import Path +from typing import Union, List, NamedTuple -from towhee.operator import Operator +import torch -import warnings -warnings.filterwarnings("ignore") +from towhee.operator.base import NNOperator +from towhee.models.vggish.torch_vggish import VGG +from towhee import register + +sys.path.append(str(Path(__file__).parent)) +import vggish_input + +warnings.filterwarnings('ignore') +log = logging.getLogger() -class TorchVggish(Operator): + +AudioOutput = NamedTuple('AudioOutput', [('vec', 'ndarray')]) + + +class Vggish(NNOperator): """ """ - def __init__(self, framework: str = 'pytorch', weights_path: str=None) -> None: - super().__init__() - if framework == 'pytorch': - import importlib.util - path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') - opname = os.path.basename(str(Path(__file__))).split('.')[0] - spec = importlib.util.spec_from_file_location(opname, path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - self.model = module.Model(weights_path) - - def __call__(self, audio_path: str) -> NamedTuple('Outputs', [('embs', numpy.ndarray)]): - audio_tensors = self.model.preprocess(audio_path) + def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None: + super().__init__(framework=framework) + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = VGG() + if not weights_path: + path = str(Path(__file__).parent) + weights_path = os.path.join(path, 'vggish.pth') + state_dict = torch.load(weights_path, map_location=torch.device('cpu')) + self.model.load_state_dict(state_dict) + self.model.eval() + self.model.to(self.device) + + def __call__(self, datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int')])]) -> numpy.ndarray: + audios = numpy.stack([item.audio for item in datas]) + sr = datas[0].sample_rate + audio_array = numpy.reshape(audios, (-1, 2)) + audio_tensors = self.preprocess(audio_array, sr).to(self.device) features = self.model(audio_tensors) - Outputs = NamedTuple('Outputs', [('embs', numpy.ndarray)]) - return Outputs(features.detach().numpy()) + outs = features.to("cpu") + return [AudioOutput(outs.detach().numpy())] + + def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None): + if audio.dtype == numpy.int32: + samples = audio / 2147483648.0 + elif audio.dtype == numpy.int16: + samples = audio / 32768.0 + return vggish_input.waveform_to_examples(samples, sr, return_tensor=True) + + + +# if __name__ == '__main__': +# encoder = Vggish() +# +# # audio_path = '/path/to/audio' +# # vec = encoder(audio_path) +# +# audio_data = numpy.zeros((2, 441344)) +# sample_rate = 44100 +# vec = encoder(audio_data, sample_rate) +# print(vec) diff --git a/vggish.bak b/vggish.bak deleted file mode 100644 index 69c3e3f..0000000 --- a/vggish.bak +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2021 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import warnings - -import os -import sys -import numpy -from pathlib import Path -from typing import Union - -import torch - -from towhee.operator.base import NNOperator -from towhee.models.vggish.torch_vggish import VGG -from towhee import register - -sys.path.append(str(Path(__file__).parent)) -import vggish_input - -warnings.filterwarnings('ignore') -log = logging.getLogger() - - -@register(output_schema=['vec']) -class Vggish(NNOperator): - """ - """ - - def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None: - super().__init__(framework=framework) - self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.model = VGG() - if not weights_path: - path = str(Path(__file__).parent) - weights_path = os.path.join(path, 'vggish.pth') - state_dict = torch.load(weights_path, map_location=torch.device('cpu')) - self.model.load_state_dict(state_dict) - self.model.eval() - self.model.to(self.device) - - def __call__(self, audio: Union[str, numpy.ndarray], sr: int = None) -> numpy.ndarray: - audio_tensors = self.preprocess(audio, sr).to(self.device) - features = self.model(audio_tensors) - outs = features.to("cpu") - return outs.detach().numpy() - - def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None): - if isinstance(audio, str): - audio_tensors = vggish_input.wavfile_to_examples(audio) - elif isinstance(audio, numpy.ndarray): - try: - audio = audio.transpose() - audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True) - except Exception as e: - log.error("Fail to load audio data.") - raise e - else: - log.error(f"Invalid input audio: {type(audio)}") - return audio_tensors - - -# if __name__ == '__main__': -# encoder = Vggish() -# -# # audio_path = '/path/to/audio' -# # vec = encoder(audio_path) -# -# audio_data = numpy.zeros((2, 441344)) -# sample_rate = 44100 -# vec = encoder(audio_data, sample_rate) -# print(vec) diff --git a/vggish.py b/vggish.py deleted file mode 100644 index 440e21b..0000000 --- a/vggish.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2021 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import warnings - -import os -import sys -import numpy -from pathlib import Path -from typing import Union, List, NamedTuple - -import torch - -from towhee.operator.base import NNOperator -from towhee.models.vggish.torch_vggish import VGG -from towhee import register - -sys.path.append(str(Path(__file__).parent)) -import vggish_input - -warnings.filterwarnings('ignore') -log = logging.getLogger() - - -AudioOutput = NamedTuple('AudioOutput', [('vec', 'ndarray')]) - - -class Vggish(NNOperator): - """ - """ - - def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None: - super().__init__(framework=framework) - self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.model = VGG() - if not weights_path: - path = str(Path(__file__).parent) - weights_path = os.path.join(path, 'vggish.pth') - state_dict = torch.load(weights_path, map_location=torch.device('cpu')) - self.model.load_state_dict(state_dict) - self.model.eval() - self.model.to(self.device) - - def __call__(self, datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int')])]) -> numpy.ndarray: - audios = numpy.stack([item.audio for item in datas]) - sr = datas[0].sample_rate - audio_array = numpy.reshape(audios, (-1, 2)) - audio_tensors = self.preprocess(audio_array, sr).to(self.device) - features = self.model(audio_tensors) - outs = features.to("cpu") - return [AudioOutput(outs.detach().numpy())] - - def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None): - if audio.dtype == numpy.int32: - samples = audio / 2147483648.0 - elif audio.dtype == numpy.int16: - samples = audio / 32768.0 - return vggish_input.waveform_to_examples(samples, sr, return_tensor=True) - - - -# if __name__ == '__main__': -# encoder = Vggish() -# -# # audio_path = '/path/to/audio' -# # vec = encoder(audio_path) -# -# audio_data = numpy.zeros((2, 441344)) -# sample_rate = 44100 -# vec = encoder(audio_data, sample_rate) -# print(vec)