towhee
/
retinaface-face-detection
copied
10 changed files with 575 additions and 75 deletions
@ -1,2 +1,49 @@ |
|||
# retinaface-face-detection |
|||
# Retinaface Face Detection (Pytorch) |
|||
|
|||
Authors: wxywb |
|||
|
|||
## Overview |
|||
|
|||
This opertator detects faces in the images by using RetinaFace Detector[1]. It will returns the locations, five keypoints and the cropped face images from origin images. This repo is a adopataion from [2]. |
|||
|
|||
## Interface |
|||
|
|||
```python |
|||
__call__(self, image: 'towhee.types.Image') |
|||
``` |
|||
|
|||
**Args:** |
|||
|
|||
- image: |
|||
- the image to detect faces. |
|||
- supported types: towhee.types.Image |
|||
|
|||
**Returns:** |
|||
|
|||
The Operator returns a tupe Tuple[('boxes', numpy.ndarray), ('keypoints', numpy.ndarray), ('cropped_imgs', numpy.ndarray)])] containing following fields: |
|||
|
|||
- boxes: |
|||
- boxes of human faces. |
|||
- data type: `numpy.ndarray` |
|||
- shape: (num_faces, 4) |
|||
- keypoints: |
|||
- keypoints of human faces. |
|||
- data type: `numpy.ndarray` |
|||
- shape: (10) |
|||
- cropped_imgs: |
|||
- cropped face images. |
|||
- data type: `numpy.ndarray` |
|||
- shape: (h, w, 3) |
|||
|
|||
## Requirements |
|||
|
|||
You can get the required python package by [requirements.txt](./requirements.txt). |
|||
|
|||
## How it works |
|||
The `towhee/retinaface-face-detection` Operators implents the function of face detection. The example pipeline can be found in [face-embedding-retinaface-inceptionresnetv1](https://towhee.io/towhee/face-embedding-retinaface-inceptionresnetv1) |
|||
|
|||
## Reference |
|||
|
|||
[1]. https://arxiv.org/abs/1905.00641 |
|||
[2]. https://github.com/biubug6/Pytorch_Retinaface |
|||
|
|||
|
@ -0,0 +1,237 @@ |
|||
import cv2 |
|||
import numpy as np |
|||
import random |
|||
from .box_utils import matrix_iof |
|||
|
|||
|
|||
def _crop(image, boxes, labels, landm, img_dim): |
|||
height, width, _ = image.shape |
|||
pad_image_flag = True |
|||
|
|||
for _ in range(250): |
|||
""" |
|||
if random.uniform(0, 1) <= 0.2: |
|||
scale = 1.0 |
|||
else: |
|||
scale = random.uniform(0.3, 1.0) |
|||
""" |
|||
PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0] |
|||
scale = random.choice(PRE_SCALES) |
|||
short_side = min(width, height) |
|||
w = int(scale * short_side) |
|||
h = w |
|||
|
|||
if width == w: |
|||
l = 0 |
|||
else: |
|||
l = random.randrange(width - w) |
|||
if height == h: |
|||
t = 0 |
|||
else: |
|||
t = random.randrange(height - h) |
|||
roi = np.array((l, t, l + w, t + h)) |
|||
|
|||
value = matrix_iof(boxes, roi[np.newaxis]) |
|||
flag = (value >= 1) |
|||
if not flag.any(): |
|||
continue |
|||
|
|||
centers = (boxes[:, :2] + boxes[:, 2:]) / 2 |
|||
mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1) |
|||
boxes_t = boxes[mask_a].copy() |
|||
labels_t = labels[mask_a].copy() |
|||
landms_t = landm[mask_a].copy() |
|||
landms_t = landms_t.reshape([-1, 5, 2]) |
|||
|
|||
if boxes_t.shape[0] == 0: |
|||
continue |
|||
|
|||
image_t = image[roi[1]:roi[3], roi[0]:roi[2]] |
|||
|
|||
boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2]) |
|||
boxes_t[:, :2] -= roi[:2] |
|||
boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:]) |
|||
boxes_t[:, 2:] -= roi[:2] |
|||
|
|||
# landm |
|||
landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2] |
|||
landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0])) |
|||
landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2]) |
|||
landms_t = landms_t.reshape([-1, 10]) |
|||
|
|||
|
|||
# make sure that the cropped image contains at least one face > 16 pixel at training image scale |
|||
b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim |
|||
b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim |
|||
mask_b = np.minimum(b_w_t, b_h_t) > 0.0 |
|||
boxes_t = boxes_t[mask_b] |
|||
labels_t = labels_t[mask_b] |
|||
landms_t = landms_t[mask_b] |
|||
|
|||
if boxes_t.shape[0] == 0: |
|||
continue |
|||
|
|||
pad_image_flag = False |
|||
|
|||
return image_t, boxes_t, labels_t, landms_t, pad_image_flag |
|||
return image, boxes, labels, landm, pad_image_flag |
|||
|
|||
|
|||
def _distort(image): |
|||
|
|||
def _convert(image, alpha=1, beta=0): |
|||
tmp = image.astype(float) * alpha + beta |
|||
tmp[tmp < 0] = 0 |
|||
tmp[tmp > 255] = 255 |
|||
image[:] = tmp |
|||
|
|||
image = image.copy() |
|||
|
|||
if random.randrange(2): |
|||
|
|||
#brightness distortion |
|||
if random.randrange(2): |
|||
_convert(image, beta=random.uniform(-32, 32)) |
|||
|
|||
#contrast distortion |
|||
if random.randrange(2): |
|||
_convert(image, alpha=random.uniform(0.5, 1.5)) |
|||
|
|||
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) |
|||
|
|||
#saturation distortion |
|||
if random.randrange(2): |
|||
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) |
|||
|
|||
#hue distortion |
|||
if random.randrange(2): |
|||
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) |
|||
tmp %= 180 |
|||
image[:, :, 0] = tmp |
|||
|
|||
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) |
|||
|
|||
else: |
|||
|
|||
#brightness distortion |
|||
if random.randrange(2): |
|||
_convert(image, beta=random.uniform(-32, 32)) |
|||
|
|||
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) |
|||
|
|||
#saturation distortion |
|||
if random.randrange(2): |
|||
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) |
|||
|
|||
#hue distortion |
|||
if random.randrange(2): |
|||
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) |
|||
tmp %= 180 |
|||
image[:, :, 0] = tmp |
|||
|
|||
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) |
|||
|
|||
#contrast distortion |
|||
if random.randrange(2): |
|||
_convert(image, alpha=random.uniform(0.5, 1.5)) |
|||
|
|||
return image |
|||
|
|||
|
|||
def _expand(image, boxes, fill, p): |
|||
if random.randrange(2): |
|||
return image, boxes |
|||
|
|||
height, width, depth = image.shape |
|||
|
|||
scale = random.uniform(1, p) |
|||
w = int(scale * width) |
|||
h = int(scale * height) |
|||
|
|||
left = random.randint(0, w - width) |
|||
top = random.randint(0, h - height) |
|||
|
|||
boxes_t = boxes.copy() |
|||
boxes_t[:, :2] += (left, top) |
|||
boxes_t[:, 2:] += (left, top) |
|||
expand_image = np.empty( |
|||
(h, w, depth), |
|||
dtype=image.dtype) |
|||
expand_image[:, :] = fill |
|||
expand_image[top:top + height, left:left + width] = image |
|||
image = expand_image |
|||
|
|||
return image, boxes_t |
|||
|
|||
|
|||
def _mirror(image, boxes, landms): |
|||
_, width, _ = image.shape |
|||
if random.randrange(2): |
|||
image = image[:, ::-1] |
|||
boxes = boxes.copy() |
|||
boxes[:, 0::2] = width - boxes[:, 2::-2] |
|||
|
|||
# landm |
|||
landms = landms.copy() |
|||
landms = landms.reshape([-1, 5, 2]) |
|||
landms[:, :, 0] = width - landms[:, :, 0] |
|||
tmp = landms[:, 1, :].copy() |
|||
landms[:, 1, :] = landms[:, 0, :] |
|||
landms[:, 0, :] = tmp |
|||
tmp1 = landms[:, 4, :].copy() |
|||
landms[:, 4, :] = landms[:, 3, :] |
|||
landms[:, 3, :] = tmp1 |
|||
landms = landms.reshape([-1, 10]) |
|||
|
|||
return image, boxes, landms |
|||
|
|||
|
|||
def _pad_to_square(image, rgb_mean, pad_image_flag): |
|||
if not pad_image_flag: |
|||
return image |
|||
height, width, _ = image.shape |
|||
long_side = max(width, height) |
|||
image_t = np.empty((long_side, long_side, 3), dtype=image.dtype) |
|||
image_t[:, :] = rgb_mean |
|||
image_t[0:0 + height, 0:0 + width] = image |
|||
return image_t |
|||
|
|||
|
|||
def _resize_subtract_mean(image, insize, rgb_mean): |
|||
interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4] |
|||
interp_method = interp_methods[random.randrange(5)] |
|||
image = cv2.resize(image, (insize, insize), interpolation=interp_method) |
|||
image = image.astype(np.float32) |
|||
image -= rgb_mean |
|||
return image.transpose(2, 0, 1) |
|||
|
|||
|
|||
class preproc(object): |
|||
|
|||
def __init__(self, img_dim, rgb_means): |
|||
self.img_dim = img_dim |
|||
self.rgb_means = rgb_means |
|||
|
|||
def __call__(self, image, targets): |
|||
assert targets.shape[0] > 0, "this image does not have gt" |
|||
|
|||
boxes = targets[:, :4].copy() |
|||
labels = targets[:, -1].copy() |
|||
landm = targets[:, 4:-1].copy() |
|||
|
|||
image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim) |
|||
image_t = _distort(image_t) |
|||
image_t = _pad_to_square(image_t,self.rgb_means, pad_image_flag) |
|||
image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t) |
|||
height, width, _ = image_t.shape |
|||
image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means) |
|||
boxes_t[:, 0::2] /= width |
|||
boxes_t[:, 1::2] /= height |
|||
|
|||
landm_t[:, 0::2] /= width |
|||
landm_t[:, 1::2] /= height |
|||
|
|||
labels_t = np.expand_dims(labels_t, 1) |
|||
targets_t = np.hstack((boxes_t, landm_t, labels_t)) |
|||
|
|||
return image_t, targets_t |
@ -0,0 +1,4 @@ |
|||
import torch |
|||
|
|||
torch.load('pytorch_retinaface_mobilenet_widerface.pth') |
|||
|
@ -0,0 +1,34 @@ |
|||
import torch |
|||
from itertools import product as product |
|||
import numpy as np |
|||
from math import ceil |
|||
|
|||
|
|||
class PriorBox(object): |
|||
def __init__(self, cfg, image_size=None, phase='train'): |
|||
super(PriorBox, self).__init__() |
|||
self.min_sizes = cfg['min_sizes'] |
|||
self.steps = cfg['steps'] |
|||
self.clip = cfg['clip'] |
|||
self.image_size = image_size |
|||
self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps] |
|||
self.name = "s" |
|||
|
|||
def forward(self): |
|||
anchors = [] |
|||
for k, f in enumerate(self.feature_maps): |
|||
min_sizes = self.min_sizes[k] |
|||
for i, j in product(range(f[0]), range(f[1])): |
|||
for min_size in min_sizes: |
|||
s_kx = min_size / self.image_size[1] |
|||
s_ky = min_size / self.image_size[0] |
|||
dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] |
|||
dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] |
|||
for cy, cx in product(dense_cy, dense_cx): |
|||
anchors += [cx, cy, s_kx, s_ky] |
|||
|
|||
# back to torch land |
|||
output = torch.Tensor(anchors).view(-1, 4) |
|||
if self.clip: |
|||
output.clamp_(max=1, min=0) |
|||
return output |
Binary file not shown.
@ -0,0 +1,102 @@ |
|||
import os |
|||
import os.path |
|||
import sys |
|||
import torch |
|||
import torch.utils.data as data |
|||
import cv2 |
|||
import numpy as np |
|||
|
|||
class WiderFaceDetection(data.Dataset): |
|||
def __init__(self, txt_path, preproc=None): |
|||
self.preproc = preproc |
|||
self.imgs_path = [] |
|||
self.words = [] |
|||
f = open(txt_path,'r') |
|||
lines = f.readlines() |
|||
isFirst = True |
|||
labels = [] |
|||
for line in lines: |
|||
line = line.rstrip() |
|||
if line.startswith('#'): |
|||
if isFirst is True: |
|||
isFirst = False |
|||
else: |
|||
labels_copy = labels.copy() |
|||
self.words.append(labels_copy) |
|||
labels.clear() |
|||
path = line[2:] |
|||
path = txt_path.replace('label.txt','images/') + path |
|||
self.imgs_path.append(path) |
|||
else: |
|||
line = line.split(' ') |
|||
label = [float(x) for x in line] |
|||
labels.append(label) |
|||
|
|||
self.words.append(labels) |
|||
|
|||
def __len__(self): |
|||
return len(self.imgs_path) |
|||
|
|||
def __getitem__(self, index): |
|||
img = cv2.imread(self.imgs_path[index]) |
|||
height, width, _ = img.shape |
|||
|
|||
labels = self.words[index] |
|||
annotations = np.zeros((0, 15)) |
|||
if len(labels) == 0: |
|||
return annotations |
|||
for idx, label in enumerate(labels): |
|||
annotation = np.zeros((1, 15)) |
|||
# bbox |
|||
annotation[0, 0] = label[0] # x1 |
|||
annotation[0, 1] = label[1] # y1 |
|||
annotation[0, 2] = label[0] + label[2] # x2 |
|||
annotation[0, 3] = label[1] + label[3] # y2 |
|||
|
|||
# landmarks |
|||
annotation[0, 4] = label[4] # l0_x |
|||
annotation[0, 5] = label[5] # l0_y |
|||
annotation[0, 6] = label[7] # l1_x |
|||
annotation[0, 7] = label[8] # l1_y |
|||
annotation[0, 8] = label[10] # l2_x |
|||
annotation[0, 9] = label[11] # l2_y |
|||
annotation[0, 10] = label[13] # l3_x |
|||
annotation[0, 11] = label[14] # l3_y |
|||
annotation[0, 12] = label[16] # l4_x |
|||
annotation[0, 13] = label[17] # l4_y |
|||
if (annotation[0, 4]<0): |
|||
annotation[0, 14] = -1 |
|||
else: |
|||
annotation[0, 14] = 1 |
|||
|
|||
annotations = np.append(annotations, annotation, axis=0) |
|||
target = np.array(annotations) |
|||
if self.preproc is not None: |
|||
img, target = self.preproc(img, target) |
|||
|
|||
return torch.from_numpy(img), target |
|||
|
|||
def detection_collate(batch): |
|||
"""Custom collate fn for dealing with batches of images that have a different |
|||
number of associated object annotations (bounding boxes). |
|||
|
|||
Arguments: |
|||
batch: (tuple) A tuple of tensor images and lists of annotations |
|||
|
|||
Return: |
|||
A tuple containing: |
|||
1) (tensor) batch of images stacked on their 0 dim |
|||
2) (list of tensors) annotations for a given image are stacked on 0 dim |
|||
""" |
|||
targets = [] |
|||
imgs = [] |
|||
for _, sample in enumerate(batch): |
|||
for _, tup in enumerate(sample): |
|||
if torch.is_tensor(tup): |
|||
imgs.append(tup) |
|||
elif isinstance(tup, type(np.empty(0))): |
|||
annos = torch.from_numpy(tup).float() |
|||
targets.append(annos) |
|||
|
|||
return (torch.stack(imgs, 0), targets) |
|||
|
@ -1,84 +1,152 @@ |
|||
import numpy as np |
|||
import ipdb |
|||
import torch |
|||
from torch.optim import AdamW |
|||
from torch.utils import data |
|||
from torchvision import transforms |
|||
from torchvision.transforms import RandomResizedCrop, Lambda |
|||
from towhee.trainer.modelcard import ModelCard |
|||
#from towhee.trainer.modelcard import ModelCard |
|||
|
|||
from towhee.trainer.training_config import TrainingConfig |
|||
from towhee.trainer.dataset import get_dataset |
|||
from resnet_image_embedding import ResnetImageEmbedding |
|||
from towhee.trainer.trainer import Trainer |
|||
from retinaface_face_detection import RetinafaceFaceDetection |
|||
from towhee.types import Image |
|||
from towhee.trainer.training_config import dump_default_yaml |
|||
from towhee.trainer.utils.trainer_utils import send_to_device |
|||
from PIL import Image as PILImage |
|||
from timm.models.resnet import ResNet |
|||
from torch import nn |
|||
from pytorch.wider_face import WiderFaceDetection, detection_collate |
|||
from pytorch.multibox_loss import MultiBoxLoss |
|||
from pytorch.data_augment import preproc |
|||
from pytorch.prior_box import PriorBox |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
dump_default_yaml(yaml_path='default_config.yaml') |
|||
# img = torch.rand([1, 3, 224, 224]) |
|||
img_path = './ILSVRC2012_val_00049771.JPEG' |
|||
# # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png') |
|||
img = PILImage.open(img_path) |
|||
img_bytes = img.tobytes() |
|||
img_width = img.width |
|||
img_height = img.height |
|||
img_channel = len(img.split()) |
|||
img_mode = img.mode |
|||
img_array = np.array(img) |
|||
array_size = np.array(img).shape |
|||
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) |
|||
|
|||
op = ResnetImageEmbedding('resnet34') |
|||
# op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") |
|||
# old_out = op(towhee_img) |
|||
# print(old_out.feature_vector[0]) |
|||
|
|||
op = RetinafaceFaceDetection() |
|||
training_config = TrainingConfig() |
|||
yaml_path = 'resnet_training_yaml.yaml' |
|||
# dump_default_yaml(yaml_path=yaml_path) |
|||
yaml_path = 'retinaface_training_yaml.yaml' # 'resnet_training_yaml.yaml' |
|||
|
|||
training_config.load_from_yaml(yaml_path) |
|||
# output_dir='./temp_output', |
|||
# overwrite_output_dir=True, |
|||
# epoch_num=2, |
|||
# per_gpu_train_batch_size=16, |
|||
# prediction_loss_only=True, |
|||
# metric='Accuracy' |
|||
# # device_str='cuda', |
|||
# # n_gpu=4 |
|||
# ) |
|||
|
|||
mnist_transform = transforms.Compose([transforms.ToTensor(), |
|||
RandomResizedCrop(224), |
|||
Lambda(lambda x: x.repeat(3, 1, 1)), |
|||
transforms.Normalize(mean=[0.5], std=[0.5])]) |
|||
train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=True) |
|||
eval_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=False) |
|||
# fake_transform = transforms.Compose([transforms.ToTensor(), |
|||
# RandomResizedCrop(224),]) |
|||
# train_data = get_dataset('fake', size=20, transform=fake_transform) |
|||
|
|||
op.change_before_train(10) |
|||
trainer = op.setup_trainer() |
|||
# my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False) |
|||
# op.setup_trainer() |
|||
|
|||
# trainer.add_callback() |
|||
# trainer.set_optimizer() |
|||
|
|||
# op.trainer.set_optimizer(my_optimimzer) |
|||
# trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml') |
|||
|
|||
# my_loss = nn.BCELoss() |
|||
# trainer.set_loss(my_loss, 'my_loss111') |
|||
# trainer.configs.save_to_yaml('chaned_loss_yaml.yaml') |
|||
# op.trainer._create_optimizer() |
|||
# op.trainer.set_optimizer() |
|||
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data) |
|||
# training_config.num_epoch = 3 |
|||
# op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') |
|||
|
|||
# op.save('./test_save') |
|||
# op.load('./test_save') |
|||
# new_out = op(towhee_img) |
|||
|
|||
# assert (new_out[0]!=old_out[0]).all() |
|||
|
|||
training_dataset = './data/widerface/train/label.txt' |
|||
|
|||
img_dim = 840 |
|||
rgb_mean = (104, 117, 123) |
|||
train_data = WiderFaceDetection( training_dataset,preproc(img_dim, rgb_mean)) |
|||
eval_data = WiderFaceDetection( training_dataset) |
|||
|
|||
cfg = {} |
|||
cfg['min_sizes'] = [[16, 32], [64, 128], [256, 512]] |
|||
cfg['steps'] = [8, 16, 32] |
|||
cfg['clip'] = False |
|||
cfg['loc_weight'] = 2.0 |
|||
|
|||
priorbox = PriorBox(cfg, image_size=(img_dim, img_dim)) |
|||
|
|||
with torch.no_grad(): |
|||
priors = priorbox.forward() |
|||
|
|||
class RetinafaceTrainer(Trainer): |
|||
def __init__( |
|||
self, |
|||
model: nn.Module = None, |
|||
training_config = None, |
|||
train_dataset = None, |
|||
eval_dataset = None, |
|||
model_card = None, |
|||
train_dataloader = None, |
|||
eval_dataloader = None, |
|||
priors = priors |
|||
): |
|||
super(RetinafaceTrainer, self).__init__(model, training_config, train_dataset, eval_dataset, model_card, train_dataloader, eval_dataloader) |
|||
self.priors = priors |
|||
|
|||
def compute_loss(self, model: nn.Module, inputs): |
|||
model.train() |
|||
labels = inputs[1] |
|||
outputs = model(inputs[0]) |
|||
priors = send_to_device(self.priors, self.configs.device) |
|||
loss_l, loss_c, loss_landm = self.loss(outputs, priors, labels) |
|||
loss = cfg['loc_weight'] * loss_l + loss_c + loss_landm |
|||
return loss |
|||
|
|||
def compute_metric(self, model, inputs): |
|||
model.eval() |
|||
model.phase = "eval" |
|||
epoch_metric = None |
|||
labels = inputs[1] |
|||
model.cpu() |
|||
predicts = [] |
|||
gts = [] |
|||
|
|||
img = send_to_device(inputs[0], torch.device("cpu")) |
|||
outputs = model.inference(img[0]) |
|||
|
|||
npreds = outputs[0].shape[0] |
|||
ngts = labels[0].shape[0] |
|||
predicts.append(dict(boxes=outputs[0][:,:4], scores=outputs[0][:,4], labels=torch.IntTensor([0 for i in range(npreds)]))) |
|||
gts.append(dict(boxes=torch.vstack(labels)[:,:4], labels=torch.IntTensor([0 for i in range(ngts)]))) |
|||
|
|||
if self.metric is not None: |
|||
self.metric.update(predicts, gts) |
|||
epoch_metric = self.metric.compute()['map'].item() |
|||
return epoch_metric |
|||
|
|||
@torch.no_grad() |
|||
def evaluate_step(self, model, inputs): |
|||
loss_metric, epoch_metric = self.update_metrics(model, inputs, 0, training=False) |
|||
step_logs = {"eval_step_loss": 0, "eval_epoch_loss": loss_metric, |
|||
"eval_epoch_metric": epoch_metric} |
|||
return step_logs |
|||
|
|||
def get_train_dataloader(self): |
|||
if self.configs.n_gpu > 1: |
|||
self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) |
|||
train_batch_sampler = torch.utils.data.BatchSampler( |
|||
self.train_sampler, 32, drop_last=True) |
|||
training_loader = torch.utils.data.DataLoader(train_data, |
|||
batch_sampler=train_batch_sampler, |
|||
num_workers=4, # self.configs.dataloader_num_workers, |
|||
pin_memory=True, |
|||
collate_fn=detection_collate |
|||
) |
|||
else: |
|||
training_loader = data.DataLoader(train_data, 32, shuffle=True, num_workers=4, collate_fn=detection_collate) |
|||
return training_loader |
|||
|
|||
#def get_eval_dataloader(self): |
|||
def get_eval_dataloader(self): |
|||
if self.configs.n_gpu > 1: |
|||
eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_data) |
|||
eval_batch_sampler = torch.utils.data.BatchSampler( |
|||
eval_sampler, 1, drop_last=True) |
|||
eval_loader = torch.utils.data.DataLoader(eval_data, |
|||
batch_sampler=eval_batch_sampler, |
|||
num_workers=1, # self.configs.dataloader_num_workers, |
|||
pin_memory=True, |
|||
collate_fn=detection_collate |
|||
) |
|||
else: |
|||
eval_loader = torch.utils.data.DataLoader(eval_data, 1, |
|||
num_workers=1, # self.configs.dataloader_num_workers, |
|||
) |
|||
return eval_loader |
|||
|
|||
|
|||
|
|||
trainer = RetinafaceTrainer(op.get_model(), training_config, train_data, eval_data, None, None, None, priors) |
|||
|
|||
op.trainer = trainer |
|||
|
|||
criterion = MultiBoxLoss(2, 0.35, True, 0, True, 7, 0.35, False, True, trainer.configs.device) |
|||
|
|||
trainer.set_loss(criterion, "multibox_loss") |
|||
op.train() |
|||
|
|||
|
Loading…
Reference in new issue