logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

36 lines
1.4 KiB

import os
import numpy
from pathlib import Path
from typing import NamedTuple
from torchvision import transforms
from towhee.operator import Operator
from towhee.utils.pil_utils import to_pil
from towhee.types.image import Image
import warnings
warnings.filterwarnings("ignore")
class AnimeganStyleTransfer(Operator):
"""
PyTorch model for image embedding.
"""
def __init__(self, model_name: str, framework: str = 'pytorch') -> None:
super().__init__()
if framework == 'pytorch':
import importlib.util
path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py')
opname = os.path.basename(str(Path(__file__))).split('.')[0]
spec = importlib.util.spec_from_file_location(opname, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
self.model = module.Model(model_name)
self.tfms = transforms.Compose([
transforms.ToTensor()
])
def __call__(self, image: 'towhee.types.Image') -> NamedTuple('Outputs', [('styled_image', numpy.ndarray)]):
img = self.tfms(to_pil(image)).unsqueeze(0)
styled_image = self.model(img)
Outputs = NamedTuple('Outputs', [('styled_image', numpy.ndarray)])
return Outputs(styled_image)