towhee
/
resnet-image-embedding
copied
11 changed files with 205 additions and 2 deletions
@ -0,0 +1,35 @@ |
|||||
|
# .gitattributes |
||||
|
|
||||
|
# Source files |
||||
|
# ============ |
||||
|
*.pxd text diff=python |
||||
|
*.py text diff=python |
||||
|
*.py3 text diff=python |
||||
|
*.pyw text diff=python |
||||
|
*.pyx text diff=python |
||||
|
*.pyz text diff=python |
||||
|
*.pyi text diff=python |
||||
|
|
||||
|
# Binary files |
||||
|
# ============ |
||||
|
*.db binary |
||||
|
*.p binary |
||||
|
*.pkl binary |
||||
|
*.pickle binary |
||||
|
*.pyc binary export-ignore |
||||
|
*.pyo binary export-ignore |
||||
|
*.pyd binary |
||||
|
|
||||
|
# Jupyter notebook |
||||
|
*.ipynb text |
||||
|
|
||||
|
# Model files |
||||
|
*.bin.* filter=lfs diff=lfs merge=lfs -text |
||||
|
*.lfs.* filter=lfs diff=lfs merge=lfs -text |
||||
|
*.bin filter=lfs diff=lfs merge=lfs -text |
||||
|
*.h5 filter=lfs diff=lfs merge=lfs -text |
||||
|
*.tflite filter=lfs diff=lfs merge=lfs -text |
||||
|
*.tar.gz filter=lfs diff=lfs merge=lfs -text |
||||
|
*.ot filter=lfs diff=lfs merge=lfs -text |
||||
|
*.onnx filter=lfs diff=lfs merge=lfs -text |
||||
|
*.msgpack filter=lfs diff=lfs merge=lfs -text |
@ -1,3 +1,20 @@ |
|||||
# resnet50-image-embedding |
|
||||
|
# Image Embedding Operator with Resnet50 |
||||
|
|
||||
This is another test repo |
|
||||
|
Authors: name or github-name(email) |
||||
|
|
||||
|
## Overview |
||||
|
|
||||
|
Introduce the functions of op and the model used. |
||||
|
|
||||
|
## Interface |
||||
|
|
||||
|
The interface of all the functions in op. (input & output) |
||||
|
|
||||
|
## How to use |
||||
|
|
||||
|
- Requirements from requirements.txt |
||||
|
- How it works in some typical pipelines and the yaml example. |
||||
|
|
||||
|
## Reference |
||||
|
|
||||
|
Model paper link. |
||||
|
@ -0,0 +1,13 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
@ -0,0 +1,13 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
@ -0,0 +1,2 @@ |
|||||
|
torch>=1.2.0 |
||||
|
torchvision>=0.4.0 |
@ -0,0 +1,40 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
|
||||
|
import torch |
||||
|
import torchvision |
||||
|
|
||||
|
|
||||
|
class Resnet50(): |
||||
|
""" |
||||
|
PyTorch model for image embedding. |
||||
|
""" |
||||
|
def __init__(self, model_name: str): |
||||
|
self.model_name = model_name |
||||
|
|
||||
|
def load_model(self): |
||||
|
""" |
||||
|
For loading model |
||||
|
""" |
||||
|
model_func = getattr(torchvision.models, self.model_name) |
||||
|
self.model = model_func(pretrained=True) |
||||
|
self.model.eval() |
||||
|
return self.model |
||||
|
|
||||
|
def train_model(self): |
||||
|
""" |
||||
|
For training model |
||||
|
""" |
||||
|
pass |
@ -0,0 +1 @@ |
|||||
|
numpy>=1.19.5 |
@ -0,0 +1,37 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
|
||||
|
from typing import NamedTuple |
||||
|
|
||||
|
import numpy |
||||
|
import torch |
||||
|
import torchvision |
||||
|
|
||||
|
from towhee.operator import Operator |
||||
|
from pytorch.resnet50 import Resnet50 |
||||
|
|
||||
|
|
||||
|
class Resnet50ImageEmbedding(Operator): |
||||
|
""" |
||||
|
PyTorch model for image embedding. |
||||
|
""" |
||||
|
def __init__(self, model_name: str) -> None: |
||||
|
super().__init__() |
||||
|
resnet50_image_embedding = Resnet50(model_name) |
||||
|
self._model = resnet50_image_embedding.load_model() |
||||
|
|
||||
|
def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('cnn', numpy.ndarray)]): |
||||
|
Outputs = NamedTuple('Outputs', [('cnn', numpy.ndarray)]) |
||||
|
return Outputs(self._model(img_tensor).detach().numpy()) |
@ -0,0 +1,13 @@ |
|||||
|
name: 'resnet50-image-embedding' |
||||
|
labels: |
||||
|
recommended_framework: pytorch1.2.0 |
||||
|
class: image-embedding |
||||
|
others: resnet50 |
||||
|
operator: 'towhee/resnet50-image-embedding' |
||||
|
init: |
||||
|
model_name: str |
||||
|
call: |
||||
|
input: |
||||
|
img_tensor: torch.Tensor |
||||
|
output: |
||||
|
cnn: numpy.ndarray |
After Width: | Height: | Size: 178 KiB |
@ -0,0 +1,32 @@ |
|||||
|
import os |
||||
|
import unittest |
||||
|
from PIL import Image |
||||
|
from torchvision import transforms |
||||
|
from resnet50_image_embedding import Resnet50ImageEmbedding |
||||
|
|
||||
|
|
||||
|
class TestResnet50ImageEmbedding(unittest.TestCase): |
||||
|
""" |
||||
|
Simple operator test |
||||
|
""" |
||||
|
def test_image_embedding(self): |
||||
|
test_img = './test_data/test.jpg' |
||||
|
img = Image.open(test_img) |
||||
|
tfms = transforms.Compose( |
||||
|
[ |
||||
|
transforms.Resize(256), |
||||
|
transforms.CenterCrop(224), |
||||
|
transforms.ToTensor(), |
||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
||||
|
] |
||||
|
) |
||||
|
img_tensor = tfms(img).unsqueeze(0) |
||||
|
model_name = 'resnet50' |
||||
|
dimension = 1000 |
||||
|
op = Resnet50ImageEmbedding(model_name) |
||||
|
print("The output shape of operator:", op(img_tensor)[0].shape) |
||||
|
self.assertEqual((1, dimension), op(img_tensor)[0].shape) |
||||
|
|
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
unittest.main() |
Loading…
Reference in new issue