towhee
/
resnet-image-embedding
copied
5 changed files with 98 additions and 155 deletions
@ -1,66 +0,0 @@ |
|||||
# Copyright 2021 Zilliz. All rights reserved. |
|
||||
# |
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
||||
# you may not use this file except in compliance with the License. |
|
||||
# You may obtain a copy of the License at |
|
||||
# |
|
||||
# http://www.apache.org/licenses/LICENSE-2.0 |
|
||||
# |
|
||||
# Unless required by applicable law or agreed to in writing, software |
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, |
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
||||
# See the License for the specific language governing permissions and |
|
||||
# limitations under the License. |
|
||||
import pprint |
|
||||
|
|
||||
class EmbeddingOutput: |
|
||||
""" |
|
||||
Container for embedding extractor. |
|
||||
""" |
|
||||
def __init__(self): |
|
||||
self.embeddings = [] |
|
||||
|
|
||||
def __call__(self, module, module_in, module_out): |
|
||||
self.embeddings.append(module_out) |
|
||||
|
|
||||
def clear(self): |
|
||||
""" |
|
||||
clear list |
|
||||
""" |
|
||||
self.embeddings = [] |
|
||||
|
|
||||
|
|
||||
class EmbeddingExtractor: |
|
||||
""" |
|
||||
Embedding extractor from a layer |
|
||||
Args: |
|
||||
model (`nn.Module`): |
|
||||
Model used for inference. |
|
||||
""" |
|
||||
def __init__(self, model): |
|
||||
# self.modules = model.modules() |
|
||||
# self.modules_list = list(model.named_modules(remove_duplicate=False)) |
|
||||
self.modules_dict = dict(model.named_modules(remove_duplicate=False)) |
|
||||
self.emb_out = EmbeddingOutput() |
|
||||
|
|
||||
def disp_modules(self, full=False): |
|
||||
""" |
|
||||
Display the the modules of the model. |
|
||||
""" |
|
||||
if not full: |
|
||||
pprint.pprint(list(self.modules_dict.keys())) |
|
||||
else: |
|
||||
pprint.pprint(self.modules_dict) |
|
||||
|
|
||||
def register(self, layer_name: str): |
|
||||
""" |
|
||||
Registration for embedding extraction. |
|
||||
Args: |
|
||||
layer_name (`str`): |
|
||||
Name of the layer from which the embedding is extracted. |
|
||||
""" |
|
||||
if layer_name in self.modules_dict: |
|
||||
layer = self.modules_dict[layer_name] |
|
||||
layer.register_forward_hook(self.emb_out) |
|
||||
else: |
|
||||
raise ValueError('layer_name not in modules') |
|
@ -1,25 +1,42 @@ |
|||||
|
callback: |
||||
|
early_stopping: |
||||
|
mode: max |
||||
|
monitor: eval_epoch_metric |
||||
|
patience: 2 |
||||
|
model_checkpoint: |
||||
|
every_n_epoch: 2 |
||||
|
tensorboard: |
||||
|
comment: '' |
||||
|
log_dir: null |
||||
device: |
device: |
||||
device_str: null |
device_str: null |
||||
n_gpu: -1 |
n_gpu: -1 |
||||
sync_bn: true |
|
||||
|
sync_bn: false |
||||
|
learning: |
||||
|
loss: CrossEntropyLoss |
||||
|
lr: 5.0e-05 |
||||
|
lr_scheduler_type: linear |
||||
|
optimizer: Adam |
||||
|
warmup_ratio: 0.0 |
||||
|
warmup_steps: 0 |
||||
|
logging: |
||||
|
logging_dir: null |
||||
|
logging_strategy: steps |
||||
|
print_steps: null |
||||
|
save_strategy: steps |
||||
metrics: |
metrics: |
||||
metric: Accuracy |
metric: Accuracy |
||||
train: |
train: |
||||
batch_size: 32 |
|
||||
|
batch_size: 16 |
||||
|
dataloader_drop_last: false |
||||
|
dataloader_num_workers: 0 |
||||
|
epoch_num: 16 |
||||
|
eval_steps: null |
||||
|
eval_strategy: epoch |
||||
|
load_best_model_at_end: false |
||||
|
max_steps: -1 |
||||
|
output_dir: ./output_dir |
||||
overwrite_output_dir: true |
overwrite_output_dir: true |
||||
epoch_num: 2 |
|
||||
learning: |
|
||||
optimizer: |
|
||||
name_: SGD |
|
||||
lr: 0.04 |
|
||||
momentum: 0.001 |
|
||||
loss: |
|
||||
name_: CrossEntropyLoss |
|
||||
ignore_index: -1 |
|
||||
logging: |
|
||||
print_steps: 2 |
|
||||
#learning: |
|
||||
# optimizer: |
|
||||
# name_: Adam |
|
||||
# lr: 0.02 |
|
||||
# eps: 0.001 |
|
||||
|
resume_from_checkpoint: null |
||||
|
seed: 42 |
||||
|
val_batch_size: -1 |
||||
|
Loading…
Reference in new issue