diff --git a/pytorch/model.py b/pytorch/model.py index 4a5fbc8..125ec3a 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -28,6 +28,10 @@ class Model(): super().__init__() model_func = getattr(torchvision.models, model_name) self._model = model_func(pretrained=True) + state_dict = torch.hub.load_state_dict_from_url( + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth') + self._model.load_state_dict(state_dict) + self._model.fc = torch.nn.Identity() self._model.eval() diff --git a/resnet_image_embedding.py b/resnet_image_embedding.py index d10f224..b0161b9 100644 --- a/resnet_image_embedding.py +++ b/resnet_image_embedding.py @@ -15,10 +15,12 @@ import sys import numpy import torch import torchvision - +from PIL import Image +from torchvision import transforms from pathlib import Path from typing import NamedTuple +from torchvision.transforms import InterpolationMode from towhee.operator import Operator @@ -32,8 +34,14 @@ class ResnetImageEmbedding(Operator): if framework == 'pytorch': from pytorch.model import Model self.model = Model(model_name) + self.tfms = transforms.Compose([transforms.Resize(235, interpolation=InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + - def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): - embedding = self.model(img_tensor) + def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): + img = self.tfms(Image.open(img_path)).unsqueeze(0) + embedding = self.model(img) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) return Outputs(embedding)