|
|
@ -15,10 +15,12 @@ import sys |
|
|
|
import numpy |
|
|
|
import torch |
|
|
|
import torchvision |
|
|
|
|
|
|
|
from PIL import Image |
|
|
|
from torchvision import transforms |
|
|
|
from pathlib import Path |
|
|
|
from typing import NamedTuple |
|
|
|
|
|
|
|
from torchvision.transforms import InterpolationMode |
|
|
|
from towhee.operator import Operator |
|
|
|
|
|
|
|
|
|
|
@ -32,8 +34,14 @@ class ResnetImageEmbedding(Operator): |
|
|
|
if framework == 'pytorch': |
|
|
|
from pytorch.model import Model |
|
|
|
self.model = Model(model_name) |
|
|
|
self.tfms = transforms.Compose([transforms.Resize(235, interpolation=InterpolationMode.BICUBIC), |
|
|
|
transforms.CenterCrop(224), |
|
|
|
transforms.ToTensor(), |
|
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) |
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, img_tensor: torch.Tensor) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): |
|
|
|
embedding = self.model(img_tensor) |
|
|
|
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): |
|
|
|
img = self.tfms(Image.open(img_path)).unsqueeze(0) |
|
|
|
embedding = self.model(img) |
|
|
|
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) |
|
|
|
return Outputs(embedding) |
|
|
|