| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -18,10 +18,11 @@ from typing import NamedTuple | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from pathlib import Path | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from PIL import Image | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import torch | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from torch import nn as nn | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import numpy | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import os | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from towhee.operator import Operator | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from towhee.operator import Operator, NNOperator | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from timm.data import resolve_data_config | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from timm.data.transforms_factory import create_transform | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from towhee.utils.pil_utils import to_pil | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -29,7 +30,7 @@ from towhee.utils.pil_utils import to_pil | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import warnings | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					warnings.filterwarnings("ignore") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class VitImageEmbedding(Operator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class VitImageEmbedding(NNOperator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    Embedding extractor using ViT. | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    Args: | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -39,7 +40,7 @@ class VitImageEmbedding(Operator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            Path to local weights. | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, model_name: str = 'vit_large_patch16_224', | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, model_name: str = 'vit_large_patch16_224', num_classes: int = 1000, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                 framework: str = 'pytorch', weights_path: str = None) -> None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        super().__init__() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if framework == 'pytorch': | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -49,7 +50,7 @@ class VitImageEmbedding(Operator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            spec = importlib.util.spec_from_file_location(opname, path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            module = importlib.util.module_from_spec(spec) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            spec.loader.exec_module(module) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.model = module.Model(model_name, weights_path) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.model = module.Model(model_name, weights_path, num_classes=num_classes) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        config = resolve_data_config({}, model=self.model._model) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.tfms = create_transform(**config) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -57,4 +58,7 @@ class VitImageEmbedding(Operator): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        img = self.tfms(to_pil(image)).unsqueeze(0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        features = self.model(img) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return Outputs(features.flatten().detach().numpy()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return Outputs(features) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def get_model(self) -> nn.Module: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return self.model._model |