logo
Browse Source

Update the test script

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 3 years ago
parent
commit
0339ed484f
  1. 1
      config.py
  2. 26
      test_transform_image_operator_template.py

1
config.py

@ -1 +0,0 @@
TEST_IMG = './test_data/test.jpg'

26
test_transform_image_operator_template.py

@ -1,28 +1,22 @@
import unittest
from PIL import Image
from torchvision import transforms
from transform_image_operator_template import TransformImageOperatorTemplate
from config import TEST_IMG, SIZE
class TestTransformImageOperatorTemplate(unittest.TestCase):
def get_transformed_img(size, img_path)
tfms = transforms.Compose(
[
transforms.Resize(SIZE),
transforms.Resize(size),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
img1 = tfms(test_img).unsqueeze(0)
img_transformed = tfms(img_path).unsqueeze(0)
return img_transformed
def test_transform_image(self):
op = TransformImageOperatorTemplate(SIZE)
outputs = op(TEST_IMG)
print("The output tyep of operator:", type(outputs.img_transformed))
c = (self.img1.numpy() == outputs.img_transformed.numpy())
self.assertEqual(c.all(), True)
if __name__ == '__main__':
unittest.main()
def test_transform_image(size=256, img_path='./test_data/test.jpg'):
op = TransformImageOperatorTemplate(size
outputs = op(img_path)
img_transformed = get_transformed_img(size, img_path)
c = (img_transformed.numpy() == outputs.img_transformed.numpy())
assert c.all()

Loading…
Cancel
Save