diff --git a/pytorch/multibox_loss.py b/pytorch/multibox_loss.py index 15e78ee..9d2dfdc 100644 --- a/pytorch/multibox_loss.py +++ b/pytorch/multibox_loss.py @@ -73,7 +73,7 @@ class MultiBoxLoss(nn.Module): conf_t = conf_t.to(self.device) landm_t = landm_t.to(self.device) - zeros = torch.tensor(0).cuda() + zeros = torch.tensor(0).to(self.device) # landm Loss (Smooth L1) # Shape: [batch,num_priors,10] pos1 = conf_t > zeros diff --git a/train.py b/train.py index 31d6df5..427efbd 100644 --- a/train.py +++ b/train.py @@ -1,5 +1,4 @@ import numpy as np -import ipdb import torch from torch.optim import AdamW from torch.utils import data @@ -82,18 +81,18 @@ if __name__ == '__main__': model.phase = "eval" epoch_metric = None labels = inputs[1] - model.cpu() predicts = [] gts = [] - img = send_to_device(inputs[0], torch.device("cpu")) + img = inputs[0] outputs = model.inference(img[0]) npreds = outputs[0].shape[0] ngts = labels[0].shape[0] - predicts.append(dict(boxes=outputs[0][:,:4], scores=outputs[0][:,4], labels=torch.IntTensor([0 for i in range(npreds)]))) - gts.append(dict(boxes=torch.vstack(labels)[:,:4], labels=torch.IntTensor([0 for i in range(ngts)]))) - + predicts.append(dict(boxes=outputs[0][:,:4], scores=outputs[0][:,4], labels=torch.IntTensor([0 for i in range(npreds)]).to(trainer.configs.device))) + if isinstance(labels, list): + labels = torch.vstack(labels) + gts.append(dict(boxes=labels.reshape(-1,15)[:,:4], labels=torch.IntTensor([0 for i in range(ngts)]).to(trainer.configs.device))) if self.metric is not None: self.metric.update(predicts, gts) epoch_metric = self.metric.compute()['map'].item() @@ -121,7 +120,6 @@ if __name__ == '__main__': training_loader = data.DataLoader(train_data, 32, shuffle=True, num_workers=4, collate_fn=detection_collate) return training_loader - #def get_eval_dataloader(self): def get_eval_dataloader(self): if self.configs.n_gpu > 1: eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_data) @@ -139,8 +137,6 @@ if __name__ == '__main__': ) return eval_loader - - trainer = RetinafaceTrainer(op.get_model(), training_config, train_data, eval_data, None, None, None, priors) op.trainer = trainer