Browse Source
        
      
      Update
      
        Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
      
      
        main
      
      
     
    
    
    
	
		
			
				 2 changed files with 
7 additions and 
5 deletions
			 
			
		 
		
			
				- 
					
					
					 
					pytorch/__init__.py
				
 
			
				- 
					
					
					 
					vit_image_embedding.py
				
 
			
		
		
			
			
			
			
			
			
				
				
					
						
							
								
									
	
		
			
				
					| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -18,4 +18,7 @@ import os | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					try: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    import timm | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					except ModuleNotFoundError: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    os.system('pip install timm') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    os.system('pip install timm') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from timm.data import resolve_data_config | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from timm.data.transforms_factory import create_transform | 
				
			
			
		
	
								
							
						
					 
					
				 
			 
		
			
			
			
			
			
			
				
				
					
						
							
								
									
	
		
			
				
					| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -20,6 +20,7 @@ from PIL import Image | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import torch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import numpy | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from towhee.operator import Operator | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -39,11 +40,9 @@ class VitImageEmbedding(Operator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        sys.path.append(str(Path(__file__).parent)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if framework == 'pytorch': | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            from pytorch.model import Model | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            from timm.data import resolve_data_config | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            from timm.data.transforms_factory import create_transform | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.model = Model(model_name, weights_path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        config = resolve_data_config({}, model=self.model._model) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.tfms = create_transform(**config) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        config = pytorch.resolve_data_config({}, model=self.model._model) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.tfms = pytorch.create_transform(**config) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |