logo
Browse Source

matching embedding test output with towhee op output

training
zhang chen 3 years ago
parent
commit
7929342d7a
  1. 4
      pytorch/model.py
  2. 14
      resnet_image_embedding.py

4
pytorch/model.py

@ -28,6 +28,10 @@ class Model():
super().__init__()
model_func = getattr(torchvision.models, model_name)
self._model = model_func(pretrained=True)
state_dict = torch.hub.load_state_dict_from_url(
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth')
self._model.load_state_dict(state_dict)
self._model.fc = torch.nn.Identity()
self._model.eval()

14
resnet_image_embedding.py

@ -15,10 +15,12 @@ import sys
import numpy
import torch
import torchvision
from PIL import Image
from torchvision import transforms
from pathlib import Path
from typing import NamedTuple
from torchvision.transforms import InterpolationMode
from towhee.operator import Operator
@ -32,8 +34,14 @@ class ResnetImageEmbedding(Operator):
if framework == 'pytorch':
from pytorch.model import Model
self.model = Model(model_name)
self.tfms = transforms.Compose([transforms.Resize(235, interpolation=InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
embedding = self.model(img_tensor)
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
img = self.tfms(Image.open(img_path)).unsqueeze(0)
embedding = self.model(img)
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
return Outputs(embedding)

Loading…
Cancel
Save