towhee
/
retinaface-face-detection
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
148 lines
6.3 KiB
148 lines
6.3 KiB
import numpy as np
|
|
import torch
|
|
from torch.optim import AdamW
|
|
from torch.utils import data
|
|
from torchvision import transforms
|
|
from torchvision.transforms import RandomResizedCrop, Lambda
|
|
#from towhee.trainer.modelcard import ModelCard
|
|
|
|
from towhee.trainer.training_config import TrainingConfig
|
|
from towhee.trainer.trainer import Trainer
|
|
from retinaface_face_detection import RetinafaceFaceDetection
|
|
from towhee.types import Image
|
|
from towhee.trainer.training_config import dump_default_yaml
|
|
from towhee.trainer.utils.trainer_utils import send_to_device
|
|
from PIL import Image as PILImage
|
|
from timm.models.resnet import ResNet
|
|
from torch import nn
|
|
from pytorch.wider_face import WiderFaceDetection, detection_collate
|
|
from pytorch.multibox_loss import MultiBoxLoss
|
|
from pytorch.data_augment import preproc
|
|
from pytorch.prior_box import PriorBox
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
op = RetinafaceFaceDetection()
|
|
training_config = TrainingConfig()
|
|
yaml_path = 'retinaface_training_yaml.yaml' # 'resnet_training_yaml.yaml'
|
|
|
|
training_config.load_from_yaml(yaml_path)
|
|
|
|
mnist_transform = transforms.Compose([transforms.ToTensor(),
|
|
RandomResizedCrop(224),
|
|
Lambda(lambda x: x.repeat(3, 1, 1)),
|
|
transforms.Normalize(mean=[0.5], std=[0.5])])
|
|
|
|
training_dataset = './data/widerface/train/label.txt'
|
|
|
|
img_dim = 840
|
|
rgb_mean = (104, 117, 123)
|
|
train_data = WiderFaceDetection( training_dataset,preproc(img_dim, rgb_mean))
|
|
eval_data = WiderFaceDetection( training_dataset)
|
|
|
|
cfg = {}
|
|
cfg['min_sizes'] = [[16, 32], [64, 128], [256, 512]]
|
|
cfg['steps'] = [8, 16, 32]
|
|
cfg['clip'] = False
|
|
cfg['loc_weight'] = 2.0
|
|
|
|
priorbox = PriorBox(cfg, image_size=(img_dim, img_dim))
|
|
|
|
with torch.no_grad():
|
|
priors = priorbox.forward()
|
|
|
|
class RetinafaceTrainer(Trainer):
|
|
def __init__(
|
|
self,
|
|
model: nn.Module = None,
|
|
training_config = None,
|
|
train_dataset = None,
|
|
eval_dataset = None,
|
|
model_card = None,
|
|
train_dataloader = None,
|
|
eval_dataloader = None,
|
|
priors = priors
|
|
):
|
|
super(RetinafaceTrainer, self).__init__(model, training_config, train_dataset, eval_dataset, model_card, train_dataloader, eval_dataloader)
|
|
self.priors = priors
|
|
|
|
def compute_loss(self, model: nn.Module, inputs):
|
|
model.train()
|
|
labels = inputs[1]
|
|
outputs = model(inputs[0])
|
|
priors = send_to_device(self.priors, self.configs.device)
|
|
loss_l, loss_c, loss_landm = self.loss(outputs, priors, labels)
|
|
loss = cfg['loc_weight'] * loss_l + loss_c + loss_landm
|
|
return loss
|
|
|
|
def compute_metric(self, model, inputs):
|
|
model.eval()
|
|
model.phase = "eval"
|
|
epoch_metric = None
|
|
labels = inputs[1]
|
|
predicts = []
|
|
gts = []
|
|
|
|
img = inputs[0]
|
|
outputs = model.inference(img[0])
|
|
|
|
npreds = outputs[0].shape[0]
|
|
ngts = labels[0].shape[0]
|
|
predicts.append(dict(boxes=outputs[0][:,:4], scores=outputs[0][:,4], labels=torch.IntTensor([0 for i in range(npreds)]).to(trainer.configs.device)))
|
|
if isinstance(labels, list):
|
|
labels = torch.vstack(labels)
|
|
gts.append(dict(boxes=labels.reshape(-1,15)[:,:4], labels=torch.IntTensor([0 for i in range(ngts)]).to(trainer.configs.device)))
|
|
if self.metric is not None:
|
|
self.metric.update(predicts, gts)
|
|
epoch_metric = self.metric.compute()['map'].item()
|
|
return epoch_metric
|
|
|
|
@torch.no_grad()
|
|
def evaluate_step(self, model, inputs):
|
|
loss_metric, epoch_metric = self.update_metrics(model, inputs, 0, training=False)
|
|
step_logs = {"eval_step_loss": 0, "eval_epoch_loss": loss_metric,
|
|
"eval_epoch_metric": epoch_metric}
|
|
return step_logs
|
|
|
|
def get_train_dataloader(self):
|
|
if self.configs.n_gpu > 1:
|
|
self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
|
|
train_batch_sampler = torch.utils.data.BatchSampler(
|
|
self.train_sampler, 32, drop_last=True)
|
|
training_loader = torch.utils.data.DataLoader(train_data,
|
|
batch_sampler=train_batch_sampler,
|
|
num_workers=4, # self.configs.dataloader_num_workers,
|
|
pin_memory=True,
|
|
collate_fn=detection_collate
|
|
)
|
|
else:
|
|
training_loader = data.DataLoader(train_data, 32, shuffle=True, num_workers=4, collate_fn=detection_collate)
|
|
return training_loader
|
|
|
|
def get_eval_dataloader(self):
|
|
if self.configs.n_gpu > 1:
|
|
eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_data)
|
|
eval_batch_sampler = torch.utils.data.BatchSampler(
|
|
eval_sampler, 1, drop_last=True)
|
|
eval_loader = torch.utils.data.DataLoader(eval_data,
|
|
batch_sampler=eval_batch_sampler,
|
|
num_workers=1, # self.configs.dataloader_num_workers,
|
|
pin_memory=True,
|
|
collate_fn=detection_collate
|
|
)
|
|
else:
|
|
eval_loader = torch.utils.data.DataLoader(eval_data, 1,
|
|
num_workers=1, # self.configs.dataloader_num_workers,
|
|
)
|
|
return eval_loader
|
|
|
|
trainer = RetinafaceTrainer(op.get_model(), training_config, train_data, eval_data, None, None, None, priors)
|
|
|
|
op.trainer = trainer
|
|
|
|
criterion = MultiBoxLoss(2, 0.35, True, 0, True, 7, 0.35, False, True, trainer.configs.device)
|
|
|
|
trainer.set_loss(criterion, "multibox_loss")
|
|
op.train()
|
|
|