logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

152 lines
6.3 KiB

import numpy as np
import ipdb
import torch
from torch.optim import AdamW
from torch.utils import data
from torchvision import transforms
from torchvision.transforms import RandomResizedCrop, Lambda
#from towhee.trainer.modelcard import ModelCard
from towhee.trainer.training_config import TrainingConfig
from towhee.trainer.trainer import Trainer
from retinaface_face_detection import RetinafaceFaceDetection
from towhee.types import Image
from towhee.trainer.training_config import dump_default_yaml
from towhee.trainer.utils.trainer_utils import send_to_device
from PIL import Image as PILImage
from timm.models.resnet import ResNet
from torch import nn
from pytorch.wider_face import WiderFaceDetection, detection_collate
from pytorch.multibox_loss import MultiBoxLoss
from pytorch.data_augment import preproc
from pytorch.prior_box import PriorBox
if __name__ == '__main__':
op = RetinafaceFaceDetection()
training_config = TrainingConfig()
yaml_path = 'retinaface_training_yaml.yaml' # 'resnet_training_yaml.yaml'
training_config.load_from_yaml(yaml_path)
mnist_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224),
Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean=[0.5], std=[0.5])])
training_dataset = './data/widerface/train/label.txt'
img_dim = 840
rgb_mean = (104, 117, 123)
train_data = WiderFaceDetection( training_dataset,preproc(img_dim, rgb_mean))
eval_data = WiderFaceDetection( training_dataset)
cfg = {}
cfg['min_sizes'] = [[16, 32], [64, 128], [256, 512]]
cfg['steps'] = [8, 16, 32]
cfg['clip'] = False
cfg['loc_weight'] = 2.0
priorbox = PriorBox(cfg, image_size=(img_dim, img_dim))
with torch.no_grad():
priors = priorbox.forward()
class RetinafaceTrainer(Trainer):
def __init__(
self,
model: nn.Module = None,
training_config = None,
train_dataset = None,
eval_dataset = None,
model_card = None,
train_dataloader = None,
eval_dataloader = None,
priors = priors
):
super(RetinafaceTrainer, self).__init__(model, training_config, train_dataset, eval_dataset, model_card, train_dataloader, eval_dataloader)
self.priors = priors
def compute_loss(self, model: nn.Module, inputs):
model.train()
labels = inputs[1]
outputs = model(inputs[0])
priors = send_to_device(self.priors, self.configs.device)
loss_l, loss_c, loss_landm = self.loss(outputs, priors, labels)
loss = cfg['loc_weight'] * loss_l + loss_c + loss_landm
return loss
def compute_metric(self, model, inputs):
model.eval()
model.phase = "eval"
epoch_metric = None
labels = inputs[1]
model.cpu()
predicts = []
gts = []
img = send_to_device(inputs[0], torch.device("cpu"))
outputs = model.inference(img[0])
npreds = outputs[0].shape[0]
ngts = labels[0].shape[0]
predicts.append(dict(boxes=outputs[0][:,:4], scores=outputs[0][:,4], labels=torch.IntTensor([0 for i in range(npreds)])))
gts.append(dict(boxes=torch.vstack(labels)[:,:4], labels=torch.IntTensor([0 for i in range(ngts)])))
if self.metric is not None:
self.metric.update(predicts, gts)
epoch_metric = self.metric.compute()['map'].item()
return epoch_metric
@torch.no_grad()
def evaluate_step(self, model, inputs):
loss_metric, epoch_metric = self.update_metrics(model, inputs, 0, training=False)
step_logs = {"eval_step_loss": 0, "eval_epoch_loss": loss_metric,
"eval_epoch_metric": epoch_metric}
return step_logs
def get_train_dataloader(self):
if self.configs.n_gpu > 1:
self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
train_batch_sampler = torch.utils.data.BatchSampler(
self.train_sampler, 32, drop_last=True)
training_loader = torch.utils.data.DataLoader(train_data,
batch_sampler=train_batch_sampler,
num_workers=4, # self.configs.dataloader_num_workers,
pin_memory=True,
collate_fn=detection_collate
)
else:
training_loader = data.DataLoader(train_data, 32, shuffle=True, num_workers=4, collate_fn=detection_collate)
return training_loader
#def get_eval_dataloader(self):
def get_eval_dataloader(self):
if self.configs.n_gpu > 1:
eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_data)
eval_batch_sampler = torch.utils.data.BatchSampler(
eval_sampler, 1, drop_last=True)
eval_loader = torch.utils.data.DataLoader(eval_data,
batch_sampler=eval_batch_sampler,
num_workers=1, # self.configs.dataloader_num_workers,
pin_memory=True,
collate_fn=detection_collate
)
else:
eval_loader = torch.utils.data.DataLoader(eval_data, 1,
num_workers=1, # self.configs.dataloader_num_workers,
)
return eval_loader
trainer = RetinafaceTrainer(op.get_model(), training_config, train_data, eval_data, None, None, None, priors)
op.trainer = trainer
criterion = MultiBoxLoss(2, 0.35, True, 0, True, 7, 0.35, False, True, trainer.configs.device)
trainer.set_loss(criterion, "multibox_loss")
op.train()