towhee
/
2 changed files with 10 additions and 17 deletions
@ -1 +0,0 @@ |
|||
TEST_IMG = './test_data/test.jpg' |
@ -1,28 +1,22 @@ |
|||
import unittest |
|||
from PIL import Image |
|||
from torchvision import transforms |
|||
from transform_image_operator_template import TransformImageOperatorTemplate |
|||
from config import TEST_IMG, SIZE |
|||
|
|||
|
|||
class TestTransformImageOperatorTemplate(unittest.TestCase): |
|||
def get_transformed_img(size, img_path) |
|||
tfms = transforms.Compose( |
|||
[ |
|||
transforms.Resize(SIZE), |
|||
transforms.Resize(size), |
|||
transforms.CenterCrop(224), |
|||
transforms.ToTensor(), |
|||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|||
] |
|||
) |
|||
img1 = tfms(test_img).unsqueeze(0) |
|||
img_transformed = tfms(img_path).unsqueeze(0) |
|||
return img_transformed |
|||
|
|||
def test_transform_image(self): |
|||
op = TransformImageOperatorTemplate(SIZE) |
|||
outputs = op(TEST_IMG) |
|||
print("The output tyep of operator:", type(outputs.img_transformed)) |
|||
c = (self.img1.numpy() == outputs.img_transformed.numpy()) |
|||
self.assertEqual(c.all(), True) |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
unittest.main() |
|||
def test_transform_image(size=256, img_path='./test_data/test.jpg'): |
|||
op = TransformImageOperatorTemplate(size |
|||
outputs = op(img_path) |
|||
img_transformed = get_transformed_img(size, img_path) |
|||
c = (img_transformed.numpy() == outputs.img_transformed.numpy()) |
|||
assert c.all() |
|||
|
Loading…
Reference in new issue