logo
Browse Source

Update

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 3 years ago
parent
commit
42c35af18d
  1. 12
      __init__.py
  2. 8
      pytorch/__init__.py
  3. 10
      vit_image_embedding.py

12
__init__.py

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
# import os
# For requirements.
try:
import timm
except ModuleNotFoundError:
os.system('pip install timm')
# # For requirements.
# try:
# import timm
# except ModuleNotFoundError:
# os.system('pip install timm')

8
pytorch/__init__.py

@ -11,3 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
# For requirements.
try:
import timm
except ModuleNotFoundError:
os.system('pip install timm')

10
vit_image_embedding.py

@ -20,12 +20,12 @@ from PIL import Image
import torch import torch
from timm.data import resolve_data_config from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform from timm.data.transforms_factory import create_transform
import numpy
from towhee.operator import Operator from towhee.operator import Operator
class VisionTransformerEmbeddingOperator(Operator):
class VitImageEmbedding(Operator):
""" """
Embedding extractor using ViT. Embedding extractor using ViT.
Args: Args:
@ -40,13 +40,13 @@ class VisionTransformerEmbeddingOperator(Operator):
super().__init__() super().__init__()
sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent))
if framework == 'pytorch': if framework == 'pytorch':
from vit_embedding.pytorch.model import Model
from pytorch.model import Model
self.model = Model(model_name, weights_path) self.model = Model(model_name, weights_path)
config = resolve_data_config({}, model=self.model._model) config = resolve_data_config({}, model=self.model._model)
self.tfms = create_transform(**config) self.tfms = create_transform(**config)
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('embedding', torch.Tensor)]):
Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)])
def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
img = self.tfms(Image.open(img_path)).unsqueeze(0) img = self.tfms(Image.open(img_path)).unsqueeze(0)
features = self.model(img) features = self.model(img)
return Outputs(features.flatten().detach().numpy()) return Outputs(features.flatten().detach().numpy())

Loading…
Cancel
Save