diff --git a/default_config.yaml b/default_config.yaml index e3b5cd6..8c1d885 100644 --- a/default_config.yaml +++ b/default_config.yaml @@ -1,39 +1,38 @@ -callback: - early_stopping: - mode: max - monitor: eval_epoch_metric - patience: 4 - model_checkpoint: - every_n_epoch: 1 - tensorboard: - comment: '' - log_dir: null -device: - device_str: null - n_gpu: -1 - sync_bn: false +train: + output_dir: ./output_dir + overwrite_output_dir: true + eval_strategy: epoch + eval_steps: + batch_size: 8 + val_batch_size: -1 + seed: 42 + epoch_num: 2 + dataloader_pin_memory: true + dataloader_drop_last: true + dataloader_num_workers: 0 + load_best_model_at_end: false + freeze_bn: false learning: - loss: CrossEntropyLoss - lr: 5.0e-05 - lr_scheduler_type: linear - optimizer: Adam - warmup_ratio: 0.0 - warmup_steps: 0 + lr: 5e-05 + loss: CrossEntropyLoss + optimizer: Adam + lr_scheduler_type: linear + warmup_ratio: 0.0 + warmup_steps: 0 +callback: + early_stopping: + monitor: eval_epoch_metric + patience: 4 + mode: max + model_checkpoint: + every_n_epoch: 1 + tensorboard: + log_dir: + comment: '' logging: - print_steps: null + print_steps: metrics: - metric: Accuracy -train: - batch_size: 8 - dataloader_drop_last: true - dataloader_num_workers: 0 - dataloader_pin_memory: true - epoch_num: 2 - eval_steps: null - eval_strategy: epoch - freeze_bn: false - load_best_model_at_end: false - output_dir: ./output_dir - overwrite_output_dir: true - seed: 42 - val_batch_size: -1 + metric: Accuracy +device: + device_str: + sync_bn: false diff --git a/pytorch/__init__.py b/pytorch/__init__.py deleted file mode 100644 index 37f5bd7..0000000 --- a/pytorch/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2021 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/pytorch/model.py b/pytorch/model.py deleted file mode 100644 index ccd9abf..0000000 --- a/pytorch/model.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2021 Zilliz. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -from torch.nn import Linear -from torch import nn -import timm - - -class Model(): - """ - PyTorch model class - """ - - def __init__(self, model_name, num_classes=1000): - super().__init__() - self._model = timm.create_model(model_name, pretrained=True) - pretrained_dict = None - if model_name == 'resnet101': - pretrained_dict = torch.hub.load_state_dict_from_url( - 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth') - if model_name == 'resnet50': - pretrained_dict = torch.hub.load_state_dict_from_url( - 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth') - if pretrained_dict: - self._model.load_state_dict(pretrained_dict, strict=False) - if num_classes != 1000: - self.create_classifier(num_classes=num_classes) - self._model.eval() - - def __call__(self, img_tensor: torch.Tensor): - self._model.eval() - features = self._model.forward_features(img_tensor) - if features.dim() == 4: # if the shape of feature map is [N, C, H, W], where H > 1 and W > 1 - global_pool = nn.AdaptiveAvgPool2d(1) - features = global_pool(features) - return features.flatten().detach().numpy() - - def create_classifier(self, num_classes): - self._model.fc = Linear(self._model.fc.in_features, num_classes, bias=True) diff --git a/requirements.txt b/requirements.txt index 33d8673..7dba25f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,2 @@ -torch==1.9.0 -torchvision==0.10.0 -pillow==8.3.1 -numpy==1.19.5 -timm==0.5.4 +numpy +timm>=0.5.4 diff --git a/resnet_image_embedding.py b/resnet_image_embedding.py index f4de412..985e57c 100644 --- a/resnet_image_embedding.py +++ b/resnet_image_embedding.py @@ -11,50 +11,58 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import sys import numpy +import timm import torch -import torchvision -from PIL import Image from torch import nn as nn +from torch.nn import Linear from torchvision import transforms -from pathlib import Path from typing import NamedTuple -import os from torchvision.transforms import InterpolationMode from towhee.operator import NNOperator from towhee.utils.pil_utils import to_pil import warnings + warnings.filterwarnings("ignore") + class ResnetImageEmbedding(NNOperator): """ PyTorch model for image embedding. """ + def __init__(self, model_name: str, num_classes: int = 1000, framework: str = 'pytorch') -> None: super().__init__(framework=framework) - if framework == 'pytorch': - import importlib.util - path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') - opname = os.path.basename(str(Path(__file__))).split('.')[0] - spec = importlib.util.spec_from_file_location(opname, path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - self.model = module.Model(model_name, num_classes=num_classes) + self.model = timm.create_model(model_name, pretrained=True) + pretrained_dict = None + if model_name == 'resnet101': + pretrained_dict = torch.hub.load_state_dict_from_url( + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth') + if model_name == 'resnet50': + pretrained_dict = torch.hub.load_state_dict_from_url( + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth') + if pretrained_dict: + self.model.load_state_dict(pretrained_dict, strict=False) + if num_classes != 1000: + self.create_classifier(num_classes=num_classes) + self.model.eval() + self.tfms = transforms.Compose([transforms.Resize(235, interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) def __call__(self, image: 'towhee.types.Image') -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): - img = self.tfms(to_pil(image)).unsqueeze(0) - embedding = self.model(img) + img_tensor = self.tfms(to_pil(image)).unsqueeze(0) + self.model.eval() + features = self.model.forward_features(img_tensor) + if features.dim() == 4: # if the shape of feature map is [N, C, H, W], where H > 1 and W > 1 + global_pool = nn.AdaptiveAvgPool2d(1) + features = global_pool(features) + embedding = features.flatten().detach().numpy() Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) return Outputs(embedding) - def get_model(self) -> nn.Module: - return self.model._model - def change_before_train(self, num_classes: int = 0): - if num_classes > 0: - self.model.create_classifier(num_classes) \ No newline at end of file + def create_classifier(self, num_classes): + self.model.fc = Linear(self.model.fc.in_features, num_classes, bias=True) diff --git a/test.py b/test.py deleted file mode 100644 index 6c61ec1..0000000 --- a/test.py +++ /dev/null @@ -1,104 +0,0 @@ -import numpy as np -from torch.optim import AdamW -from torchvision import transforms -from torchvision.transforms import RandomResizedCrop, Lambda -from towhee import dataset -from towhee.trainer.modelcard import ModelCard - -from towhee.trainer.training_config import TrainingConfig -# from towhee.trainer.dataset import get_dataset -from towhee.trainer.utils.layer_freezer import LayerFreezer - -from resnet_image_embedding import ResnetImageEmbedding -from towhee.types import Image -from towhee.trainer.training_config import dump_default_yaml -from PIL import Image as PILImage -from timm.models.resnet import ResNet -from torch import nn - -if __name__ == '__main__': - dump_default_yaml(yaml_path='default_config.yaml') - # img = torch.rand([1, 3, 224, 224]) - img_path = './ILSVRC2012_val_00049771.JPEG' - # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png') - img = PILImage.open(img_path) - img_bytes = img.tobytes() - img_width = img.width - img_height = img.height - img_channel = len(img.split()) - img_mode = img.mode - img_array = np.array(img) - array_size = np.array(img).shape - towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) - - op = ResnetImageEmbedding('resnet18', num_classes=10) - # op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") - old_out = op(towhee_img) - # print(old_out.feature_vector[0][:10]) - # print(old_out.feature_vector[:10]) - # print(old_out.feature_vector.shape) - - training_config = TrainingConfig() - yaml_path = 'resnet_training_yaml.yaml' - # dump_default_yaml(yaml_path=yaml_path) - training_config.load_from_yaml(yaml_path) - - # training_config.overwrite_output_dir=True - # training_config.epoch_num=3 - # training_config.batch_size=256 - training_config.device_str='cpu' - training_config.tensorboard = None - training_config.batch_size = 2 - training_config.epoch_num = 2 - # training_config.n_gpu=-1 - # training_config.save_to_yaml(yaml_path) - - mnist_transform = transforms.Compose([transforms.ToTensor(), - Lambda(lambda x: x.repeat(3, 1, 1)), - transforms.Normalize(mean=[0.1307,0.1307,0.1307], std=[0.3081,0.3081,0.3081])]) - train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True) - eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False) - # training_config.output_dir = 'mnist_output' - # fake_transform = transforms.Compose([transforms.ToTensor(), - # RandomResizedCrop(224),]) - # train_data = dataset('fake', size=100, transform=fake_transform) - # eval_data = dataset('fake', size=10, transform=fake_transform) - # training_config.output_dir = 'mnist_0228_5' - # op.model_card = ModelCard(datasets='mnist dataset') - - # fake_transform = transforms.Compose([transforms.ToTensor()]) - # # RandomResizedCrop(224)]) - # train_data = dataset('fake', size=20, transform=fake_transform) - # eval_data = dataset('fake', size=10, transform=fake_transform) - # training_config.output_dir = 'fake_output' - # op.model_card = ModelCard(datasets='fake dataset') - - # op.trainer - # trainer = op.setup_trainer() - # print(op.get_model()) - # my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False) - # op.setup_trainer() - # - # trainer.add_callback() - # trainer.set_optimizer() - # - # op.trainer.set_optimizer(my_optimimzer) - # trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml') - # - # my_loss = nn.BCELoss() - # trainer.set_loss(my_loss, 'my_loss111') - # trainer.configs.save_to_yaml('chaned_loss_yaml.yaml') - # op.trainer._create_optimizer() - # op.trainer.set_optimizer() - # trainer = op.setup_trainer(training_config, train_dataset=train_data, eval_dataset=eval_data) - - # freezer = LayerFreezer(op.get_model()) - # freezer.set_slice(-1) - op.train(training_config, train_dataset=train_data, eval_dataset=eval_data) - # op.train(training_config, train_dataset=train_data, eval_dataset=eval_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') - - # op.save('./test_save') - # op.load('./test_save')\ - - # new_out = op(towhee_img) - # assert (new_out.feature_vector==old_out.feature_vector).all()