towhee
/
vit-image-embedding
copied
9 changed files with 433 additions and 25 deletions
@ -1,28 +1,35 @@ |
|||||
|
# .gitattributes |
||||
|
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text |
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text |
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text |
|
||||
|
# Source files |
||||
|
# ============ |
||||
|
*.pxd text diff=python |
||||
|
*.py text diff=python |
||||
|
*.py3 text diff=python |
||||
|
*.pyw text diff=python |
||||
|
*.pyx text diff=python |
||||
|
*.pyz text diff=python |
||||
|
*.pyi text diff=python |
||||
|
|
||||
|
# Binary files |
||||
|
# ============ |
||||
|
*.db binary |
||||
|
*.p binary |
||||
|
*.pkl binary |
||||
|
*.pickle binary |
||||
|
*.pyc binary export-ignore |
||||
|
*.pyo binary export-ignore |
||||
|
*.pyd binary |
||||
|
|
||||
|
# Jupyter notebook |
||||
|
*.ipynb text |
||||
|
|
||||
|
# Model files |
||||
*.bin.* filter=lfs diff=lfs merge=lfs -text |
*.bin.* filter=lfs diff=lfs merge=lfs -text |
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text |
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text |
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text |
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text |
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text |
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text |
*.lfs.* filter=lfs diff=lfs merge=lfs -text |
||||
*.model filter=lfs diff=lfs merge=lfs -text |
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text |
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text |
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text |
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text |
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text |
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text |
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text |
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text |
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text |
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text |
|
||||
|
*.bin filter=lfs diff=lfs merge=lfs -text |
||||
|
*.h5 filter=lfs diff=lfs merge=lfs -text |
||||
*.tflite filter=lfs diff=lfs merge=lfs -text |
*.tflite filter=lfs diff=lfs merge=lfs -text |
||||
*.tgz filter=lfs diff=lfs merge=lfs -text |
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text |
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text |
|
||||
*.zstandard filter=lfs diff=lfs merge=lfs -text |
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text |
|
||||
|
*.tar.gz filter=lfs diff=lfs merge=lfs -text |
||||
|
*.ot filter=lfs diff=lfs merge=lfs -text |
||||
|
*.onnx filter=lfs diff=lfs merge=lfs -text |
||||
|
*.msgpack filter=lfs diff=lfs merge=lfs -text |
||||
|
@ -0,0 +1,209 @@ |
|||||
|
### Linux ### |
||||
|
*~ |
||||
|
|
||||
|
# temporary files which can be created if a process still has a handle open of a deleted file |
||||
|
.fuse_hidden* |
||||
|
|
||||
|
# KDE directory preferences |
||||
|
.directory |
||||
|
|
||||
|
# Linux trash folder which might appear on any partition or disk |
||||
|
.Trash-* |
||||
|
|
||||
|
# .nfs files are created when an open file is removed but is still being accessed |
||||
|
.nfs* |
||||
|
|
||||
|
### OSX ### |
||||
|
# General |
||||
|
.DS_Store |
||||
|
.AppleDouble |
||||
|
.LSOverride |
||||
|
|
||||
|
# Icon must end with two \r |
||||
|
Icon |
||||
|
|
||||
|
|
||||
|
# Thumbnails |
||||
|
._* |
||||
|
|
||||
|
# Files that might appear in the root of a volume |
||||
|
.DocumentRevisions-V100 |
||||
|
.fseventsd |
||||
|
.Spotlight-V100 |
||||
|
.TemporaryItems |
||||
|
.Trashes |
||||
|
.VolumeIcon.icns |
||||
|
.com.apple.timemachine.donotpresent |
||||
|
|
||||
|
# Directories potentially created on remote AFP share |
||||
|
.AppleDB |
||||
|
.AppleDesktop |
||||
|
Network Trash Folder |
||||
|
Temporary Items |
||||
|
.apdisk |
||||
|
|
||||
|
### Python ### |
||||
|
# Byte-compiled / optimized / DLL files |
||||
|
__pycache__/ |
||||
|
*.py[cod] |
||||
|
*$py.class |
||||
|
|
||||
|
# C extensions |
||||
|
*.so |
||||
|
|
||||
|
# Distribution / packaging |
||||
|
.Python |
||||
|
build/ |
||||
|
develop-eggs/ |
||||
|
dist/ |
||||
|
downloads/ |
||||
|
eggs/ |
||||
|
.eggs/ |
||||
|
lib/ |
||||
|
lib64/ |
||||
|
parts/ |
||||
|
sdist/ |
||||
|
var/ |
||||
|
wheels/ |
||||
|
share/python-wheels/ |
||||
|
*.egg-info/ |
||||
|
.installed.cfg |
||||
|
*.egg |
||||
|
MANIFEST |
||||
|
|
||||
|
# PyInstaller |
||||
|
# Usually these files are written by a python script from a template |
||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it. |
||||
|
*.manifest |
||||
|
*.spec |
||||
|
|
||||
|
# Installer logs |
||||
|
pip-log.txt |
||||
|
pip-delete-this-directory.txt |
||||
|
|
||||
|
# Unit test / coverage reports |
||||
|
htmlcov/ |
||||
|
.tox/ |
||||
|
.nox/ |
||||
|
.coverage |
||||
|
.coverage.* |
||||
|
.cache |
||||
|
nosetests.xml |
||||
|
coverage.xml |
||||
|
*.cover |
||||
|
*.py,cover |
||||
|
.hypothesis/ |
||||
|
.pytest_cache/ |
||||
|
cover/ |
||||
|
|
||||
|
# Translations |
||||
|
*.mo |
||||
|
*.pot |
||||
|
|
||||
|
# Django stuff: |
||||
|
*.log |
||||
|
local_settings.py |
||||
|
db.sqlite3 |
||||
|
db.sqlite3-journal |
||||
|
|
||||
|
# Flask stuff: |
||||
|
instance/ |
||||
|
.webassets-cache |
||||
|
|
||||
|
# Scrapy stuff: |
||||
|
.scrapy |
||||
|
|
||||
|
# Sphinx documentation |
||||
|
docs/_build/ |
||||
|
|
||||
|
# PyBuilder |
||||
|
.pybuilder/ |
||||
|
target/ |
||||
|
|
||||
|
# Jupyter Notebook |
||||
|
.ipynb_checkpoints |
||||
|
|
||||
|
# IPython |
||||
|
profile_default/ |
||||
|
ipython_config.py |
||||
|
|
||||
|
# pyenv |
||||
|
# For a library or package, you might want to ignore these files since the code is |
||||
|
# intended to run in multiple environments; otherwise, check them in: |
||||
|
# .python-version |
||||
|
|
||||
|
# pipenv |
||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. |
||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies |
||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not |
||||
|
# install all needed dependencies. |
||||
|
#Pipfile.lock |
||||
|
|
||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow |
||||
|
__pypackages__/ |
||||
|
|
||||
|
# Celery stuff |
||||
|
celerybeat-schedule |
||||
|
celerybeat.pid |
||||
|
|
||||
|
# SageMath parsed files |
||||
|
*.sage.py |
||||
|
|
||||
|
# Environments |
||||
|
.env |
||||
|
.venv |
||||
|
env/ |
||||
|
venv/ |
||||
|
ENV/ |
||||
|
env.bak/ |
||||
|
venv.bak/ |
||||
|
|
||||
|
# Spyder project settings |
||||
|
.spyderproject |
||||
|
.spyproject |
||||
|
|
||||
|
# Rope project settings |
||||
|
.ropeproject |
||||
|
|
||||
|
# mkdocs documentation |
||||
|
/site |
||||
|
|
||||
|
# mypy |
||||
|
.mypy_cache/ |
||||
|
.dmypy.json |
||||
|
dmypy.json |
||||
|
|
||||
|
# Pyre type checker |
||||
|
.pyre/ |
||||
|
|
||||
|
# pytype static type analyzer |
||||
|
.pytype/ |
||||
|
|
||||
|
# Cython debug symbols |
||||
|
cython_debug/ |
||||
|
|
||||
|
### Windows ### |
||||
|
# Windows thumbnail cache files |
||||
|
Thumbs.db |
||||
|
Thumbs.db:encryptable |
||||
|
ehthumbs.db |
||||
|
ehthumbs_vista.db |
||||
|
|
||||
|
# Dump file |
||||
|
*.stackdump |
||||
|
|
||||
|
# Folder config file |
||||
|
[Dd]esktop.ini |
||||
|
|
||||
|
# Recycle Bin used on file shares |
||||
|
$RECYCLE.BIN/ |
||||
|
|
||||
|
# Windows Installer files |
||||
|
*.cab |
||||
|
*.msi |
||||
|
*.msix |
||||
|
*.msm |
||||
|
*.msp |
||||
|
|
||||
|
# Windows shortcuts |
||||
|
*.lnk |
@ -1,2 +1,56 @@ |
|||||
# vit-image-embedding |
|
||||
|
# ViT Embedding Operator |
||||
|
|
||||
|
Authors: kyle he |
||||
|
|
||||
|
## Overview |
||||
|
|
||||
|
The ViT(Vision Transformer) is a model for image classification that employs a Transformer-like architecture over patches of the image. This includes the use of Multi-Head Attention, Scaled Dot-Product Attention and other architectural features seen in the Transformer architecture traditionally used for NLP[1], which is trained on [imagenet dataset](https://image-net.org/download.php). |
||||
|
|
||||
|
## Interface |
||||
|
|
||||
|
```python |
||||
|
__init__(self, model_name: str = 'vit_large_patch16_224', |
||||
|
framework: str = 'pytorch', weights_path: str = None) |
||||
|
``` |
||||
|
|
||||
|
**Args:** |
||||
|
|
||||
|
- model_name: |
||||
|
- the model name for embedding |
||||
|
- supported types: `str`, for example 'vit_large_patch16_224' |
||||
|
- framework: |
||||
|
- the framework of the model |
||||
|
- supported types: `str`, default is 'pytorch' |
||||
|
- weights_path: |
||||
|
- the weights path |
||||
|
- supported types: `str`, default is None, using pretrained weights |
||||
|
|
||||
|
```python |
||||
|
__call__(self, img_path: str) |
||||
|
``` |
||||
|
|
||||
|
**Args:** |
||||
|
|
||||
|
- img_path: |
||||
|
- the input image path |
||||
|
- supported types: `str` |
||||
|
|
||||
|
**Returns:** |
||||
|
|
||||
|
The Operator returns a tuple `Tuple[('embedding', numpy.ndarray)]` containing following fields: |
||||
|
|
||||
|
- feature_vector: |
||||
|
- the embedding of the image |
||||
|
- data type: `numpy.ndarray` |
||||
|
|
||||
|
## Requirements |
||||
|
|
||||
|
You can get the required python package by [requirements.txt](./requirements.txt). |
||||
|
|
||||
|
## How it works |
||||
|
|
||||
|
The `towhee/vit-embedding` Operator implements the function of image embedding, which can add to the pipeline. For example, it's the key Operator named embedding_model within [vit-embedding](https://hub.towhee.io/towhee/vit-embedding) pipeline. |
||||
|
|
||||
|
## Reference |
||||
|
|
||||
|
[1].https://arxiv.org/abs/2010.11929 |
||||
|
@ -0,0 +1,21 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
import os |
||||
|
|
||||
|
# For requirements. |
||||
|
try: |
||||
|
import timm |
||||
|
except ModuleNotFoundError: |
||||
|
os.system('pip install timm') |
@ -0,0 +1,13 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
@ -0,0 +1,39 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
|
||||
|
import torch |
||||
|
import timm |
||||
|
|
||||
|
|
||||
|
class Model(): |
||||
|
""" |
||||
|
PyTorch model class |
||||
|
""" |
||||
|
def __init__(self, model_name: str, weights_path: str): |
||||
|
super().__init__() |
||||
|
if weights_path: |
||||
|
self._model = timm.create_model(model_name, checkpoint_path=weights_path, num_classes=0) |
||||
|
else: |
||||
|
self._model = timm.create_model(model_name, pretrained=True, num_classes=0) |
||||
|
self._model.eval() |
||||
|
|
||||
|
def __call__(self, img_tensor: torch.Tensor): |
||||
|
return self._model(img_tensor) |
||||
|
|
||||
|
def train(self): |
||||
|
""" |
||||
|
For training model |
||||
|
""" |
||||
|
pass |
@ -0,0 +1,52 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
|
||||
|
import sys |
||||
|
from typing import NamedTuple |
||||
|
from pathlib import Path |
||||
|
from PIL import Image |
||||
|
import torch |
||||
|
from timm.data import resolve_data_config |
||||
|
from timm.data.transforms_factory import create_transform |
||||
|
|
||||
|
|
||||
|
from towhee.operator import Operator |
||||
|
|
||||
|
|
||||
|
class VisionTransformerEmbeddingOperator(Operator): |
||||
|
""" |
||||
|
Embedding extractor using ViT. |
||||
|
Args: |
||||
|
model_name (`string`): |
||||
|
Model name. |
||||
|
weights_path (`string`): |
||||
|
Path to local weights. |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, model_name: str = 'vit_large_patch16_224', |
||||
|
framework: str = 'pytorch', weights_path: str = None) -> None: |
||||
|
super().__init__() |
||||
|
sys.path.append(str(Path(__file__).parent)) |
||||
|
if framework == 'pytorch': |
||||
|
from vit_embedding.pytorch.model import Model |
||||
|
self.model = Model(model_name, weights_path) |
||||
|
config = resolve_data_config({}, model=self.model._model) |
||||
|
self.tfms = create_transform(**config) |
||||
|
|
||||
|
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('embedding', torch.Tensor)]): |
||||
|
Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)]) |
||||
|
img = self.tfms(Image.open(img_path)).unsqueeze(0) |
||||
|
features = self.model(img) |
||||
|
return Outputs(features.flatten().detach().numpy()) |
@ -0,0 +1,13 @@ |
|||||
|
name: 'vit-embedding' |
||||
|
labels: |
||||
|
recommended_framework: pytorch1.2.0 |
||||
|
class: vit-embedding |
||||
|
others: vit |
||||
|
operator: 'towhee/vit-embedding' |
||||
|
init: |
||||
|
model_name: str |
||||
|
call: |
||||
|
input: |
||||
|
img_path: str |
||||
|
output: |
||||
|
feature_vector: numpy.ndarray |
Loading…
Reference in new issue