towhee
/
resnet-image-embedding
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
66 lines
2.5 KiB
66 lines
2.5 KiB
3 years ago
|
import numpy as np
|
||
|
from torchvision import transforms
|
||
|
from torchvision.transforms import RandomResizedCrop, Lambda
|
||
|
from towhee.trainer.modelcard import ModelCard
|
||
|
|
||
|
from towhee.trainer.training_config import TrainingConfig
|
||
|
from towhee.trainer.dataset import get_dataset
|
||
|
from resnet_image_embedding import ResnetImageEmbedding
|
||
|
from towhee.types import Image
|
||
|
from towhee.trainer.training_config import dump_default_yaml
|
||
|
from PIL import Image as PILImage
|
||
|
from timm.models.resnet import ResNet
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
# img = torch.rand([1, 3, 224, 224])
|
||
|
img_path = './ILSVRC2012_val_00049771.JPEG'
|
||
|
# # logo_path = os.path.join(Path(__file__).parent.parent.parent.parent.resolve(), 'towhee_logo.png')
|
||
|
img = PILImage.open(img_path)
|
||
|
img_bytes = img.tobytes()
|
||
|
img_width = img.width
|
||
|
img_height = img.height
|
||
|
img_channel = len(img.split())
|
||
|
img_mode = img.mode
|
||
|
img_array = np.array(img)
|
||
|
array_size = np.array(img).shape
|
||
|
towhee_img = Image(img_bytes, img_width, img_height, img_channel, img_mode, img_array)
|
||
|
|
||
|
op = ResnetImageEmbedding('resnet34')
|
||
|
op.model_card = ModelCard(model_details="resnet test modelcard", training_data="use resnet test data")
|
||
|
# old_out = op(towhee_img)
|
||
|
|
||
|
# print(old_out.feature_vector[0])
|
||
|
|
||
|
training_config = TrainingConfig()
|
||
|
yaml_path = 'resnet_training_yaml.yaml'
|
||
|
# dump_default_yaml(yaml_path=yaml_path)
|
||
|
training_config.load_from_yaml(yaml_path)
|
||
|
# output_dir='./temp_output',
|
||
|
# overwrite_output_dir=True,
|
||
|
# epoch_num=2,
|
||
|
# per_gpu_train_batch_size=16,
|
||
|
# prediction_loss_only=True,
|
||
|
# metric='Accuracy'
|
||
|
# # device_str='cuda',
|
||
|
# # n_gpu=4
|
||
|
# )
|
||
|
|
||
|
|
||
|
mnist_transform = transforms.Compose([transforms.ToTensor(),
|
||
|
RandomResizedCrop(224),
|
||
|
Lambda(lambda x: x.repeat(3, 1, 1)),
|
||
|
transforms.Normalize(mean=[0.5], std=[0.5])])
|
||
|
train_data = get_dataset('mnist', transform=mnist_transform, download=True, root='data')
|
||
|
|
||
|
fake_transform = transforms.Compose([transforms.ToTensor(),
|
||
|
RandomResizedCrop(224),])
|
||
|
# train_data = get_dataset('fake', size=20, transform=fake_transform)
|
||
|
|
||
|
op.change_before_train(10)
|
||
|
op.train(training_config, train_dataset=train_data)
|
||
|
# e.save('./test_save')
|
||
|
# e.load('./test_save')
|
||
|
# new_out = e(img)
|
||
|
|
||
|
# assert (new_out[0]!=old_out[0]).all()
|