diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..ab2832a --- /dev/null +++ b/.gitattributes @@ -0,0 +1,35 @@ +# .gitattributes + +# Source files +# ============ +*.pxd text diff=python +*.py text diff=python +*.py3 text diff=python +*.pyw text diff=python +*.pyx text diff=python +*.pyz text diff=python +*.pyi text diff=python + +# Binary files +# ============ +*.db binary +*.p binary +*.pkl binary +*.pickle binary +*.pyc binary export-ignore +*.pyo binary export-ignore +*.pyd binary + +# Jupyter notebook +*.ipynb text + +# Model files +*.bin.* filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tar.gz filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text \ No newline at end of file diff --git a/README.md b/README.md index 54bf8c4..4ce0c37 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,20 @@ -# resnet50-image-embedding +# Image Embedding Operator with Resnet50 -This is another test repo \ No newline at end of file +Authors: name or github-name(email) + +## Overview + +Introduce the functions of op and the model used. + +## Interface + +The interface of all the functions in op. (input & output) + +## How to use + +- Requirements from requirements.txt +- How it works in some typical pipelines and the yaml example. + +## Reference + +Model paper link. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..37f5bd7 --- /dev/null +++ b/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/pytorch/__init__.py b/pytorch/__init__.py new file mode 100644 index 0000000..37f5bd7 --- /dev/null +++ b/pytorch/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/pytorch/requirements.txt b/pytorch/requirements.txt new file mode 100644 index 0000000..ac3ffc1 --- /dev/null +++ b/pytorch/requirements.txt @@ -0,0 +1,2 @@ +torch>=1.2.0 +torchvision>=0.4.0 diff --git a/pytorch/resnet50.py b/pytorch/resnet50.py new file mode 100644 index 0000000..7044957 --- /dev/null +++ b/pytorch/resnet50.py @@ -0,0 +1,40 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torchvision + + +class Resnet50(): + """ + PyTorch model for image embedding. + """ + def __init__(self, model_name: str): + self.model_name = model_name + + def load_model(self): + """ + For loading model + """ + model_func = getattr(torchvision.models, self.model_name) + self.model = model_func(pretrained=True) + self.model.eval() + return self.model + + def train_model(self): + """ + For training model + """ + pass diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b0e64dd --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +numpy>=1.19.5 diff --git a/resnet50_image_embedding.py b/resnet50_image_embedding.py new file mode 100644 index 0000000..ac484c0 --- /dev/null +++ b/resnet50_image_embedding.py @@ -0,0 +1,37 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import NamedTuple + +import numpy +import torch +import torchvision + +from towhee.operator import Operator +from pytorch.resnet50 import Resnet50 + + +class Resnet50ImageEmbedding(Operator): + """ + PyTorch model for image embedding. + """ + def __init__(self, model_name: str) -> None: + super().__init__() + resnet50_image_embedding = Resnet50(model_name) + self._model = resnet50_image_embedding.load_model() + + def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('cnn', numpy.ndarray)]): + Outputs = NamedTuple('Outputs', [('cnn', numpy.ndarray)]) + return Outputs(self._model(img_tensor).detach().numpy()) diff --git a/resnet50_image_embedding.yaml b/resnet50_image_embedding.yaml new file mode 100644 index 0000000..7d1a2ba --- /dev/null +++ b/resnet50_image_embedding.yaml @@ -0,0 +1,13 @@ +name: 'resnet50-image-embedding' +labels: + recommended_framework: pytorch1.2.0 + class: image-embedding + others: resnet50 +operator: 'towhee/resnet50-image-embedding' +init: + model_name: str +call: + input: + img_tensor: torch.Tensor + output: + cnn: numpy.ndarray diff --git a/test_data/test.jpg b/test_data/test.jpg new file mode 100755 index 0000000..8fdc2b3 Binary files /dev/null and b/test_data/test.jpg differ diff --git a/test_resnet50_image_embedding.py b/test_resnet50_image_embedding.py new file mode 100644 index 0000000..ea6bc08 --- /dev/null +++ b/test_resnet50_image_embedding.py @@ -0,0 +1,32 @@ +import os +import unittest +from PIL import Image +from torchvision import transforms +from resnet50_image_embedding import Resnet50ImageEmbedding + + +class TestResnet50ImageEmbedding(unittest.TestCase): + """ + Simple operator test + """ + def test_image_embedding(self): + test_img = './test_data/test.jpg' + img = Image.open(test_img) + tfms = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + img_tensor = tfms(img).unsqueeze(0) + model_name = 'resnet50' + dimension = 1000 + op = Resnet50ImageEmbedding(model_name) + print("The output shape of operator:", op(img_tensor)[0].shape) + self.assertEqual((1, dimension), op(img_tensor)[0].shape) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file