logo
Browse Source

matching embedding test output with towhee op output

training
zhang chen 3 years ago
parent
commit
901594b554
  1. 12
      pytorch/model.py

12
pytorch/model.py

@ -28,9 +28,15 @@ class Model():
super().__init__() super().__init__()
model_func = getattr(torchvision.models, model_name) model_func = getattr(torchvision.models, model_name)
self._model = model_func(pretrained=True) self._model = model_func(pretrained=True)
state_dict = torch.hub.load_state_dict_from_url(
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth')
self._model.load_state_dict(state_dict)
state_dict = None
if model_name == 'resnet101':
state_dict = torch.hub.load_state_dict_from_url(
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth')
if model_name == 'resnet50':
state_dict = torch.hub.load_state_dict_from_url(
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth')
if state_dict:
self._model.load_state_dict(state_dict)
self._model.fc = torch.nn.Identity() self._model.fc = torch.nn.Identity()
self._model.eval() self._model.eval()

Loading…
Cancel
Save