|
|
@ -1,4 +1,5 @@ |
|
|
|
import numpy as np |
|
|
|
from torch.optim import AdamW |
|
|
|
from torchvision import transforms |
|
|
|
from torchvision.transforms import RandomResizedCrop, Lambda |
|
|
|
from towhee.trainer.modelcard import ModelCard |
|
|
@ -10,8 +11,10 @@ from towhee.types import Image |
|
|
|
from towhee.trainer.training_config import dump_default_yaml |
|
|
|
from PIL import Image as PILImage |
|
|
|
from timm.models.resnet import ResNet |
|
|
|
from torch import nn |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
dump_default_yaml(yaml_path='default_config.yaml') |
|
|
|
# img = torch.rand([1, 3, 224, 224]) |
|
|
|
img_path = './ILSVRC2012_val_00049771.JPEG' |
|
|
|
# # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png') |
|
|
@ -28,7 +31,6 @@ if __name__ == '__main__': |
|
|
|
op = ResnetImageEmbedding('resnet34') |
|
|
|
op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") |
|
|
|
# old_out = op(towhee_img) |
|
|
|
|
|
|
|
# print(old_out.feature_vector[0]) |
|
|
|
|
|
|
|
training_config = TrainingConfig() |
|
|
@ -45,7 +47,6 @@ if __name__ == '__main__': |
|
|
|
# # n_gpu=4 |
|
|
|
# ) |
|
|
|
|
|
|
|
|
|
|
|
mnist_transform = transforms.Compose([transforms.ToTensor(), |
|
|
|
RandomResizedCrop(224), |
|
|
|
Lambda(lambda x: x.repeat(3, 1, 1)), |
|
|
@ -57,6 +58,23 @@ if __name__ == '__main__': |
|
|
|
# train_data = get_dataset('fake', size=20, transform=fake_transform) |
|
|
|
|
|
|
|
op.change_before_train(10) |
|
|
|
trainer = op.setup_trainer() |
|
|
|
|
|
|
|
|
|
|
|
my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False) |
|
|
|
# op.setup_trainer() |
|
|
|
|
|
|
|
# trainer.add_callback() |
|
|
|
# trainer.set_optimizer() |
|
|
|
|
|
|
|
op.trainer.set_optimizer(my_optimimzer) |
|
|
|
trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml') |
|
|
|
|
|
|
|
# my_loss = nn.BCELoss() |
|
|
|
# trainer.set_loss(my_loss, 'my_loss111') |
|
|
|
# trainer.configs.save_to_yaml('chaned_loss_yaml.yaml') |
|
|
|
# op.trainer._create_optimizer() |
|
|
|
# op.trainer.set_optimizer() |
|
|
|
op.train(training_config, train_dataset=train_data) |
|
|
|
training_config.num_epoch = 3 |
|
|
|
op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') |
|
|
|