|
|
|
from torchvision import transforms
|
|
|
|
from transform_image_operator_template import TransformImageOperatorTemplate
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
def get_transformed_img(size, img_path):
|
|
|
|
img = Image.open(img_path)
|
|
|
|
tfms = transforms.Compose(
|
|
|
|
[
|
|
|
|
transforms.Resize(size),
|
|
|
|
transforms.CenterCrop(224),
|
|
|
|
transforms.ToTensor(),
|
|
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
|
|
]
|
|
|
|
)
|
|
|
|
img_transformed = tfms(img).unsqueeze(0)
|
|
|
|
return img_transformed
|
|
|
|
|
|
|
|
def test_transform_image(size=256, img_path='./test_data/test.jpg'):
|
|
|
|
op = TransformImageOperatorTemplate(size)
|
|
|
|
outputs = op(img_path)
|
|
|
|
img_transformed = get_transformed_img(size, img_path)
|
|
|
|
c = (img_transformed.numpy() == outputs.img_transformed.numpy())
|
|
|
|
assert c.all()
|