towhee
/
resnet-image-embedding
copied
5 changed files with 98 additions and 155 deletions
@ -1,66 +0,0 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
import pprint |
|||
|
|||
class EmbeddingOutput: |
|||
""" |
|||
Container for embedding extractor. |
|||
""" |
|||
def __init__(self): |
|||
self.embeddings = [] |
|||
|
|||
def __call__(self, module, module_in, module_out): |
|||
self.embeddings.append(module_out) |
|||
|
|||
def clear(self): |
|||
""" |
|||
clear list |
|||
""" |
|||
self.embeddings = [] |
|||
|
|||
|
|||
class EmbeddingExtractor: |
|||
""" |
|||
Embedding extractor from a layer |
|||
Args: |
|||
model (`nn.Module`): |
|||
Model used for inference. |
|||
""" |
|||
def __init__(self, model): |
|||
# self.modules = model.modules() |
|||
# self.modules_list = list(model.named_modules(remove_duplicate=False)) |
|||
self.modules_dict = dict(model.named_modules(remove_duplicate=False)) |
|||
self.emb_out = EmbeddingOutput() |
|||
|
|||
def disp_modules(self, full=False): |
|||
""" |
|||
Display the the modules of the model. |
|||
""" |
|||
if not full: |
|||
pprint.pprint(list(self.modules_dict.keys())) |
|||
else: |
|||
pprint.pprint(self.modules_dict) |
|||
|
|||
def register(self, layer_name: str): |
|||
""" |
|||
Registration for embedding extraction. |
|||
Args: |
|||
layer_name (`str`): |
|||
Name of the layer from which the embedding is extracted. |
|||
""" |
|||
if layer_name in self.modules_dict: |
|||
layer = self.modules_dict[layer_name] |
|||
layer.register_forward_hook(self.emb_out) |
|||
else: |
|||
raise ValueError('layer_name not in modules') |
@ -1,25 +1,42 @@ |
|||
callback: |
|||
early_stopping: |
|||
mode: max |
|||
monitor: eval_epoch_metric |
|||
patience: 2 |
|||
model_checkpoint: |
|||
every_n_epoch: 2 |
|||
tensorboard: |
|||
comment: '' |
|||
log_dir: null |
|||
device: |
|||
device_str: null |
|||
n_gpu: -1 |
|||
sync_bn: true |
|||
sync_bn: false |
|||
learning: |
|||
loss: CrossEntropyLoss |
|||
lr: 5.0e-05 |
|||
lr_scheduler_type: linear |
|||
optimizer: Adam |
|||
warmup_ratio: 0.0 |
|||
warmup_steps: 0 |
|||
logging: |
|||
logging_dir: null |
|||
logging_strategy: steps |
|||
print_steps: null |
|||
save_strategy: steps |
|||
metrics: |
|||
metric: Accuracy |
|||
train: |
|||
batch_size: 32 |
|||
batch_size: 16 |
|||
dataloader_drop_last: false |
|||
dataloader_num_workers: 0 |
|||
epoch_num: 16 |
|||
eval_steps: null |
|||
eval_strategy: epoch |
|||
load_best_model_at_end: false |
|||
max_steps: -1 |
|||
output_dir: ./output_dir |
|||
overwrite_output_dir: true |
|||
epoch_num: 2 |
|||
learning: |
|||
optimizer: |
|||
name_: SGD |
|||
lr: 0.04 |
|||
momentum: 0.001 |
|||
loss: |
|||
name_: CrossEntropyLoss |
|||
ignore_index: -1 |
|||
logging: |
|||
print_steps: 2 |
|||
#learning: |
|||
# optimizer: |
|||
# name_: Adam |
|||
# lr: 0.02 |
|||
# eps: 0.001 |
|||
resume_from_checkpoint: null |
|||
seed: 42 |
|||
val_batch_size: -1 |
|||
|
Loading…
Reference in new issue