towhee
/
resnet-image-embedding
copied
6 changed files with 190 additions and 9 deletions
After Width: | Height: | Size: 92 KiB |
@ -0,0 +1,66 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
import pprint |
||||
|
|
||||
|
class EmbeddingOutput: |
||||
|
""" |
||||
|
Container for embedding extractor. |
||||
|
""" |
||||
|
def __init__(self): |
||||
|
self.embeddings = [] |
||||
|
|
||||
|
def __call__(self, module, module_in, module_out): |
||||
|
self.embeddings.append(module_out) |
||||
|
|
||||
|
def clear(self): |
||||
|
""" |
||||
|
clear list |
||||
|
""" |
||||
|
self.embeddings = [] |
||||
|
|
||||
|
|
||||
|
class EmbeddingExtractor: |
||||
|
""" |
||||
|
Embedding extractor from a layer |
||||
|
Args: |
||||
|
model (`nn.Module`): |
||||
|
Model used for inference. |
||||
|
""" |
||||
|
def __init__(self, model): |
||||
|
# self.modules = model.modules() |
||||
|
# self.modules_list = list(model.named_modules(remove_duplicate=False)) |
||||
|
self.modules_dict = dict(model.named_modules(remove_duplicate=False)) |
||||
|
self.emb_out = EmbeddingOutput() |
||||
|
|
||||
|
def disp_modules(self, full=False): |
||||
|
""" |
||||
|
Display the the modules of the model. |
||||
|
""" |
||||
|
if not full: |
||||
|
pprint.pprint(list(self.modules_dict.keys())) |
||||
|
else: |
||||
|
pprint.pprint(self.modules_dict) |
||||
|
|
||||
|
def register(self, layer_name: str): |
||||
|
""" |
||||
|
Registration for embedding extraction. |
||||
|
Args: |
||||
|
layer_name (`str`): |
||||
|
Name of the layer from which the embedding is extracted. |
||||
|
""" |
||||
|
if layer_name in self.modules_dict: |
||||
|
layer = self.modules_dict[layer_name] |
||||
|
layer.register_forward_hook(self.emb_out) |
||||
|
else: |
||||
|
raise ValueError('layer_name not in modules') |
@ -0,0 +1,22 @@ |
|||||
|
device: |
||||
|
device_str: null |
||||
|
n_gpu: -1 |
||||
|
sync_bn: true |
||||
|
metrics: |
||||
|
metric: Accuracy |
||||
|
train: |
||||
|
batch_size: 16 |
||||
|
learning: |
||||
|
optimizer: |
||||
|
name_: SGD |
||||
|
lr: 0.03 |
||||
|
momentum: 0.001 |
||||
|
nesterov: 111 |
||||
|
loss: |
||||
|
name_: CrossEntropyLoss |
||||
|
label_smoothing: 0.1 |
||||
|
#learning: |
||||
|
# optimizer: |
||||
|
# name_: Adam |
||||
|
# lr: 0.02 |
||||
|
# eps: 0.001 |
@ -0,0 +1,65 @@ |
|||||
|
import numpy as np |
||||
|
from torchvision import transforms |
||||
|
from torchvision.transforms import RandomResizedCrop, Lambda |
||||
|
from towhee.trainer.modelcard import ModelCard |
||||
|
|
||||
|
from towhee.trainer.training_config import TrainingConfig |
||||
|
from towhee.trainer.dataset import get_dataset |
||||
|
from resnet_image_embedding import ResnetImageEmbedding |
||||
|
from towhee.types import Image |
||||
|
from towhee.trainer.training_config import dump_default_yaml |
||||
|
from PIL import Image as PILImage |
||||
|
from timm.models.resnet import ResNet |
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
# img = torch.rand([1, 3, 224, 224]) |
||||
|
img_path = './ILSVRC2012_val_00049771.JPEG' |
||||
|
# # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png') |
||||
|
img = PILImage.open(img_path) |
||||
|
img_bytes = img.tobytes() |
||||
|
img_width = img.width |
||||
|
img_height = img.height |
||||
|
img_channel = len(img.split()) |
||||
|
img_mode = img.mode |
||||
|
img_array = np.array(img) |
||||
|
array_size = np.array(img).shape |
||||
|
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) |
||||
|
|
||||
|
op = ResnetImageEmbedding('resnet34') |
||||
|
op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") |
||||
|
# old_out = op(towhee_img) |
||||
|
|
||||
|
# print(old_out.feature_vector[0]) |
||||
|
|
||||
|
training_config = TrainingConfig() |
||||
|
yaml_path = 'resnet_training_yaml.yaml' |
||||
|
# dump_default_yaml(yaml_path=yaml_path) |
||||
|
training_config.load_from_yaml(yaml_path) |
||||
|
# output_dir='./temp_output', |
||||
|
# overwrite_output_dir=True, |
||||
|
# epoch_num=2, |
||||
|
# per_gpu_train_batch_size=16, |
||||
|
# prediction_loss_only=True, |
||||
|
# metric='Accuracy' |
||||
|
# # device_str='cuda', |
||||
|
# # n_gpu=4 |
||||
|
# ) |
||||
|
|
||||
|
|
||||
|
mnist_transform = transforms.Compose([transforms.ToTensor(), |
||||
|
RandomResizedCrop(224), |
||||
|
Lambda(lambda x: x.repeat(3, 1, 1)), |
||||
|
transforms.Normalize(mean=[0.5], std=[0.5])]) |
||||
|
train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data') |
||||
|
|
||||
|
fake_transform = transforms.Compose([transforms.ToTensor(), |
||||
|
RandomResizedCrop(224),]) |
||||
|
# train_data = get_dataset('fake', size=20, transform=fake_transform) |
||||
|
|
||||
|
op.change_before_train(10) |
||||
|
op.train(training_config, train_dataset=train_data) |
||||
|
# e.save('./test_save') |
||||
|
# e.load('./test_save') |
||||
|
# new_out = e(img) |
||||
|
|
||||
|
# assert (new_out[0]!=old_out[0]).all() |
Loading…
Reference in new issue