logo
Browse Source

modify to training

main
zhang chen 3 years ago
parent
commit
ca90e87322
  1. 21
      pytorch/model.py
  2. 14
      vit_image_embedding.py

21
pytorch/model.py

@ -14,6 +14,8 @@
import torch import torch
from torch.nn import Linear
from torch import nn
import timm import timm
@ -21,19 +23,20 @@ class Model():
""" """
PyTorch model class PyTorch model class
""" """
def __init__(self, model_name: str, weights_path: str):
def __init__(self, model_name: str, weights_path: str, num_classes=1000):
super().__init__() super().__init__()
if weights_path: if weights_path:
self._model = timm.create_model(model_name, checkpoint_path=weights_path, num_classes=0)
self._model = timm.create_model(model_name, checkpoint_path=weights_path, num_classes=num_classes)
else: else:
self._model = timm.create_model(model_name, pretrained=True, num_classes=0)
self._model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)
self._model.eval() self._model.eval()
def __call__(self, img_tensor: torch.Tensor): def __call__(self, img_tensor: torch.Tensor):
return self._model(img_tensor)
self._model.eval()
features = self._model.forward_features(img_tensor)
if features.dim() == 4: # if the shape of feature map is [N, C, H, W], where H > 1 and W > 1
global_pool = nn.AdaptiveAvgPool2d(1)
features = global_pool(features)
return features.flatten().detach().numpy()
def train(self):
"""
For training model
"""
pass

14
vit_image_embedding.py

@ -18,10 +18,11 @@ from typing import NamedTuple
from pathlib import Path from pathlib import Path
from PIL import Image from PIL import Image
import torch import torch
from torch import nn as nn
import numpy import numpy
import os import os
from towhee.operator import Operator
from towhee.operator import Operator, NNOperator
from timm.data import resolve_data_config from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform from timm.data.transforms_factory import create_transform
from towhee.utils.pil_utils import to_pil from towhee.utils.pil_utils import to_pil
@ -29,7 +30,7 @@ from towhee.utils.pil_utils import to_pil
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
class VitImageEmbedding(Operator):
class VitImageEmbedding(NNOperator):
""" """
Embedding extractor using ViT. Embedding extractor using ViT.
Args: Args:
@ -39,7 +40,7 @@ class VitImageEmbedding(Operator):
Path to local weights. Path to local weights.
""" """
def __init__(self, model_name: str = 'vit_large_patch16_224',
def __init__(self, model_name: str = 'vit_large_patch16_224', num_classes: int = 1000,
framework: str = 'pytorch', weights_path: str = None) -> None: framework: str = 'pytorch', weights_path: str = None) -> None:
super().__init__() super().__init__()
if framework == 'pytorch': if framework == 'pytorch':
@ -49,7 +50,7 @@ class VitImageEmbedding(Operator):
spec = importlib.util.spec_from_file_location(opname, path) spec = importlib.util.spec_from_file_location(opname, path)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) spec.loader.exec_module(module)
self.model = module.Model(model_name, weights_path)
self.model = module.Model(model_name, weights_path, num_classes=num_classes)
config = resolve_data_config({}, model=self.model._model) config = resolve_data_config({}, model=self.model._model)
self.tfms = create_transform(**config) self.tfms = create_transform(**config)
@ -57,4 +58,7 @@ class VitImageEmbedding(Operator):
img = self.tfms(to_pil(image)).unsqueeze(0) img = self.tfms(to_pil(image)).unsqueeze(0)
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
features = self.model(img) features = self.model(img)
return Outputs(features.flatten().detach().numpy())
return Outputs(features)
def get_model(self) -> nn.Module:
return self.model._model
Loading…
Cancel
Save