towhee
/
            
              vit-image-embedding
              
                
                
            
          copied
				 9 changed files with 433 additions and 25 deletions
			
			
		@ -1,28 +1,35 @@ | 
			
		|||||
 | 
				# .gitattributes | 
			
		||||
 | 
				
 | 
			
		||||
*.7z filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.arrow filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.bin filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
 | 
				# Source files | 
			
		||||
 | 
				# ============ | 
			
		||||
 | 
				*.pxd    text diff=python | 
			
		||||
 | 
				*.py     text diff=python | 
			
		||||
 | 
				*.py3    text diff=python | 
			
		||||
 | 
				*.pyw    text diff=python | 
			
		||||
 | 
				*.pyx    text diff=python | 
			
		||||
 | 
				*.pyz    text diff=python | 
			
		||||
 | 
				*.pyi    text diff=python | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Binary files | 
			
		||||
 | 
				# ============ | 
			
		||||
 | 
				*.db     binary | 
			
		||||
 | 
				*.p      binary | 
			
		||||
 | 
				*.pkl    binary | 
			
		||||
 | 
				*.pickle binary | 
			
		||||
 | 
				*.pyc    binary export-ignore | 
			
		||||
 | 
				*.pyo    binary export-ignore | 
			
		||||
 | 
				*.pyd    binary | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Jupyter notebook | 
			
		||||
 | 
				*.ipynb  text | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Model files | 
			
		||||
*.bin.* filter=lfs diff=lfs merge=lfs -text | 
				*.bin.* filter=lfs diff=lfs merge=lfs -text | 
			
		||||
*.bz2 filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.ftz filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.gz filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.h5 filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.joblib filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text | 
				*.lfs.* filter=lfs diff=lfs merge=lfs -text | 
			
		||||
*.model filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.msgpack filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.onnx filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.ot filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.parquet filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.pb filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.pt filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.pth filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.rar filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text  | 
				 | 
			
		||||
*.tar.* filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
 | 
				*.bin filter=lfs diff=lfs merge=lfs -text | 
			
		||||
 | 
				*.h5 filter=lfs diff=lfs merge=lfs -text | 
			
		||||
*.tflite filter=lfs diff=lfs merge=lfs -text | 
				*.tflite filter=lfs diff=lfs merge=lfs -text | 
			
		||||
*.tgz filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.xz filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.zip filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*.zstandard filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
*tfevents* filter=lfs diff=lfs merge=lfs -text | 
				 | 
			
		||||
 | 
				*.tar.gz filter=lfs diff=lfs merge=lfs -text | 
			
		||||
 | 
				*.ot filter=lfs diff=lfs merge=lfs -text | 
			
		||||
 | 
				*.onnx filter=lfs diff=lfs merge=lfs -text | 
			
		||||
 | 
				*.msgpack filter=lfs diff=lfs merge=lfs -text | 
			
		||||
 | 
			
		|||||
@ -0,0 +1,209 @@ | 
			
		|||||
 | 
				### Linux ### | 
			
		||||
 | 
				*~ | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# temporary files which can be created if a process still has a handle open of a deleted file | 
			
		||||
 | 
				.fuse_hidden* | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# KDE directory preferences | 
			
		||||
 | 
				.directory | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Linux trash folder which might appear on any partition or disk | 
			
		||||
 | 
				.Trash-* | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# .nfs files are created when an open file is removed but is still being accessed | 
			
		||||
 | 
				.nfs* | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				### OSX ### | 
			
		||||
 | 
				# General | 
			
		||||
 | 
				.DS_Store | 
			
		||||
 | 
				.AppleDouble | 
			
		||||
 | 
				.LSOverride | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Icon must end with two \r | 
			
		||||
 | 
				Icon | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Thumbnails | 
			
		||||
 | 
				._* | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Files that might appear in the root of a volume | 
			
		||||
 | 
				.DocumentRevisions-V100 | 
			
		||||
 | 
				.fseventsd | 
			
		||||
 | 
				.Spotlight-V100 | 
			
		||||
 | 
				.TemporaryItems | 
			
		||||
 | 
				.Trashes | 
			
		||||
 | 
				.VolumeIcon.icns | 
			
		||||
 | 
				.com.apple.timemachine.donotpresent | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Directories potentially created on remote AFP share | 
			
		||||
 | 
				.AppleDB | 
			
		||||
 | 
				.AppleDesktop | 
			
		||||
 | 
				Network Trash Folder | 
			
		||||
 | 
				Temporary Items | 
			
		||||
 | 
				.apdisk | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				### Python ### | 
			
		||||
 | 
				# Byte-compiled / optimized / DLL files | 
			
		||||
 | 
				__pycache__/ | 
			
		||||
 | 
				*.py[cod] | 
			
		||||
 | 
				*$py.class | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# C extensions | 
			
		||||
 | 
				*.so | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Distribution / packaging | 
			
		||||
 | 
				.Python | 
			
		||||
 | 
				build/ | 
			
		||||
 | 
				develop-eggs/ | 
			
		||||
 | 
				dist/ | 
			
		||||
 | 
				downloads/ | 
			
		||||
 | 
				eggs/ | 
			
		||||
 | 
				.eggs/ | 
			
		||||
 | 
				lib/ | 
			
		||||
 | 
				lib64/ | 
			
		||||
 | 
				parts/ | 
			
		||||
 | 
				sdist/ | 
			
		||||
 | 
				var/ | 
			
		||||
 | 
				wheels/ | 
			
		||||
 | 
				share/python-wheels/ | 
			
		||||
 | 
				*.egg-info/ | 
			
		||||
 | 
				.installed.cfg | 
			
		||||
 | 
				*.egg | 
			
		||||
 | 
				MANIFEST | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# PyInstaller | 
			
		||||
 | 
				#  Usually these files are written by a python script from a template | 
			
		||||
 | 
				#  before PyInstaller builds the exe, so as to inject date/other infos into it. | 
			
		||||
 | 
				*.manifest | 
			
		||||
 | 
				*.spec | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Installer logs | 
			
		||||
 | 
				pip-log.txt | 
			
		||||
 | 
				pip-delete-this-directory.txt | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Unit test / coverage reports | 
			
		||||
 | 
				htmlcov/ | 
			
		||||
 | 
				.tox/ | 
			
		||||
 | 
				.nox/ | 
			
		||||
 | 
				.coverage | 
			
		||||
 | 
				.coverage.* | 
			
		||||
 | 
				.cache | 
			
		||||
 | 
				nosetests.xml | 
			
		||||
 | 
				coverage.xml | 
			
		||||
 | 
				*.cover | 
			
		||||
 | 
				*.py,cover | 
			
		||||
 | 
				.hypothesis/ | 
			
		||||
 | 
				.pytest_cache/ | 
			
		||||
 | 
				cover/ | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Translations | 
			
		||||
 | 
				*.mo | 
			
		||||
 | 
				*.pot | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Django stuff: | 
			
		||||
 | 
				*.log | 
			
		||||
 | 
				local_settings.py | 
			
		||||
 | 
				db.sqlite3 | 
			
		||||
 | 
				db.sqlite3-journal | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Flask stuff: | 
			
		||||
 | 
				instance/ | 
			
		||||
 | 
				.webassets-cache | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Scrapy stuff: | 
			
		||||
 | 
				.scrapy | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Sphinx documentation | 
			
		||||
 | 
				docs/_build/ | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# PyBuilder | 
			
		||||
 | 
				.pybuilder/ | 
			
		||||
 | 
				target/ | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Jupyter Notebook | 
			
		||||
 | 
				.ipynb_checkpoints | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# IPython | 
			
		||||
 | 
				profile_default/ | 
			
		||||
 | 
				ipython_config.py | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# pyenv | 
			
		||||
 | 
				#   For a library or package, you might want to ignore these files since the code is | 
			
		||||
 | 
				#   intended to run in multiple environments; otherwise, check them in: | 
			
		||||
 | 
				# .python-version | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# pipenv | 
			
		||||
 | 
				#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | 
			
		||||
 | 
				#   However, in case of collaboration, if having platform-specific dependencies or dependencies | 
			
		||||
 | 
				#   having no cross-platform support, pipenv may install dependencies that don't work, or not | 
			
		||||
 | 
				#   install all needed dependencies. | 
			
		||||
 | 
				#Pipfile.lock | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# PEP 582; used by e.g. github.com/David-OConnor/pyflow | 
			
		||||
 | 
				__pypackages__/ | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Celery stuff | 
			
		||||
 | 
				celerybeat-schedule | 
			
		||||
 | 
				celerybeat.pid | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# SageMath parsed files | 
			
		||||
 | 
				*.sage.py | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Environments | 
			
		||||
 | 
				.env | 
			
		||||
 | 
				.venv | 
			
		||||
 | 
				env/ | 
			
		||||
 | 
				venv/ | 
			
		||||
 | 
				ENV/ | 
			
		||||
 | 
				env.bak/ | 
			
		||||
 | 
				venv.bak/ | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Spyder project settings | 
			
		||||
 | 
				.spyderproject | 
			
		||||
 | 
				.spyproject | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Rope project settings | 
			
		||||
 | 
				.ropeproject | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# mkdocs documentation | 
			
		||||
 | 
				/site | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# mypy | 
			
		||||
 | 
				.mypy_cache/ | 
			
		||||
 | 
				.dmypy.json | 
			
		||||
 | 
				dmypy.json | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Pyre type checker | 
			
		||||
 | 
				.pyre/ | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# pytype static type analyzer | 
			
		||||
 | 
				.pytype/ | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Cython debug symbols | 
			
		||||
 | 
				cython_debug/ | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				### Windows ### | 
			
		||||
 | 
				# Windows thumbnail cache files | 
			
		||||
 | 
				Thumbs.db | 
			
		||||
 | 
				Thumbs.db:encryptable | 
			
		||||
 | 
				ehthumbs.db | 
			
		||||
 | 
				ehthumbs_vista.db | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Dump file | 
			
		||||
 | 
				*.stackdump | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Folder config file | 
			
		||||
 | 
				[Dd]esktop.ini | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Recycle Bin used on file shares | 
			
		||||
 | 
				$RECYCLE.BIN/ | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Windows Installer files | 
			
		||||
 | 
				*.cab | 
			
		||||
 | 
				*.msi | 
			
		||||
 | 
				*.msix | 
			
		||||
 | 
				*.msm | 
			
		||||
 | 
				*.msp | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Windows shortcuts | 
			
		||||
 | 
				*.lnk | 
			
		||||
@ -1,2 +1,56 @@ | 
			
		|||||
# vit-image-embedding | 
				 | 
			
		||||
 | 
				# ViT Embedding Operator | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				Authors: kyle he | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				## Overview | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				The ViT(Vision Transformer) is a model for image classification that employs a Transformer-like architecture over patches of the image. This includes the use of Multi-Head Attention, Scaled Dot-Product Attention and other architectural features seen in the Transformer architecture traditionally used for NLP[1], which is trained on [imagenet dataset](https://image-net.org/download.php). | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				## Interface | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				```python | 
			
		||||
 | 
				__init__(self, model_name: str = 'vit_large_patch16_224', | 
			
		||||
 | 
				                 framework: str = 'pytorch', weights_path: str = None) | 
			
		||||
 | 
				``` | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				**Args:** | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				- model_name: | 
			
		||||
 | 
				  - the model name for embedding | 
			
		||||
 | 
				  - supported types: `str`, for example 'vit_large_patch16_224' | 
			
		||||
 | 
				- framework: | 
			
		||||
 | 
				  - the framework of the model | 
			
		||||
 | 
				  - supported types: `str`, default is 'pytorch' | 
			
		||||
 | 
				- weights_path: | 
			
		||||
 | 
				  - the weights path | 
			
		||||
 | 
				  - supported types: `str`, default is None, using pretrained weights | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				```python | 
			
		||||
 | 
				__call__(self, img_path: str) | 
			
		||||
 | 
				``` | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				**Args:** | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				- img_path: | 
			
		||||
 | 
				  - the input image path | 
			
		||||
 | 
				  - supported types: `str` | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				**Returns:** | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				The Operator returns a tuple `Tuple[('embedding', numpy.ndarray)]` containing following fields: | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				- feature_vector: | 
			
		||||
 | 
				  - the embedding of the image | 
			
		||||
 | 
				  - data type: `numpy.ndarray` | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				## Requirements | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				You can get the required python package by [requirements.txt](./requirements.txt). | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				## How it works | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				The `towhee/vit-embedding` Operator implements the function of image embedding, which can add to the pipeline. For example, it's the key Operator named embedding_model within [vit-embedding](https://hub.towhee.io/towhee/vit-embedding) pipeline. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				## Reference | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				[1].https://arxiv.org/abs/2010.11929 | 
			
		||||
 | 
			
		|||||
@ -0,0 +1,21 @@ | 
			
		|||||
 | 
				# Copyright 2021 Zilliz. All rights reserved. | 
			
		||||
 | 
				# | 
			
		||||
 | 
				# Licensed under the Apache License, Version 2.0 (the "License"); | 
			
		||||
 | 
				# you may not use this file except in compliance with the License. | 
			
		||||
 | 
				# You may obtain a copy of the License at | 
			
		||||
 | 
				# | 
			
		||||
 | 
				#     http://www.apache.org/licenses/LICENSE-2.0 | 
			
		||||
 | 
				# | 
			
		||||
 | 
				# Unless required by applicable law or agreed to in writing, software | 
			
		||||
 | 
				# distributed under the License is distributed on an "AS IS" BASIS, | 
			
		||||
 | 
				# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
			
		||||
 | 
				# See the License for the specific language governing permissions and | 
			
		||||
 | 
				# limitations under the License. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import os | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# For requirements. | 
			
		||||
 | 
				try: | 
			
		||||
 | 
				    import timm | 
			
		||||
 | 
				except ModuleNotFoundError: | 
			
		||||
 | 
				    os.system('pip install timm') | 
			
		||||
@ -0,0 +1,13 @@ | 
			
		|||||
 | 
				# Copyright 2021 Zilliz. All rights reserved. | 
			
		||||
 | 
				# | 
			
		||||
 | 
				# Licensed under the Apache License, Version 2.0 (the "License"); | 
			
		||||
 | 
				# you may not use this file except in compliance with the License. | 
			
		||||
 | 
				# You may obtain a copy of the License at | 
			
		||||
 | 
				# | 
			
		||||
 | 
				#     http://www.apache.org/licenses/LICENSE-2.0 | 
			
		||||
 | 
				# | 
			
		||||
 | 
				# Unless required by applicable law or agreed to in writing, software | 
			
		||||
 | 
				# distributed under the License is distributed on an "AS IS" BASIS, | 
			
		||||
 | 
				# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
			
		||||
 | 
				# See the License for the specific language governing permissions and | 
			
		||||
 | 
				# limitations under the License. | 
			
		||||
@ -0,0 +1,39 @@ | 
			
		|||||
 | 
				# Copyright 2021 Zilliz. All rights reserved. | 
			
		||||
 | 
				# | 
			
		||||
 | 
				# Licensed under the Apache License, Version 2.0 (the "License"); | 
			
		||||
 | 
				# you may not use this file except in compliance with the License. | 
			
		||||
 | 
				# You may obtain a copy of the License at | 
			
		||||
 | 
				# | 
			
		||||
 | 
				#     http://www.apache.org/licenses/LICENSE-2.0 | 
			
		||||
 | 
				# | 
			
		||||
 | 
				# Unless required by applicable law or agreed to in writing, software | 
			
		||||
 | 
				# distributed under the License is distributed on an "AS IS" BASIS, | 
			
		||||
 | 
				# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
			
		||||
 | 
				# See the License for the specific language governing permissions and | 
			
		||||
 | 
				# limitations under the License. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import torch | 
			
		||||
 | 
				import timm | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class Model(): | 
			
		||||
 | 
				    """ | 
			
		||||
 | 
				    PyTorch model class | 
			
		||||
 | 
				    """ | 
			
		||||
 | 
				    def __init__(self, model_name: str, weights_path: str): | 
			
		||||
 | 
				        super().__init__() | 
			
		||||
 | 
				        if weights_path: | 
			
		||||
 | 
				            self._model = timm.create_model(model_name, checkpoint_path=weights_path, num_classes=0) | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            self._model = timm.create_model(model_name, pretrained=True, num_classes=0) | 
			
		||||
 | 
				        self._model.eval() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def __call__(self, img_tensor: torch.Tensor): | 
			
		||||
 | 
				        return self._model(img_tensor) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def train(self): | 
			
		||||
 | 
				        """ | 
			
		||||
 | 
				        For training model | 
			
		||||
 | 
				        """ | 
			
		||||
 | 
				        pass | 
			
		||||
@ -0,0 +1,52 @@ | 
			
		|||||
 | 
				# Copyright 2021 Zilliz. All rights reserved. | 
			
		||||
 | 
				# | 
			
		||||
 | 
				# Licensed under the Apache License, Version 2.0 (the "License"); | 
			
		||||
 | 
				# you may not use this file except in compliance with the License. | 
			
		||||
 | 
				# You may obtain a copy of the License at | 
			
		||||
 | 
				# | 
			
		||||
 | 
				#     http://www.apache.org/licenses/LICENSE-2.0 | 
			
		||||
 | 
				# | 
			
		||||
 | 
				# Unless required by applicable law or agreed to in writing, software | 
			
		||||
 | 
				# distributed under the License is distributed on an "AS IS" BASIS, | 
			
		||||
 | 
				# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
			
		||||
 | 
				# See the License for the specific language governing permissions and | 
			
		||||
 | 
				# limitations under the License. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import sys | 
			
		||||
 | 
				from typing import NamedTuple | 
			
		||||
 | 
				from pathlib import Path | 
			
		||||
 | 
				from PIL import Image | 
			
		||||
 | 
				import torch | 
			
		||||
 | 
				from timm.data import resolve_data_config | 
			
		||||
 | 
				from timm.data.transforms_factory import create_transform | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				from towhee.operator import Operator | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class VisionTransformerEmbeddingOperator(Operator): | 
			
		||||
 | 
				    """ | 
			
		||||
 | 
				    Embedding extractor using ViT. | 
			
		||||
 | 
				    Args: | 
			
		||||
 | 
				        model_name (`string`): | 
			
		||||
 | 
				            Model name. | 
			
		||||
 | 
				        weights_path (`string`): | 
			
		||||
 | 
				            Path to local weights. | 
			
		||||
 | 
				    """ | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def __init__(self, model_name: str = 'vit_large_patch16_224', | 
			
		||||
 | 
				                 framework: str = 'pytorch', weights_path: str = None) -> None: | 
			
		||||
 | 
				        super().__init__() | 
			
		||||
 | 
				        sys.path.append(str(Path(__file__).parent)) | 
			
		||||
 | 
				        if framework == 'pytorch': | 
			
		||||
 | 
				            from vit_embedding.pytorch.model import Model | 
			
		||||
 | 
				        self.model = Model(model_name, weights_path) | 
			
		||||
 | 
				        config = resolve_data_config({}, model=self.model._model) | 
			
		||||
 | 
				        self.tfms = create_transform(**config) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def __call__(self, img_path: str) -> NamedTuple('Outputs', [('embedding', torch.Tensor)]): | 
			
		||||
 | 
				        Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)]) | 
			
		||||
 | 
				        img = self.tfms(Image.open(img_path)).unsqueeze(0) | 
			
		||||
 | 
				        features = self.model(img) | 
			
		||||
 | 
				        return Outputs(features.flatten().detach().numpy()) | 
			
		||||
@ -0,0 +1,13 @@ | 
			
		|||||
 | 
				name: 'vit-embedding' | 
			
		||||
 | 
				labels:  | 
			
		||||
 | 
				  recommended_framework: pytorch1.2.0 | 
			
		||||
 | 
				  class: vit-embedding | 
			
		||||
 | 
				  others: vit | 
			
		||||
 | 
				operator: 'towhee/vit-embedding' | 
			
		||||
 | 
				init: | 
			
		||||
 | 
				  model_name: str | 
			
		||||
 | 
				call: | 
			
		||||
 | 
				  input: | 
			
		||||
 | 
				    img_path: str | 
			
		||||
 | 
				  output: | 
			
		||||
 | 
				    feature_vector: numpy.ndarray | 
			
		||||
					Loading…
					
					
				
		Reference in new issue