import numpy as np import torch from torch.optim import AdamW from torch.utils import data from torchvision import transforms from torchvision.transforms import RandomResizedCrop, Lambda #from towhee.trainer.modelcard import ModelCard from towhee.trainer.training_config import TrainingConfig from towhee.trainer.trainer import Trainer from retinaface_face_detection import RetinafaceFaceDetection from towhee.types import Image from towhee.trainer.training_config import dump_default_yaml from towhee.trainer.utils.trainer_utils import send_to_device from PIL import Image as PILImage from timm.models.resnet import ResNet from torch import nn from pytorch.wider_face import WiderFaceDetection, detection_collate from pytorch.multibox_loss import MultiBoxLoss from pytorch.data_augment import preproc from pytorch.prior_box import PriorBox if __name__ == '__main__': op = RetinafaceFaceDetection() training_config = TrainingConfig() yaml_path = 'retinaface_training_yaml.yaml' # 'resnet_training_yaml.yaml' training_config.load_from_yaml(yaml_path) mnist_transform = transforms.Compose([transforms.ToTensor(), RandomResizedCrop(224), Lambda(lambda x: x.repeat(3, 1, 1)), transforms.Normalize(mean=[0.5], std=[0.5])]) training_dataset = './data/widerface/train/label.txt' img_dim = 840 rgb_mean = (104, 117, 123) train_data = WiderFaceDetection( training_dataset,preproc(img_dim, rgb_mean)) eval_data = WiderFaceDetection( training_dataset) cfg = {} cfg['min_sizes'] = [[16, 32], [64, 128], [256, 512]] cfg['steps'] = [8, 16, 32] cfg['clip'] = False cfg['loc_weight'] = 2.0 priorbox = PriorBox(cfg, image_size=(img_dim, img_dim)) with torch.no_grad(): priors = priorbox.forward() class RetinafaceTrainer(Trainer): def __init__( self, model: nn.Module = None, training_config = None, train_dataset = None, eval_dataset = None, model_card = None, train_dataloader = None, eval_dataloader = None, priors = priors ): super(RetinafaceTrainer, self).__init__(model, training_config, train_dataset, eval_dataset, model_card, train_dataloader, eval_dataloader) self.priors = priors def compute_loss(self, model: nn.Module, inputs): model.train() labels = inputs[1] outputs = model(inputs[0]) priors = send_to_device(self.priors, self.configs.device) loss_l, loss_c, loss_landm = self.loss(outputs, priors, labels) loss = cfg['loc_weight'] * loss_l + loss_c + loss_landm return loss def compute_metric(self, model, inputs): model.eval() model.phase = "eval" epoch_metric = None labels = inputs[1] predicts = [] gts = [] img = inputs[0] outputs = model.inference(img[0]) npreds = outputs[0].shape[0] ngts = labels[0].shape[0] predicts.append(dict(boxes=outputs[0][:,:4], scores=outputs[0][:,4], labels=torch.IntTensor([0 for i in range(npreds)]).to(trainer.configs.device))) if isinstance(labels, list): labels = torch.vstack(labels) gts.append(dict(boxes=labels.reshape(-1,15)[:,:4], labels=torch.IntTensor([0 for i in range(ngts)]).to(trainer.configs.device))) if self.metric is not None: self.metric.update(predicts, gts) epoch_metric = self.metric.compute()['map'].item() return epoch_metric @torch.no_grad() def evaluate_step(self, model, inputs): loss_metric, epoch_metric = self.update_metrics(model, inputs, 0, training=False) step_logs = {"eval_step_loss": 0, "eval_epoch_loss": loss_metric, "eval_epoch_metric": epoch_metric} return step_logs def get_train_dataloader(self): if self.configs.n_gpu > 1: self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) train_batch_sampler = torch.utils.data.BatchSampler( self.train_sampler, 32, drop_last=True) training_loader = torch.utils.data.DataLoader(train_data, batch_sampler=train_batch_sampler, num_workers=4, # self.configs.dataloader_num_workers, pin_memory=True, collate_fn=detection_collate ) else: training_loader = data.DataLoader(train_data, 32, shuffle=True, num_workers=4, collate_fn=detection_collate) return training_loader def get_eval_dataloader(self): if self.configs.n_gpu > 1: eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_data) eval_batch_sampler = torch.utils.data.BatchSampler( eval_sampler, 1, drop_last=True) eval_loader = torch.utils.data.DataLoader(eval_data, batch_sampler=eval_batch_sampler, num_workers=1, # self.configs.dataloader_num_workers, pin_memory=True, collate_fn=detection_collate ) else: eval_loader = torch.utils.data.DataLoader(eval_data, 1, num_workers=1, # self.configs.dataloader_num_workers, ) return eval_loader trainer = RetinafaceTrainer(op.get_model(), training_config, train_data, eval_data, None, None, None, priors) op.trainer = trainer criterion = MultiBoxLoss(2, 0.35, True, 0, True, 7, 0.35, False, True, trainer.configs.device) trainer.set_loss(criterion, "multibox_loss") op.train()