logo
Browse Source

refactor the training code.

training
wxywb 3 years ago
parent
commit
cabafc2339
  1. 49
      README.md
  2. 237
      pytorch/data_augment.py
  3. 4
      pytorch/main.py
  4. 14
      pytorch/multibox_loss.py
  5. 34
      pytorch/prior_box.py
  6. BIN
      pytorch/pytorch_retinaface_mobilenet_widerface.pth
  7. 102
      pytorch/wider_face.py
  8. 7
      retinaface_face_detection.py
  9. 7
      retinaface_training_yaml.yaml
  10. 194
      train.py

49
README.md

@ -1,2 +1,49 @@
# retinaface-face-detection
# Retinaface Face Detection (Pytorch)
Authors: wxywb
## Overview
This opertator detects faces in the images by using RetinaFace Detector[1]. It will returns the locations, five keypoints and the cropped face images from origin images. This repo is a adopataion from [2].
## Interface
```python
__call__(self, image: 'towhee.types.Image')
```
**Args:**
- image:
- the image to detect faces.
- supported types: towhee.types.Image
**Returns:**
The Operator returns a tupe Tuple[('boxes', numpy.ndarray), ('keypoints', numpy.ndarray), ('cropped_imgs', numpy.ndarray)])] containing following fields:
- boxes:
- boxes of human faces.
- data type: `numpy.ndarray`
- shape: (num_faces, 4)
- keypoints:
- keypoints of human faces.
- data type: `numpy.ndarray`
- shape: (10)
- cropped_imgs:
- cropped face images.
- data type: `numpy.ndarray`
- shape: (h, w, 3)
## Requirements
You can get the required python package by [requirements.txt](./requirements.txt).
## How it works
The `towhee/retinaface-face-detection` Operators implents the function of face detection. The example pipeline can be found in [face-embedding-retinaface-inceptionresnetv1](https://towhee.io/towhee/face-embedding-retinaface-inceptionresnetv1)
## Reference
[1]. https://arxiv.org/abs/1905.00641
[2]. https://github.com/biubug6/Pytorch_Retinaface

237
pytorch/data_augment.py

@ -0,0 +1,237 @@
import cv2
import numpy as np
import random
from .box_utils import matrix_iof
def _crop(image, boxes, labels, landm, img_dim):
height, width, _ = image.shape
pad_image_flag = True
for _ in range(250):
"""
if random.uniform(0, 1) <= 0.2:
scale = 1.0
else:
scale = random.uniform(0.3, 1.0)
"""
PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0]
scale = random.choice(PRE_SCALES)
short_side = min(width, height)
w = int(scale * short_side)
h = w
if width == w:
l = 0
else:
l = random.randrange(width - w)
if height == h:
t = 0
else:
t = random.randrange(height - h)
roi = np.array((l, t, l + w, t + h))
value = matrix_iof(boxes, roi[np.newaxis])
flag = (value >= 1)
if not flag.any():
continue
centers = (boxes[:, :2] + boxes[:, 2:]) / 2
mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1)
boxes_t = boxes[mask_a].copy()
labels_t = labels[mask_a].copy()
landms_t = landm[mask_a].copy()
landms_t = landms_t.reshape([-1, 5, 2])
if boxes_t.shape[0] == 0:
continue
image_t = image[roi[1]:roi[3], roi[0]:roi[2]]
boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2])
boxes_t[:, :2] -= roi[:2]
boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:])
boxes_t[:, 2:] -= roi[:2]
# landm
landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2]
landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0]))
landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2])
landms_t = landms_t.reshape([-1, 10])
# make sure that the cropped image contains at least one face > 16 pixel at training image scale
b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim
b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim
mask_b = np.minimum(b_w_t, b_h_t) > 0.0
boxes_t = boxes_t[mask_b]
labels_t = labels_t[mask_b]
landms_t = landms_t[mask_b]
if boxes_t.shape[0] == 0:
continue
pad_image_flag = False
return image_t, boxes_t, labels_t, landms_t, pad_image_flag
return image, boxes, labels, landm, pad_image_flag
def _distort(image):
def _convert(image, alpha=1, beta=0):
tmp = image.astype(float) * alpha + beta
tmp[tmp < 0] = 0
tmp[tmp > 255] = 255
image[:] = tmp
image = image.copy()
if random.randrange(2):
#brightness distortion
if random.randrange(2):
_convert(image, beta=random.uniform(-32, 32))
#contrast distortion
if random.randrange(2):
_convert(image, alpha=random.uniform(0.5, 1.5))
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
#saturation distortion
if random.randrange(2):
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
#hue distortion
if random.randrange(2):
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
tmp %= 180
image[:, :, 0] = tmp
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
else:
#brightness distortion
if random.randrange(2):
_convert(image, beta=random.uniform(-32, 32))
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
#saturation distortion
if random.randrange(2):
_convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5))
#hue distortion
if random.randrange(2):
tmp = image[:, :, 0].astype(int) + random.randint(-18, 18)
tmp %= 180
image[:, :, 0] = tmp
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
#contrast distortion
if random.randrange(2):
_convert(image, alpha=random.uniform(0.5, 1.5))
return image
def _expand(image, boxes, fill, p):
if random.randrange(2):
return image, boxes
height, width, depth = image.shape
scale = random.uniform(1, p)
w = int(scale * width)
h = int(scale * height)
left = random.randint(0, w - width)
top = random.randint(0, h - height)
boxes_t = boxes.copy()
boxes_t[:, :2] += (left, top)
boxes_t[:, 2:] += (left, top)
expand_image = np.empty(
(h, w, depth),
dtype=image.dtype)
expand_image[:, :] = fill
expand_image[top:top + height, left:left + width] = image
image = expand_image
return image, boxes_t
def _mirror(image, boxes, landms):
_, width, _ = image.shape
if random.randrange(2):
image = image[:, ::-1]
boxes = boxes.copy()
boxes[:, 0::2] = width - boxes[:, 2::-2]
# landm
landms = landms.copy()
landms = landms.reshape([-1, 5, 2])
landms[:, :, 0] = width - landms[:, :, 0]
tmp = landms[:, 1, :].copy()
landms[:, 1, :] = landms[:, 0, :]
landms[:, 0, :] = tmp
tmp1 = landms[:, 4, :].copy()
landms[:, 4, :] = landms[:, 3, :]
landms[:, 3, :] = tmp1
landms = landms.reshape([-1, 10])
return image, boxes, landms
def _pad_to_square(image, rgb_mean, pad_image_flag):
if not pad_image_flag:
return image
height, width, _ = image.shape
long_side = max(width, height)
image_t = np.empty((long_side, long_side, 3), dtype=image.dtype)
image_t[:, :] = rgb_mean
image_t[0:0 + height, 0:0 + width] = image
return image_t
def _resize_subtract_mean(image, insize, rgb_mean):
interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4]
interp_method = interp_methods[random.randrange(5)]
image = cv2.resize(image, (insize, insize), interpolation=interp_method)
image = image.astype(np.float32)
image -= rgb_mean
return image.transpose(2, 0, 1)
class preproc(object):
def __init__(self, img_dim, rgb_means):
self.img_dim = img_dim
self.rgb_means = rgb_means
def __call__(self, image, targets):
assert targets.shape[0] > 0, "this image does not have gt"
boxes = targets[:, :4].copy()
labels = targets[:, -1].copy()
landm = targets[:, 4:-1].copy()
image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim)
image_t = _distort(image_t)
image_t = _pad_to_square(image_t,self.rgb_means, pad_image_flag)
image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t)
height, width, _ = image_t.shape
image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means)
boxes_t[:, 0::2] /= width
boxes_t[:, 1::2] /= height
landm_t[:, 0::2] /= width
landm_t[:, 1::2] /= height
labels_t = np.expand_dims(labels_t, 1)
targets_t = np.hstack((boxes_t, landm_t, labels_t))
return image_t, targets_t

4
pytorch/main.py

@ -0,0 +1,4 @@
import torch
torch.load('pytorch_retinaface_mobilenet_widerface.pth')

14
pytorch/multibox_loss.py

@ -27,7 +27,7 @@ class MultiBoxLoss(nn.Module):
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
"""
def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target, use_gpu=False, device=None):
super(MultiBoxLoss, self).__init__()
self.num_classes = num_classes
self.threshold = overlap_thresh
@ -38,7 +38,8 @@ class MultiBoxLoss(nn.Module):
self.negpos_ratio = neg_pos
self.neg_overlap = neg_overlap
self.variance = [0.1, 0.2]
self.GPU = False
self.GPU = use_gpu
self.device = device
def forward(self, predictions, priors, targets):
"""Multibox Loss
@ -68,9 +69,9 @@ class MultiBoxLoss(nn.Module):
defaults = priors.data
match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx)
if self.GPU:
loc_t = loc_t.cuda()
conf_t = conf_t.cuda()
landm_t = landm_t.cuda()
loc_t = loc_t.to(self.device)
conf_t = conf_t.to(self.device)
landm_t = landm_t.to(self.device)
zeros = torch.tensor(0).cuda()
# landm Loss (Smooth L1)
@ -114,7 +115,8 @@ class MultiBoxLoss(nn.Module):
targets_weighted = conf_t[(pos+neg).gt(0)]
loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N N = max(num_pos.data.sum().float(), 1)
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
N = max(num_pos.data.sum().float(), 1)
loss_l /= N
loss_c /= N
loss_landm /= N1

34
pytorch/prior_box.py

@ -0,0 +1,34 @@
import torch
from itertools import product as product
import numpy as np
from math import ceil
class PriorBox(object):
def __init__(self, cfg, image_size=None, phase='train'):
super(PriorBox, self).__init__()
self.min_sizes = cfg['min_sizes']
self.steps = cfg['steps']
self.clip = cfg['clip']
self.image_size = image_size
self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
self.name = "s"
def forward(self):
anchors = []
for k, f in enumerate(self.feature_maps):
min_sizes = self.min_sizes[k]
for i, j in product(range(f[0]), range(f[1])):
for min_size in min_sizes:
s_kx = min_size / self.image_size[1]
s_ky = min_size / self.image_size[0]
dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
for cy, cx in product(dense_cy, dense_cx):
anchors += [cx, cy, s_kx, s_ky]
# back to torch land
output = torch.Tensor(anchors).view(-1, 4)
if self.clip:
output.clamp_(max=1, min=0)
return output

BIN
pytorch/pytorch_retinaface_mobilenet_widerface.pth

Binary file not shown.

102
pytorch/wider_face.py

@ -0,0 +1,102 @@
import os
import os.path
import sys
import torch
import torch.utils.data as data
import cv2
import numpy as np
class WiderFaceDetection(data.Dataset):
def __init__(self, txt_path, preproc=None):
self.preproc = preproc
self.imgs_path = []
self.words = []
f = open(txt_path,'r')
lines = f.readlines()
isFirst = True
labels = []
for line in lines:
line = line.rstrip()
if line.startswith('#'):
if isFirst is True:
isFirst = False
else:
labels_copy = labels.copy()
self.words.append(labels_copy)
labels.clear()
path = line[2:]
path = txt_path.replace('label.txt','images/') + path
self.imgs_path.append(path)
else:
line = line.split(' ')
label = [float(x) for x in line]
labels.append(label)
self.words.append(labels)
def __len__(self):
return len(self.imgs_path)
def __getitem__(self, index):
img = cv2.imread(self.imgs_path[index])
height, width, _ = img.shape
labels = self.words[index]
annotations = np.zeros((0, 15))
if len(labels) == 0:
return annotations
for idx, label in enumerate(labels):
annotation = np.zeros((1, 15))
# bbox
annotation[0, 0] = label[0] # x1
annotation[0, 1] = label[1] # y1
annotation[0, 2] = label[0] + label[2] # x2
annotation[0, 3] = label[1] + label[3] # y2
# landmarks
annotation[0, 4] = label[4] # l0_x
annotation[0, 5] = label[5] # l0_y
annotation[0, 6] = label[7] # l1_x
annotation[0, 7] = label[8] # l1_y
annotation[0, 8] = label[10] # l2_x
annotation[0, 9] = label[11] # l2_y
annotation[0, 10] = label[13] # l3_x
annotation[0, 11] = label[14] # l3_y
annotation[0, 12] = label[16] # l4_x
annotation[0, 13] = label[17] # l4_y
if (annotation[0, 4]<0):
annotation[0, 14] = -1
else:
annotation[0, 14] = 1
annotations = np.append(annotations, annotation, axis=0)
target = np.array(annotations)
if self.preproc is not None:
img, target = self.preproc(img, target)
return torch.from_numpy(img), target
def detection_collate(batch):
"""Custom collate fn for dealing with batches of images that have a different
number of associated object annotations (bounding boxes).
Arguments:
batch: (tuple) A tuple of tensor images and lists of annotations
Return:
A tuple containing:
1) (tensor) batch of images stacked on their 0 dim
2) (list of tensors) annotations for a given image are stacked on 0 dim
"""
targets = []
imgs = []
for _, sample in enumerate(batch):
for _, tup in enumerate(sample):
if torch.is_tensor(tup):
imgs.append(tup)
elif isinstance(tup, type(np.empty(0))):
annos = torch.from_numpy(tup).float()
targets.append(annos)
return (torch.stack(imgs, 0), targets)

7
retinaface_face_detection.py

@ -15,6 +15,7 @@
from typing import NamedTuple, List
from PIL import Image
import torch
from torch import nn
from torchvision import transforms
import sys
import towhee
@ -23,11 +24,12 @@ import numpy
from towhee.operator import Operator
from towhee.utils.pil_utils import to_pil
from towhee.operator import NNOperator
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import os
class RetinafaceFaceDetection(Operator):
class RetinafaceFaceDetection(NNOperator):
"""
Embedding extractor using efficientnet.
Args:
@ -71,3 +73,6 @@ class RetinafaceFaceDetection(Operator):
output = Outputs(bboxes[i], keypoints[i,:], croppeds[i])
outputs.append(output)
return outputs
def get_model(self) -> nn.Module:
return self.model._model

7
retinaface_training_yaml.yaml

@ -1,13 +1,14 @@
device:
device_str: null
n_gpu: -1
device_str: cuda
n_gpu: 2
sync_bn: true
metrics:
metric: Accuracy
metric: MeanAveragePrecision
train:
batch_size: 32
overwrite_output_dir: true
epoch_num: 2
eval_strategy: eval_epoch
learning:
optimizer:
name_: SGD

194
train.py

@ -1,84 +1,152 @@
import numpy as np
import ipdb
import torch
from torch.optim import AdamW
from torch.utils import data
from torchvision import transforms
from torchvision.transforms import RandomResizedCrop, Lambda
from towhee.trainer.modelcard import ModelCard
#from towhee.trainer.modelcard import ModelCard
from towhee.trainer.training_config import TrainingConfig
from towhee.trainer.dataset import get_dataset
from resnet_image_embedding import ResnetImageEmbedding
from towhee.trainer.trainer import Trainer
from retinaface_face_detection import RetinafaceFaceDetection
from towhee.types import Image
from towhee.trainer.training_config import dump_default_yaml
from towhee.trainer.utils.trainer_utils import send_to_device
from PIL import Image as PILImage
from timm.models.resnet import ResNet
from torch import nn
from pytorch.wider_face import WiderFaceDetection, detection_collate
from pytorch.multibox_loss import MultiBoxLoss
from pytorch.data_augment import preproc
from pytorch.prior_box import PriorBox
if __name__ == '__main__':
dump_default_yaml(yaml_path='default_config.yaml')
# img = torch.rand([1, 3, 224, 224])
img_path = './ILSVRC2012_val_00049771.JPEG'
# # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png')
img = PILImage.open(img_path)
img_bytes = img.tobytes()
img_width = img.width
img_height = img.height
img_channel = len(img.split())
img_mode = img.mode
img_array = np.array(img)
array_size = np.array(img).shape
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array)
op = ResnetImageEmbedding('resnet34')
# op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data")
# old_out = op(towhee_img)
# print(old_out.feature_vector[0])
op = RetinafaceFaceDetection()
training_config = TrainingConfig()
yaml_path = 'resnet_training_yaml.yaml'
# dump_default_yaml(yaml_path=yaml_path)
yaml_path = 'retinaface_training_yaml.yaml' # 'resnet_training_yaml.yaml'
training_config.load_from_yaml(yaml_path)
# output_dir='./temp_output',
# overwrite_output_dir=True,
# epoch_num=2,
# per_gpu_train_batch_size=16,
# prediction_loss_only=True,
# metric='Accuracy'
# # device_str='cuda',
# # n_gpu=4
# )
mnist_transform = transforms.Compose([transforms.ToTensor(),
RandomResizedCrop(224),
Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize(mean=[0.5], std=[0.5])])
train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)
eval_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)
# fake_transform = transforms.Compose([transforms.ToTensor(),
# RandomResizedCrop(224),])
# train_data = get_dataset('fake', size=20, transform=fake_transform)
op.change_before_train(10)
trainer = op.setup_trainer()
# my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False)
# op.setup_trainer()
# trainer.add_callback()
# trainer.set_optimizer()
# op.trainer.set_optimizer(my_optimimzer)
# trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml')
# my_loss = nn.BCELoss()
# trainer.set_loss(my_loss, 'my_loss111')
# trainer.configs.save_to_yaml('chaned_loss_yaml.yaml')
# op.trainer._create_optimizer()
# op.trainer.set_optimizer()
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)
# training_config.num_epoch = 3
# op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2')
# op.save('./test_save')
# op.load('./test_save')
# new_out = op(towhee_img)
# assert (new_out[0]!=old_out[0]).all()
training_dataset = './data/widerface/train/label.txt'
img_dim = 840
rgb_mean = (104, 117, 123)
train_data = WiderFaceDetection( training_dataset,preproc(img_dim, rgb_mean))
eval_data = WiderFaceDetection( training_dataset)
cfg = {}
cfg['min_sizes'] = [[16, 32], [64, 128], [256, 512]]
cfg['steps'] = [8, 16, 32]
cfg['clip'] = False
cfg['loc_weight'] = 2.0
priorbox = PriorBox(cfg, image_size=(img_dim, img_dim))
with torch.no_grad():
priors = priorbox.forward()
class RetinafaceTrainer(Trainer):
def __init__(
self,
model: nn.Module = None,
training_config = None,
train_dataset = None,
eval_dataset = None,
model_card = None,
train_dataloader = None,
eval_dataloader = None,
priors = priors
):
super(RetinafaceTrainer, self).__init__(model, training_config, train_dataset, eval_dataset, model_card, train_dataloader, eval_dataloader)
self.priors = priors
def compute_loss(self, model: nn.Module, inputs):
model.train()
labels = inputs[1]
outputs = model(inputs[0])
priors = send_to_device(self.priors, self.configs.device)
loss_l, loss_c, loss_landm = self.loss(outputs, priors, labels)
loss = cfg['loc_weight'] * loss_l + loss_c + loss_landm
return loss
def compute_metric(self, model, inputs):
model.eval()
model.phase = "eval"
epoch_metric = None
labels = inputs[1]
model.cpu()
predicts = []
gts = []
img = send_to_device(inputs[0], torch.device("cpu"))
outputs = model.inference(img[0])
npreds = outputs[0].shape[0]
ngts = labels[0].shape[0]
predicts.append(dict(boxes=outputs[0][:,:4], scores=outputs[0][:,4], labels=torch.IntTensor([0 for i in range(npreds)])))
gts.append(dict(boxes=torch.vstack(labels)[:,:4], labels=torch.IntTensor([0 for i in range(ngts)])))
if self.metric is not None:
self.metric.update(predicts, gts)
epoch_metric = self.metric.compute()['map'].item()
return epoch_metric
@torch.no_grad()
def evaluate_step(self, model, inputs):
loss_metric, epoch_metric = self.update_metrics(model, inputs, 0, training=False)
step_logs = {"eval_step_loss": 0, "eval_epoch_loss": loss_metric,
"eval_epoch_metric": epoch_metric}
return step_logs
def get_train_dataloader(self):
if self.configs.n_gpu > 1:
self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
train_batch_sampler = torch.utils.data.BatchSampler(
self.train_sampler, 32, drop_last=True)
training_loader = torch.utils.data.DataLoader(train_data,
batch_sampler=train_batch_sampler,
num_workers=4, # self.configs.dataloader_num_workers,
pin_memory=True,
collate_fn=detection_collate
)
else:
training_loader = data.DataLoader(train_data, 32, shuffle=True, num_workers=4, collate_fn=detection_collate)
return training_loader
#def get_eval_dataloader(self):
def get_eval_dataloader(self):
if self.configs.n_gpu > 1:
eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_data)
eval_batch_sampler = torch.utils.data.BatchSampler(
eval_sampler, 1, drop_last=True)
eval_loader = torch.utils.data.DataLoader(eval_data,
batch_sampler=eval_batch_sampler,
num_workers=1, # self.configs.dataloader_num_workers,
pin_memory=True,
collate_fn=detection_collate
)
else:
eval_loader = torch.utils.data.DataLoader(eval_data, 1,
num_workers=1, # self.configs.dataloader_num_workers,
)
return eval_loader
trainer = RetinafaceTrainer(op.get_model(), training_config, train_data, eval_data, None, None, None, priors)
op.trainer = trainer
criterion = MultiBoxLoss(2, 0.35, True, 0, True, 7, 0.35, False, True, trainer.configs.device)
trainer.set_loss(criterion, "multibox_loss")
op.train()

Loading…
Cancel
Save