Browse Source
        
      
      Update model
      
        Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
      
      
        training
      
      
     
    
    
    
	
		
			
				 2 changed files with 
23 additions and 
20 deletions
			 
			
		 
		
			
				- 
					
					
					 
					pytorch/model.py
				
 
			
				- 
					
					
					 
					resnet50_image_embedding.py
				
 
			
		
		
			
			
			
			
			
			
				
				
					
						
							
								
									
	
		
			
				
					| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -13,27 +13,27 @@ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					# limitations under the License. | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from typing import NamedTuple | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import numpy | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import torch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import torchvision | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class Resnet50(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class Model(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    PyTorch model for image embedding. | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    PyTorch model class | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, model_name: str): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.model_name = model_name | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, model_name): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        super().__init__() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        model_func = getattr(torchvision.models, model_name) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self._model = model_func(pretrained=True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self._model.eval() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def load_model(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        For loading model | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        model_func = getattr(torchvision.models, self.model_name) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.model = model_func(pretrained=True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.model.eval() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return self.model | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __call__(self, img_tensor: torch.Tensor): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return self._model(img_tensor).detach().numpy() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def train_model(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def train(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        For training model | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        """ | 
				
			
			
		
	
								
							
						
					 
					
				 
			 
		
			
			
			
			
			
			
				
				
					
						
							
								
									
	
		
			
				
					| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -26,13 +26,16 @@ class Resnet50ImageEmbedding(Operator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    PyTorch model for image embedding. | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, model_name: str) -> None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, model_name: str, framework: str = 'pytorch') -> None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        super().__init__() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        sys.path.append(str(Path(__file__).parent)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        from pytorch.resnet50 import Resnet50 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        resnet50_image_embedding = Resnet50(model_name) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self._model = resnet50_image_embedding.load_model() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if framework == 'pytorch': | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            from pytorch.model import Model | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if framework == 'tensorflow': | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            from tensorflow.model import Model | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.model = Model(model_name) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('cnn', numpy.ndarray)]): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        embedding = self.model(img_tensor) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        Outputs = NamedTuple('Outputs', [('cnn', numpy.ndarray)]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return Outputs(self._model(img_tensor).detach().numpy()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return Outputs(embedding) | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
					 | 
				
				 | 
				
					
  |