logo
Browse Source

matching embedding test output with towhee op output

training
zhang chen 3 years ago
parent
commit
901594b554
  1. 6
      pytorch/model.py

6
pytorch/model.py

@ -28,8 +28,14 @@ class Model():
super().__init__() super().__init__()
model_func = getattr(torchvision.models, model_name) model_func = getattr(torchvision.models, model_name)
self._model = model_func(pretrained=True) self._model = model_func(pretrained=True)
state_dict = None
if model_name == 'resnet101':
state_dict = torch.hub.load_state_dict_from_url( state_dict = torch.hub.load_state_dict_from_url(
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth') 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth')
if model_name == 'resnet50':
state_dict = torch.hub.load_state_dict_from_url(
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth')
if state_dict:
self._model.load_state_dict(state_dict) self._model.load_state_dict(state_dict)
self._model.fc = torch.nn.Identity() self._model.fc = torch.nn.Identity()

Loading…
Cancel
Save