logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

66 lines
2.5 KiB

import numpy as np
from torchvision import transforms
from torchvision.transforms import RandomResizedCrop, Lambda
from towhee.trainer.modelcard import ModelCard
from towhee.trainer.training_config import TrainingConfig
from towhee.trainer.dataset import get_dataset
from resnet_image_embedding import ResnetImageEmbedding
from towhee.types import Image
from towhee.trainer.training_config import dump_default_yaml
from PIL import Image as PILImage
from timm.models.resnet import ResNet
if __name__ == '__main__':
# img = torch.rand([1, 3, 224, 224])
img_path = './ILSVRC2012_val_00049771.JPEG'
# # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png')
img = PILImage.open(img_path)
img_bytes = img.tobytes()
img_width = img.width
img_height = img.height
img_channel = len(img.split())
img_mode = img.mode
img_array = np.array(img)
array_size = np.array(img).shape
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array)
op = ResnetImageEmbedding('resnet34')
op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data")
# old_out = op(towhee_img)
# print(old_out.feature_vector[0])
training_config = TrainingConfig()
yaml_path = 'resnet_training_yaml.yaml'
# dump_default_yaml(yaml_path=yaml_path)
training_config.load_from_yaml(yaml_path)
# output_dir='./temp_output',
# overwrite_output_dir=True,
# epoch_num=2,
# per_gpu_train_batch_size=16,
# prediction_loss_only=True,
# metric='Accuracy'
# # device_str='cuda',
# # n_gpu=4
# )
mnist_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224),
Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean=[0.5], std=[0.5])])
train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data')
fake_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224),])
# train_data = get_dataset('fake', size=20, transform=fake_transform)
op.change_before_train(10)
op.train(training_config, train_dataset=train_data)
# e.save('./test_save')
# e.load('./test_save')
# new_out = e(img)
# assert (new_out[0]!=old_out[0]).all()