towhee
/
resnet-image-embedding
copied
6 changed files with 65 additions and 230 deletions
@ -1,39 +1,38 @@ |
|||||
|
train: |
||||
|
output_dir: ./output_dir |
||||
|
overwrite_output_dir: true |
||||
|
eval_strategy: epoch |
||||
|
eval_steps: |
||||
|
batch_size: 8 |
||||
|
val_batch_size: -1 |
||||
|
seed: 42 |
||||
|
epoch_num: 2 |
||||
|
dataloader_pin_memory: true |
||||
|
dataloader_drop_last: true |
||||
|
dataloader_num_workers: 0 |
||||
|
load_best_model_at_end: false |
||||
|
freeze_bn: false |
||||
|
learning: |
||||
|
lr: 5e-05 |
||||
|
loss: CrossEntropyLoss |
||||
|
optimizer: Adam |
||||
|
lr_scheduler_type: linear |
||||
|
warmup_ratio: 0.0 |
||||
|
warmup_steps: 0 |
||||
callback: |
callback: |
||||
early_stopping: |
early_stopping: |
||||
mode: max |
|
||||
monitor: eval_epoch_metric |
monitor: eval_epoch_metric |
||||
patience: 4 |
patience: 4 |
||||
|
mode: max |
||||
model_checkpoint: |
model_checkpoint: |
||||
every_n_epoch: 1 |
every_n_epoch: 1 |
||||
tensorboard: |
tensorboard: |
||||
|
log_dir: |
||||
comment: '' |
comment: '' |
||||
log_dir: null |
|
||||
device: |
|
||||
device_str: null |
|
||||
n_gpu: -1 |
|
||||
sync_bn: false |
|
||||
learning: |
|
||||
loss: CrossEntropyLoss |
|
||||
lr: 5.0e-05 |
|
||||
lr_scheduler_type: linear |
|
||||
optimizer: Adam |
|
||||
warmup_ratio: 0.0 |
|
||||
warmup_steps: 0 |
|
||||
logging: |
logging: |
||||
print_steps: null |
|
||||
|
print_steps: |
||||
metrics: |
metrics: |
||||
metric: Accuracy |
metric: Accuracy |
||||
train: |
|
||||
batch_size: 8 |
|
||||
dataloader_drop_last: true |
|
||||
dataloader_num_workers: 0 |
|
||||
dataloader_pin_memory: true |
|
||||
epoch_num: 2 |
|
||||
eval_steps: null |
|
||||
eval_strategy: epoch |
|
||||
freeze_bn: false |
|
||||
load_best_model_at_end: false |
|
||||
output_dir: ./output_dir |
|
||||
overwrite_output_dir: true |
|
||||
seed: 42 |
|
||||
val_batch_size: -1 |
|
||||
|
device: |
||||
|
device_str: |
||||
|
sync_bn: false |
||||
|
@ -1,13 +0,0 @@ |
|||||
# Copyright 2021 Zilliz. All rights reserved. |
|
||||
# |
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
||||
# you may not use this file except in compliance with the License. |
|
||||
# You may obtain a copy of the License at |
|
||||
# |
|
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
|
||||
# |
|
||||
# Unless required by applicable law or agreed to in writing, software |
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
||||
# See the License for the specific language governing permissions and |
|
||||
# limitations under the License. |
|
@ -1,52 +0,0 @@ |
|||||
# Copyright 2021 Zilliz. All rights reserved. |
|
||||
# |
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
||||
# you may not use this file except in compliance with the License. |
|
||||
# You may obtain a copy of the License at |
|
||||
# |
|
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
|
||||
# |
|
||||
# Unless required by applicable law or agreed to in writing, software |
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
||||
# See the License for the specific language governing permissions and |
|
||||
# limitations under the License. |
|
||||
|
|
||||
|
|
||||
import torch |
|
||||
from torch.nn import Linear |
|
||||
from torch import nn |
|
||||
import timm |
|
||||
|
|
||||
|
|
||||
class Model(): |
|
||||
""" |
|
||||
PyTorch model class |
|
||||
""" |
|
||||
|
|
||||
def __init__(self, model_name, num_classes=1000): |
|
||||
super().__init__() |
|
||||
self._model = timm.create_model(model_name, pretrained=True) |
|
||||
pretrained_dict = None |
|
||||
if model_name == 'resnet101': |
|
||||
pretrained_dict = torch.hub.load_state_dict_from_url( |
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth') |
|
||||
if model_name == 'resnet50': |
|
||||
pretrained_dict = torch.hub.load_state_dict_from_url( |
|
||||
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth') |
|
||||
if pretrained_dict: |
|
||||
self._model.load_state_dict(pretrained_dict, strict=False) |
|
||||
if num_classes != 1000: |
|
||||
self.create_classifier(num_classes=num_classes) |
|
||||
self._model.eval() |
|
||||
|
|
||||
def __call__(self, img_tensor: torch.Tensor): |
|
||||
self._model.eval() |
|
||||
features = self._model.forward_features(img_tensor) |
|
||||
if features.dim() == 4: # if the shape of feature map is [N, C, H, W], where H > 1 and W > 1 |
|
||||
global_pool = nn.AdaptiveAvgPool2d(1) |
|
||||
features = global_pool(features) |
|
||||
return features.flatten().detach().numpy() |
|
||||
|
|
||||
def create_classifier(self, num_classes): |
|
||||
self._model.fc = Linear(self._model.fc.in_features, num_classes, bias=True) |
|
@ -1,5 +1,2 @@ |
|||||
torch==1.9.0 |
|
||||
torchvision==0.10.0 |
|
||||
pillow==8.3.1 |
|
||||
numpy==1.19.5 |
|
||||
timm==0.5.4 |
|
||||
|
numpy |
||||
|
timm>=0.5.4 |
||||
|
@ -1,104 +0,0 @@ |
|||||
import numpy as np |
|
||||
from torch.optim import AdamW |
|
||||
from torchvision import transforms |
|
||||
from torchvision.transforms import RandomResizedCrop, Lambda |
|
||||
from towhee import dataset |
|
||||
from towhee.trainer.modelcard import ModelCard |
|
||||
|
|
||||
from towhee.trainer.training_config import TrainingConfig |
|
||||
# from towhee.trainer.dataset import get_dataset |
|
||||
from towhee.trainer.utils.layer_freezer import LayerFreezer |
|
||||
|
|
||||
from resnet_image_embedding import ResnetImageEmbedding |
|
||||
from towhee.types import Image |
|
||||
from towhee.trainer.training_config import dump_default_yaml |
|
||||
from PIL import Image as PILImage |
|
||||
from timm.models.resnet import ResNet |
|
||||
from torch import nn |
|
||||
|
|
||||
if __name__ == '__main__': |
|
||||
dump_default_yaml(yaml_path='default_config.yaml') |
|
||||
# img = torch.rand([1, 3, 224, 224]) |
|
||||
img_path = './ILSVRC2012_val_00049771.JPEG' |
|
||||
# logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png') |
|
||||
img = PILImage.open(img_path) |
|
||||
img_bytes = img.tobytes() |
|
||||
img_width = img.width |
|
||||
img_height = img.height |
|
||||
img_channel = len(img.split()) |
|
||||
img_mode = img.mode |
|
||||
img_array = np.array(img) |
|
||||
array_size = np.array(img).shape |
|
||||
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) |
|
||||
|
|
||||
op = ResnetImageEmbedding('resnet18', num_classes=10) |
|
||||
# op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") |
|
||||
old_out = op(towhee_img) |
|
||||
# print(old_out.feature_vector[0][:10]) |
|
||||
# print(old_out.feature_vector[:10]) |
|
||||
# print(old_out.feature_vector.shape) |
|
||||
|
|
||||
training_config = TrainingConfig() |
|
||||
yaml_path = 'resnet_training_yaml.yaml' |
|
||||
# dump_default_yaml(yaml_path=yaml_path) |
|
||||
training_config.load_from_yaml(yaml_path) |
|
||||
|
|
||||
# training_config.overwrite_output_dir=True |
|
||||
# training_config.epoch_num=3 |
|
||||
# training_config.batch_size=256 |
|
||||
training_config.device_str='cpu' |
|
||||
training_config.tensorboard = None |
|
||||
training_config.batch_size = 2 |
|
||||
training_config.epoch_num = 2 |
|
||||
# training_config.n_gpu=-1 |
|
||||
# training_config.save_to_yaml(yaml_path) |
|
||||
|
|
||||
mnist_transform = transforms.Compose([transforms.ToTensor(), |
|
||||
Lambda(lambda x: x.repeat(3, 1, 1)), |
|
||||
transforms.Normalize(mean=[0.1307,0.1307,0.1307], std=[0.3081,0.3081,0.3081])]) |
|
||||
train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True) |
|
||||
eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False) |
|
||||
# training_config.output_dir = 'mnist_output' |
|
||||
# fake_transform = transforms.Compose([transforms.ToTensor(), |
|
||||
# RandomResizedCrop(224),]) |
|
||||
# train_data = dataset('fake', size=100, transform=fake_transform) |
|
||||
# eval_data = dataset('fake', size=10, transform=fake_transform) |
|
||||
# training_config.output_dir = 'mnist_0228_5' |
|
||||
# op.model_card = ModelCard(datasets='mnist dataset') |
|
||||
|
|
||||
# fake_transform = transforms.Compose([transforms.ToTensor()]) |
|
||||
# # RandomResizedCrop(224)]) |
|
||||
# train_data = dataset('fake', size=20, transform=fake_transform) |
|
||||
# eval_data = dataset('fake', size=10, transform=fake_transform) |
|
||||
# training_config.output_dir = 'fake_output' |
|
||||
# op.model_card = ModelCard(datasets='fake dataset') |
|
||||
|
|
||||
# op.trainer |
|
||||
# trainer = op.setup_trainer() |
|
||||
# print(op.get_model()) |
|
||||
# my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False) |
|
||||
# op.setup_trainer() |
|
||||
# |
|
||||
# trainer.add_callback() |
|
||||
# trainer.set_optimizer() |
|
||||
# |
|
||||
# op.trainer.set_optimizer(my_optimimzer) |
|
||||
# trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml') |
|
||||
# |
|
||||
# my_loss = nn.BCELoss() |
|
||||
# trainer.set_loss(my_loss, 'my_loss111') |
|
||||
# trainer.configs.save_to_yaml('chaned_loss_yaml.yaml') |
|
||||
# op.trainer._create_optimizer() |
|
||||
# op.trainer.set_optimizer() |
|
||||
# trainer = op.setup_trainer(training_config, train_dataset=train_data, eval_dataset=eval_data) |
|
||||
|
|
||||
# freezer = LayerFreezer(op.get_model()) |
|
||||
# freezer.set_slice(-1) |
|
||||
op.train(training_config, train_dataset=train_data, eval_dataset=eval_data) |
|
||||
# op.train(training_config, train_dataset=train_data, eval_dataset=eval_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') |
|
||||
|
|
||||
# op.save('./test_save') |
|
||||
# op.load('./test_save')\ |
|
||||
|
|
||||
# new_out = op(towhee_img) |
|
||||
# assert (new_out.feature_vector==old_out.feature_vector).all() |
|
Loading…
Reference in new issue