towhee
/
            
              vit-image-embedding
              
                
                
            
          copied
				 9 changed files with 433 additions and 25 deletions
			
			
		@ -1,28 +1,35 @@ | 
				
			|||
# .gitattributes | 
				
			|||
 | 
				
			|||
*.7z filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.arrow filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.bin filter=lfs diff=lfs merge=lfs -text | 
				
			|||
# Source files | 
				
			|||
# ============ | 
				
			|||
*.pxd    text diff=python | 
				
			|||
*.py     text diff=python | 
				
			|||
*.py3    text diff=python | 
				
			|||
*.pyw    text diff=python | 
				
			|||
*.pyx    text diff=python | 
				
			|||
*.pyz    text diff=python | 
				
			|||
*.pyi    text diff=python | 
				
			|||
 | 
				
			|||
# Binary files | 
				
			|||
# ============ | 
				
			|||
*.db     binary | 
				
			|||
*.p      binary | 
				
			|||
*.pkl    binary | 
				
			|||
*.pickle binary | 
				
			|||
*.pyc    binary export-ignore | 
				
			|||
*.pyo    binary export-ignore | 
				
			|||
*.pyd    binary | 
				
			|||
 | 
				
			|||
# Jupyter notebook | 
				
			|||
*.ipynb  text | 
				
			|||
 | 
				
			|||
# Model files | 
				
			|||
*.bin.* filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.bz2 filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.ftz filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.gz filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.h5 filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.joblib filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.lfs.* filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.model filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.msgpack filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.onnx filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.ot filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.parquet filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.pb filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.pt filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.pth filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.rar filter=lfs diff=lfs merge=lfs -text | 
				
			|||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text  | 
				
			|||
*.tar.* filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.bin filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.h5 filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.tflite filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.tgz filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.xz filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.zip filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.zstandard filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*tfevents* filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.tar.gz filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.ot filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.onnx filter=lfs diff=lfs merge=lfs -text | 
				
			|||
*.msgpack filter=lfs diff=lfs merge=lfs -text | 
				
			|||
 | 
				
			|||
@ -0,0 +1,209 @@ | 
				
			|||
### Linux ### | 
				
			|||
*~ | 
				
			|||
 | 
				
			|||
# temporary files which can be created if a process still has a handle open of a deleted file | 
				
			|||
.fuse_hidden* | 
				
			|||
 | 
				
			|||
# KDE directory preferences | 
				
			|||
.directory | 
				
			|||
 | 
				
			|||
# Linux trash folder which might appear on any partition or disk | 
				
			|||
.Trash-* | 
				
			|||
 | 
				
			|||
# .nfs files are created when an open file is removed but is still being accessed | 
				
			|||
.nfs* | 
				
			|||
 | 
				
			|||
### OSX ### | 
				
			|||
# General | 
				
			|||
.DS_Store | 
				
			|||
.AppleDouble | 
				
			|||
.LSOverride | 
				
			|||
 | 
				
			|||
# Icon must end with two \r | 
				
			|||
Icon | 
				
			|||
 | 
				
			|||
 | 
				
			|||
# Thumbnails | 
				
			|||
._* | 
				
			|||
 | 
				
			|||
# Files that might appear in the root of a volume | 
				
			|||
.DocumentRevisions-V100 | 
				
			|||
.fseventsd | 
				
			|||
.Spotlight-V100 | 
				
			|||
.TemporaryItems | 
				
			|||
.Trashes | 
				
			|||
.VolumeIcon.icns | 
				
			|||
.com.apple.timemachine.donotpresent | 
				
			|||
 | 
				
			|||
# Directories potentially created on remote AFP share | 
				
			|||
.AppleDB | 
				
			|||
.AppleDesktop | 
				
			|||
Network Trash Folder | 
				
			|||
Temporary Items | 
				
			|||
.apdisk | 
				
			|||
 | 
				
			|||
### Python ### | 
				
			|||
# Byte-compiled / optimized / DLL files | 
				
			|||
__pycache__/ | 
				
			|||
*.py[cod] | 
				
			|||
*$py.class | 
				
			|||
 | 
				
			|||
# C extensions | 
				
			|||
*.so | 
				
			|||
 | 
				
			|||
# Distribution / packaging | 
				
			|||
.Python | 
				
			|||
build/ | 
				
			|||
develop-eggs/ | 
				
			|||
dist/ | 
				
			|||
downloads/ | 
				
			|||
eggs/ | 
				
			|||
.eggs/ | 
				
			|||
lib/ | 
				
			|||
lib64/ | 
				
			|||
parts/ | 
				
			|||
sdist/ | 
				
			|||
var/ | 
				
			|||
wheels/ | 
				
			|||
share/python-wheels/ | 
				
			|||
*.egg-info/ | 
				
			|||
.installed.cfg | 
				
			|||
*.egg | 
				
			|||
MANIFEST | 
				
			|||
 | 
				
			|||
# PyInstaller | 
				
			|||
#  Usually these files are written by a python script from a template | 
				
			|||
#  before PyInstaller builds the exe, so as to inject date/other infos into it. | 
				
			|||
*.manifest | 
				
			|||
*.spec | 
				
			|||
 | 
				
			|||
# Installer logs | 
				
			|||
pip-log.txt | 
				
			|||
pip-delete-this-directory.txt | 
				
			|||
 | 
				
			|||
# Unit test / coverage reports | 
				
			|||
htmlcov/ | 
				
			|||
.tox/ | 
				
			|||
.nox/ | 
				
			|||
.coverage | 
				
			|||
.coverage.* | 
				
			|||
.cache | 
				
			|||
nosetests.xml | 
				
			|||
coverage.xml | 
				
			|||
*.cover | 
				
			|||
*.py,cover | 
				
			|||
.hypothesis/ | 
				
			|||
.pytest_cache/ | 
				
			|||
cover/ | 
				
			|||
 | 
				
			|||
# Translations | 
				
			|||
*.mo | 
				
			|||
*.pot | 
				
			|||
 | 
				
			|||
# Django stuff: | 
				
			|||
*.log | 
				
			|||
local_settings.py | 
				
			|||
db.sqlite3 | 
				
			|||
db.sqlite3-journal | 
				
			|||
 | 
				
			|||
# Flask stuff: | 
				
			|||
instance/ | 
				
			|||
.webassets-cache | 
				
			|||
 | 
				
			|||
# Scrapy stuff: | 
				
			|||
.scrapy | 
				
			|||
 | 
				
			|||
# Sphinx documentation | 
				
			|||
docs/_build/ | 
				
			|||
 | 
				
			|||
# PyBuilder | 
				
			|||
.pybuilder/ | 
				
			|||
target/ | 
				
			|||
 | 
				
			|||
# Jupyter Notebook | 
				
			|||
.ipynb_checkpoints | 
				
			|||
 | 
				
			|||
# IPython | 
				
			|||
profile_default/ | 
				
			|||
ipython_config.py | 
				
			|||
 | 
				
			|||
# pyenv | 
				
			|||
#   For a library or package, you might want to ignore these files since the code is | 
				
			|||
#   intended to run in multiple environments; otherwise, check them in: | 
				
			|||
# .python-version | 
				
			|||
 | 
				
			|||
# pipenv | 
				
			|||
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | 
				
			|||
#   However, in case of collaboration, if having platform-specific dependencies or dependencies | 
				
			|||
#   having no cross-platform support, pipenv may install dependencies that don't work, or not | 
				
			|||
#   install all needed dependencies. | 
				
			|||
#Pipfile.lock | 
				
			|||
 | 
				
			|||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | 
				
			|||
__pypackages__/ | 
				
			|||
 | 
				
			|||
# Celery stuff | 
				
			|||
celerybeat-schedule | 
				
			|||
celerybeat.pid | 
				
			|||
 | 
				
			|||
# SageMath parsed files | 
				
			|||
*.sage.py | 
				
			|||
 | 
				
			|||
# Environments | 
				
			|||
.env | 
				
			|||
.venv | 
				
			|||
env/ | 
				
			|||
venv/ | 
				
			|||
ENV/ | 
				
			|||
env.bak/ | 
				
			|||
venv.bak/ | 
				
			|||
 | 
				
			|||
# Spyder project settings | 
				
			|||
.spyderproject | 
				
			|||
.spyproject | 
				
			|||
 | 
				
			|||
# Rope project settings | 
				
			|||
.ropeproject | 
				
			|||
 | 
				
			|||
# mkdocs documentation | 
				
			|||
/site | 
				
			|||
 | 
				
			|||
# mypy | 
				
			|||
.mypy_cache/ | 
				
			|||
.dmypy.json | 
				
			|||
dmypy.json | 
				
			|||
 | 
				
			|||
# Pyre type checker | 
				
			|||
.pyre/ | 
				
			|||
 | 
				
			|||
# pytype static type analyzer | 
				
			|||
.pytype/ | 
				
			|||
 | 
				
			|||
# Cython debug symbols | 
				
			|||
cython_debug/ | 
				
			|||
 | 
				
			|||
### Windows ### | 
				
			|||
# Windows thumbnail cache files | 
				
			|||
Thumbs.db | 
				
			|||
Thumbs.db:encryptable | 
				
			|||
ehthumbs.db | 
				
			|||
ehthumbs_vista.db | 
				
			|||
 | 
				
			|||
# Dump file | 
				
			|||
*.stackdump | 
				
			|||
 | 
				
			|||
# Folder config file | 
				
			|||
[Dd]esktop.ini | 
				
			|||
 | 
				
			|||
# Recycle Bin used on file shares | 
				
			|||
$RECYCLE.BIN/ | 
				
			|||
 | 
				
			|||
# Windows Installer files | 
				
			|||
*.cab | 
				
			|||
*.msi | 
				
			|||
*.msix | 
				
			|||
*.msm | 
				
			|||
*.msp | 
				
			|||
 | 
				
			|||
# Windows shortcuts | 
				
			|||
*.lnk | 
				
			|||
@ -1,2 +1,56 @@ | 
				
			|||
# vit-image-embedding | 
				
			|||
# ViT Embedding Operator | 
				
			|||
 | 
				
			|||
Authors: kyle he | 
				
			|||
 | 
				
			|||
## Overview | 
				
			|||
 | 
				
			|||
The ViT(Vision Transformer) is a model for image classification that employs a Transformer-like architecture over patches of the image. This includes the use of Multi-Head Attention, Scaled Dot-Product Attention and other architectural features seen in the Transformer architecture traditionally used for NLP[1], which is trained on [imagenet dataset](https://image-net.org/download.php). | 
				
			|||
 | 
				
			|||
## Interface | 
				
			|||
 | 
				
			|||
```python | 
				
			|||
__init__(self, model_name: str = 'vit_large_patch16_224', | 
				
			|||
                 framework: str = 'pytorch', weights_path: str = None) | 
				
			|||
``` | 
				
			|||
 | 
				
			|||
**Args:** | 
				
			|||
 | 
				
			|||
- model_name: | 
				
			|||
  - the model name for embedding | 
				
			|||
  - supported types: `str`, for example 'vit_large_patch16_224' | 
				
			|||
- framework: | 
				
			|||
  - the framework of the model | 
				
			|||
  - supported types: `str`, default is 'pytorch' | 
				
			|||
- weights_path: | 
				
			|||
  - the weights path | 
				
			|||
  - supported types: `str`, default is None, using pretrained weights | 
				
			|||
 | 
				
			|||
```python | 
				
			|||
__call__(self, img_path: str) | 
				
			|||
``` | 
				
			|||
 | 
				
			|||
**Args:** | 
				
			|||
 | 
				
			|||
- img_path: | 
				
			|||
  - the input image path | 
				
			|||
  - supported types: `str` | 
				
			|||
 | 
				
			|||
**Returns:** | 
				
			|||
 | 
				
			|||
The Operator returns a tuple `Tuple[('embedding', numpy.ndarray)]` containing following fields: | 
				
			|||
 | 
				
			|||
- feature_vector: | 
				
			|||
  - the embedding of the image | 
				
			|||
  - data type: `numpy.ndarray` | 
				
			|||
 | 
				
			|||
## Requirements | 
				
			|||
 | 
				
			|||
You can get the required python package by [requirements.txt](./requirements.txt). | 
				
			|||
 | 
				
			|||
## How it works | 
				
			|||
 | 
				
			|||
The `towhee/vit-embedding` Operator implements the function of image embedding, which can add to the pipeline. For example, it's the key Operator named embedding_model within [vit-embedding](https://hub.towhee.io/towhee/vit-embedding) pipeline. | 
				
			|||
 | 
				
			|||
## Reference | 
				
			|||
 | 
				
			|||
[1].https://arxiv.org/abs/2010.11929 | 
				
			|||
 | 
				
			|||
@ -0,0 +1,21 @@ | 
				
			|||
# Copyright 2021 Zilliz. All rights reserved. | 
				
			|||
# | 
				
			|||
# Licensed under the Apache License, Version 2.0 (the "License"); | 
				
			|||
# you may not use this file except in compliance with the License. | 
				
			|||
# You may obtain a copy of the License at | 
				
			|||
# | 
				
			|||
#     http://www.apache.org/licenses/LICENSE-2.0 | 
				
			|||
# | 
				
			|||
# Unless required by applicable law or agreed to in writing, software | 
				
			|||
# distributed under the License is distributed on an "AS IS" BASIS, | 
				
			|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
				
			|||
# See the License for the specific language governing permissions and | 
				
			|||
# limitations under the License. | 
				
			|||
 | 
				
			|||
import os | 
				
			|||
 | 
				
			|||
# For requirements. | 
				
			|||
try: | 
				
			|||
    import timm | 
				
			|||
except ModuleNotFoundError: | 
				
			|||
    os.system('pip install timm') | 
				
			|||
@ -0,0 +1,13 @@ | 
				
			|||
# Copyright 2021 Zilliz. All rights reserved. | 
				
			|||
# | 
				
			|||
# Licensed under the Apache License, Version 2.0 (the "License"); | 
				
			|||
# you may not use this file except in compliance with the License. | 
				
			|||
# You may obtain a copy of the License at | 
				
			|||
# | 
				
			|||
#     http://www.apache.org/licenses/LICENSE-2.0 | 
				
			|||
# | 
				
			|||
# Unless required by applicable law or agreed to in writing, software | 
				
			|||
# distributed under the License is distributed on an "AS IS" BASIS, | 
				
			|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
				
			|||
# See the License for the specific language governing permissions and | 
				
			|||
# limitations under the License. | 
				
			|||
@ -0,0 +1,39 @@ | 
				
			|||
# Copyright 2021 Zilliz. All rights reserved. | 
				
			|||
# | 
				
			|||
# Licensed under the Apache License, Version 2.0 (the "License"); | 
				
			|||
# you may not use this file except in compliance with the License. | 
				
			|||
# You may obtain a copy of the License at | 
				
			|||
# | 
				
			|||
#     http://www.apache.org/licenses/LICENSE-2.0 | 
				
			|||
# | 
				
			|||
# Unless required by applicable law or agreed to in writing, software | 
				
			|||
# distributed under the License is distributed on an "AS IS" BASIS, | 
				
			|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
				
			|||
# See the License for the specific language governing permissions and | 
				
			|||
# limitations under the License. | 
				
			|||
 | 
				
			|||
 | 
				
			|||
import torch | 
				
			|||
import timm | 
				
			|||
 | 
				
			|||
 | 
				
			|||
class Model(): | 
				
			|||
    """ | 
				
			|||
    PyTorch model class | 
				
			|||
    """ | 
				
			|||
    def __init__(self, model_name: str, weights_path: str): | 
				
			|||
        super().__init__() | 
				
			|||
        if weights_path: | 
				
			|||
            self._model = timm.create_model(model_name, checkpoint_path=weights_path, num_classes=0) | 
				
			|||
        else: | 
				
			|||
            self._model = timm.create_model(model_name, pretrained=True, num_classes=0) | 
				
			|||
        self._model.eval() | 
				
			|||
 | 
				
			|||
    def __call__(self, img_tensor: torch.Tensor): | 
				
			|||
        return self._model(img_tensor) | 
				
			|||
 | 
				
			|||
    def train(self): | 
				
			|||
        """ | 
				
			|||
        For training model | 
				
			|||
        """ | 
				
			|||
        pass | 
				
			|||
@ -0,0 +1,52 @@ | 
				
			|||
# Copyright 2021 Zilliz. All rights reserved. | 
				
			|||
# | 
				
			|||
# Licensed under the Apache License, Version 2.0 (the "License"); | 
				
			|||
# you may not use this file except in compliance with the License. | 
				
			|||
# You may obtain a copy of the License at | 
				
			|||
# | 
				
			|||
#     http://www.apache.org/licenses/LICENSE-2.0 | 
				
			|||
# | 
				
			|||
# Unless required by applicable law or agreed to in writing, software | 
				
			|||
# distributed under the License is distributed on an "AS IS" BASIS, | 
				
			|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
				
			|||
# See the License for the specific language governing permissions and | 
				
			|||
# limitations under the License. | 
				
			|||
 | 
				
			|||
 | 
				
			|||
import sys | 
				
			|||
from typing import NamedTuple | 
				
			|||
from pathlib import Path | 
				
			|||
from PIL import Image | 
				
			|||
import torch | 
				
			|||
from timm.data import resolve_data_config | 
				
			|||
from timm.data.transforms_factory import create_transform | 
				
			|||
 | 
				
			|||
 | 
				
			|||
from towhee.operator import Operator | 
				
			|||
 | 
				
			|||
 | 
				
			|||
class VisionTransformerEmbeddingOperator(Operator): | 
				
			|||
    """ | 
				
			|||
    Embedding extractor using ViT. | 
				
			|||
    Args: | 
				
			|||
        model_name (`string`): | 
				
			|||
            Model name. | 
				
			|||
        weights_path (`string`): | 
				
			|||
            Path to local weights. | 
				
			|||
    """ | 
				
			|||
 | 
				
			|||
    def __init__(self, model_name: str = 'vit_large_patch16_224', | 
				
			|||
                 framework: str = 'pytorch', weights_path: str = None) -> None: | 
				
			|||
        super().__init__() | 
				
			|||
        sys.path.append(str(Path(__file__).parent)) | 
				
			|||
        if framework == 'pytorch': | 
				
			|||
            from vit_embedding.pytorch.model import Model | 
				
			|||
        self.model = Model(model_name, weights_path) | 
				
			|||
        config = resolve_data_config({}, model=self.model._model) | 
				
			|||
        self.tfms = create_transform(**config) | 
				
			|||
 | 
				
			|||
    def __call__(self, img_path: str) -> NamedTuple('Outputs', [('embedding', torch.Tensor)]): | 
				
			|||
        Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)]) | 
				
			|||
        img = self.tfms(Image.open(img_path)).unsqueeze(0) | 
				
			|||
        features = self.model(img) | 
				
			|||
        return Outputs(features.flatten().detach().numpy()) | 
				
			|||
@ -0,0 +1,13 @@ | 
				
			|||
name: 'vit-embedding' | 
				
			|||
labels:  | 
				
			|||
  recommended_framework: pytorch1.2.0 | 
				
			|||
  class: vit-embedding | 
				
			|||
  others: vit | 
				
			|||
operator: 'towhee/vit-embedding' | 
				
			|||
init: | 
				
			|||
  model_name: str | 
				
			|||
call: | 
				
			|||
  input: | 
				
			|||
    img_path: str | 
				
			|||
  output: | 
				
			|||
    feature_vector: numpy.ndarray | 
				
			|||
					Loading…
					
					
				
		Reference in new issue