towhee
/
resnet-image-embedding
copied
11 changed files with 205 additions and 2 deletions
@ -0,0 +1,35 @@ |
|||
# .gitattributes |
|||
|
|||
# Source files |
|||
# ============ |
|||
*.pxd text diff=python |
|||
*.py text diff=python |
|||
*.py3 text diff=python |
|||
*.pyw text diff=python |
|||
*.pyx text diff=python |
|||
*.pyz text diff=python |
|||
*.pyi text diff=python |
|||
|
|||
# Binary files |
|||
# ============ |
|||
*.db binary |
|||
*.p binary |
|||
*.pkl binary |
|||
*.pickle binary |
|||
*.pyc binary export-ignore |
|||
*.pyo binary export-ignore |
|||
*.pyd binary |
|||
|
|||
# Jupyter notebook |
|||
*.ipynb text |
|||
|
|||
# Model files |
|||
*.bin.* filter=lfs diff=lfs merge=lfs -text |
|||
*.lfs.* filter=lfs diff=lfs merge=lfs -text |
|||
*.bin filter=lfs diff=lfs merge=lfs -text |
|||
*.h5 filter=lfs diff=lfs merge=lfs -text |
|||
*.tflite filter=lfs diff=lfs merge=lfs -text |
|||
*.tar.gz filter=lfs diff=lfs merge=lfs -text |
|||
*.ot filter=lfs diff=lfs merge=lfs -text |
|||
*.onnx filter=lfs diff=lfs merge=lfs -text |
|||
*.msgpack filter=lfs diff=lfs merge=lfs -text |
@ -1,3 +1,20 @@ |
|||
# resnet50-image-embedding |
|||
# Image Embedding Operator with Resnet50 |
|||
|
|||
This is another test repo |
|||
Authors: name or github-name(email) |
|||
|
|||
## Overview |
|||
|
|||
Introduce the functions of op and the model used. |
|||
|
|||
## Interface |
|||
|
|||
The interface of all the functions in op. (input & output) |
|||
|
|||
## How to use |
|||
|
|||
- Requirements from requirements.txt |
|||
- How it works in some typical pipelines and the yaml example. |
|||
|
|||
## Reference |
|||
|
|||
Model paper link. |
|||
|
@ -0,0 +1,13 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
@ -0,0 +1,13 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
@ -0,0 +1,2 @@ |
|||
torch>=1.2.0 |
|||
torchvision>=0.4.0 |
@ -0,0 +1,40 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
|
|||
import torch |
|||
import torchvision |
|||
|
|||
|
|||
class Resnet50(): |
|||
""" |
|||
PyTorch model for image embedding. |
|||
""" |
|||
def __init__(self, model_name: str): |
|||
self.model_name = model_name |
|||
|
|||
def load_model(self): |
|||
""" |
|||
For loading model |
|||
""" |
|||
model_func = getattr(torchvision.models, self.model_name) |
|||
self.model = model_func(pretrained=True) |
|||
self.model.eval() |
|||
return self.model |
|||
|
|||
def train_model(self): |
|||
""" |
|||
For training model |
|||
""" |
|||
pass |
@ -0,0 +1 @@ |
|||
numpy>=1.19.5 |
@ -0,0 +1,37 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
|
|||
from typing import NamedTuple |
|||
|
|||
import numpy |
|||
import torch |
|||
import torchvision |
|||
|
|||
from towhee.operator import Operator |
|||
from pytorch.resnet50 import Resnet50 |
|||
|
|||
|
|||
class Resnet50ImageEmbedding(Operator): |
|||
""" |
|||
PyTorch model for image embedding. |
|||
""" |
|||
def __init__(self, model_name: str) -> None: |
|||
super().__init__() |
|||
resnet50_image_embedding = Resnet50(model_name) |
|||
self._model = resnet50_image_embedding.load_model() |
|||
|
|||
def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('cnn', numpy.ndarray)]): |
|||
Outputs = NamedTuple('Outputs', [('cnn', numpy.ndarray)]) |
|||
return Outputs(self._model(img_tensor).detach().numpy()) |
@ -0,0 +1,13 @@ |
|||
name: 'resnet50-image-embedding' |
|||
labels: |
|||
recommended_framework: pytorch1.2.0 |
|||
class: image-embedding |
|||
others: resnet50 |
|||
operator: 'towhee/resnet50-image-embedding' |
|||
init: |
|||
model_name: str |
|||
call: |
|||
input: |
|||
img_tensor: torch.Tensor |
|||
output: |
|||
cnn: numpy.ndarray |
After Width: | Height: | Size: 178 KiB |
@ -0,0 +1,32 @@ |
|||
import os |
|||
import unittest |
|||
from PIL import Image |
|||
from torchvision import transforms |
|||
from resnet50_image_embedding import Resnet50ImageEmbedding |
|||
|
|||
|
|||
class TestResnet50ImageEmbedding(unittest.TestCase): |
|||
""" |
|||
Simple operator test |
|||
""" |
|||
def test_image_embedding(self): |
|||
test_img = './test_data/test.jpg' |
|||
img = Image.open(test_img) |
|||
tfms = transforms.Compose( |
|||
[ |
|||
transforms.Resize(256), |
|||
transforms.CenterCrop(224), |
|||
transforms.ToTensor(), |
|||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|||
] |
|||
) |
|||
img_tensor = tfms(img).unsqueeze(0) |
|||
model_name = 'resnet50' |
|||
dimension = 1000 |
|||
op = Resnet50ImageEmbedding(model_name) |
|||
print("The output shape of operator:", op(img_tensor)[0].shape) |
|||
self.assertEqual((1, dimension), op(img_tensor)[0].shape) |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
unittest.main() |
Loading…
Reference in new issue