From dad377dd4504eca3070615a5f6f5147abcdcc7bc Mon Sep 17 00:00:00 2001 From: shiyu22 Date: Wed, 15 Dec 2021 19:20:58 +0800 Subject: [PATCH] Update Signed-off-by: shiyu22 --- .gitattributes | 55 ++++++----- .gitignore | 209 +++++++++++++++++++++++++++++++++++++++ README.md | 56 ++++++++++- __init__.py | 21 ++++ pytorch/__init__.py | 13 +++ pytorch/model.py | 39 ++++++++ requirements.txt | 0 vit_image_embedding.py | 52 ++++++++++ vit_image_embedding.yaml | 13 +++ 9 files changed, 433 insertions(+), 25 deletions(-) create mode 100644 .gitignore create mode 100644 __init__.py create mode 100644 pytorch/__init__.py create mode 100644 pytorch/model.py create mode 100644 requirements.txt create mode 100644 vit_image_embedding.py create mode 100644 vit_image_embedding.yaml diff --git a/.gitattributes b/.gitattributes index ad2c207..705e663 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,28 +1,35 @@ +# .gitattributes -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text +# Source files +# ============ +*.pxd text diff=python +*.py text diff=python +*.py3 text diff=python +*.pyw text diff=python +*.pyx text diff=python +*.pyz text diff=python +*.pyi text diff=python + +# Binary files +# ============ +*.db binary +*.p binary +*.pkl binary +*.pickle binary +*.pyc binary export-ignore +*.pyo binary export-ignore +*.pyd binary + +# Jupyter notebook +*.ipynb text + +# Model files *.bin.* filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text *.lfs.* filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text *.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zstandard filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text +*.tar.gz filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..32030ff --- /dev/null +++ b/.gitignore @@ -0,0 +1,209 @@ +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### OSX ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +### Windows ### +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk \ No newline at end of file diff --git a/README.md b/README.md index a339cff..68d28c0 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,56 @@ -# vit-image-embedding +# ViT Embedding Operator +Authors: kyle he + +## Overview + +The ViT(Vision Transformer) is a model for image classification that employs a Transformer-like architecture over patches of the image. This includes the use of Multi-Head Attention, Scaled Dot-Product Attention and other architectural features seen in the Transformer architecture traditionally used for NLP[1], which is trained on [imagenet dataset](https://image-net.org/download.php). + +## Interface + +```python +__init__(self, model_name: str = 'vit_large_patch16_224', + framework: str = 'pytorch', weights_path: str = None) +``` + +**Args:** + +- model_name: + - the model name for embedding + - supported types: `str`, for example 'vit_large_patch16_224' +- framework: + - the framework of the model + - supported types: `str`, default is 'pytorch' +- weights_path: + - the weights path + - supported types: `str`, default is None, using pretrained weights + +```python +__call__(self, img_path: str) +``` + +**Args:** + +- img_path: + - the input image path + - supported types: `str` + +**Returns:** + +The Operator returns a tuple `Tuple[('embedding', numpy.ndarray)]` containing following fields: + +- feature_vector: + - the embedding of the image + - data type: `numpy.ndarray` + +## Requirements + +You can get the required python package by [requirements.txt](./requirements.txt). + +## How it works + +The `towhee/vit-embedding` Operator implements the function of image embedding, which can add to the pipeline. For example, it's the key Operator named embedding_model within [vit-embedding](https://hub.towhee.io/towhee/vit-embedding) pipeline. + +## Reference + +[1].https://arxiv.org/abs/2010.11929 diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..8516979 --- /dev/null +++ b/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +# For requirements. +try: + import timm +except ModuleNotFoundError: + os.system('pip install timm') diff --git a/pytorch/__init__.py b/pytorch/__init__.py new file mode 100644 index 0000000..37f5bd7 --- /dev/null +++ b/pytorch/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/pytorch/model.py b/pytorch/model.py new file mode 100644 index 0000000..007b645 --- /dev/null +++ b/pytorch/model.py @@ -0,0 +1,39 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import timm + + +class Model(): + """ + PyTorch model class + """ + def __init__(self, model_name: str, weights_path: str): + super().__init__() + if weights_path: + self._model = timm.create_model(model_name, checkpoint_path=weights_path, num_classes=0) + else: + self._model = timm.create_model(model_name, pretrained=True, num_classes=0) + self._model.eval() + + def __call__(self, img_tensor: torch.Tensor): + return self._model(img_tensor) + + def train(self): + """ + For training model + """ + pass diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e69de29 diff --git a/vit_image_embedding.py b/vit_image_embedding.py new file mode 100644 index 0000000..5d4706e --- /dev/null +++ b/vit_image_embedding.py @@ -0,0 +1,52 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys +from typing import NamedTuple +from pathlib import Path +from PIL import Image +import torch +from timm.data import resolve_data_config +from timm.data.transforms_factory import create_transform + + +from towhee.operator import Operator + + +class VisionTransformerEmbeddingOperator(Operator): + """ + Embedding extractor using ViT. + Args: + model_name (`string`): + Model name. + weights_path (`string`): + Path to local weights. + """ + + def __init__(self, model_name: str = 'vit_large_patch16_224', + framework: str = 'pytorch', weights_path: str = None) -> None: + super().__init__() + sys.path.append(str(Path(__file__).parent)) + if framework == 'pytorch': + from vit_embedding.pytorch.model import Model + self.model = Model(model_name, weights_path) + config = resolve_data_config({}, model=self.model._model) + self.tfms = create_transform(**config) + + def __call__(self, img_path: str) -> NamedTuple('Outputs', [('embedding', torch.Tensor)]): + Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)]) + img = self.tfms(Image.open(img_path)).unsqueeze(0) + features = self.model(img) + return Outputs(features.flatten().detach().numpy()) diff --git a/vit_image_embedding.yaml b/vit_image_embedding.yaml new file mode 100644 index 0000000..fa903c1 --- /dev/null +++ b/vit_image_embedding.yaml @@ -0,0 +1,13 @@ +name: 'vit-embedding' +labels: + recommended_framework: pytorch1.2.0 + class: vit-embedding + others: vit +operator: 'towhee/vit-embedding' +init: + model_name: str +call: + input: + img_path: str + output: + feature_vector: numpy.ndarray