diff --git a/pytorch/model.py b/pytorch/model.py index ea0e408..6a97391 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -28,6 +28,7 @@ class Model(): super().__init__() model_func = getattr(torchvision.models, model_name) self._model = model_func(pretrained=True) + self._model.fc = torch.nn.Identify() self._model.eval() def __call__(self, img_tensor: torch.Tensor):