diff --git a/transform_image.py b/transform_image.py index f2fd873..5d5ed8e 100644 --- a/transform_image.py +++ b/transform_image.py @@ -54,9 +54,9 @@ class TransformImage(Operator): (`torch.Tensor`) The normalized image tensor. """ - if isinstance(img, str): - img_tensor = Image.open(img) - if isinstance(img, Image.Image): - img_tensor = img.convert('RGB') + if isinstance(img_path, str): + img_tensor = Image.open(img_path) + if isinstance(img_path, Image.Image): + img_tensor = img_path.convert('RGB') Outputs = NamedTuple('Outputs', [('img_transformed', torch.Tensor)]) return Outputs(self.tfms(img_tensor).unsqueeze(0))