logo
Browse Source

Update

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 3 years ago
parent
commit
b84b7e54fc
  1. 12
      __init__.py
  2. 4
      vit_image_embedding.py

12
__init__.py

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
# import os
# For requirements.
try:
import timm
except ModuleNotFoundError:
os.system('pip install timm')
# # For requirements.
# try:
# import timm
# except ModuleNotFoundError:
# os.system('pip install timm')

4
vit_image_embedding.py

@ -18,8 +18,6 @@ from typing import NamedTuple
from pathlib import Path
from PIL import Image
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import numpy
from towhee.operator import Operator
@ -41,6 +39,8 @@ class VitImageEmbedding(Operator):
sys.path.append(str(Path(__file__).parent))
if framework == 'pytorch':
from pytorch.model import Model
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
self.model = Model(model_name, weights_path)
config = resolve_data_config({}, model=self.model._model)
self.tfms = create_transform(**config)

Loading…
Cancel
Save