Browse Source
        
      
      Update
      
        Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
      
      
        training
      
      
     
    
    
    
	
		
			
				 1 changed files with 
2 additions and 
2 deletions
			 
			
		 
		
			
				- 
					
					
					 
					resnet_image_embedding.py
				
 
			
		
		
			
			
			
			
			
			
				
				
					
						
							
								
									
	
		
			
				| 
					
					
						
							
						
					
					
				 | 
				@ -26,12 +26,12 @@ class ResnetImageEmbedding(Operator): | 
			
		
		
	
		
			
				 | 
				 | 
				    """ | 
				 | 
				 | 
				    """ | 
			
		
		
	
		
			
				 | 
				 | 
				    PyTorch model for image embedding. | 
				 | 
				 | 
				    PyTorch model for image embedding. | 
			
		
		
	
		
			
				 | 
				 | 
				    """ | 
				 | 
				 | 
				    """ | 
			
		
		
	
		
			
				 | 
				 | 
				    def __init__(self, model_name: str, framework: str = 'pytorch', weights_path: str = None) -> None: | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    def __init__(self, model_name: str, framework: str = 'pytorch') -> None: | 
			
		
		
	
		
			
				 | 
				 | 
				        super().__init__() | 
				 | 
				 | 
				        super().__init__() | 
			
		
		
	
		
			
				 | 
				 | 
				        sys.path.append(str(Path(__file__).parent)) | 
				 | 
				 | 
				        sys.path.append(str(Path(__file__).parent)) | 
			
		
		
	
		
			
				 | 
				 | 
				        if framework == 'pytorch': | 
				 | 
				 | 
				        if framework == 'pytorch': | 
			
		
		
	
		
			
				 | 
				 | 
				            from pytorch.model import Model | 
				 | 
				 | 
				            from pytorch.model import Model | 
			
		
		
	
		
			
				 | 
				 | 
				        self.model = Model(model_name, weights_path) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        self.model = Model(model_name) | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): | 
				 | 
				 | 
				    def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): | 
			
		
		
	
		
			
				 | 
				 | 
				        embedding = self.model(img_tensor) | 
				 | 
				 | 
				        embedding = self.model(img_tensor) | 
			
		
		
	
	
		
			
				| 
					
						
							
						
					
					
					
				 | 
				
  |