towhee
/
retinaface-face-detection
copied
10 changed files with 575 additions and 75 deletions
@ -1,2 +1,49 @@ |
|||||
# retinaface-face-detection |
|
||||
|
# Retinaface Face Detection (Pytorch) |
||||
|
|
||||
|
Authors: wxywb |
||||
|
|
||||
|
## Overview |
||||
|
|
||||
|
This opertator detects faces in the images by using RetinaFace Detector[1]. It will returns the locations, five keypoints and the cropped face images from origin images. This repo is a adopataion from [2]. |
||||
|
|
||||
|
## Interface |
||||
|
|
||||
|
```python |
||||
|
__call__(self, image: 'towhee.types.Image') |
||||
|
``` |
||||
|
|
||||
|
**Args:** |
||||
|
|
||||
|
- image: |
||||
|
- the image to detect faces. |
||||
|
- supported types: towhee.types.Image |
||||
|
|
||||
|
**Returns:** |
||||
|
|
||||
|
The Operator returns a tupe Tuple[('boxes', numpy.ndarray), ('keypoints', numpy.ndarray), ('cropped_imgs', numpy.ndarray)])] containing following fields: |
||||
|
|
||||
|
- boxes: |
||||
|
- boxes of human faces. |
||||
|
- data type: `numpy.ndarray` |
||||
|
- shape: (num_faces, 4) |
||||
|
- keypoints: |
||||
|
- keypoints of human faces. |
||||
|
- data type: `numpy.ndarray` |
||||
|
- shape: (10) |
||||
|
- cropped_imgs: |
||||
|
- cropped face images. |
||||
|
- data type: `numpy.ndarray` |
||||
|
- shape: (h, w, 3) |
||||
|
|
||||
|
## Requirements |
||||
|
|
||||
|
You can get the required python package by [requirements.txt](./requirements.txt). |
||||
|
|
||||
|
## How it works |
||||
|
The `towhee/retinaface-face-detection` Operators implents the function of face detection. The example pipeline can be found in [face-embedding-retinaface-inceptionresnetv1](https://towhee.io/towhee/face-embedding-retinaface-inceptionresnetv1) |
||||
|
|
||||
|
## Reference |
||||
|
|
||||
|
[1]. https://arxiv.org/abs/1905.00641 |
||||
|
[2]. https://github.com/biubug6/Pytorch_Retinaface |
||||
|
|
||||
|
@ -0,0 +1,237 @@ |
|||||
|
import cv2 |
||||
|
import numpy as np |
||||
|
import random |
||||
|
from .box_utils import matrix_iof |
||||
|
|
||||
|
|
||||
|
def _crop(image, boxes, labels, landm, img_dim): |
||||
|
height, width, _ = image.shape |
||||
|
pad_image_flag = True |
||||
|
|
||||
|
for _ in range(250): |
||||
|
""" |
||||
|
if random.uniform(0, 1) <= 0.2: |
||||
|
scale = 1.0 |
||||
|
else: |
||||
|
scale = random.uniform(0.3, 1.0) |
||||
|
""" |
||||
|
PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0] |
||||
|
scale = random.choice(PRE_SCALES) |
||||
|
short_side = min(width, height) |
||||
|
w = int(scale * short_side) |
||||
|
h = w |
||||
|
|
||||
|
if width == w: |
||||
|
l = 0 |
||||
|
else: |
||||
|
l = random.randrange(width - w) |
||||
|
if height == h: |
||||
|
t = 0 |
||||
|
else: |
||||
|
t = random.randrange(height - h) |
||||
|
roi = np.array((l, t, l + w, t + h)) |
||||
|
|
||||
|
value = matrix_iof(boxes, roi[np.newaxis]) |
||||
|
flag = (value >= 1) |
||||
|
if not flag.any(): |
||||
|
continue |
||||
|
|
||||
|
centers = (boxes[:, :2] + boxes[:, 2:]) / 2 |
||||
|
mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1) |
||||
|
boxes_t = boxes[mask_a].copy() |
||||
|
labels_t = labels[mask_a].copy() |
||||
|
landms_t = landm[mask_a].copy() |
||||
|
landms_t = landms_t.reshape([-1, 5, 2]) |
||||
|
|
||||
|
if boxes_t.shape[0] == 0: |
||||
|
continue |
||||
|
|
||||
|
image_t = image[roi[1]:roi[3], roi[0]:roi[2]] |
||||
|
|
||||
|
boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2]) |
||||
|
boxes_t[:, :2] -= roi[:2] |
||||
|
boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:]) |
||||
|
boxes_t[:, 2:] -= roi[:2] |
||||
|
|
||||
|
# landm |
||||
|
landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2] |
||||
|
landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0])) |
||||
|
landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2]) |
||||
|
landms_t = landms_t.reshape([-1, 10]) |
||||
|
|
||||
|
|
||||
|
# make sure that the cropped image contains at least one face > 16 pixel at training image scale |
||||
|
b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim |
||||
|
b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim |
||||
|
mask_b = np.minimum(b_w_t, b_h_t) > 0.0 |
||||
|
boxes_t = boxes_t[mask_b] |
||||
|
labels_t = labels_t[mask_b] |
||||
|
landms_t = landms_t[mask_b] |
||||
|
|
||||
|
if boxes_t.shape[0] == 0: |
||||
|
continue |
||||
|
|
||||
|
pad_image_flag = False |
||||
|
|
||||
|
return image_t, boxes_t, labels_t, landms_t, pad_image_flag |
||||
|
return image, boxes, labels, landm, pad_image_flag |
||||
|
|
||||
|
|
||||
|
def _distort(image): |
||||
|
|
||||
|
def _convert(image, alpha=1, beta=0): |
||||
|
tmp = image.astype(float) * alpha + beta |
||||
|
tmp[tmp < 0] = 0 |
||||
|
tmp[tmp > 255] = 255 |
||||
|
image[:] = tmp |
||||
|
|
||||
|
image = image.copy() |
||||
|
|
||||
|
if random.randrange(2): |
||||
|
|
||||
|
#brightness distortion |
||||
|
if random.randrange(2): |
||||
|
_convert(image, beta=random.uniform(-32, 32)) |
||||
|
|
||||
|
#contrast distortion |
||||
|
if random.randrange(2): |
||||
|
_convert(image, alpha=random.uniform(0.5, 1.5)) |
||||
|
|
||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) |
||||
|
|
||||
|
#saturation distortion |
||||
|
if random.randrange(2): |
||||
|
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) |
||||
|
|
||||
|
#hue distortion |
||||
|
if random.randrange(2): |
||||
|
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) |
||||
|
tmp %= 180 |
||||
|
image[:, :, 0] = tmp |
||||
|
|
||||
|
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) |
||||
|
|
||||
|
else: |
||||
|
|
||||
|
#brightness distortion |
||||
|
if random.randrange(2): |
||||
|
_convert(image, beta=random.uniform(-32, 32)) |
||||
|
|
||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) |
||||
|
|
||||
|
#saturation distortion |
||||
|
if random.randrange(2): |
||||
|
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) |
||||
|
|
||||
|
#hue distortion |
||||
|
if random.randrange(2): |
||||
|
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) |
||||
|
tmp %= 180 |
||||
|
image[:, :, 0] = tmp |
||||
|
|
||||
|
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) |
||||
|
|
||||
|
#contrast distortion |
||||
|
if random.randrange(2): |
||||
|
_convert(image, alpha=random.uniform(0.5, 1.5)) |
||||
|
|
||||
|
return image |
||||
|
|
||||
|
|
||||
|
def _expand(image, boxes, fill, p): |
||||
|
if random.randrange(2): |
||||
|
return image, boxes |
||||
|
|
||||
|
height, width, depth = image.shape |
||||
|
|
||||
|
scale = random.uniform(1, p) |
||||
|
w = int(scale * width) |
||||
|
h = int(scale * height) |
||||
|
|
||||
|
left = random.randint(0, w - width) |
||||
|
top = random.randint(0, h - height) |
||||
|
|
||||
|
boxes_t = boxes.copy() |
||||
|
boxes_t[:, :2] += (left, top) |
||||
|
boxes_t[:, 2:] += (left, top) |
||||
|
expand_image = np.empty( |
||||
|
(h, w, depth), |
||||
|
dtype=image.dtype) |
||||
|
expand_image[:, :] = fill |
||||
|
expand_image[top:top + height, left:left + width] = image |
||||
|
image = expand_image |
||||
|
|
||||
|
return image, boxes_t |
||||
|
|
||||
|
|
||||
|
def _mirror(image, boxes, landms): |
||||
|
_, width, _ = image.shape |
||||
|
if random.randrange(2): |
||||
|
image = image[:, ::-1] |
||||
|
boxes = boxes.copy() |
||||
|
boxes[:, 0::2] = width - boxes[:, 2::-2] |
||||
|
|
||||
|
# landm |
||||
|
landms = landms.copy() |
||||
|
landms = landms.reshape([-1, 5, 2]) |
||||
|
landms[:, :, 0] = width - landms[:, :, 0] |
||||
|
tmp = landms[:, 1, :].copy() |
||||
|
landms[:, 1, :] = landms[:, 0, :] |
||||
|
landms[:, 0, :] = tmp |
||||
|
tmp1 = landms[:, 4, :].copy() |
||||
|
landms[:, 4, :] = landms[:, 3, :] |
||||
|
landms[:, 3, :] = tmp1 |
||||
|
landms = landms.reshape([-1, 10]) |
||||
|
|
||||
|
return image, boxes, landms |
||||
|
|
||||
|
|
||||
|
def _pad_to_square(image, rgb_mean, pad_image_flag): |
||||
|
if not pad_image_flag: |
||||
|
return image |
||||
|
height, width, _ = image.shape |
||||
|
long_side = max(width, height) |
||||
|
image_t = np.empty((long_side, long_side, 3), dtype=image.dtype) |
||||
|
image_t[:, :] = rgb_mean |
||||
|
image_t[0:0 + height, 0:0 + width] = image |
||||
|
return image_t |
||||
|
|
||||
|
|
||||
|
def _resize_subtract_mean(image, insize, rgb_mean): |
||||
|
interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4] |
||||
|
interp_method = interp_methods[random.randrange(5)] |
||||
|
image = cv2.resize(image, (insize, insize), interpolation=interp_method) |
||||
|
image = image.astype(np.float32) |
||||
|
image -= rgb_mean |
||||
|
return image.transpose(2, 0, 1) |
||||
|
|
||||
|
|
||||
|
class preproc(object): |
||||
|
|
||||
|
def __init__(self, img_dim, rgb_means): |
||||
|
self.img_dim = img_dim |
||||
|
self.rgb_means = rgb_means |
||||
|
|
||||
|
def __call__(self, image, targets): |
||||
|
assert targets.shape[0] > 0, "this image does not have gt" |
||||
|
|
||||
|
boxes = targets[:, :4].copy() |
||||
|
labels = targets[:, -1].copy() |
||||
|
landm = targets[:, 4:-1].copy() |
||||
|
|
||||
|
image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim) |
||||
|
image_t = _distort(image_t) |
||||
|
image_t = _pad_to_square(image_t,self.rgb_means, pad_image_flag) |
||||
|
image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t) |
||||
|
height, width, _ = image_t.shape |
||||
|
image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means) |
||||
|
boxes_t[:, 0::2] /= width |
||||
|
boxes_t[:, 1::2] /= height |
||||
|
|
||||
|
landm_t[:, 0::2] /= width |
||||
|
landm_t[:, 1::2] /= height |
||||
|
|
||||
|
labels_t = np.expand_dims(labels_t, 1) |
||||
|
targets_t = np.hstack((boxes_t, landm_t, labels_t)) |
||||
|
|
||||
|
return image_t, targets_t |
@ -0,0 +1,4 @@ |
|||||
|
import torch |
||||
|
|
||||
|
torch.load('pytorch_retinaface_mobilenet_widerface.pth') |
||||
|
|
@ -0,0 +1,34 @@ |
|||||
|
import torch |
||||
|
from itertools import product as product |
||||
|
import numpy as np |
||||
|
from math import ceil |
||||
|
|
||||
|
|
||||
|
class PriorBox(object): |
||||
|
def __init__(self, cfg, image_size=None, phase='train'): |
||||
|
super(PriorBox, self).__init__() |
||||
|
self.min_sizes = cfg['min_sizes'] |
||||
|
self.steps = cfg['steps'] |
||||
|
self.clip = cfg['clip'] |
||||
|
self.image_size = image_size |
||||
|
self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps] |
||||
|
self.name = "s" |
||||
|
|
||||
|
def forward(self): |
||||
|
anchors = [] |
||||
|
for k, f in enumerate(self.feature_maps): |
||||
|
min_sizes = self.min_sizes[k] |
||||
|
for i, j in product(range(f[0]), range(f[1])): |
||||
|
for min_size in min_sizes: |
||||
|
s_kx = min_size / self.image_size[1] |
||||
|
s_ky = min_size / self.image_size[0] |
||||
|
dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] |
||||
|
dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] |
||||
|
for cy, cx in product(dense_cy, dense_cx): |
||||
|
anchors += [cx, cy, s_kx, s_ky] |
||||
|
|
||||
|
# back to torch land |
||||
|
output = torch.Tensor(anchors).view(-1, 4) |
||||
|
if self.clip: |
||||
|
output.clamp_(max=1, min=0) |
||||
|
return output |
Binary file not shown.
@ -0,0 +1,102 @@ |
|||||
|
import os |
||||
|
import os.path |
||||
|
import sys |
||||
|
import torch |
||||
|
import torch.utils.data as data |
||||
|
import cv2 |
||||
|
import numpy as np |
||||
|
|
||||
|
class WiderFaceDetection(data.Dataset): |
||||
|
def __init__(self, txt_path, preproc=None): |
||||
|
self.preproc = preproc |
||||
|
self.imgs_path = [] |
||||
|
self.words = [] |
||||
|
f = open(txt_path,'r') |
||||
|
lines = f.readlines() |
||||
|
isFirst = True |
||||
|
labels = [] |
||||
|
for line in lines: |
||||
|
line = line.rstrip() |
||||
|
if line.startswith('#'): |
||||
|
if isFirst is True: |
||||
|
isFirst = False |
||||
|
else: |
||||
|
labels_copy = labels.copy() |
||||
|
self.words.append(labels_copy) |
||||
|
labels.clear() |
||||
|
path = line[2:] |
||||
|
path = txt_path.replace('label.txt','images/') + path |
||||
|
self.imgs_path.append(path) |
||||
|
else: |
||||
|
line = line.split(' ') |
||||
|
label = [float(x) for x in line] |
||||
|
labels.append(label) |
||||
|
|
||||
|
self.words.append(labels) |
||||
|
|
||||
|
def __len__(self): |
||||
|
return len(self.imgs_path) |
||||
|
|
||||
|
def __getitem__(self, index): |
||||
|
img = cv2.imread(self.imgs_path[index]) |
||||
|
height, width, _ = img.shape |
||||
|
|
||||
|
labels = self.words[index] |
||||
|
annotations = np.zeros((0, 15)) |
||||
|
if len(labels) == 0: |
||||
|
return annotations |
||||
|
for idx, label in enumerate(labels): |
||||
|
annotation = np.zeros((1, 15)) |
||||
|
# bbox |
||||
|
annotation[0, 0] = label[0] # x1 |
||||
|
annotation[0, 1] = label[1] # y1 |
||||
|
annotation[0, 2] = label[0] + label[2] # x2 |
||||
|
annotation[0, 3] = label[1] + label[3] # y2 |
||||
|
|
||||
|
# landmarks |
||||
|
annotation[0, 4] = label[4] # l0_x |
||||
|
annotation[0, 5] = label[5] # l0_y |
||||
|
annotation[0, 6] = label[7] # l1_x |
||||
|
annotation[0, 7] = label[8] # l1_y |
||||
|
annotation[0, 8] = label[10] # l2_x |
||||
|
annotation[0, 9] = label[11] # l2_y |
||||
|
annotation[0, 10] = label[13] # l3_x |
||||
|
annotation[0, 11] = label[14] # l3_y |
||||
|
annotation[0, 12] = label[16] # l4_x |
||||
|
annotation[0, 13] = label[17] # l4_y |
||||
|
if (annotation[0, 4]<0): |
||||
|
annotation[0, 14] = -1 |
||||
|
else: |
||||
|
annotation[0, 14] = 1 |
||||
|
|
||||
|
annotations = np.append(annotations, annotation, axis=0) |
||||
|
target = np.array(annotations) |
||||
|
if self.preproc is not None: |
||||
|
img, target = self.preproc(img, target) |
||||
|
|
||||
|
return torch.from_numpy(img), target |
||||
|
|
||||
|
def detection_collate(batch): |
||||
|
"""Custom collate fn for dealing with batches of images that have a different |
||||
|
number of associated object annotations (bounding boxes). |
||||
|
|
||||
|
Arguments: |
||||
|
batch: (tuple) A tuple of tensor images and lists of annotations |
||||
|
|
||||
|
Return: |
||||
|
A tuple containing: |
||||
|
1) (tensor) batch of images stacked on their 0 dim |
||||
|
2) (list of tensors) annotations for a given image are stacked on 0 dim |
||||
|
""" |
||||
|
targets = [] |
||||
|
imgs = [] |
||||
|
for _, sample in enumerate(batch): |
||||
|
for _, tup in enumerate(sample): |
||||
|
if torch.is_tensor(tup): |
||||
|
imgs.append(tup) |
||||
|
elif isinstance(tup, type(np.empty(0))): |
||||
|
annos = torch.from_numpy(tup).float() |
||||
|
targets.append(annos) |
||||
|
|
||||
|
return (torch.stack(imgs, 0), targets) |
||||
|
|
@ -1,84 +1,152 @@ |
|||||
import numpy as np |
import numpy as np |
||||
|
import ipdb |
||||
|
import torch |
||||
from torch.optim import AdamW |
from torch.optim import AdamW |
||||
|
from torch.utils import data |
||||
from torchvision import transforms |
from torchvision import transforms |
||||
from torchvision.transforms import RandomResizedCrop, Lambda |
from torchvision.transforms import RandomResizedCrop, Lambda |
||||
from towhee.trainer.modelcard import ModelCard |
|
||||
|
#from towhee.trainer.modelcard import ModelCard |
||||
|
|
||||
from towhee.trainer.training_config import TrainingConfig |
from towhee.trainer.training_config import TrainingConfig |
||||
from towhee.trainer.dataset import get_dataset |
|
||||
from resnet_image_embedding import ResnetImageEmbedding |
|
||||
|
from towhee.trainer.trainer import Trainer |
||||
|
from retinaface_face_detection import RetinafaceFaceDetection |
||||
from towhee.types import Image |
from towhee.types import Image |
||||
from towhee.trainer.training_config import dump_default_yaml |
from towhee.trainer.training_config import dump_default_yaml |
||||
|
from towhee.trainer.utils.trainer_utils import send_to_device |
||||
from PIL import Image as PILImage |
from PIL import Image as PILImage |
||||
from timm.models.resnet import ResNet |
from timm.models.resnet import ResNet |
||||
from torch import nn |
from torch import nn |
||||
|
from pytorch.wider_face import WiderFaceDetection, detection_collate |
||||
|
from pytorch.multibox_loss import MultiBoxLoss |
||||
|
from pytorch.data_augment import preproc |
||||
|
from pytorch.prior_box import PriorBox |
||||
|
|
||||
|
|
||||
if __name__ == '__main__': |
if __name__ == '__main__': |
||||
dump_default_yaml(yaml_path='default_config.yaml') |
|
||||
# img = torch.rand([1, 3, 224, 224]) |
|
||||
img_path = './ILSVRC2012_val_00049771.JPEG' |
|
||||
# # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png') |
|
||||
img = PILImage.open(img_path) |
|
||||
img_bytes = img.tobytes() |
|
||||
img_width = img.width |
|
||||
img_height = img.height |
|
||||
img_channel = len(img.split()) |
|
||||
img_mode = img.mode |
|
||||
img_array = np.array(img) |
|
||||
array_size = np.array(img).shape |
|
||||
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) |
|
||||
|
|
||||
op = ResnetImageEmbedding('resnet34') |
|
||||
# op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") |
|
||||
# old_out = op(towhee_img) |
|
||||
# print(old_out.feature_vector[0]) |
|
||||
|
|
||||
|
op = RetinafaceFaceDetection() |
||||
training_config = TrainingConfig() |
training_config = TrainingConfig() |
||||
yaml_path = 'resnet_training_yaml.yaml' |
|
||||
# dump_default_yaml(yaml_path=yaml_path) |
|
||||
|
yaml_path = 'retinaface_training_yaml.yaml' # 'resnet_training_yaml.yaml' |
||||
|
|
||||
training_config.load_from_yaml(yaml_path) |
training_config.load_from_yaml(yaml_path) |
||||
# output_dir='./temp_output', |
|
||||
# overwrite_output_dir=True, |
|
||||
# epoch_num=2, |
|
||||
# per_gpu_train_batch_size=16, |
|
||||
# prediction_loss_only=True, |
|
||||
# metric='Accuracy' |
|
||||
# # device_str='cuda', |
|
||||
# # n_gpu=4 |
|
||||
# ) |
|
||||
|
|
||||
mnist_transform = transforms.Compose([transforms.ToTensor(), |
mnist_transform = transforms.Compose([transforms.ToTensor(), |
||||
RandomResizedCrop(224), |
RandomResizedCrop(224), |
||||
Lambda(lambda x: x.repeat(3, 1, 1)), |
Lambda(lambda x: x.repeat(3, 1, 1)), |
||||
transforms.Normalize(mean=[0.5], std=[0.5])]) |
transforms.Normalize(mean=[0.5], std=[0.5])]) |
||||
train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=True) |
|
||||
eval_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=False) |
|
||||
# fake_transform = transforms.Compose([transforms.ToTensor(), |
|
||||
# RandomResizedCrop(224),]) |
|
||||
# train_data = get_dataset('fake', size=20, transform=fake_transform) |
|
||||
|
|
||||
op.change_before_train(10) |
|
||||
trainer = op.setup_trainer() |
|
||||
# my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False) |
|
||||
# op.setup_trainer() |
|
||||
|
|
||||
# trainer.add_callback() |
|
||||
# trainer.set_optimizer() |
|
||||
|
|
||||
# op.trainer.set_optimizer(my_optimimzer) |
|
||||
# trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml') |
|
||||
|
|
||||
# my_loss = nn.BCELoss() |
|
||||
# trainer.set_loss(my_loss, 'my_loss111') |
|
||||
# trainer.configs.save_to_yaml('chaned_loss_yaml.yaml') |
|
||||
# op.trainer._create_optimizer() |
|
||||
# op.trainer.set_optimizer() |
|
||||
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data) |
|
||||
# training_config.num_epoch = 3 |
|
||||
# op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') |
|
||||
|
|
||||
# op.save('./test_save') |
|
||||
# op.load('./test_save') |
|
||||
# new_out = op(towhee_img) |
|
||||
|
|
||||
# assert (new_out[0]!=old_out[0]).all() |
|
||||
|
|
||||
|
training_dataset = './data/widerface/train/label.txt' |
||||
|
|
||||
|
img_dim = 840 |
||||
|
rgb_mean = (104, 117, 123) |
||||
|
train_data = WiderFaceDetection( training_dataset,preproc(img_dim, rgb_mean)) |
||||
|
eval_data = WiderFaceDetection( training_dataset) |
||||
|
|
||||
|
cfg = {} |
||||
|
cfg['min_sizes'] = [[16, 32], [64, 128], [256, 512]] |
||||
|
cfg['steps'] = [8, 16, 32] |
||||
|
cfg['clip'] = False |
||||
|
cfg['loc_weight'] = 2.0 |
||||
|
|
||||
|
priorbox = PriorBox(cfg, image_size=(img_dim, img_dim)) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
priors = priorbox.forward() |
||||
|
|
||||
|
class RetinafaceTrainer(Trainer): |
||||
|
def __init__( |
||||
|
self, |
||||
|
model: nn.Module = None, |
||||
|
training_config = None, |
||||
|
train_dataset = None, |
||||
|
eval_dataset = None, |
||||
|
model_card = None, |
||||
|
train_dataloader = None, |
||||
|
eval_dataloader = None, |
||||
|
priors = priors |
||||
|
): |
||||
|
super(RetinafaceTrainer, self).__init__(model, training_config, train_dataset, eval_dataset, model_card, train_dataloader, eval_dataloader) |
||||
|
self.priors = priors |
||||
|
|
||||
|
def compute_loss(self, model: nn.Module, inputs): |
||||
|
model.train() |
||||
|
labels = inputs[1] |
||||
|
outputs = model(inputs[0]) |
||||
|
priors = send_to_device(self.priors, self.configs.device) |
||||
|
loss_l, loss_c, loss_landm = self.loss(outputs, priors, labels) |
||||
|
loss = cfg['loc_weight'] * loss_l + loss_c + loss_landm |
||||
|
return loss |
||||
|
|
||||
|
def compute_metric(self, model, inputs): |
||||
|
model.eval() |
||||
|
model.phase = "eval" |
||||
|
epoch_metric = None |
||||
|
labels = inputs[1] |
||||
|
model.cpu() |
||||
|
predicts = [] |
||||
|
gts = [] |
||||
|
|
||||
|
img = send_to_device(inputs[0], torch.device("cpu")) |
||||
|
outputs = model.inference(img[0]) |
||||
|
|
||||
|
npreds = outputs[0].shape[0] |
||||
|
ngts = labels[0].shape[0] |
||||
|
predicts.append(dict(boxes=outputs[0][:,:4], scores=outputs[0][:,4], labels=torch.IntTensor([0 for i in range(npreds)]))) |
||||
|
gts.append(dict(boxes=torch.vstack(labels)[:,:4], labels=torch.IntTensor([0 for i in range(ngts)]))) |
||||
|
|
||||
|
if self.metric is not None: |
||||
|
self.metric.update(predicts, gts) |
||||
|
epoch_metric = self.metric.compute()['map'].item() |
||||
|
return epoch_metric |
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def evaluate_step(self, model, inputs): |
||||
|
loss_metric, epoch_metric = self.update_metrics(model, inputs, 0, training=False) |
||||
|
step_logs = {"eval_step_loss": 0, "eval_epoch_loss": loss_metric, |
||||
|
"eval_epoch_metric": epoch_metric} |
||||
|
return step_logs |
||||
|
|
||||
|
def get_train_dataloader(self): |
||||
|
if self.configs.n_gpu > 1: |
||||
|
self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_data) |
||||
|
train_batch_sampler = torch.utils.data.BatchSampler( |
||||
|
self.train_sampler, 32, drop_last=True) |
||||
|
training_loader = torch.utils.data.DataLoader(train_data, |
||||
|
batch_sampler=train_batch_sampler, |
||||
|
num_workers=4, # self.configs.dataloader_num_workers, |
||||
|
pin_memory=True, |
||||
|
collate_fn=detection_collate |
||||
|
) |
||||
|
else: |
||||
|
training_loader = data.DataLoader(train_data, 32, shuffle=True, num_workers=4, collate_fn=detection_collate) |
||||
|
return training_loader |
||||
|
|
||||
|
#def get_eval_dataloader(self): |
||||
|
def get_eval_dataloader(self): |
||||
|
if self.configs.n_gpu > 1: |
||||
|
eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_data) |
||||
|
eval_batch_sampler = torch.utils.data.BatchSampler( |
||||
|
eval_sampler, 1, drop_last=True) |
||||
|
eval_loader = torch.utils.data.DataLoader(eval_data, |
||||
|
batch_sampler=eval_batch_sampler, |
||||
|
num_workers=1, # self.configs.dataloader_num_workers, |
||||
|
pin_memory=True, |
||||
|
collate_fn=detection_collate |
||||
|
) |
||||
|
else: |
||||
|
eval_loader = torch.utils.data.DataLoader(eval_data, 1, |
||||
|
num_workers=1, # self.configs.dataloader_num_workers, |
||||
|
) |
||||
|
return eval_loader |
||||
|
|
||||
|
|
||||
|
|
||||
|
trainer = RetinafaceTrainer(op.get_model(), training_config, train_data, eval_data, None, None, None, priors) |
||||
|
|
||||
|
op.trainer = trainer |
||||
|
|
||||
|
criterion = MultiBoxLoss(2, 0.35, True, 0, True, 7, 0.35, False, True, trainer.configs.device) |
||||
|
|
||||
|
trainer.set_loss(criterion, "multibox_loss") |
||||
|
op.train() |
||||
|
|
||||
|
Loading…
Reference in new issue