diff --git a/__init__.py b/__init__.py index 8516979..9379ed6 100644 --- a/__init__.py +++ b/__init__.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +# import os -# For requirements. -try: - import timm -except ModuleNotFoundError: - os.system('pip install timm') +# # For requirements. +# try: +# import timm +# except ModuleNotFoundError: +# os.system('pip install timm') diff --git a/pytorch/__init__.py b/pytorch/__init__.py index 37f5bd7..86316bc 100644 --- a/pytorch/__init__.py +++ b/pytorch/__init__.py @@ -11,3 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import os + +# For requirements. +try: + import timm +except ModuleNotFoundError: + os.system('pip install timm') \ No newline at end of file diff --git a/vit_image_embedding.py b/vit_image_embedding.py index 5d4706e..a416bba 100644 --- a/vit_image_embedding.py +++ b/vit_image_embedding.py @@ -20,12 +20,12 @@ from PIL import Image import torch from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform - +import numpy from towhee.operator import Operator -class VisionTransformerEmbeddingOperator(Operator): +class VitImageEmbedding(Operator): """ Embedding extractor using ViT. Args: @@ -40,13 +40,13 @@ class VisionTransformerEmbeddingOperator(Operator): super().__init__() sys.path.append(str(Path(__file__).parent)) if framework == 'pytorch': - from vit_embedding.pytorch.model import Model + from pytorch.model import Model self.model = Model(model_name, weights_path) config = resolve_data_config({}, model=self.model._model) self.tfms = create_transform(**config) - def __call__(self, img_path: str) -> NamedTuple('Outputs', [('embedding', torch.Tensor)]): - Outputs = NamedTuple('Outputs', [('embedding', torch.Tensor)]) + def __call__(self, img_path: str) -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): + Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) img = self.tfms(Image.open(img_path)).unsqueeze(0) features = self.model(img) return Outputs(features.flatten().detach().numpy())