towhee
/
resnet-image-embedding
copied
6 changed files with 190 additions and 9 deletions
After Width: | Height: | Size: 92 KiB |
@ -0,0 +1,66 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
import pprint |
|||
|
|||
class EmbeddingOutput: |
|||
""" |
|||
Container for embedding extractor. |
|||
""" |
|||
def __init__(self): |
|||
self.embeddings = [] |
|||
|
|||
def __call__(self, module, module_in, module_out): |
|||
self.embeddings.append(module_out) |
|||
|
|||
def clear(self): |
|||
""" |
|||
clear list |
|||
""" |
|||
self.embeddings = [] |
|||
|
|||
|
|||
class EmbeddingExtractor: |
|||
""" |
|||
Embedding extractor from a layer |
|||
Args: |
|||
model (`nn.Module`): |
|||
Model used for inference. |
|||
""" |
|||
def __init__(self, model): |
|||
# self.modules = model.modules() |
|||
# self.modules_list = list(model.named_modules(remove_duplicate=False)) |
|||
self.modules_dict = dict(model.named_modules(remove_duplicate=False)) |
|||
self.emb_out = EmbeddingOutput() |
|||
|
|||
def disp_modules(self, full=False): |
|||
""" |
|||
Display the the modules of the model. |
|||
""" |
|||
if not full: |
|||
pprint.pprint(list(self.modules_dict.keys())) |
|||
else: |
|||
pprint.pprint(self.modules_dict) |
|||
|
|||
def register(self, layer_name: str): |
|||
""" |
|||
Registration for embedding extraction. |
|||
Args: |
|||
layer_name (`str`): |
|||
Name of the layer from which the embedding is extracted. |
|||
""" |
|||
if layer_name in self.modules_dict: |
|||
layer = self.modules_dict[layer_name] |
|||
layer.register_forward_hook(self.emb_out) |
|||
else: |
|||
raise ValueError('layer_name not in modules') |
@ -0,0 +1,22 @@ |
|||
device: |
|||
device_str: null |
|||
n_gpu: -1 |
|||
sync_bn: true |
|||
metrics: |
|||
metric: Accuracy |
|||
train: |
|||
batch_size: 16 |
|||
learning: |
|||
optimizer: |
|||
name_: SGD |
|||
lr: 0.03 |
|||
momentum: 0.001 |
|||
nesterov: 111 |
|||
loss: |
|||
name_: CrossEntropyLoss |
|||
label_smoothing: 0.1 |
|||
#learning: |
|||
# optimizer: |
|||
# name_: Adam |
|||
# lr: 0.02 |
|||
# eps: 0.001 |
@ -0,0 +1,65 @@ |
|||
import numpy as np |
|||
from torchvision import transforms |
|||
from torchvision.transforms import RandomResizedCrop, Lambda |
|||
from towhee.trainer.modelcard import ModelCard |
|||
|
|||
from towhee.trainer.training_config import TrainingConfig |
|||
from towhee.trainer.dataset import get_dataset |
|||
from resnet_image_embedding import ResnetImageEmbedding |
|||
from towhee.types import Image |
|||
from towhee.trainer.training_config import dump_default_yaml |
|||
from PIL import Image as PILImage |
|||
from timm.models.resnet import ResNet |
|||
|
|||
if __name__ == '__main__': |
|||
# img = torch.rand([1, 3, 224, 224]) |
|||
img_path = './ILSVRC2012_val_00049771.JPEG' |
|||
# # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png') |
|||
img = PILImage.open(img_path) |
|||
img_bytes = img.tobytes() |
|||
img_width = img.width |
|||
img_height = img.height |
|||
img_channel = len(img.split()) |
|||
img_mode = img.mode |
|||
img_array = np.array(img) |
|||
array_size = np.array(img).shape |
|||
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) |
|||
|
|||
op = ResnetImageEmbedding('resnet34') |
|||
op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") |
|||
# old_out = op(towhee_img) |
|||
|
|||
# print(old_out.feature_vector[0]) |
|||
|
|||
training_config = TrainingConfig() |
|||
yaml_path = 'resnet_training_yaml.yaml' |
|||
# dump_default_yaml(yaml_path=yaml_path) |
|||
training_config.load_from_yaml(yaml_path) |
|||
# output_dir='./temp_output', |
|||
# overwrite_output_dir=True, |
|||
# epoch_num=2, |
|||
# per_gpu_train_batch_size=16, |
|||
# prediction_loss_only=True, |
|||
# metric='Accuracy' |
|||
# # device_str='cuda', |
|||
# # n_gpu=4 |
|||
# ) |
|||
|
|||
|
|||
mnist_transform = transforms.Compose([transforms.ToTensor(), |
|||
RandomResizedCrop(224), |
|||
Lambda(lambda x: x.repeat(3, 1, 1)), |
|||
transforms.Normalize(mean=[0.5], std=[0.5])]) |
|||
train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data') |
|||
|
|||
fake_transform = transforms.Compose([transforms.ToTensor(), |
|||
RandomResizedCrop(224),]) |
|||
# train_data = get_dataset('fake', size=20, transform=fake_transform) |
|||
|
|||
op.change_before_train(10) |
|||
op.train(training_config, train_dataset=train_data) |
|||
# e.save('./test_save') |
|||
# e.load('./test_save') |
|||
# new_out = e(img) |
|||
|
|||
# assert (new_out[0]!=old_out[0]).all() |
Loading…
Reference in new issue