| 
					
					
						
							
						
					
					
				 | 
				@ -19,10 +19,11 @@ from pathlib import Path | 
			
		
		
	
		
			
				 | 
				 | 
				from PIL import Image | 
				 | 
				 | 
				from PIL import Image | 
			
		
		
	
		
			
				 | 
				 | 
				import torch | 
				 | 
				 | 
				import torch | 
			
		
		
	
		
			
				 | 
				 | 
				import numpy | 
				 | 
				 | 
				import numpy | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				import os | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				from towhee.operator import Operator | 
				 | 
				 | 
				from towhee.operator import Operator | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				from timm.data import resolve_data_config | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				from timm.data.transforms_factory import create_transform | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				class VitImageEmbedding(Operator): | 
				 | 
				 | 
				class VitImageEmbedding(Operator): | 
			
		
		
	
		
			
				 | 
				 | 
				    """ | 
				 | 
				 | 
				    """ | 
			
		
		
	
	
		
			
				| 
					
					
					
						
							
						
					
				 | 
				@ -37,13 +38,16 @@ class VitImageEmbedding(Operator): | 
			
		
		
	
		
			
				 | 
				 | 
				    def __init__(self, model_name: str = 'vit_large_patch16_224', | 
				 | 
				 | 
				    def __init__(self, model_name: str = 'vit_large_patch16_224', | 
			
		
		
	
		
			
				 | 
				 | 
				                 framework: str = 'pytorch', weights_path: str = None) -> None: | 
				 | 
				 | 
				                 framework: str = 'pytorch', weights_path: str = None) -> None: | 
			
		
		
	
		
			
				 | 
				 | 
				        super().__init__() | 
				 | 
				 | 
				        super().__init__() | 
			
		
		
	
		
			
				 | 
				 | 
				        sys.path.append(str(Path(__file__).parent)) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				        if framework == 'pytorch': | 
				 | 
				 | 
				        if framework == 'pytorch': | 
			
		
		
	
		
			
				 | 
				 | 
				            import pytorch | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				            from pytorch.model import Model | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				        self.model = Model(model_name, weights_path) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				        config = pytorch.resolve_data_config({}, model=self.model._model) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				        self.tfms = pytorch.create_transform(**config) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            import importlib.util | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            opname = os.path.basename(str(Path(__file__))).split('.')[0] | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            spec = importlib.util.spec_from_file_location(opname, path) | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            module = importlib.util.module_from_spec(spec) | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            spec.loader.exec_module(module) | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        self.model = module.Model(model_name, weights_path) | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        config = resolve_data_config({}, model=self.model._model) | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        self.tfms = create_transform(**config) | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): | 
				 | 
				 | 
				    def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): | 
			
		
		
	
		
			
				 | 
				 | 
				        Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) | 
				 | 
				 | 
				        Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) | 
			
		
		
	
	
		
			
				| 
					
						
							
						
					
					
					
				 | 
				
  |