| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -12,12 +12,13 @@ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					# See the License for the specific language governing permissions and | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					# limitations under the License. | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from torch import nn | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import torch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import torch.nn as nn | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import numpy as np | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import sys | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from pathlib import Path | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from towhee.models.vggish.torch_vggish import VGG | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					sys.path.append(str(Path(__file__).parent)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import vggish_input | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -26,39 +27,18 @@ class Model(nn.Module): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    PyTorch model class | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, weights_path: str=None): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        super().__init__() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.features = nn.Sequential( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Conv2d(1, 64, 3, 1, 1), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.ReLU(inplace=True), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.MaxPool2d(2, 2), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Conv2d(64, 128, 3, 1, 1), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.ReLU(inplace=True), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.MaxPool2d(2, 2), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Conv2d(128, 256, 3, 1, 1), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.ReLU(inplace=True), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Conv2d(256, 256, 3, 1, 1), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.ReLU(inplace=True), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.MaxPool2d(2, 2), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Conv2d(256, 512, 3, 1, 1), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.ReLU(inplace=True), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Conv2d(512, 512, 3, 1, 1), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.ReLU(inplace=True), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.MaxPool2d(2, 2)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.embeddings = nn.Sequential( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(512 * 24, 4096), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.ReLU(inplace=True), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(4096, 4096), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.ReLU(inplace=True), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            nn.Linear(4096, 128), | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            #nn.ReLU(inplace=True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self._model = VGG() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if not weights_path: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            path = str(Path(__file__).parent) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            weights_path = path + '/vggish.pth' | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        state_dict = torch.load(weights_path, map_location=torch.device('cpu')) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self._model.load_state_dict(state_dict) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self._model.eval() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def forward(self, x): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        x = self.features(x).permute(0, 2, 3, 1).contiguous() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        x = x.view(x.size(0), -1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        x = self.embeddings(x) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return x | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return self._model(x) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def preprocess(self, audio_path: str): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        audio_tensors = vggish_input.wavfile_to_examples(audio_path) | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |