From 901594b55444051b93aa187b8247f96b00068293 Mon Sep 17 00:00:00 2001 From: zhang chen Date: Mon, 20 Dec 2021 14:48:31 +0800 Subject: [PATCH] matching embedding test output with towhee op output --- pytorch/model.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/pytorch/model.py b/pytorch/model.py index 125ec3a..3a32524 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -28,9 +28,15 @@ class Model(): super().__init__() model_func = getattr(torchvision.models, model_name) self._model = model_func(pretrained=True) - state_dict = torch.hub.load_state_dict_from_url( - 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth') - self._model.load_state_dict(state_dict) + state_dict = None + if model_name == 'resnet101': + state_dict = torch.hub.load_state_dict_from_url( + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth') + if model_name == 'resnet50': + state_dict = torch.hub.load_state_dict_from_url( + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth') + if state_dict: + self._model.load_state_dict(state_dict) self._model.fc = torch.nn.Identity() self._model.eval()