diff --git a/resnet_image_embedding.py b/resnet_image_embedding.py index ba18e2d..f11d958 100644 --- a/resnet_image_embedding.py +++ b/resnet_image_embedding.py @@ -45,7 +45,7 @@ class ResnetImageEmbedding(Operator): def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): - img = self.tfms(Image.open(img_path)).unsqueeze(0) + img = self.tfms(Image.open(img_path).convert('RGB')).unsqueeze(0) embedding = self.model(img) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) return Outputs(embedding)