diff --git a/README.md b/README.md index 8016b02..347fa07 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,49 @@ -# retinaface-face-detection +# Retinaface Face Detection (Pytorch) + +Authors: wxywb + +## Overview + +This opertator detects faces in the images by using RetinaFace Detector[1]. It will returns the locations, five keypoints and the cropped face images from origin images. This repo is a adopataion from [2]. + +## Interface + +```python +__call__(self, image: 'towhee.types.Image') +``` + +**Args:** + +- image: + - the image to detect faces. + - supported types: towhee.types.Image + +**Returns:** + +The Operator returns a tupe Tuple[('boxes', numpy.ndarray), ('keypoints', numpy.ndarray), ('cropped_imgs', numpy.ndarray)])] containing following fields: + +- boxes: + - boxes of human faces. + - data type: `numpy.ndarray` + - shape: (num_faces, 4) +- keypoints: + - keypoints of human faces. + - data type: `numpy.ndarray` + - shape: (10) +- cropped_imgs: + - cropped face images. + - data type: `numpy.ndarray` + - shape: (h, w, 3) + +## Requirements + +You can get the required python package by [requirements.txt](./requirements.txt). + +## How it works +The `towhee/retinaface-face-detection` Operators implents the function of face detection. The example pipeline can be found in [face-embedding-retinaface-inceptionresnetv1](https://towhee.io/towhee/face-embedding-retinaface-inceptionresnetv1) + +## Reference + +[1]. https://arxiv.org/abs/1905.00641 +[2]. https://github.com/biubug6/Pytorch_Retinaface diff --git a/pytorch/data_augment.py b/pytorch/data_augment.py new file mode 100644 index 0000000..ead5480 --- /dev/null +++ b/pytorch/data_augment.py @@ -0,0 +1,237 @@ +import cv2 +import numpy as np +import random +from .box_utils import matrix_iof + + +def _crop(image, boxes, labels, landm, img_dim): + height, width, _ = image.shape + pad_image_flag = True + + for _ in range(250): + """ + if random.uniform(0, 1) <= 0.2: + scale = 1.0 + else: + scale = random.uniform(0.3, 1.0) + """ + PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0] + scale = random.choice(PRE_SCALES) + short_side = min(width, height) + w = int(scale * short_side) + h = w + + if width == w: + l = 0 + else: + l = random.randrange(width - w) + if height == h: + t = 0 + else: + t = random.randrange(height - h) + roi = np.array((l, t, l + w, t + h)) + + value = matrix_iof(boxes, roi[np.newaxis]) + flag = (value >= 1) + if not flag.any(): + continue + + centers = (boxes[:, :2] + boxes[:, 2:]) / 2 + mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1) + boxes_t = boxes[mask_a].copy() + labels_t = labels[mask_a].copy() + landms_t = landm[mask_a].copy() + landms_t = landms_t.reshape([-1, 5, 2]) + + if boxes_t.shape[0] == 0: + continue + + image_t = image[roi[1]:roi[3], roi[0]:roi[2]] + + boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2]) + boxes_t[:, :2] -= roi[:2] + boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:]) + boxes_t[:, 2:] -= roi[:2] + + # landm + landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2] + landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0])) + landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2]) + landms_t = landms_t.reshape([-1, 10]) + + + # make sure that the cropped image contains at least one face > 16 pixel at training image scale + b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim + b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim + mask_b = np.minimum(b_w_t, b_h_t) > 0.0 + boxes_t = boxes_t[mask_b] + labels_t = labels_t[mask_b] + landms_t = landms_t[mask_b] + + if boxes_t.shape[0] == 0: + continue + + pad_image_flag = False + + return image_t, boxes_t, labels_t, landms_t, pad_image_flag + return image, boxes, labels, landm, pad_image_flag + + +def _distort(image): + + def _convert(image, alpha=1, beta=0): + tmp = image.astype(float) * alpha + beta + tmp[tmp < 0] = 0 + tmp[tmp > 255] = 255 + image[:] = tmp + + image = image.copy() + + if random.randrange(2): + + #brightness distortion + if random.randrange(2): + _convert(image, beta=random.uniform(-32, 32)) + + #contrast distortion + if random.randrange(2): + _convert(image, alpha=random.uniform(0.5, 1.5)) + + image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + + #saturation distortion + if random.randrange(2): + _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) + + #hue distortion + if random.randrange(2): + tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) + tmp %= 180 + image[:, :, 0] = tmp + + image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) + + else: + + #brightness distortion + if random.randrange(2): + _convert(image, beta=random.uniform(-32, 32)) + + image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + + #saturation distortion + if random.randrange(2): + _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) + + #hue distortion + if random.randrange(2): + tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) + tmp %= 180 + image[:, :, 0] = tmp + + image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) + + #contrast distortion + if random.randrange(2): + _convert(image, alpha=random.uniform(0.5, 1.5)) + + return image + + +def _expand(image, boxes, fill, p): + if random.randrange(2): + return image, boxes + + height, width, depth = image.shape + + scale = random.uniform(1, p) + w = int(scale * width) + h = int(scale * height) + + left = random.randint(0, w - width) + top = random.randint(0, h - height) + + boxes_t = boxes.copy() + boxes_t[:, :2] += (left, top) + boxes_t[:, 2:] += (left, top) + expand_image = np.empty( + (h, w, depth), + dtype=image.dtype) + expand_image[:, :] = fill + expand_image[top:top + height, left:left + width] = image + image = expand_image + + return image, boxes_t + + +def _mirror(image, boxes, landms): + _, width, _ = image.shape + if random.randrange(2): + image = image[:, ::-1] + boxes = boxes.copy() + boxes[:, 0::2] = width - boxes[:, 2::-2] + + # landm + landms = landms.copy() + landms = landms.reshape([-1, 5, 2]) + landms[:, :, 0] = width - landms[:, :, 0] + tmp = landms[:, 1, :].copy() + landms[:, 1, :] = landms[:, 0, :] + landms[:, 0, :] = tmp + tmp1 = landms[:, 4, :].copy() + landms[:, 4, :] = landms[:, 3, :] + landms[:, 3, :] = tmp1 + landms = landms.reshape([-1, 10]) + + return image, boxes, landms + + +def _pad_to_square(image, rgb_mean, pad_image_flag): + if not pad_image_flag: + return image + height, width, _ = image.shape + long_side = max(width, height) + image_t = np.empty((long_side, long_side, 3), dtype=image.dtype) + image_t[:, :] = rgb_mean + image_t[0:0 + height, 0:0 + width] = image + return image_t + + +def _resize_subtract_mean(image, insize, rgb_mean): + interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4] + interp_method = interp_methods[random.randrange(5)] + image = cv2.resize(image, (insize, insize), interpolation=interp_method) + image = image.astype(np.float32) + image -= rgb_mean + return image.transpose(2, 0, 1) + + +class preproc(object): + + def __init__(self, img_dim, rgb_means): + self.img_dim = img_dim + self.rgb_means = rgb_means + + def __call__(self, image, targets): + assert targets.shape[0] > 0, "this image does not have gt" + + boxes = targets[:, :4].copy() + labels = targets[:, -1].copy() + landm = targets[:, 4:-1].copy() + + image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim) + image_t = _distort(image_t) + image_t = _pad_to_square(image_t,self.rgb_means, pad_image_flag) + image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t) + height, width, _ = image_t.shape + image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means) + boxes_t[:, 0::2] /= width + boxes_t[:, 1::2] /= height + + landm_t[:, 0::2] /= width + landm_t[:, 1::2] /= height + + labels_t = np.expand_dims(labels_t, 1) + targets_t = np.hstack((boxes_t, landm_t, labels_t)) + + return image_t, targets_t diff --git a/pytorch/main.py b/pytorch/main.py new file mode 100644 index 0000000..a94befe --- /dev/null +++ b/pytorch/main.py @@ -0,0 +1,4 @@ +import torch + +torch.load('pytorch_retinaface_mobilenet_widerface.pth') + diff --git a/pytorch/multibox_loss.py b/pytorch/multibox_loss.py index c34b1df..15e78ee 100644 --- a/pytorch/multibox_loss.py +++ b/pytorch/multibox_loss.py @@ -27,7 +27,7 @@ class MultiBoxLoss(nn.Module): See: https://arxiv.org/pdf/1512.02325.pdf for more details. """ - def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target): + def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target, use_gpu=False, device=None): super(MultiBoxLoss, self).__init__() self.num_classes = num_classes self.threshold = overlap_thresh @@ -38,7 +38,8 @@ class MultiBoxLoss(nn.Module): self.negpos_ratio = neg_pos self.neg_overlap = neg_overlap self.variance = [0.1, 0.2] - self.GPU = False + self.GPU = use_gpu + self.device = device def forward(self, predictions, priors, targets): """Multibox Loss @@ -68,9 +69,9 @@ class MultiBoxLoss(nn.Module): defaults = priors.data match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx) if self.GPU: - loc_t = loc_t.cuda() - conf_t = conf_t.cuda() - landm_t = landm_t.cuda() + loc_t = loc_t.to(self.device) + conf_t = conf_t.to(self.device) + landm_t = landm_t.to(self.device) zeros = torch.tensor(0).cuda() # landm Loss (Smooth L1) @@ -114,7 +115,8 @@ class MultiBoxLoss(nn.Module): targets_weighted = conf_t[(pos+neg).gt(0)] loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum') - # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N N = max(num_pos.data.sum().float(), 1) + # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N + N = max(num_pos.data.sum().float(), 1) loss_l /= N loss_c /= N loss_landm /= N1 diff --git a/pytorch/prior_box.py b/pytorch/prior_box.py new file mode 100644 index 0000000..80c7f85 --- /dev/null +++ b/pytorch/prior_box.py @@ -0,0 +1,34 @@ +import torch +from itertools import product as product +import numpy as np +from math import ceil + + +class PriorBox(object): + def __init__(self, cfg, image_size=None, phase='train'): + super(PriorBox, self).__init__() + self.min_sizes = cfg['min_sizes'] + self.steps = cfg['steps'] + self.clip = cfg['clip'] + self.image_size = image_size + self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps] + self.name = "s" + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] + dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output diff --git a/pytorch/pytorch_retinaface_mobilenet_widerface.pth b/pytorch/pytorch_retinaface_mobilenet_widerface.pth index 287f82f..769ae52 100644 Binary files a/pytorch/pytorch_retinaface_mobilenet_widerface.pth and b/pytorch/pytorch_retinaface_mobilenet_widerface.pth differ diff --git a/pytorch/wider_face.py b/pytorch/wider_face.py new file mode 100644 index 0000000..8c485b1 --- /dev/null +++ b/pytorch/wider_face.py @@ -0,0 +1,102 @@ +import os +import os.path +import sys +import torch +import torch.utils.data as data +import cv2 +import numpy as np + +class WiderFaceDetection(data.Dataset): + def __init__(self, txt_path, preproc=None): + self.preproc = preproc + self.imgs_path = [] + self.words = [] + f = open(txt_path,'r') + lines = f.readlines() + isFirst = True + labels = [] + for line in lines: + line = line.rstrip() + if line.startswith('#'): + if isFirst is True: + isFirst = False + else: + labels_copy = labels.copy() + self.words.append(labels_copy) + labels.clear() + path = line[2:] + path = txt_path.replace('label.txt','images/') + path + self.imgs_path.append(path) + else: + line = line.split(' ') + label = [float(x) for x in line] + labels.append(label) + + self.words.append(labels) + + def __len__(self): + return len(self.imgs_path) + + def __getitem__(self, index): + img = cv2.imread(self.imgs_path[index]) + height, width, _ = img.shape + + labels = self.words[index] + annotations = np.zeros((0, 15)) + if len(labels) == 0: + return annotations + for idx, label in enumerate(labels): + annotation = np.zeros((1, 15)) + # bbox + annotation[0, 0] = label[0] # x1 + annotation[0, 1] = label[1] # y1 + annotation[0, 2] = label[0] + label[2] # x2 + annotation[0, 3] = label[1] + label[3] # y2 + + # landmarks + annotation[0, 4] = label[4] # l0_x + annotation[0, 5] = label[5] # l0_y + annotation[0, 6] = label[7] # l1_x + annotation[0, 7] = label[8] # l1_y + annotation[0, 8] = label[10] # l2_x + annotation[0, 9] = label[11] # l2_y + annotation[0, 10] = label[13] # l3_x + annotation[0, 11] = label[14] # l3_y + annotation[0, 12] = label[16] # l4_x + annotation[0, 13] = label[17] # l4_y + if (annotation[0, 4]<0): + annotation[0, 14] = -1 + else: + annotation[0, 14] = 1 + + annotations = np.append(annotations, annotation, axis=0) + target = np.array(annotations) + if self.preproc is not None: + img, target = self.preproc(img, target) + + return torch.from_numpy(img), target + +def detection_collate(batch): + """Custom collate fn for dealing with batches of images that have a different + number of associated object annotations (bounding boxes). + + Arguments: + batch: (tuple) A tuple of tensor images and lists of annotations + + Return: + A tuple containing: + 1) (tensor) batch of images stacked on their 0 dim + 2) (list of tensors) annotations for a given image are stacked on 0 dim + """ + targets = [] + imgs = [] + for _, sample in enumerate(batch): + for _, tup in enumerate(sample): + if torch.is_tensor(tup): + imgs.append(tup) + elif isinstance(tup, type(np.empty(0))): + annos = torch.from_numpy(tup).float() + targets.append(annos) + + return (torch.stack(imgs, 0), targets) + diff --git a/retinaface_face_detection.py b/retinaface_face_detection.py index 97630a4..21bf298 100644 --- a/retinaface_face_detection.py +++ b/retinaface_face_detection.py @@ -15,6 +15,7 @@ from typing import NamedTuple, List from PIL import Image import torch +from torch import nn from torchvision import transforms import sys import towhee @@ -23,11 +24,12 @@ import numpy from towhee.operator import Operator from towhee.utils.pil_utils import to_pil +from towhee.operator import NNOperator from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform import os -class RetinafaceFaceDetection(Operator): +class RetinafaceFaceDetection(NNOperator): """ Embedding extractor using efficientnet. Args: @@ -71,3 +73,6 @@ class RetinafaceFaceDetection(Operator): output = Outputs(bboxes[i], keypoints[i,:], croppeds[i]) outputs.append(output) return outputs + + def get_model(self) -> nn.Module: + return self.model._model diff --git a/retinaface_training_yaml.yaml b/retinaface_training_yaml.yaml index 34129b5..b0d9189 100644 --- a/retinaface_training_yaml.yaml +++ b/retinaface_training_yaml.yaml @@ -1,13 +1,14 @@ device: - device_str: null - n_gpu: -1 + device_str: cuda + n_gpu: 2 sync_bn: true metrics: - metric: Accuracy + metric: MeanAveragePrecision train: batch_size: 32 overwrite_output_dir: true epoch_num: 2 + eval_strategy: eval_epoch learning: optimizer: name_: SGD diff --git a/train.py b/train.py index 23d52ca..31d6df5 100644 --- a/train.py +++ b/train.py @@ -1,84 +1,152 @@ import numpy as np +import ipdb +import torch from torch.optim import AdamW +from torch.utils import data from torchvision import transforms from torchvision.transforms import RandomResizedCrop, Lambda -from towhee.trainer.modelcard import ModelCard +#from towhee.trainer.modelcard import ModelCard from towhee.trainer.training_config import TrainingConfig -from towhee.trainer.dataset import get_dataset -from resnet_image_embedding import ResnetImageEmbedding +from towhee.trainer.trainer import Trainer +from retinaface_face_detection import RetinafaceFaceDetection from towhee.types import Image from towhee.trainer.training_config import dump_default_yaml +from towhee.trainer.utils.trainer_utils import send_to_device from PIL import Image as PILImage from timm.models.resnet import ResNet from torch import nn +from pytorch.wider_face import WiderFaceDetection, detection_collate +from pytorch.multibox_loss import MultiBoxLoss +from pytorch.data_augment import preproc +from pytorch.prior_box import PriorBox + if __name__ == '__main__': - dump_default_yaml(yaml_path='default_config.yaml') - # img = torch.rand([1, 3, 224, 224]) - img_path = './ILSVRC2012_val_00049771.JPEG' - # # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png') - img = PILImage.open(img_path) - img_bytes = img.tobytes() - img_width = img.width - img_height = img.height - img_channel = len(img.split()) - img_mode = img.mode - img_array = np.array(img) - array_size = np.array(img).shape - towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) - - op = ResnetImageEmbedding('resnet34') - # op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") - # old_out = op(towhee_img) - # print(old_out.feature_vector[0]) + op = RetinafaceFaceDetection() training_config = TrainingConfig() - yaml_path = 'resnet_training_yaml.yaml' - # dump_default_yaml(yaml_path=yaml_path) - training_config.load_from_yaml(yaml_path) - # output_dir='./temp_output', - # overwrite_output_dir=True, - # epoch_num=2, - # per_gpu_train_batch_size=16, - # prediction_loss_only=True, - # metric='Accuracy' - # # device_str='cuda', - # # n_gpu=4 - # ) + yaml_path = 'retinaface_training_yaml.yaml' # 'resnet_training_yaml.yaml' + training_config.load_from_yaml(yaml_path) + mnist_transform = transforms.Compose([transforms.ToTensor(), RandomResizedCrop(224), Lambda(lambda x: x.repeat(3, 1, 1)), transforms.Normalize(mean=[0.5], std=[0.5])]) - train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=True) - eval_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=False) - # fake_transform = transforms.Compose([transforms.ToTensor(), - # RandomResizedCrop(224),]) - # train_data = get_dataset('fake', size=20, transform=fake_transform) - - op.change_before_train(10) - trainer = op.setup_trainer() - # my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False) - # op.setup_trainer() - - # trainer.add_callback() - # trainer.set_optimizer() - - # op.trainer.set_optimizer(my_optimimzer) - # trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml') - - # my_loss = nn.BCELoss() - # trainer.set_loss(my_loss, 'my_loss111') - # trainer.configs.save_to_yaml('chaned_loss_yaml.yaml') - # op.trainer._create_optimizer() - # op.trainer.set_optimizer() - op.train(training_config, train_dataset=train_data, eval_dataset=eval_data) - # training_config.num_epoch = 3 - # op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') - - # op.save('./test_save') - # op.load('./test_save') - # new_out = op(towhee_img) - - # assert (new_out[0]!=old_out[0]).all() + + training_dataset = './data/widerface/train/label.txt' + + img_dim = 840 + rgb_mean = (104, 117, 123) + train_data = WiderFaceDetection( training_dataset,preproc(img_dim, rgb_mean)) + eval_data = WiderFaceDetection( training_dataset) + + cfg = {} + cfg['min_sizes'] = [[16, 32], [64, 128], [256, 512]] + cfg['steps'] = [8, 16, 32] + cfg['clip'] = False + cfg['loc_weight'] = 2.0 + + priorbox = PriorBox(cfg, image_size=(img_dim, img_dim)) + + with torch.no_grad(): + priors = priorbox.forward() + + class RetinafaceTrainer(Trainer): + def __init__( + self, + model: nn.Module = None, + training_config = None, + train_dataset = None, + eval_dataset = None, + model_card = None, + train_dataloader = None, + eval_dataloader = None, + priors = priors + ): + super(RetinafaceTrainer, self).__init__(model, training_config, train_dataset, eval_dataset, model_card, train_dataloader, eval_dataloader) + self.priors = priors + + def compute_loss(self, model: nn.Module, inputs): + model.train() + labels = inputs[1] + outputs = model(inputs[0]) + priors = send_to_device(self.priors, self.configs.device) + loss_l, loss_c, loss_landm = self.loss(outputs, priors, labels) + loss = cfg['loc_weight'] * loss_l + loss_c + loss_landm + return loss + + def compute_metric(self, model, inputs): + model.eval() + model.phase = "eval" + epoch_metric = None + labels = inputs[1] + model.cpu() + predicts = [] + gts = [] + + img = send_to_device(inputs[0], torch.device("cpu")) + outputs = model.inference(img[0]) + + npreds = outputs[0].shape[0] + ngts = labels[0].shape[0] + predicts.append(dict(boxes=outputs[0][:,:4], scores=outputs[0][:,4], labels=torch.IntTensor([0 for i in range(npreds)]))) + gts.append(dict(boxes=torch.vstack(labels)[:,:4], labels=torch.IntTensor([0 for i in range(ngts)]))) + + if self.metric is not None: + self.metric.update(predicts, gts) + epoch_metric = self.metric.compute()['map'].item() + return epoch_metric + + @torch.no_grad() + def evaluate_step(self, model, inputs): + loss_metric, epoch_metric = self.update_metrics(model, inputs, 0, training=False) + step_logs = {"eval_step_loss": 0, "eval_epoch_loss": loss_metric, + "eval_epoch_metric": epoch_metric} + return step_logs + + def get_train_dataloader(self): + if self.configs.n_gpu > 1: + self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) + train_batch_sampler = torch.utils.data.BatchSampler( + self.train_sampler, 32, drop_last=True) + training_loader = torch.utils.data.DataLoader(train_data, + batch_sampler=train_batch_sampler, + num_workers=4, # self.configs.dataloader_num_workers, + pin_memory=True, + collate_fn=detection_collate + ) + else: + training_loader = data.DataLoader(train_data, 32, shuffle=True, num_workers=4, collate_fn=detection_collate) + return training_loader + + #def get_eval_dataloader(self): + def get_eval_dataloader(self): + if self.configs.n_gpu > 1: + eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_data) + eval_batch_sampler = torch.utils.data.BatchSampler( + eval_sampler, 1, drop_last=True) + eval_loader = torch.utils.data.DataLoader(eval_data, + batch_sampler=eval_batch_sampler, + num_workers=1, # self.configs.dataloader_num_workers, + pin_memory=True, + collate_fn=detection_collate + ) + else: + eval_loader = torch.utils.data.DataLoader(eval_data, 1, + num_workers=1, # self.configs.dataloader_num_workers, + ) + return eval_loader + + + + trainer = RetinafaceTrainer(op.get_model(), training_config, train_data, eval_data, None, None, None, priors) + + op.trainer = trainer + + criterion = MultiBoxLoss(2, 0.35, True, 0, True, 7, 0.35, False, True, trainer.configs.device) + + trainer.set_loss(criterion, "multibox_loss") + op.train() +