logo
Browse Source

make multi-gpu works.

Signed-off-by: wxywb <xy.wang@zilliz.com>
training
wxywb 3 years ago
parent
commit
e8d553f57b
  1. 2
      pytorch/multibox_loss.py
  2. 14
      train.py

2
pytorch/multibox_loss.py

@ -73,7 +73,7 @@ class MultiBoxLoss(nn.Module):
conf_t = conf_t.to(self.device)
landm_t = landm_t.to(self.device)
zeros = torch.tensor(0).cuda()
zeros = torch.tensor(0).to(self.device)
# landm Loss (Smooth L1)
# Shape: [batch,num_priors,10]
pos1 = conf_t > zeros

14
train.py

@ -1,5 +1,4 @@
import numpy as np
import ipdb
import torch
from torch.optim import AdamW
from torch.utils import data
@ -82,18 +81,18 @@ if __name__ == '__main__':
model.phase = "eval"
epoch_metric = None
labels = inputs[1]
model.cpu()
predicts = []
gts = []
img = send_to_device(inputs[0], torch.device("cpu"))
img = inputs[0]
outputs = model.inference(img[0])
npreds = outputs[0].shape[0]
ngts = labels[0].shape[0]
predicts.append(dict(boxes=outputs[0][:,:4], scores=outputs[0][:,4], labels=torch.IntTensor([0 for i in range(npreds)])))
gts.append(dict(boxes=torch.vstack(labels)[:,:4], labels=torch.IntTensor([0 for i in range(ngts)])))
predicts.append(dict(boxes=outputs[0][:,:4], scores=outputs[0][:,4], labels=torch.IntTensor([0 for i in range(npreds)]).to(trainer.configs.device)))
if isinstance(labels, list):
labels = torch.vstack(labels)
gts.append(dict(boxes=labels.reshape(-1,15)[:,:4], labels=torch.IntTensor([0 for i in range(ngts)]).to(trainer.configs.device)))
if self.metric is not None:
self.metric.update(predicts, gts)
epoch_metric = self.metric.compute()['map'].item()
@ -121,7 +120,6 @@ if __name__ == '__main__':
training_loader = data.DataLoader(train_data, 32, shuffle=True, num_workers=4, collate_fn=detection_collate)
return training_loader
#def get_eval_dataloader(self):
def get_eval_dataloader(self):
if self.configs.n_gpu > 1:
eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_data)
@ -139,8 +137,6 @@ if __name__ == '__main__':
)
return eval_loader
trainer = RetinafaceTrainer(op.get_model(), training_config, train_data, eval_data, None, None, None, priors)
op.trainer = trainer

Loading…
Cancel
Save