diff --git a/requirements.txt b/requirements.txt index e69de29..eda3109 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1,4 @@ +torch==1.9.0 +torchvision==0.10.0 +pillow==8.3.1 +numpy==1.19.5 diff --git a/resnet_image_embedding.py b/resnet_image_embedding.py index 98cb1a0..d0c91c8 100644 --- a/resnet_image_embedding.py +++ b/resnet_image_embedding.py @@ -46,7 +46,7 @@ class ResnetImageEmbedding(Operator): def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): - img = self.tfms(Image.open(img_path)).unsqueeze(0) + img = self.tfms(Image.open(img_path).convert('RGB')).unsqueeze(0) embedding = self.model(img) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) return Outputs(embedding)