towhee
/
vit-image-embedding
copied
9 changed files with 433 additions and 25 deletions
@ -1,28 +1,35 @@ |
|||
# .gitattributes |
|||
|
|||
*.7z filter=lfs diff=lfs merge=lfs -text |
|||
*.arrow filter=lfs diff=lfs merge=lfs -text |
|||
*.bin filter=lfs diff=lfs merge=lfs -text |
|||
# Source files |
|||
# ============ |
|||
*.pxd text diff=python |
|||
*.py text diff=python |
|||
*.py3 text diff=python |
|||
*.pyw text diff=python |
|||
*.pyx text diff=python |
|||
*.pyz text diff=python |
|||
*.pyi text diff=python |
|||
|
|||
# Binary files |
|||
# ============ |
|||
*.db binary |
|||
*.p binary |
|||
*.pkl binary |
|||
*.pickle binary |
|||
*.pyc binary export-ignore |
|||
*.pyo binary export-ignore |
|||
*.pyd binary |
|||
|
|||
# Jupyter notebook |
|||
*.ipynb text |
|||
|
|||
# Model files |
|||
*.bin.* filter=lfs diff=lfs merge=lfs -text |
|||
*.bz2 filter=lfs diff=lfs merge=lfs -text |
|||
*.ftz filter=lfs diff=lfs merge=lfs -text |
|||
*.gz filter=lfs diff=lfs merge=lfs -text |
|||
*.h5 filter=lfs diff=lfs merge=lfs -text |
|||
*.joblib filter=lfs diff=lfs merge=lfs -text |
|||
*.lfs.* filter=lfs diff=lfs merge=lfs -text |
|||
*.model filter=lfs diff=lfs merge=lfs -text |
|||
*.msgpack filter=lfs diff=lfs merge=lfs -text |
|||
*.onnx filter=lfs diff=lfs merge=lfs -text |
|||
*.ot filter=lfs diff=lfs merge=lfs -text |
|||
*.parquet filter=lfs diff=lfs merge=lfs -text |
|||
*.pb filter=lfs diff=lfs merge=lfs -text |
|||
*.pt filter=lfs diff=lfs merge=lfs -text |
|||
*.pth filter=lfs diff=lfs merge=lfs -text |
|||
*.rar filter=lfs diff=lfs merge=lfs -text |
|||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text |
|||
*.tar.* filter=lfs diff=lfs merge=lfs -text |
|||
*.bin filter=lfs diff=lfs merge=lfs -text |
|||
*.h5 filter=lfs diff=lfs merge=lfs -text |
|||
*.tflite filter=lfs diff=lfs merge=lfs -text |
|||
*.tgz filter=lfs diff=lfs merge=lfs -text |
|||
*.xz filter=lfs diff=lfs merge=lfs -text |
|||
*.zip filter=lfs diff=lfs merge=lfs -text |
|||
*.zstandard filter=lfs diff=lfs merge=lfs -text |
|||
*tfevents* filter=lfs diff=lfs merge=lfs -text |
|||
*.tar.gz filter=lfs diff=lfs merge=lfs -text |
|||
*.ot filter=lfs diff=lfs merge=lfs -text |
|||
*.onnx filter=lfs diff=lfs merge=lfs -text |
|||
*.msgpack filter=lfs diff=lfs merge=lfs -text |
|||
|
@ -0,0 +1,209 @@ |
|||
### Linux ### |
|||
*~ |
|||
|
|||
# temporary files which can be created if a process still has a handle open of a deleted file |
|||
.fuse_hidden* |
|||
|
|||
# KDE directory preferences |
|||
.directory |
|||
|
|||
# Linux trash folder which might appear on any partition or disk |
|||
.Trash-* |
|||
|
|||
# .nfs files are created when an open file is removed but is still being accessed |
|||
.nfs* |
|||
|
|||
### OSX ### |
|||
# General |
|||
.DS_Store |
|||
.AppleDouble |
|||
.LSOverride |
|||
|
|||
# Icon must end with two \r |
|||
Icon |
|||
|
|||
|
|||
# Thumbnails |
|||
._* |
|||
|
|||
# Files that might appear in the root of a volume |
|||
.DocumentRevisions-V100 |
|||
.fseventsd |
|||
.Spotlight-V100 |
|||
.TemporaryItems |
|||
.Trashes |
|||
.VolumeIcon.icns |
|||
.com.apple.timemachine.donotpresent |
|||
|
|||
# Directories potentially created on remote AFP share |
|||
.AppleDB |
|||
.AppleDesktop |
|||
Network Trash Folder |
|||
Temporary Items |
|||
.apdisk |
|||
|
|||
### Python ### |
|||
# Byte-compiled / optimized / DLL files |
|||
__pycache__/ |
|||
*.py[cod] |
|||
*$py.class |
|||
|
|||
# C extensions |
|||
*.so |
|||
|
|||
# Distribution / packaging |
|||
.Python |
|||
build/ |
|||
develop-eggs/ |
|||
dist/ |
|||
downloads/ |
|||
eggs/ |
|||
.eggs/ |
|||
lib/ |
|||
lib64/ |
|||
parts/ |
|||
sdist/ |
|||
var/ |
|||
wheels/ |
|||
share/python-wheels/ |
|||
*.egg-info/ |
|||
.installed.cfg |
|||
*.egg |
|||
MANIFEST |
|||
|
|||
# PyInstaller |
|||
# Usually these files are written by a python script from a template |
|||
# before PyInstaller builds the exe, so as to inject date/other infos into it. |
|||
*.manifest |
|||
*.spec |
|||
|
|||
# Installer logs |
|||
pip-log.txt |
|||
pip-delete-this-directory.txt |
|||
|
|||
# Unit test / coverage reports |
|||
htmlcov/ |
|||
.tox/ |
|||
.nox/ |
|||
.coverage |
|||
.coverage.* |
|||
.cache |
|||
nosetests.xml |
|||
coverage.xml |
|||
*.cover |
|||
*.py,cover |
|||
.hypothesis/ |
|||
.pytest_cache/ |
|||
cover/ |
|||
|
|||
# Translations |
|||
*.mo |
|||
*.pot |
|||
|
|||
# Django stuff: |
|||
*.log |
|||
local_settings.py |
|||
db.sqlite3 |
|||
db.sqlite3-journal |
|||
|
|||
# Flask stuff: |
|||
instance/ |
|||
.webassets-cache |
|||
|
|||
# Scrapy stuff: |
|||
.scrapy |
|||
|
|||
# Sphinx documentation |
|||
docs/_build/ |
|||
|
|||
# PyBuilder |
|||
.pybuilder/ |
|||
target/ |
|||
|
|||
# Jupyter Notebook |
|||
.ipynb_checkpoints |
|||
|
|||
# IPython |
|||
profile_default/ |
|||
ipython_config.py |
|||
|
|||
# pyenv |
|||
# For a library or package, you might want to ignore these files since the code is |
|||
# intended to run in multiple environments; otherwise, check them in: |
|||
# .python-version |
|||
|
|||
# pipenv |
|||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. |
|||
# However, in case of collaboration, if having platform-specific dependencies or dependencies |
|||
# having no cross-platform support, pipenv may install dependencies that don't work, or not |
|||
# install all needed dependencies. |
|||
#Pipfile.lock |
|||
|
|||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow |
|||
__pypackages__/ |
|||
|
|||
# Celery stuff |
|||
celerybeat-schedule |
|||
celerybeat.pid |
|||
|
|||
# SageMath parsed files |
|||
*.sage.py |
|||
|
|||
# Environments |
|||
.env |
|||
.venv |
|||
env/ |
|||
venv/ |
|||
ENV/ |
|||
env.bak/ |
|||
venv.bak/ |
|||
|
|||
# Spyder project settings |
|||
.spyderproject |
|||
.spyproject |
|||
|
|||
# Rope project settings |
|||
.ropeproject |
|||
|
|||
# mkdocs documentation |
|||
/site |
|||
|
|||
# mypy |
|||
.mypy_cache/ |
|||
.dmypy.json |
|||
dmypy.json |
|||
|
|||
# Pyre type checker |
|||
.pyre/ |
|||
|
|||
# pytype static type analyzer |
|||
.pytype/ |
|||
|
|||
# Cython debug symbols |
|||
cython_debug/ |
|||
|
|||
### Windows ### |
|||
# Windows thumbnail cache files |
|||
Thumbs.db |
|||
Thumbs.db:encryptable |
|||
ehthumbs.db |
|||
ehthumbs_vista.db |
|||
|
|||
# Dump file |
|||
*.stackdump |
|||
|
|||
# Folder config file |
|||
[Dd]esktop.ini |
|||
|
|||
# Recycle Bin used on file shares |
|||
$RECYCLE.BIN/ |
|||
|
|||
# Windows Installer files |
|||
*.cab |
|||
*.msi |
|||
*.msix |
|||
*.msm |
|||
*.msp |
|||
|
|||
# Windows shortcuts |
|||
*.lnk |
@ -1,2 +1,56 @@ |
|||
# vit-image-embedding |
|||
# ViT Embedding Operator |
|||
|
|||
Authors: kyle he |
|||
|
|||
## Overview |
|||
|
|||
The ViT(Vision Transformer) is a model for image classification that employs a Transformer-like architecture over patches of the image. This includes the use of Multi-Head Attention, Scaled Dot-Product Attention and other architectural features seen in the Transformer architecture traditionally used for NLP[1], which is trained on [imagenet dataset](https://image-net.org/download.php). |
|||
|
|||
## Interface |
|||
|
|||
```python |
|||
__init__(self, model_name: str = 'vit_large_patch16_224', |
|||
framework: str = 'pytorch', weights_path: str = None) |
|||
``` |
|||
|
|||
**Args:** |
|||
|
|||
- model_name: |
|||
- the model name for embedding |
|||
- supported types: `str`, for example 'vit_large_patch16_224' |
|||
- framework: |
|||
- the framework of the model |
|||
- supported types: `str`, default is 'pytorch' |
|||
- weights_path: |
|||
- the weights path |
|||
- supported types: `str`, default is None, using pretrained weights |
|||
|
|||
```python |
|||
__call__(self, img_path: str) |
|||
``` |
|||
|
|||
**Args:** |
|||
|
|||
- img_path: |
|||
- the input image path |
|||
- supported types: `str` |
|||
|
|||
**Returns:** |
|||
|
|||
The Operator returns a tuple `Tuple[('embedding', numpy.ndarray)]` containing following fields: |
|||
|
|||
- feature_vector: |
|||
- the embedding of the image |
|||
- data type: `numpy.ndarray` |
|||
|
|||
## Requirements |
|||
|
|||
You can get the required python package by [requirements.txt](./requirements.txt). |
|||
|
|||
## How it works |
|||
|
|||
The `towhee/vit-embedding` Operator implements the function of image embedding, which can add to the pipeline. For example, it's the key Operator named embedding_model within [vit-embedding](https://hub.towhee.io/towhee/vit-embedding) pipeline. |
|||
|
|||
## Reference |
|||
|
|||
[1].https://arxiv.org/abs/2010.11929 |
|||
|
@ -0,0 +1,21 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
import os |
|||
|
|||
# For requirements. |
|||
try: |
|||
import timm |
|||
except ModuleNotFoundError: |
|||
os.system('pip install timm') |
@ -0,0 +1,13 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
@ -0,0 +1,39 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
|
|||
import torch |
|||
import timm |
|||
|
|||
|
|||
class Model(): |
|||
""" |
|||
PyTorch model class |
|||
""" |
|||
def __init__(self, model_name: str, weights_path: str): |
|||
super().__init__() |
|||
if weights_path: |
|||
self._model = timm.create_model(model_name, checkpoint_path=weights_path, num_classes=0) |
|||
else: |
|||
self._model = timm.create_model(model_name, pretrained=True, num_classes=0) |
|||
self._model.eval() |
|||
|
|||
def __call__(self, img_tensor: torch.Tensor): |
|||
return self._model(img_tensor) |
|||
|
|||
def train(self): |
|||
""" |
|||
For training model |
|||
""" |
|||
pass |
@ -0,0 +1,52 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
|
|||
import sys |
|||
from typing import NamedTuple |
|||
from pathlib import Path |
|||
from PIL import Image |
|||
import torch |
|||
from timm.data import resolve_data_config |
|||
from timm.data.transforms_factory import create_transform |
|||
|
|||
|
|||
from towhee.operator import Operator |
|||
|
|||
|
|||
class VisionTransformerEmbeddingOperator(Operator): |
|||
""" |
|||
Embedding extractor using ViT. |
|||
Args: |
|||
model_name (`string`): |
|||
Model name. |
|||
weights_path (`string`): |
|||
Path to local weights. |
|||
""" |
|||
|
|||
def __init__(self, model_name: str = 'vit_large_patch16_224', |
|||
framework: str = 'pytorch', weights_path: str = None) -> None: |
|||
super().__init__() |
|||
sys.path.append(str(Path(__file__).parent)) |
|||
if framework == 'pytorch': |
|||
from vit_embedding.pytorch.model import Model |
|||
self.model = Model(model_name, weights_path) |
|||
config = resolve_data_config({}, model=self.model._model) |
|||
self.tfms = create_transform(**config) |
|||
|
|||
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('embedding', torch.Tensor)]): |
|||
Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)]) |
|||
img = self.tfms(Image.open(img_path)).unsqueeze(0) |
|||
features = self.model(img) |
|||
return Outputs(features.flatten().detach().numpy()) |
@ -0,0 +1,13 @@ |
|||
name: 'vit-embedding' |
|||
labels: |
|||
recommended_framework: pytorch1.2.0 |
|||
class: vit-embedding |
|||
others: vit |
|||
operator: 'towhee/vit-embedding' |
|||
init: |
|||
model_name: str |
|||
call: |
|||
input: |
|||
img_path: str |
|||
output: |
|||
feature_vector: numpy.ndarray |
Loading…
Reference in new issue