From 7929342d7ade6dfe0fb82429e9a1d6eb232d855d Mon Sep 17 00:00:00 2001 From: zhang chen Date: Sat, 18 Dec 2021 15:25:41 +0800 Subject: [PATCH] matching embedding test output with towhee op output --- pytorch/model.py | 4 ++++ resnet_image_embedding.py | 14 +++++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pytorch/model.py b/pytorch/model.py index 4a5fbc8..125ec3a 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -28,6 +28,10 @@ class Model(): super().__init__() model_func = getattr(torchvision.models, model_name) self._model = model_func(pretrained=True) + state_dict = torch.hub.load_state_dict_from_url( + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth') + self._model.load_state_dict(state_dict) + self._model.fc = torch.nn.Identity() self._model.eval() diff --git a/resnet_image_embedding.py b/resnet_image_embedding.py index d10f224..b0161b9 100644 --- a/resnet_image_embedding.py +++ b/resnet_image_embedding.py @@ -15,10 +15,12 @@ import sys import numpy import torch import torchvision - +from PIL import Image +from torchvision import transforms from pathlib import Path from typing import NamedTuple +from torchvision.transforms import InterpolationMode from towhee.operator import Operator @@ -32,8 +34,14 @@ class ResnetImageEmbedding(Operator): if framework == 'pytorch': from pytorch.model import Model self.model = Model(model_name) + self.tfms = transforms.Compose([transforms.Resize(235, interpolation=InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + - def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): - embedding = self.model(img_tensor) + def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): + img = self.tfms(Image.open(img_path)).unsqueeze(0) + embedding = self.model(img) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) return Outputs(embedding)