logo
Browse Source

refactor the training code.

training
wxywb 3 years ago
parent
commit
cabafc2339
  1. 49
      README.md
  2. 237
      pytorch/data_augment.py
  3. 4
      pytorch/main.py
  4. 14
      pytorch/multibox_loss.py
  5. 34
      pytorch/prior_box.py
  6. BIN
      pytorch/pytorch_retinaface_mobilenet_widerface.pth
  7. 102
      pytorch/wider_face.py
  8. 7
      retinaface_face_detection.py
  9. 7
      retinaface_training_yaml.yaml
  10. 194
      train.py

49
README.md

@ -1,2 +1,49 @@
# retinaface-face-detection
# Retinaface Face Detection (Pytorch)
Authors: wxywb
## Overview
This opertator detects faces in the images by using RetinaFace Detector[1]. It will returns the locations, five keypoints and the cropped face images from origin images. This repo is a adopataion from [2].
## Interface
```python
__call__(self, image: 'towhee.types.Image')
```
**Args:**
- image:
- the image to detect faces.
- supported types: towhee.types.Image
**Returns:**
The Operator returns a tupe Tuple[('boxes', numpy.ndarray), ('keypoints', numpy.ndarray), ('cropped_imgs', numpy.ndarray)])] containing following fields:
- boxes:
- boxes of human faces.
- data type: `numpy.ndarray`
- shape: (num_faces, 4)
- keypoints:
- keypoints of human faces.
- data type: `numpy.ndarray`
- shape: (10)
- cropped_imgs:
- cropped face images.
- data type: `numpy.ndarray`
- shape: (h, w, 3)
## Requirements
You can get the required python package by [requirements.txt](./requirements.txt).
## How it works
The `towhee/retinaface-face-detection` Operators implents the function of face detection. The example pipeline can be found in [face-embedding-retinaface-inceptionresnetv1](https://towhee.io/towhee/face-embedding-retinaface-inceptionresnetv1)
## Reference
[1]. https://arxiv.org/abs/1905.00641
[2]. https://github.com/biubug6/Pytorch_Retinaface

237
pytorch/data_augment.py

@ -0,0 +1,237 @@
import cv2
import numpy as np
import random
from .box_utils import matrix_iof
def _crop(image, boxes, labels, landm, img_dim):
height, width, _ = image.shape
pad_image_flag = True
for _ in range(250):
"""
if random.uniform(0, 1) <= 0.2:
scale = 1.0
else:
scale = random.uniform(0.3, 1.0)
"""
PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0]
scale = random.choice(PRE_SCALES)
short_side = min(width, height)
w = int(scale * short_side)
h = w
if width == w:
l = 0
else:
l = random.randrange(width - w)
if height == h:
t = 0
else:
t = random.randrange(height - h)
roi = np.array((l, t, l + w, t + h))
value = matrix_iof(boxes, roi[np.newaxis])
flag = (value >= 1)
if not flag.any():
continue
centers = (boxes[:, :2] + boxes[:, 2:]) / 2
mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1)
boxes_t = boxes[mask_a].copy()
labels_t = labels[mask_a].copy()
landms_t = landm[mask_a].copy()
landms_t = landms_t.reshape([-1, 5, 2])
if boxes_t.shape[0] == 0:
continue
image_t = image[roi[1]:roi[3], roi[0]:roi[2]]
boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2])
boxes_t[:, :2] -= roi[:2]
boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:])
boxes_t[:, 2:] -= roi[:2]
# landm
landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2]
landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0]))
landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2])
landms_t = landms_t.reshape([-1, 10])
# make sure that the cropped image contains at least one face > 16 pixel at training image scale
b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim
b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim
mask_b = np.minimum(b_w_t, b_h_t) > 0.0
boxes_t = boxes_t[mask_b]
labels_t = labels_t[mask_b]
landms_t = landms_t[mask_b]
if boxes_t.shape[0] == 0:
continue
pad_image_flag = False
return image_t, boxes_t, labels_t, landms_t, pad_image_flag
return image, boxes, labels, landm, pad_image_flag
def _distort(image):
def _convert(image, alpha=1, beta=0):
tmp = image.astype(float) * alpha + beta
tmp[tmp < 0] = 0
tmp[tmp > 255] = 255
image[:] = tmp
image = image.copy()
if random.randrange(2):
#brightness distortion
if random.randrange(2):
_convert(image, beta=random.uniform(-32, 32))
#contrast distortion
if random.randrange(2):
_convert(image, alpha=random.uniform(0.5, 1.5))
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
#saturation distortion
if random.randrange(2):
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
#hue distortion
if random.randrange(2):
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
tmp %= 180
image[:, :, 0] = tmp
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
else:
#brightness distortion
if random.randrange(2):
_convert(image, beta=random.uniform(-32, 32))
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
#saturation distortion
if random.randrange(2):
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
#hue distortion
if random.randrange(2):
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
tmp %= 180
image[:, :, 0] = tmp
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
#contrast distortion
if random.randrange(2):
_convert(image, alpha=random.uniform(0.5, 1.5))
return image
def _expand(image, boxes, fill, p):
if random.randrange(2):
return image, boxes
height, width, depth = image.shape
scale = random.uniform(1, p)
w = int(scale * width)
h = int(scale * height)
left = random.randint(0, w - width)
top = random.randint(0, h - height)
boxes_t = boxes.copy()
boxes_t[:, :2] += (left, top)
boxes_t[:, 2:] += (left, top)
expand_image = np.empty(
(h, w, depth),
dtype=image.dtype)
expand_image[:, :] = fill
expand_image[top:top + height, left:left + width] = image
image = expand_image
return image, boxes_t
def _mirror(image, boxes, landms):
_, width, _ = image.shape
if random.randrange(2):
image = image[:, ::-1]
boxes = boxes.copy()
boxes[:, 0::2] = width - boxes[:, 2::-2]
# landm
landms = landms.copy()
landms = landms.reshape([-1, 5, 2])
landms[:, :, 0] = width - landms[:, :, 0]
tmp = landms[:, 1, :].copy()
landms[:, 1, :] = landms[:, 0, :]
landms[:, 0, :] = tmp
tmp1 = landms[:, 4, :].copy()
landms[:, 4, :] = landms[:, 3, :]
landms[:, 3, :] = tmp1
landms = landms.reshape([-1, 10])
return image, boxes, landms
def _pad_to_square(image, rgb_mean, pad_image_flag):
if not pad_image_flag:
return image
height, width, _ = image.shape
long_side = max(width, height)
image_t = np.empty((long_side, long_side, 3), dtype=image.dtype)
image_t[:, :] = rgb_mean
image_t[0:0 + height, 0:0 + width] = image
return image_t
def _resize_subtract_mean(image, insize, rgb_mean):
interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4]
interp_method = interp_methods[random.randrange(5)]
image = cv2.resize(image, (insize, insize), interpolation=interp_method)
image = image.astype(np.float32)
image -= rgb_mean
return image.transpose(2, 0, 1)
class preproc(object):
def __init__(self, img_dim, rgb_means):
self.img_dim = img_dim
self.rgb_means = rgb_means
def __call__(self, image, targets):
assert targets.shape[0] > 0, "this image does not have gt"
boxes = targets[:, :4].copy()
labels = targets[:, -1].copy()
landm = targets[:, 4:-1].copy()
image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim)
image_t = _distort(image_t)
image_t = _pad_to_square(image_t,self.rgb_means, pad_image_flag)
image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t)
height, width, _ = image_t.shape
image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means)
boxes_t[:, 0::2] /= width
boxes_t[:, 1::2] /= height
landm_t[:, 0::2] /= width
landm_t[:, 1::2] /= height
labels_t = np.expand_dims(labels_t, 1)
targets_t = np.hstack((boxes_t, landm_t, labels_t))
return image_t, targets_t

4
pytorch/main.py

@ -0,0 +1,4 @@
import torch
torch.load('pytorch_retinaface_mobilenet_widerface.pth')

14
pytorch/multibox_loss.py

@ -27,7 +27,7 @@ class MultiBoxLoss(nn.Module):
See: https://arxiv.org/pdf/1512.02325.pdf for more details. See: https://arxiv.org/pdf/1512.02325.pdf for more details.
""" """
def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target, use_gpu=False, device=None):
super(MultiBoxLoss, self).__init__() super(MultiBoxLoss, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.threshold = overlap_thresh self.threshold = overlap_thresh
@ -38,7 +38,8 @@ class MultiBoxLoss(nn.Module):
self.negpos_ratio = neg_pos self.negpos_ratio = neg_pos
self.neg_overlap = neg_overlap self.neg_overlap = neg_overlap
self.variance = [0.1, 0.2] self.variance = [0.1, 0.2]
self.GPU = False
self.GPU = use_gpu
self.device = device
def forward(self, predictions, priors, targets): def forward(self, predictions, priors, targets):
"""Multibox Loss """Multibox Loss
@ -68,9 +69,9 @@ class MultiBoxLoss(nn.Module):
defaults = priors.data defaults = priors.data
match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx) match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx)
if self.GPU: if self.GPU:
loc_t = loc_t.cuda()
conf_t = conf_t.cuda()
landm_t = landm_t.cuda()
loc_t = loc_t.to(self.device)
conf_t = conf_t.to(self.device)
landm_t = landm_t.to(self.device)
zeros = torch.tensor(0).cuda() zeros = torch.tensor(0).cuda()
# landm Loss (Smooth L1) # landm Loss (Smooth L1)
@ -114,7 +115,8 @@ class MultiBoxLoss(nn.Module):
targets_weighted = conf_t[(pos+neg).gt(0)] targets_weighted = conf_t[(pos+neg).gt(0)]
loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum') loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N N = max(num_pos.data.sum().float(), 1)
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
N = max(num_pos.data.sum().float(), 1)
loss_l /= N loss_l /= N
loss_c /= N loss_c /= N
loss_landm /= N1 loss_landm /= N1

34
pytorch/prior_box.py

@ -0,0 +1,34 @@
import torch
from itertools import product as product
import numpy as np
from math import ceil
class PriorBox(object):
def __init__(self, cfg, image_size=None, phase='train'):
super(PriorBox, self).__init__()
self.min_sizes = cfg['min_sizes']
self.steps = cfg['steps']
self.clip = cfg['clip']
self.image_size = image_size
self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
self.name = "s"
def forward(self):
anchors = []
for k, f in enumerate(self.feature_maps):
min_sizes = self.min_sizes[k]
for i, j in product(range(f[0]), range(f[1])):
for min_size in min_sizes:
s_kx = min_size / self.image_size[1]
s_ky = min_size / self.image_size[0]
dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
for cy, cx in product(dense_cy, dense_cx):
anchors += [cx, cy, s_kx, s_ky]
# back to torch land
output = torch.Tensor(anchors).view(-1, 4)
if self.clip:
output.clamp_(max=1, min=0)
return output

BIN
pytorch/pytorch_retinaface_mobilenet_widerface.pth

Binary file not shown.

102
pytorch/wider_face.py

@ -0,0 +1,102 @@
import os
import os.path
import sys
import torch
import torch.utils.data as data
import cv2
import numpy as np
class WiderFaceDetection(data.Dataset):
def __init__(self, txt_path, preproc=None):
self.preproc = preproc
self.imgs_path = []
self.words = []
f = open(txt_path,'r')
lines = f.readlines()
isFirst = True
labels = []
for line in lines:
line = line.rstrip()
if line.startswith('#'):
if isFirst is True:
isFirst = False
else:
labels_copy = labels.copy()
self.words.append(labels_copy)
labels.clear()
path = line[2:]
path = txt_path.replace('label.txt','images/') + path
self.imgs_path.append(path)
else:
line = line.split(' ')
label = [float(x) for x in line]
labels.append(label)
self.words.append(labels)
def __len__(self):
return len(self.imgs_path)
def __getitem__(self, index):
img = cv2.imread(self.imgs_path[index])
height, width, _ = img.shape
labels = self.words[index]
annotations = np.zeros((0, 15))
if len(labels) == 0:
return annotations
for idx, label in enumerate(labels):
annotation = np.zeros((1, 15))
# bbox
annotation[0, 0] = label[0] # x1
annotation[0, 1] = label[1] # y1
annotation[0, 2] = label[0] + label[2] # x2
annotation[0, 3] = label[1] + label[3] # y2
# landmarks
annotation[0, 4] = label[4] # l0_x
annotation[0, 5] = label[5] # l0_y
annotation[0, 6] = label[7] # l1_x
annotation[0, 7] = label[8] # l1_y
annotation[0, 8] = label[10] # l2_x
annotation[0, 9] = label[11] # l2_y
annotation[0, 10] = label[13] # l3_x
annotation[0, 11] = label[14] # l3_y
annotation[0, 12] = label[16] # l4_x
annotation[0, 13] = label[17] # l4_y
if (annotation[0, 4]<0):
annotation[0, 14] = -1
else:
annotation[0, 14] = 1
annotations = np.append(annotations, annotation, axis=0)
target = np.array(annotations)
if self.preproc is not None:
img, target = self.preproc(img, target)
return torch.from_numpy(img), target
def detection_collate(batch):
"""Custom collate fn for dealing with batches of images that have a different
number of associated object annotations (bounding boxes).
Arguments:
batch: (tuple) A tuple of tensor images and lists of annotations
Return:
A tuple containing:
1) (tensor) batch of images stacked on their 0 dim
2) (list of tensors) annotations for a given image are stacked on 0 dim
"""
targets = []
imgs = []
for _, sample in enumerate(batch):
for _, tup in enumerate(sample):
if torch.is_tensor(tup):
imgs.append(tup)
elif isinstance(tup, type(np.empty(0))):
annos = torch.from_numpy(tup).float()
targets.append(annos)
return (torch.stack(imgs, 0), targets)

7
retinaface_face_detection.py

@ -15,6 +15,7 @@
from typing import NamedTuple, List from typing import NamedTuple, List
from PIL import Image from PIL import Image
import torch import torch
from torch import nn
from torchvision import transforms from torchvision import transforms
import sys import sys
import towhee import towhee
@ -23,11 +24,12 @@ import numpy
from towhee.operator import Operator from towhee.operator import Operator
from towhee.utils.pil_utils import to_pil from towhee.utils.pil_utils import to_pil
from towhee.operator import NNOperator
from timm.data import resolve_data_config from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform from timm.data.transforms_factory import create_transform
import os import os
class RetinafaceFaceDetection(Operator):
class RetinafaceFaceDetection(NNOperator):
""" """
Embedding extractor using efficientnet. Embedding extractor using efficientnet.
Args: Args:
@ -71,3 +73,6 @@ class RetinafaceFaceDetection(Operator):
output = Outputs(bboxes[i], keypoints[i,:], croppeds[i]) output = Outputs(bboxes[i], keypoints[i,:], croppeds[i])
outputs.append(output) outputs.append(output)
return outputs return outputs
def get_model(self) -> nn.Module:
return self.model._model

7
retinaface_training_yaml.yaml

@ -1,13 +1,14 @@
device: device:
device_str: null
n_gpu: -1
device_str: cuda
n_gpu: 2
sync_bn: true sync_bn: true
metrics: metrics:
metric: Accuracy
metric: MeanAveragePrecision
train: train:
batch_size: 32 batch_size: 32
overwrite_output_dir: true overwrite_output_dir: true
epoch_num: 2 epoch_num: 2
eval_strategy: eval_epoch
learning: learning:
optimizer: optimizer:
name_: SGD name_: SGD

194
train.py

@ -1,84 +1,152 @@
import numpy as np import numpy as np
import ipdb
import torch
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils import data
from torchvision import transforms from torchvision import transforms
from torchvision.transforms import RandomResizedCrop, Lambda from torchvision.transforms import RandomResizedCrop, Lambda
from towhee.trainer.modelcard import ModelCard
#from towhee.trainer.modelcard import ModelCard
from towhee.trainer.training_config import TrainingConfig from towhee.trainer.training_config import TrainingConfig
from towhee.trainer.dataset import get_dataset
from resnet_image_embedding import ResnetImageEmbedding
from towhee.trainer.trainer import Trainer
from retinaface_face_detection import RetinafaceFaceDetection
from towhee.types import Image from towhee.types import Image
from towhee.trainer.training_config import dump_default_yaml from towhee.trainer.training_config import dump_default_yaml
from towhee.trainer.utils.trainer_utils import send_to_device
from PIL import Image as PILImage from PIL import Image as PILImage
from timm.models.resnet import ResNet from timm.models.resnet import ResNet
from torch import nn from torch import nn
from pytorch.wider_face import WiderFaceDetection, detection_collate
from pytorch.multibox_loss import MultiBoxLoss
from pytorch.data_augment import preproc
from pytorch.prior_box import PriorBox
if __name__ == '__main__': if __name__ == '__main__':
dump_default_yaml(yaml_path='default_config.yaml')
# img = torch.rand([1, 3, 224, 224])
img_path = './ILSVRC2012_val_00049771.JPEG'
# # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png')
img = PILImage.open(img_path)
img_bytes = img.tobytes()
img_width = img.width
img_height = img.height
img_channel = len(img.split())
img_mode = img.mode
img_array = np.array(img)
array_size = np.array(img).shape
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array)
op = ResnetImageEmbedding('resnet34')
# op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data")
# old_out = op(towhee_img)
# print(old_out.feature_vector[0])
op = RetinafaceFaceDetection()
training_config = TrainingConfig() training_config = TrainingConfig()
yaml_path = 'resnet_training_yaml.yaml'
# dump_default_yaml(yaml_path=yaml_path)
yaml_path = 'retinaface_training_yaml.yaml' # 'resnet_training_yaml.yaml'
training_config.load_from_yaml(yaml_path) training_config.load_from_yaml(yaml_path)
# output_dir='./temp_output',
# overwrite_output_dir=True,
# epoch_num=2,
# per_gpu_train_batch_size=16,
# prediction_loss_only=True,
# metric='Accuracy'
# # device_str='cuda',
# # n_gpu=4
# )
mnist_transform = transforms.Compose([transforms.ToTensor(), mnist_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224), RandomResizedCrop(224),
Lambda(lambda x: x.repeat(3, 1, 1)), Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean=[0.5], std=[0.5])]) transforms.Normalize(mean=[0.5], std=[0.5])])
train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)
eval_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)
# fake_transform = transforms.Compose([transforms.ToTensor(),
# RandomResizedCrop(224),])
# train_data = get_dataset('fake', size=20, transform=fake_transform)
op.change_before_train(10)
trainer = op.setup_trainer()
# my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False)
# op.setup_trainer()
# trainer.add_callback()
# trainer.set_optimizer()
# op.trainer.set_optimizer(my_optimimzer)
# trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml')
# my_loss = nn.BCELoss()
# trainer.set_loss(my_loss, 'my_loss111')
# trainer.configs.save_to_yaml('chaned_loss_yaml.yaml')
# op.trainer._create_optimizer()
# op.trainer.set_optimizer()
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)
# training_config.num_epoch = 3
# op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2')
# op.save('./test_save')
# op.load('./test_save')
# new_out = op(towhee_img)
# assert (new_out[0]!=old_out[0]).all()
training_dataset = './data/widerface/train/label.txt'
img_dim = 840
rgb_mean = (104, 117, 123)
train_data = WiderFaceDetection( training_dataset,preproc(img_dim, rgb_mean))
eval_data = WiderFaceDetection( training_dataset)
cfg = {}
cfg['min_sizes'] = [[16, 32], [64, 128], [256, 512]]
cfg['steps'] = [8, 16, 32]
cfg['clip'] = False
cfg['loc_weight'] = 2.0
priorbox = PriorBox(cfg, image_size=(img_dim, img_dim))
with torch.no_grad():
priors = priorbox.forward()
class RetinafaceTrainer(Trainer):
def __init__(
self,
model: nn.Module = None,
training_config = None,
train_dataset = None,
eval_dataset = None,
model_card = None,
train_dataloader = None,
eval_dataloader = None,
priors = priors
):
super(RetinafaceTrainer, self).__init__(model, training_config, train_dataset, eval_dataset, model_card, train_dataloader, eval_dataloader)
self.priors = priors
def compute_loss(self, model: nn.Module, inputs):
model.train()
labels = inputs[1]
outputs = model(inputs[0])
priors = send_to_device(self.priors, self.configs.device)
loss_l, loss_c, loss_landm = self.loss(outputs, priors, labels)
loss = cfg['loc_weight'] * loss_l + loss_c + loss_landm
return loss
def compute_metric(self, model, inputs):
model.eval()
model.phase = "eval"
epoch_metric = None
labels = inputs[1]
model.cpu()
predicts = []
gts = []
img = send_to_device(inputs[0], torch.device("cpu"))
outputs = model.inference(img[0])
npreds = outputs[0].shape[0]
ngts = labels[0].shape[0]
predicts.append(dict(boxes=outputs[0][:,:4], scores=outputs[0][:,4], labels=torch.IntTensor([0 for i in range(npreds)])))
gts.append(dict(boxes=torch.vstack(labels)[:,:4], labels=torch.IntTensor([0 for i in range(ngts)])))
if self.metric is not None:
self.metric.update(predicts, gts)
epoch_metric = self.metric.compute()['map'].item()
return epoch_metric
@torch.no_grad()
def evaluate_step(self, model, inputs):
loss_metric, epoch_metric = self.update_metrics(model, inputs, 0, training=False)
step_logs = {"eval_step_loss": 0, "eval_epoch_loss": loss_metric,
"eval_epoch_metric": epoch_metric}
return step_logs
def get_train_dataloader(self):
if self.configs.n_gpu > 1:
self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
train_batch_sampler = torch.utils.data.BatchSampler(
self.train_sampler, 32, drop_last=True)
training_loader = torch.utils.data.DataLoader(train_data,
batch_sampler=train_batch_sampler,
num_workers=4, # self.configs.dataloader_num_workers,
pin_memory=True,
collate_fn=detection_collate
)
else:
training_loader = data.DataLoader(train_data, 32, shuffle=True, num_workers=4, collate_fn=detection_collate)
return training_loader
#def get_eval_dataloader(self):
def get_eval_dataloader(self):
if self.configs.n_gpu > 1:
eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_data)
eval_batch_sampler = torch.utils.data.BatchSampler(
eval_sampler, 1, drop_last=True)
eval_loader = torch.utils.data.DataLoader(eval_data,
batch_sampler=eval_batch_sampler,
num_workers=1, # self.configs.dataloader_num_workers,
pin_memory=True,
collate_fn=detection_collate
)
else:
eval_loader = torch.utils.data.DataLoader(eval_data, 1,
num_workers=1, # self.configs.dataloader_num_workers,
)
return eval_loader
trainer = RetinafaceTrainer(op.get_model(), training_config, train_data, eval_data, None, None, None, priors)
op.trainer = trainer
criterion = MultiBoxLoss(2, 0.35, True, 0, True, 7, 0.35, False, True, trainer.configs.device)
trainer.set_loss(criterion, "multibox_loss")
op.train()

Loading…
Cancel
Save