towhee
/
            
              resnet-image-embedding
              
                
                
            
          copied
				 5 changed files with 98 additions and 155 deletions
			
			
		@ -1,66 +0,0 @@ | 
			
		|||||
# Copyright 2021 Zilliz. All rights reserved. | 
				 | 
			
		||||
# | 
				 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License"); | 
				 | 
			
		||||
# you may not use this file except in compliance with the License. | 
				 | 
			
		||||
# You may obtain a copy of the License at | 
				 | 
			
		||||
# | 
				 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0 | 
				 | 
			
		||||
# | 
				 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software | 
				 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS, | 
				 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
				 | 
			
		||||
# See the License for the specific language governing permissions and | 
				 | 
			
		||||
# limitations under the License. | 
				 | 
			
		||||
import pprint | 
				 | 
			
		||||
 | 
				 | 
			
		||||
class EmbeddingOutput: | 
				 | 
			
		||||
    """ | 
				 | 
			
		||||
    Container for embedding extractor. | 
				 | 
			
		||||
    """ | 
				 | 
			
		||||
    def __init__(self): | 
				 | 
			
		||||
        self.embeddings = [] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def __call__(self, module, module_in, module_out): | 
				 | 
			
		||||
        self.embeddings.append(module_out) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def clear(self): | 
				 | 
			
		||||
        """ | 
				 | 
			
		||||
        clear list | 
				 | 
			
		||||
        """ | 
				 | 
			
		||||
        self.embeddings = [] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
class EmbeddingExtractor: | 
				 | 
			
		||||
    """ | 
				 | 
			
		||||
    Embedding extractor from a layer | 
				 | 
			
		||||
    Args: | 
				 | 
			
		||||
        model (`nn.Module`): | 
				 | 
			
		||||
            Model used for inference. | 
				 | 
			
		||||
    """ | 
				 | 
			
		||||
    def __init__(self, model): | 
				 | 
			
		||||
        # self.modules = model.modules() | 
				 | 
			
		||||
        # self.modules_list = list(model.named_modules(remove_duplicate=False)) | 
				 | 
			
		||||
        self.modules_dict = dict(model.named_modules(remove_duplicate=False)) | 
				 | 
			
		||||
        self.emb_out = EmbeddingOutput() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def disp_modules(self, full=False): | 
				 | 
			
		||||
        """ | 
				 | 
			
		||||
        Display the the modules of the model. | 
				 | 
			
		||||
        """ | 
				 | 
			
		||||
        if not full: | 
				 | 
			
		||||
            pprint.pprint(list(self.modules_dict.keys())) | 
				 | 
			
		||||
        else: | 
				 | 
			
		||||
            pprint.pprint(self.modules_dict) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def register(self, layer_name: str): | 
				 | 
			
		||||
        """ | 
				 | 
			
		||||
        Registration for embedding extraction. | 
				 | 
			
		||||
        Args: | 
				 | 
			
		||||
            layer_name (`str`): | 
				 | 
			
		||||
                Name of the layer from which the embedding is extracted. | 
				 | 
			
		||||
        """ | 
				 | 
			
		||||
        if layer_name in self.modules_dict: | 
				 | 
			
		||||
            layer = self.modules_dict[layer_name] | 
				 | 
			
		||||
            layer.register_forward_hook(self.emb_out) | 
				 | 
			
		||||
        else: | 
				 | 
			
		||||
            raise ValueError('layer_name not in modules') | 
				 | 
			
		||||
@ -1,25 +1,42 @@ | 
			
		|||||
 | 
				callback: | 
			
		||||
 | 
				  early_stopping: | 
			
		||||
 | 
				    mode: max | 
			
		||||
 | 
				    monitor: eval_epoch_metric | 
			
		||||
 | 
				    patience: 2 | 
			
		||||
 | 
				  model_checkpoint: | 
			
		||||
 | 
				    every_n_epoch: 2 | 
			
		||||
 | 
				  tensorboard: | 
			
		||||
 | 
				    comment: '' | 
			
		||||
 | 
				    log_dir: null | 
			
		||||
device: | 
				device: | 
			
		||||
  device_str: null | 
				  device_str: null | 
			
		||||
  n_gpu: -1 | 
				  n_gpu: -1 | 
			
		||||
  sync_bn: true | 
				 | 
			
		||||
 | 
				  sync_bn: false | 
			
		||||
 | 
				learning: | 
			
		||||
 | 
				  loss: CrossEntropyLoss | 
			
		||||
 | 
				  lr: 5.0e-05 | 
			
		||||
 | 
				  lr_scheduler_type: linear | 
			
		||||
 | 
				  optimizer: Adam | 
			
		||||
 | 
				  warmup_ratio: 0.0 | 
			
		||||
 | 
				  warmup_steps: 0 | 
			
		||||
 | 
				logging: | 
			
		||||
 | 
				  logging_dir: null | 
			
		||||
 | 
				  logging_strategy: steps | 
			
		||||
 | 
				  print_steps: null | 
			
		||||
 | 
				  save_strategy: steps | 
			
		||||
metrics: | 
				metrics: | 
			
		||||
  metric: Accuracy | 
				  metric: Accuracy | 
			
		||||
train: | 
				train: | 
			
		||||
  batch_size: 32 | 
				 | 
			
		||||
 | 
				  batch_size: 16 | 
			
		||||
 | 
				  dataloader_drop_last: false | 
			
		||||
 | 
				  dataloader_num_workers: 0 | 
			
		||||
 | 
				  epoch_num: 16 | 
			
		||||
 | 
				  eval_steps: null | 
			
		||||
 | 
				  eval_strategy: epoch | 
			
		||||
 | 
				  load_best_model_at_end: false | 
			
		||||
 | 
				  max_steps: -1 | 
			
		||||
 | 
				  output_dir: ./output_dir | 
			
		||||
  overwrite_output_dir: true | 
				  overwrite_output_dir: true | 
			
		||||
  epoch_num: 2 | 
				 | 
			
		||||
learning: | 
				 | 
			
		||||
  optimizer: | 
				 | 
			
		||||
    name_: SGD | 
				 | 
			
		||||
    lr: 0.04 | 
				 | 
			
		||||
    momentum: 0.001 | 
				 | 
			
		||||
  loss: | 
				 | 
			
		||||
    name_: CrossEntropyLoss | 
				 | 
			
		||||
    ignore_index: -1 | 
				 | 
			
		||||
logging: | 
				 | 
			
		||||
  print_steps: 2 | 
				 | 
			
		||||
#learning: | 
				 | 
			
		||||
#  optimizer: | 
				 | 
			
		||||
#    name_: Adam | 
				 | 
			
		||||
#    lr: 0.02 | 
				 | 
			
		||||
#    eps: 0.001 | 
				 | 
			
		||||
 | 
				  resume_from_checkpoint: null | 
			
		||||
 | 
				  seed: 42 | 
			
		||||
 | 
				  val_batch_size: -1 | 
			
		||||
 | 
			
		|||||
					Loading…
					
					
				
		Reference in new issue