logo
Browse Source

Update the test script

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 3 years ago
parent
commit
df00bba97b
  1. 8
      test_transform_image_operator_template.py

8
test_transform_image_operator_template.py

@ -1,8 +1,10 @@
from torchvision import transforms from torchvision import transforms
from transform_image_operator_template import TransformImageOperatorTemplate from transform_image_operator_template import TransformImageOperatorTemplate
from PIL import Image
def get_transformed_img(size, img_path)
def get_transformed_img(size, img_path):
img = Image.open(img_path)
tfms = transforms.Compose( tfms = transforms.Compose(
[ [
transforms.Resize(size), transforms.Resize(size),
@ -11,11 +13,11 @@ def get_transformed_img(size, img_path)
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
] ]
) )
img_transformed = tfms(img_path).unsqueeze(0)
img_transformed = tfms(img).unsqueeze(0)
return img_transformed return img_transformed
def test_transform_image(size=256, img_path='./test_data/test.jpg'): def test_transform_image(size=256, img_path='./test_data/test.jpg'):
op = TransformImageOperatorTemplate(size
op = TransformImageOperatorTemplate(size)
outputs = op(img_path) outputs = op(img_path)
img_transformed = get_transformed_img(size, img_path) img_transformed = get_transformed_img(size, img_path)
c = (img_transformed.numpy() == outputs.img_transformed.numpy()) c = (img_transformed.numpy() == outputs.img_transformed.numpy())

Loading…
Cancel
Save