logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

149 lines
6.3 KiB

3 years ago
import numpy as np
import torch
3 years ago
from torch.optim import AdamW
from torch.utils import data
3 years ago
from torchvision import transforms
from torchvision.transforms import RandomResizedCrop, Lambda
#from towhee.trainer.modelcard import ModelCard
3 years ago
from towhee.trainer.training_config import TrainingConfig
from towhee.trainer.trainer import Trainer
from retinaface_face_detection import RetinafaceFaceDetection
3 years ago
from towhee.types import Image
from towhee.trainer.training_config import dump_default_yaml
from towhee.trainer.utils.trainer_utils import send_to_device
3 years ago
from PIL import Image as PILImage
from timm.models.resnet import ResNet
from torch import nn
from pytorch.wider_face import WiderFaceDetection, detection_collate
from pytorch.multibox_loss import MultiBoxLoss
from pytorch.data_augment import preproc
from pytorch.prior_box import PriorBox
3 years ago
if __name__ == '__main__':
op = RetinafaceFaceDetection()
3 years ago
training_config = TrainingConfig()
yaml_path = 'retinaface_training_yaml.yaml' # 'resnet_training_yaml.yaml'
3 years ago
training_config.load_from_yaml(yaml_path)
3 years ago
mnist_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224),
Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean=[0.5], std=[0.5])])
training_dataset = './data/widerface/train/label.txt'
img_dim = 840
rgb_mean = (104, 117, 123)
train_data = WiderFaceDetection( training_dataset,preproc(img_dim, rgb_mean))
eval_data = WiderFaceDetection( training_dataset)
cfg = {}
cfg['min_sizes'] = [[16, 32], [64, 128], [256, 512]]
cfg['steps'] = [8, 16, 32]
cfg['clip'] = False
cfg['loc_weight'] = 2.0
priorbox = PriorBox(cfg, image_size=(img_dim, img_dim))
with torch.no_grad():
priors = priorbox.forward()
class RetinafaceTrainer(Trainer):
def __init__(
self,
model: nn.Module = None,
training_config = None,
train_dataset = None,
eval_dataset = None,
model_card = None,
train_dataloader = None,
eval_dataloader = None,
priors = priors
):
super(RetinafaceTrainer, self).__init__(model, training_config, train_dataset, eval_dataset, model_card, train_dataloader, eval_dataloader)
self.priors = priors
def compute_loss(self, model: nn.Module, inputs):
model.train()
labels = inputs[1]
outputs = model(inputs[0])
priors = send_to_device(self.priors, self.configs.device)
loss_l, loss_c, loss_landm = self.loss(outputs, priors, labels)
loss = cfg['loc_weight'] * loss_l + loss_c + loss_landm
return loss
def compute_metric(self, model, inputs):
model.eval()
model.phase = "eval"
epoch_metric = None
labels = inputs[1]
predicts = []
gts = []
img = inputs[0]
outputs = model.inference(img[0])
npreds = outputs[0].shape[0]
ngts = labels[0].shape[0]
predicts.append(dict(boxes=outputs[0][:,:4], scores=outputs[0][:,4], labels=torch.IntTensor([0 for i in range(npreds)]).to(trainer.configs.device)))
if isinstance(labels, list):
labels = torch.vstack(labels)
gts.append(dict(boxes=labels.reshape(-1,15)[:,:4], labels=torch.IntTensor([0 for i in range(ngts)]).to(trainer.configs.device)))
if self.metric is not None:
self.metric.update(predicts, gts)
epoch_metric = self.metric.compute()['map'].item()
return epoch_metric
@torch.no_grad()
def evaluate_step(self, model, inputs):
loss_metric, epoch_metric = self.update_metrics(model, inputs, 0, training=False)
step_logs = {"eval_step_loss": 0, "eval_epoch_loss": loss_metric,
"eval_epoch_metric": epoch_metric}
return step_logs
def get_train_dataloader(self):
if self.configs.n_gpu > 1:
self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
train_batch_sampler = torch.utils.data.BatchSampler(
self.train_sampler, 32, drop_last=True)
training_loader = torch.utils.data.DataLoader(train_data,
batch_sampler=train_batch_sampler,
num_workers=4, # self.configs.dataloader_num_workers,
pin_memory=True,
collate_fn=detection_collate
)
else:
training_loader = data.DataLoader(train_data, 32, shuffle=True, num_workers=4, collate_fn=detection_collate)
return training_loader
def get_eval_dataloader(self):
if self.configs.n_gpu > 1:
eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_data)
eval_batch_sampler = torch.utils.data.BatchSampler(
eval_sampler, 1, drop_last=True)
eval_loader = torch.utils.data.DataLoader(eval_data,
batch_sampler=eval_batch_sampler,
num_workers=1, # self.configs.dataloader_num_workers,
pin_memory=True,
collate_fn=detection_collate
)
else:
eval_loader = torch.utils.data.DataLoader(eval_data, 1,
num_workers=1, # self.configs.dataloader_num_workers,
)
return eval_loader
trainer = RetinafaceTrainer(op.get_model(), training_config, train_data, eval_data, None, None, None, priors)
op.trainer = trainer
criterion = MultiBoxLoss(2, 0.35, True, 0, True, 7, 0.35, False, True, trainer.configs.device)
trainer.set_loss(criterion, "multibox_loss")
op.train()