| 
					
					
						
							
						
					
					
				 | 
				@ -29,7 +29,7 @@ if __name__ == '__main__': | 
			
		
		
	
		
			
				 | 
				 | 
				    towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) | 
				 | 
				 | 
				    towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array) | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    op = ResnetImageEmbedding('resnet34') | 
				 | 
				 | 
				    op = ResnetImageEmbedding('resnet34') | 
			
		
		
	
		
			
				 | 
				 | 
				    op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    # op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data") | 
			
		
		
	
		
			
				 | 
				 | 
				    # old_out = op(towhee_img) | 
				 | 
				 | 
				    # old_out = op(towhee_img) | 
			
		
		
	
		
			
				 | 
				 | 
				    # print(old_out.feature_vector[0]) | 
				 | 
				 | 
				    # print(old_out.feature_vector[0]) | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
	
		
			
				| 
					
					
					
						
							
						
					
				 | 
				@ -51,33 +51,31 @@ if __name__ == '__main__': | 
			
		
		
	
		
			
				 | 
				 | 
				                                          RandomResizedCrop(224), | 
				 | 
				 | 
				                                          RandomResizedCrop(224), | 
			
		
		
	
		
			
				 | 
				 | 
				                                          Lambda(lambda x: x.repeat(3, 1, 1)), | 
				 | 
				 | 
				                                          Lambda(lambda x: x.repeat(3, 1, 1)), | 
			
		
		
	
		
			
				 | 
				 | 
				                                          transforms.Normalize(mean=[0.5], std=[0.5])]) | 
				 | 
				 | 
				                                          transforms.Normalize(mean=[0.5], std=[0.5])]) | 
			
		
		
	
		
			
				 | 
				 | 
				    train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data') | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=True) | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    eval_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data', train=False) | 
			
		
		
	
		
			
				 | 
				 | 
				    # fake_transform = transforms.Compose([transforms.ToTensor(), | 
				 | 
				 | 
				    # fake_transform = transforms.Compose([transforms.ToTensor(), | 
			
		
		
	
		
			
				 | 
				 | 
				    #                                       RandomResizedCrop(224),]) | 
				 | 
				 | 
				    #                                       RandomResizedCrop(224),]) | 
			
		
		
	
		
			
				 | 
				 | 
				    # train_data = get_dataset('fake', size=20, transform=fake_transform) | 
				 | 
				 | 
				    # train_data = get_dataset('fake', size=20, transform=fake_transform) | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    op.change_before_train(10) | 
				 | 
				 | 
				    op.change_before_train(10) | 
			
		
		
	
		
			
				 | 
				 | 
				    trainer = op.setup_trainer() | 
				 | 
				 | 
				    trainer = op.setup_trainer() | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				    my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    # my_optimimzer = AdamW(op.get_model().parameters(), lr=0.002, betas=(0.91, 0.98), eps=1e-08, weight_decay=0.01, amsgrad=False) | 
			
		
		
	
		
			
				 | 
				 | 
				    # op.setup_trainer() | 
				 | 
				 | 
				    # op.setup_trainer() | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    # trainer.add_callback() | 
				 | 
				 | 
				    # trainer.add_callback() | 
			
		
		
	
		
			
				 | 
				 | 
				    # trainer.set_optimizer() | 
				 | 
				 | 
				    # trainer.set_optimizer() | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    op.trainer.set_optimizer(my_optimimzer) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				    trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml') | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    # op.trainer.set_optimizer(my_optimimzer) | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    # trainer.configs.save_to_yaml('changed_optimizer_yaml.yaml') | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    # my_loss = nn.BCELoss() | 
				 | 
				 | 
				    # my_loss = nn.BCELoss() | 
			
		
		
	
		
			
				 | 
				 | 
				    # trainer.set_loss(my_loss, 'my_loss111') | 
				 | 
				 | 
				    # trainer.set_loss(my_loss, 'my_loss111') | 
			
		
		
	
		
			
				 | 
				 | 
				    # trainer.configs.save_to_yaml('chaned_loss_yaml.yaml') | 
				 | 
				 | 
				    # trainer.configs.save_to_yaml('chaned_loss_yaml.yaml') | 
			
		
		
	
		
			
				 | 
				 | 
				    # op.trainer._create_optimizer() | 
				 | 
				 | 
				    # op.trainer._create_optimizer() | 
			
		
		
	
		
			
				 | 
				 | 
				    # op.trainer.set_optimizer() | 
				 | 
				 | 
				    # op.trainer.set_optimizer() | 
			
		
		
	
		
			
				 | 
				 | 
				    op.train(training_config, train_dataset=train_data) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				    training_config.num_epoch = 3 | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				    op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    op.train(training_config, train_dataset=train_data, eval_dataset=eval_data) | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    # training_config.num_epoch = 3 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    # op.train(training_config, train_dataset=train_data, resume_checkpoint_path=training_config.output_dir + '/epoch_2') | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    # op.save('./test_save') | 
				 | 
				 | 
				    # op.save('./test_save') | 
			
		
		
	
		
			
				 | 
				 | 
				    # op.load('./test_save') | 
				 | 
				 | 
				    # op.load('./test_save') | 
			
		
		
	
	
		
			
				| 
					
						
							
						
					
					
					
				 | 
				
  |