Browse Source
Update
Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
2 changed files with
8 additions and
8 deletions
-
__init__.py
-
vit_image_embedding.py
|
|
@ -12,10 +12,10 @@ |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
import os |
|
|
|
# import os |
|
|
|
|
|
|
|
# For requirements. |
|
|
|
try: |
|
|
|
import timm |
|
|
|
except ModuleNotFoundError: |
|
|
|
os.system('pip install timm') |
|
|
|
# # For requirements. |
|
|
|
# try: |
|
|
|
# import timm |
|
|
|
# except ModuleNotFoundError: |
|
|
|
# os.system('pip install timm') |
|
|
|
|
|
@ -18,8 +18,6 @@ from typing import NamedTuple |
|
|
|
from pathlib import Path |
|
|
|
from PIL import Image |
|
|
|
import torch |
|
|
|
from timm.data import resolve_data_config |
|
|
|
from timm.data.transforms_factory import create_transform |
|
|
|
import numpy |
|
|
|
|
|
|
|
from towhee.operator import Operator |
|
|
@ -41,6 +39,8 @@ class VitImageEmbedding(Operator): |
|
|
|
sys.path.append(str(Path(__file__).parent)) |
|
|
|
if framework == 'pytorch': |
|
|
|
from pytorch.model import Model |
|
|
|
from timm.data import resolve_data_config |
|
|
|
from timm.data.transforms_factory import create_transform |
|
|
|
self.model = Model(model_name, weights_path) |
|
|
|
config = resolve_data_config({}, model=self.model._model) |
|
|
|
self.tfms = create_transform(**config) |
|
|
|