diff --git a/pytorch/model.py b/pytorch/model.py index 6a97391..5a3a050 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -28,7 +28,7 @@ class Model(): super().__init__() model_func = getattr(torchvision.models, model_name) self._model = model_func(pretrained=True) - self._model.fc = torch.nn.Identify() + self._model.fc = torch.nn.Identity() self._model.eval() def __call__(self, img_tensor: torch.Tensor):