logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

99 lines
4.0 KiB

import numpy as np
from torch.optim import AdamW
from torchvision import transforms
from torchvision.transforms import RandomResizedCrop, Lambda
from towhee.data.dataset.dataset import dataset
from towhee.trainer.modelcard import ModelCard
from towhee.trainer.training_config import TrainingConfig
# from towhee.trainer.dataset import get_dataset
from towhee.trainer.utils.layer_freezer import LayerFreezer
from resnet_image_embedding import ResnetImageEmbedding
from towhee.types import Image
from towhee.trainer.training_config import dump_default_yaml
from PIL import Image as PILImage
from timm.models.resnet import ResNet
from torch import nn
if __name__ == '__main__':
dump_default_yaml(yaml_path='default_config.yaml')
# img = torch.rand([1, 3, 224, 224])
img_path = './ILSVRC2012_val_00049771.JPEG'
# logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png')
img = PILImage.open(img_path)
img_bytes = img.tobytes()
img_width = img.width
img_height = img.height
img_channel = len(img.split())
img_mode = img.mode
img_array = np.array(img)
array_size = np.array(img).shape
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array)
op = ResnetImageEmbedding('resnet50', num_classes=10)
old_out = op(towhee_img)
# print(old_out.feature_vector[0][:10])
# print(old_out.feature_vector[:10])
# print(old_out.feature_vector.shape)
training_config = TrainingConfig()
yaml_path = 'resnet_training_yaml.yaml'
# dump_default_yaml(yaml_path=yaml_path)
training_config.load_from_yaml(yaml_path)
# training_config.overwrite_output_dir=True
# training_config.epoch_num=3
# training_config.batch_size=256
# training_config.device_str='cpu'
# training_config.n_gpu=-1
# training_config.save_to_yaml(yaml_path)
# mnist_transform = transforms.Compose([transforms.ToTensor(),
# RandomResizedCrop(224),
# Lambda(lambda x: x.repeat(3, 1, 1)),
# transforms.Normalize(mean=[0.5], std=[0.5])])
# train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)
# eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)
# training_config.output_dir = 'mnist_output'
# op.model_card = ModelCard(datasets='mnist dataset')
fake_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224)])
train_data = dataset('fake', size=20, transform=fake_transform)
eval_data = dataset('fake', size=10, transform=fake_transform)
training_config.output_dir = 'fake_output'
op.model_card = ModelCard(datasets='fake dataset')
# trainer = op.setup_trainer()
# print(op.get_model())
# my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False)
# op.setup_trainer()
#
# trainer.add_callback()
# trainer.set_optimizer()
#
# op.trainer.set_optimizer(my_optimimzer)
# trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml')
#
# my_loss = nn.BCELoss()
# trainer.set_loss(my_loss, 'my_loss111')
# trainer.configs.save_to_yaml('chaned_loss_yaml.yaml')
# op.trainer._create_optimizer()
# op.trainer.set_optimizer()
# trainer = op.setup_trainer(training_config, train_dataset=train_data, eval_dataset=eval_data)
freezer = LayerFreezer(op.get_model())
freezer.set_slice(-1)
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)
# op.train(training_config, train_dataset=train_data, eval_dataset=eval_data,
# resume_checkpoint_path=training_config.output_dir + '/epoch_2')
# op.save('./test_save')
# op.load('./test_save')\
new_out = op(towhee_img)
assert (new_out.feature_vector == old_out.feature_vector).all()