logo
Browse Source

matching embedding test output with towhee op output

training
zhang chen 3 years ago
parent
commit
7929342d7a
  1. 4
      pytorch/model.py
  2. 14
      resnet_image_embedding.py

4
pytorch/model.py

@ -28,6 +28,10 @@ class Model():
super().__init__() super().__init__()
model_func = getattr(torchvision.models, model_name) model_func = getattr(torchvision.models, model_name)
self._model = model_func(pretrained=True) self._model = model_func(pretrained=True)
state_dict = torch.hub.load_state_dict_from_url(
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth')
self._model.load_state_dict(state_dict)
self._model.fc = torch.nn.Identity() self._model.fc = torch.nn.Identity()
self._model.eval() self._model.eval()

14
resnet_image_embedding.py

@ -15,10 +15,12 @@ import sys
import numpy import numpy
import torch import torch
import torchvision import torchvision
from PIL import Image
from torchvision import transforms
from pathlib import Path from pathlib import Path
from typing import NamedTuple from typing import NamedTuple
from torchvision.transforms import InterpolationMode
from towhee.operator import Operator from towhee.operator import Operator
@ -32,8 +34,14 @@ class ResnetImageEmbedding(Operator):
if framework == 'pytorch': if framework == 'pytorch':
from pytorch.model import Model from pytorch.model import Model
self.model = Model(model_name) self.model = Model(model_name)
self.tfms = transforms.Compose([transforms.Resize(235, interpolation=InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
embedding = self.model(img_tensor)
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
img = self.tfms(Image.open(img_path)).unsqueeze(0)
embedding = self.model(img)
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
return Outputs(embedding) return Outputs(embedding)

Loading…
Cancel
Save