Browse Source
matching embedding test output with towhee op output
training
zhang chen
3 years ago
1 changed files with
9 additions and
3 deletions
-
pytorch/model.py
|
|
@ -28,9 +28,15 @@ class Model(): |
|
|
|
super().__init__() |
|
|
|
model_func = getattr(torchvision.models, model_name) |
|
|
|
self._model = model_func(pretrained=True) |
|
|
|
state_dict = torch.hub.load_state_dict_from_url( |
|
|
|
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth') |
|
|
|
self._model.load_state_dict(state_dict) |
|
|
|
state_dict = None |
|
|
|
if model_name == 'resnet101': |
|
|
|
state_dict = torch.hub.load_state_dict_from_url( |
|
|
|
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth') |
|
|
|
if model_name == 'resnet50': |
|
|
|
state_dict = torch.hub.load_state_dict_from_url( |
|
|
|
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth') |
|
|
|
if state_dict: |
|
|
|
self._model.load_state_dict(state_dict) |
|
|
|
|
|
|
|
self._model.fc = torch.nn.Identity() |
|
|
|
self._model.eval() |
|
|
|