towhee
/
resnet-image-embedding
copied
6 changed files with 65 additions and 230 deletions
@ -1,39 +1,38 @@ |
|||
callback: |
|||
early_stopping: |
|||
mode: max |
|||
monitor: eval_epoch_metric |
|||
patience: 4 |
|||
model_checkpoint: |
|||
every_n_epoch: 1 |
|||
tensorboard: |
|||
comment: '' |
|||
log_dir: null |
|||
device: |
|||
device_str: null |
|||
n_gpu: -1 |
|||
sync_bn: false |
|||
train: |
|||
output_dir: ./output_dir |
|||
overwrite_output_dir: true |
|||
eval_strategy: epoch |
|||
eval_steps: |
|||
batch_size: 8 |
|||
val_batch_size: -1 |
|||
seed: 42 |
|||
epoch_num: 2 |
|||
dataloader_pin_memory: true |
|||
dataloader_drop_last: true |
|||
dataloader_num_workers: 0 |
|||
load_best_model_at_end: false |
|||
freeze_bn: false |
|||
learning: |
|||
loss: CrossEntropyLoss |
|||
lr: 5.0e-05 |
|||
lr_scheduler_type: linear |
|||
optimizer: Adam |
|||
warmup_ratio: 0.0 |
|||
warmup_steps: 0 |
|||
lr: 5e-05 |
|||
loss: CrossEntropyLoss |
|||
optimizer: Adam |
|||
lr_scheduler_type: linear |
|||
warmup_ratio: 0.0 |
|||
warmup_steps: 0 |
|||
callback: |
|||
early_stopping: |
|||
monitor: eval_epoch_metric |
|||
patience: 4 |
|||
mode: max |
|||
model_checkpoint: |
|||
every_n_epoch: 1 |
|||
tensorboard: |
|||
log_dir: |
|||
comment: '' |
|||
logging: |
|||
print_steps: null |
|||
print_steps: |
|||
metrics: |
|||
metric: Accuracy |
|||
train: |
|||
batch_size: 8 |
|||
dataloader_drop_last: true |
|||
dataloader_num_workers: 0 |
|||
dataloader_pin_memory: true |
|||
epoch_num: 2 |
|||
eval_steps: null |
|||
eval_strategy: epoch |
|||
freeze_bn: false |
|||
load_best_model_at_end: false |
|||
output_dir: ./output_dir |
|||
overwrite_output_dir: true |
|||
seed: 42 |
|||
val_batch_size: -1 |
|||
metric: Accuracy |
|||
device: |
|||
device_str: |
|||
sync_bn: false |
|||
|
@ -1,13 +0,0 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
@ -1,52 +0,0 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
|
|||
import torch |
|||
from torch.nn import Linear |
|||
from torch import nn |
|||
import timm |
|||
|
|||
|
|||
class Model(): |
|||
""" |
|||
PyTorch model class |
|||
""" |
|||
|
|||
def __init__(self, model_name, num_classes=1000): |
|||
super().__init__() |
|||
self._model = timm.create_model(model_name, pretrained=True) |
|||
pretrained_dict = None |
|||
if model_name == 'resnet101': |
|||
pretrained_dict = torch.hub.load_state_dict_from_url( |
|||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth') |
|||
if model_name == 'resnet50': |
|||
pretrained_dict = torch.hub.load_state_dict_from_url( |
|||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth') |
|||
if pretrained_dict: |
|||
self._model.load_state_dict(pretrained_dict, strict=False) |
|||
if num_classes != 1000: |
|||
self.create_classifier(num_classes=num_classes) |
|||
self._model.eval() |
|||
|
|||
def __call__(self, img_tensor: torch.Tensor): |
|||
self._model.eval() |
|||
features = self._model.forward_features(img_tensor) |
|||
if features.dim() == 4: # if the shape of feature map is [N, C, H, W], where H > 1 and W > 1 |
|||
global_pool = nn.AdaptiveAvgPool2d(1) |
|||
features = global_pool(features) |
|||
return features.flatten().detach().numpy() |
|||
|
|||
def create_classifier(self, num_classes): |
|||
self._model.fc = Linear(self._model.fc.in_features, num_classes, bias=True) |
@ -1,5 +1,2 @@ |
|||
torch==1.9.0 |
|||
torchvision==0.10.0 |
|||
pillow==8.3.1 |
|||
numpy==1.19.5 |
|||
timm==0.5.4 |
|||
numpy |
|||
timm>=0.5.4 |
|||
|
@ -1,104 +0,0 @@ |
|||
import numpy as np |
|||
from torch.optim import AdamW |
|||
from torchvision import transforms |
|||
from torchvision.transforms import RandomResizedCrop, Lambda |
|||
from towhee import dataset |
|||
from towhee.trainer.modelcard import ModelCard |
|||
|
|||
from towhee.trainer.training_config import TrainingConfig |
|||
# from towhee.trainer.dataset import get_dataset |
|||
from towhee.trainer.utils.layer_freezer import LayerFreezer |
|||
|
|||
from resnet_image_embedding import ResnetImageEmbedding |
|||
from towhee.types import Image |
|||
from towhee.trainer.training_config import dump_default_yaml |
|||
from PIL import Image as PILImage |
|||
from timm.models.resnet import ResNet |
|||
from torch import nn |
|||
|
|||
if __name__ == '__main__': |
|||
dump_default_yaml(yaml_path='default_config.yaml') |
|||
# img = torch.rand([1, 3, 224, 224]) |
|||
img_path = './ILSVRC2012_val_00049771.JPEG' |
|||
# logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png') |
|||
img = PILImage.open(img_path) |
|||
img_bytes = img.tobytes() |
|||
img_width = img.width |
|||
img_height = img.height |
|||
img_channel = len(img.split()) |
|||
img_mode = img.mode |
|||
img_array = np.array(img) |
|||
array_size = np.array(img).shape |
|||
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) |
|||
|
|||
op = ResnetImageEmbedding('resnet18', num_classes=10) |
|||
# op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") |
|||
old_out = op(towhee_img) |
|||
# print(old_out.feature_vector[0][:10]) |
|||
# print(old_out.feature_vector[:10]) |
|||
# print(old_out.feature_vector.shape) |
|||
|
|||
training_config = TrainingConfig() |
|||
yaml_path = 'resnet_training_yaml.yaml' |
|||
# dump_default_yaml(yaml_path=yaml_path) |
|||
training_config.load_from_yaml(yaml_path) |
|||
|
|||
# training_config.overwrite_output_dir=True |
|||
# training_config.epoch_num=3 |
|||
# training_config.batch_size=256 |
|||
training_config.device_str='cpu' |
|||
training_config.tensorboard = None |
|||
training_config.batch_size = 2 |
|||
training_config.epoch_num = 2 |
|||
# training_config.n_gpu=-1 |
|||
# training_config.save_to_yaml(yaml_path) |
|||
|
|||
mnist_transform = transforms.Compose([transforms.ToTensor(), |
|||
Lambda(lambda x: x.repeat(3, 1, 1)), |
|||
transforms.Normalize(mean=[0.1307,0.1307,0.1307], std=[0.3081,0.3081,0.3081])]) |
|||
train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True) |
|||
eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False) |
|||
# training_config.output_dir = 'mnist_output' |
|||
# fake_transform = transforms.Compose([transforms.ToTensor(), |
|||
# RandomResizedCrop(224),]) |
|||
# train_data = dataset('fake', size=100, transform=fake_transform) |
|||
# eval_data = dataset('fake', size=10, transform=fake_transform) |
|||
# training_config.output_dir = 'mnist_0228_5' |
|||
# op.model_card = ModelCard(datasets='mnist dataset') |
|||
|
|||
# fake_transform = transforms.Compose([transforms.ToTensor()]) |
|||
# # RandomResizedCrop(224)]) |
|||
# train_data = dataset('fake', size=20, transform=fake_transform) |
|||
# eval_data = dataset('fake', size=10, transform=fake_transform) |
|||
# training_config.output_dir = 'fake_output' |
|||
# op.model_card = ModelCard(datasets='fake dataset') |
|||
|
|||
# op.trainer |
|||
# trainer = op.setup_trainer() |
|||
# print(op.get_model()) |
|||
# my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False) |
|||
# op.setup_trainer() |
|||
# |
|||
# trainer.add_callback() |
|||
# trainer.set_optimizer() |
|||
# |
|||
# op.trainer.set_optimizer(my_optimimzer) |
|||
# trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml') |
|||
# |
|||
# my_loss = nn.BCELoss() |
|||
# trainer.set_loss(my_loss, 'my_loss111') |
|||
# trainer.configs.save_to_yaml('chaned_loss_yaml.yaml') |
|||
# op.trainer._create_optimizer() |
|||
# op.trainer.set_optimizer() |
|||
# trainer = op.setup_trainer(training_config, train_dataset=train_data, eval_dataset=eval_data) |
|||
|
|||
# freezer = LayerFreezer(op.get_model()) |
|||
# freezer.set_slice(-1) |
|||
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data) |
|||
# op.train(training_config, train_dataset=train_data, eval_dataset=eval_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') |
|||
|
|||
# op.save('./test_save') |
|||
# op.load('./test_save')\ |
|||
|
|||
# new_out = op(towhee_img) |
|||
# assert (new_out.feature_vector==old_out.feature_vector).all() |
Loading…
Reference in new issue