|  |  |  | # Copyright 2021 Zilliz. All rights reserved. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #     http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import logging | 
					
						
							|  |  |  | import warnings | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | import sys | 
					
						
							|  |  |  | import numpy | 
					
						
							|  |  |  | from pathlib import Path | 
					
						
							|  |  |  | from typing import Union, List, NamedTuple | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from towhee.operator.base import NNOperator | 
					
						
							|  |  |  | from towhee.models.vggish.torch_vggish import VGG | 
					
						
							|  |  |  | from towhee import register | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | sys.path.append(str(Path(__file__).parent)) | 
					
						
							|  |  |  | import vggish_input | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | warnings.filterwarnings('ignore') | 
					
						
							|  |  |  | log = logging.getLogger() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | AudioOutput = NamedTuple('AudioOutput', [('vec', 'ndarray')]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Vggish(NNOperator): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None: | 
					
						
							|  |  |  |         super().__init__(framework=framework) | 
					
						
							|  |  |  |         self.device = "cuda" if torch.cuda.is_available() else "cpu" | 
					
						
							|  |  |  |         self.model = VGG() | 
					
						
							|  |  |  |         if not weights_path: | 
					
						
							|  |  |  |             path = str(Path(__file__).parent) | 
					
						
							|  |  |  |             weights_path = os.path.join(path, 'vggish.pth') | 
					
						
							|  |  |  |         state_dict = torch.load(weights_path, map_location=torch.device('cpu')) | 
					
						
							|  |  |  |         self.model.load_state_dict(state_dict) | 
					
						
							|  |  |  |         self.model.eval() | 
					
						
							|  |  |  |         self.model.to(self.device) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int')])]) -> numpy.ndarray: | 
					
						
							|  |  |  |         audios = numpy.hstack([item.audio for item in datas]) | 
					
						
							|  |  |  |         sr = datas[0].sample_rate | 
					
						
							|  |  |  |         audio_array = numpy.reshape(audios, (-1, 1)) | 
					
						
							|  |  |  |         audio_tensors = self.preprocess(audio_array, sr).to(self.device) | 
					
						
							|  |  |  |         features = self.model(audio_tensors) | 
					
						
							|  |  |  |         outs = features.to("cpu") | 
					
						
							|  |  |  |         return [AudioOutput(outs.detach().numpy())] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def preprocess(self, audio: numpy.ndarray, sr: int = None): | 
					
						
							|  |  |  |         ii = numpy.iinfo(audio.dtype) | 
					
						
							|  |  |  |         samples = 2 * audio / (ii.max - ii.min + 1) | 
					
						
							|  |  |  |         return vggish_input.waveform_to_examples(samples, sr, return_tensor=True) |