towhee
/
            
              torch-vggish
              
                
                
            
          copied
				 3 changed files with 59 additions and 190 deletions
			
			
		@ -1,84 +0,0 @@ | 
				
			|||
# Copyright 2021 Zilliz. All rights reserved. | 
				
			|||
# | 
				
			|||
# Licensed under the Apache License, Version 2.0 (the "License"); | 
				
			|||
# you may not use this file except in compliance with the License. | 
				
			|||
# You may obtain a copy of the License at | 
				
			|||
# | 
				
			|||
#     http://www.apache.org/licenses/LICENSE-2.0 | 
				
			|||
# | 
				
			|||
# Unless required by applicable law or agreed to in writing, software | 
				
			|||
# distributed under the License is distributed on an "AS IS" BASIS, | 
				
			|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
				
			|||
# See the License for the specific language governing permissions and | 
				
			|||
# limitations under the License. | 
				
			|||
 | 
				
			|||
import logging | 
				
			|||
import warnings | 
				
			|||
 | 
				
			|||
import os | 
				
			|||
import sys | 
				
			|||
import numpy | 
				
			|||
from pathlib import Path | 
				
			|||
from typing import Union | 
				
			|||
 | 
				
			|||
import torch | 
				
			|||
 | 
				
			|||
from towhee.operator.base import NNOperator | 
				
			|||
from towhee.models.vggish.torch_vggish import VGG | 
				
			|||
from towhee import register | 
				
			|||
 | 
				
			|||
sys.path.append(str(Path(__file__).parent)) | 
				
			|||
import vggish_input | 
				
			|||
 | 
				
			|||
warnings.filterwarnings('ignore') | 
				
			|||
log = logging.getLogger() | 
				
			|||
 | 
				
			|||
 | 
				
			|||
@register(output_schema=['vec']) | 
				
			|||
class Vggish(NNOperator): | 
				
			|||
    """ | 
				
			|||
    """ | 
				
			|||
 | 
				
			|||
    def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None: | 
				
			|||
        super().__init__(framework=framework) | 
				
			|||
        self.device = "cuda" if torch.cuda.is_available() else "cpu" | 
				
			|||
        self.model = VGG() | 
				
			|||
        if not weights_path: | 
				
			|||
            path = str(Path(__file__).parent) | 
				
			|||
            weights_path = os.path.join(path, 'vggish.pth') | 
				
			|||
        state_dict = torch.load(weights_path, map_location=torch.device('cpu')) | 
				
			|||
        self.model.load_state_dict(state_dict) | 
				
			|||
        self.model.eval() | 
				
			|||
        self.model.to(self.device) | 
				
			|||
 | 
				
			|||
    def __call__(self, audio: Union[str, numpy.ndarray], sr: int = None) -> numpy.ndarray: | 
				
			|||
        audio_tensors = self.preprocess(audio, sr).to(self.device) | 
				
			|||
        features = self.model(audio_tensors) | 
				
			|||
        outs = features.to("cpu") | 
				
			|||
        return outs.detach().numpy() | 
				
			|||
 | 
				
			|||
    def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None): | 
				
			|||
        if isinstance(audio, str): | 
				
			|||
            audio_tensors = vggish_input.wavfile_to_examples(audio) | 
				
			|||
        elif isinstance(audio, numpy.ndarray): | 
				
			|||
            try: | 
				
			|||
                audio = audio.transpose() | 
				
			|||
                audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True) | 
				
			|||
            except Exception as e: | 
				
			|||
                log.error("Fail to load audio data.") | 
				
			|||
                raise e | 
				
			|||
        else: | 
				
			|||
            log.error(f"Invalid input audio: {type(audio)}") | 
				
			|||
        return audio_tensors | 
				
			|||
 | 
				
			|||
 | 
				
			|||
# if __name__ == '__main__': | 
				
			|||
#     encoder = Vggish() | 
				
			|||
# | 
				
			|||
#     # audio_path = '/path/to/audio' | 
				
			|||
#     # vec = encoder(audio_path) | 
				
			|||
# | 
				
			|||
#     audio_data = numpy.zeros((2, 441344)) | 
				
			|||
#     sample_rate = 44100 | 
				
			|||
#     vec = encoder(audio_data, sample_rate) | 
				
			|||
#     print(vec) | 
				
			|||
@ -1,83 +0,0 @@ | 
				
			|||
# Copyright 2021 Zilliz. All rights reserved. | 
				
			|||
# | 
				
			|||
# Licensed under the Apache License, Version 2.0 (the "License"); | 
				
			|||
# you may not use this file except in compliance with the License. | 
				
			|||
# You may obtain a copy of the License at | 
				
			|||
# | 
				
			|||
#     http://www.apache.org/licenses/LICENSE-2.0 | 
				
			|||
# | 
				
			|||
# Unless required by applicable law or agreed to in writing, software | 
				
			|||
# distributed under the License is distributed on an "AS IS" BASIS, | 
				
			|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
				
			|||
# See the License for the specific language governing permissions and | 
				
			|||
# limitations under the License. | 
				
			|||
 | 
				
			|||
import logging | 
				
			|||
import warnings | 
				
			|||
 | 
				
			|||
import os | 
				
			|||
import sys | 
				
			|||
import numpy | 
				
			|||
from pathlib import Path | 
				
			|||
from typing import Union, List, NamedTuple | 
				
			|||
 | 
				
			|||
import torch | 
				
			|||
 | 
				
			|||
from towhee.operator.base import NNOperator | 
				
			|||
from towhee.models.vggish.torch_vggish import VGG | 
				
			|||
from towhee import register | 
				
			|||
 | 
				
			|||
sys.path.append(str(Path(__file__).parent)) | 
				
			|||
import vggish_input | 
				
			|||
 | 
				
			|||
warnings.filterwarnings('ignore') | 
				
			|||
log = logging.getLogger() | 
				
			|||
 | 
				
			|||
 | 
				
			|||
AudioOutput = NamedTuple('AudioOutput', [('vec', 'ndarray')]) | 
				
			|||
 | 
				
			|||
 | 
				
			|||
class Vggish(NNOperator): | 
				
			|||
    """ | 
				
			|||
    """ | 
				
			|||
 | 
				
			|||
    def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None: | 
				
			|||
        super().__init__(framework=framework) | 
				
			|||
        self.device = "cuda" if torch.cuda.is_available() else "cpu" | 
				
			|||
        self.model = VGG() | 
				
			|||
        if not weights_path: | 
				
			|||
            path = str(Path(__file__).parent) | 
				
			|||
            weights_path = os.path.join(path, 'vggish.pth') | 
				
			|||
        state_dict = torch.load(weights_path, map_location=torch.device('cpu')) | 
				
			|||
        self.model.load_state_dict(state_dict) | 
				
			|||
        self.model.eval() | 
				
			|||
        self.model.to(self.device) | 
				
			|||
 | 
				
			|||
    def __call__(self, datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int')])]) -> numpy.ndarray: | 
				
			|||
        audios = numpy.stack([item.audio for item in datas]) | 
				
			|||
        sr = datas[0].sample_rate | 
				
			|||
        audio_array = numpy.reshape(audios, (-1, 2)) | 
				
			|||
        audio_tensors = self.preprocess(audio_array, sr).to(self.device) | 
				
			|||
        features = self.model(audio_tensors) | 
				
			|||
        outs = features.to("cpu") | 
				
			|||
        return [AudioOutput(outs.detach().numpy())] | 
				
			|||
 | 
				
			|||
    def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None): | 
				
			|||
        if audio.dtype == numpy.int32: | 
				
			|||
            samples = audio / 2147483648.0 | 
				
			|||
        elif audio.dtype == numpy.int16: | 
				
			|||
            samples = audio / 32768.0 | 
				
			|||
        return vggish_input.waveform_to_examples(samples, sr, return_tensor=True) | 
				
			|||
 | 
				
			|||
 | 
				
			|||
 | 
				
			|||
# if __name__ == '__main__': | 
				
			|||
#     encoder = Vggish() | 
				
			|||
# | 
				
			|||
#     # audio_path = '/path/to/audio' | 
				
			|||
#     # vec = encoder(audio_path) | 
				
			|||
# | 
				
			|||
#     audio_data = numpy.zeros((2, 441344)) | 
				
			|||
#     sample_rate = 44100 | 
				
			|||
#     vec = encoder(audio_data, sample_rate) | 
				
			|||
#     print(vec) | 
				
			|||
					Loading…
					
					
				
		Reference in new issue