import numpy as np from torchvision import transforms from torchvision.transforms import RandomResizedCrop, Lambda from towhee.trainer.modelcard import ModelCard from towhee.trainer.training_config import TrainingConfig from towhee.trainer.dataset import get_dataset from resnet_image_embedding import ResnetImageEmbedding from towhee.types import Image from towhee.trainer.training_config import dump_default_yaml from PIL import Image as PILImage from timm.models.resnet import ResNet if __name__ == '__main__': # img = torch.rand([1, 3, 224, 224]) img_path = './ILSVRC2012_val_00049771.JPEG' # # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png') img = PILImage.open(img_path) img_bytes = img.tobytes() img_width = img.width img_height = img.height img_channel = len(img.split()) img_mode = img.mode img_array = np.array(img) array_size = np.array(img).shape towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) op = ResnetImageEmbedding('resnet34') op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") # old_out = op(towhee_img) # print(old_out.feature_vector[0]) training_config = TrainingConfig() yaml_path = 'resnet_training_yaml.yaml' # dump_default_yaml(yaml_path=yaml_path) training_config.load_from_yaml(yaml_path) # output_dir='./temp_output', # overwrite_output_dir=True, # epoch_num=2, # per_gpu_train_batch_size=16, # prediction_loss_only=True, # metric='Accuracy' # # device_str='cuda', # # n_gpu=4 # ) mnist_transform = transforms.Compose([transforms.ToTensor(), RandomResizedCrop(224), Lambda(lambda x: x.repeat(3, 1, 1)), transforms.Normalize(mean=[0.5], std=[0.5])]) train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data') # fake_transform = transforms.Compose([transforms.ToTensor(), # RandomResizedCrop(224),]) # train_data = get_dataset('fake', size=20, transform=fake_transform) op.change_before_train(10) op.train(training_config, train_dataset=train_data) training_config.num_epoch = 3 op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') # op.save('./test_save') # op.load('./test_save') # new_out = op(towhee_img) # assert (new_out[0]!=old_out[0]).all()