import numpy as np
from torch.optim import AdamW
from torchvision import transforms
from torchvision.transforms import RandomResizedCrop, Lambda
from towhee.data.dataset.dataset import dataset
from towhee.trainer.modelcard import ModelCard

from towhee.trainer.training_config import TrainingConfig
# from towhee.trainer.dataset import get_dataset
from towhee.trainer.utils.layer_freezer import LayerFreezer

from resnet_image_embedding import ResnetImageEmbedding
from towhee.types import Image
from towhee.trainer.training_config import dump_default_yaml
from PIL import Image as PILImage
from timm.models.resnet import ResNet
from torch import nn

if __name__ == '__main__':
    dump_default_yaml(yaml_path='default_config.yaml')
    # img = torch.rand([1, 3, 224, 224])
    img_path = './ILSVRC2012_val_00049771.JPEG'
    # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png')
    img = PILImage.open(img_path)
    img_bytes = img.tobytes()
    img_width = img.width
    img_height = img.height
    img_channel = len(img.split())
    img_mode = img.mode
    img_array = np.array(img)
    array_size = np.array(img).shape
    towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array)

    op = ResnetImageEmbedding('resnet50', num_classes=10)
    # op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data")
    old_out = op(towhee_img)
    # print(old_out.feature_vector[0][:10])
    print(old_out.feature_vector[:10])
    # print(old_out.feature_vector.shape)

    training_config = TrainingConfig()
    yaml_path = 'resnet_training_yaml.yaml'
    # dump_default_yaml(yaml_path=yaml_path)
    training_config.load_from_yaml(yaml_path)
    #     output_dir='./temp_output',
    #     overwrite_output_dir=True,
    #     epoch_num=2,
    #     per_gpu_train_batch_size=16,
    #     prediction_loss_only=True,
    #     metric='Accuracy'
    #     # device_str='cuda',
    #     # n_gpu=4
    # )
    #
    mnist_transform = transforms.Compose([transforms.ToTensor(),
                                          RandomResizedCrop(224),
                                          Lambda(lambda x: x.repeat(3, 1, 1)),
                                          transforms.Normalize(mean=[0.5], std=[0.5])])
    train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)
    eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)
    # fake_transform = transforms.Compose([transforms.ToTensor(),
    #                                       RandomResizedCrop(224),])
    # train_data = get_dataset('fake', size=20, transform=fake_transform)
    #
    # op.change_before_train(num_classes=10)
    # # trainer = op.setup_trainer()
    # print(op.get_model())
    # # my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False)
    # # op.setup_trainer()
    #
    # # trainer.add_callback()
    # # trainer.set_optimizer()
    #
    # # op.trainer.set_optimizer(my_optimimzer)
    # # trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml')
    #
    # # my_loss = nn.BCELoss()
    # # trainer.set_loss(my_loss, 'my_loss111')
    # # trainer.configs.save_to_yaml('chaned_loss_yaml.yaml')
    # # op.trainer._create_optimizer()
    # # op.trainer.set_optimizer()
    # # trainer = op.setup_trainer(training_config, train_dataset=train_data, eval_dataset=eval_data)
    #
    # # freezer = LayerFreezer(op.get_model())
    # # freezer.by_idx([-1])
    op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)
    # # op.trainer.run_train()
    # # training_config.num_epoch = 3
    # # op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2')
    #
    # # op.save('./test_save')
    # # op.load('./test_save')
    # # new_out = op(towhee_img)
    #
    # # assert (new_out[0]!=old_out[0]).all()