towhee
/
            
              retinaface-face-detection
              
                 
                
            
          copied
				 6 changed files with 145 additions and 0 deletions
			
			
		| @ -0,0 +1,13 @@ | |||||
|  | # Copyright 2021 Zilliz. All rights reserved. | ||||
|  | # | ||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  | # you may not use this file except in compliance with the License. | ||||
|  | # You may obtain a copy of the License at | ||||
|  | # | ||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | # | ||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  | # See the License for the specific language governing permissions and | ||||
|  | # limitations under the License. | ||||
| @ -0,0 +1,14 @@ | |||||
|  | # Copyright 2021 Zilliz. All rights reserved. | ||||
|  | # | ||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  | # you may not use this file except in compliance with the License. | ||||
|  | # You may obtain a copy of the License at | ||||
|  | # | ||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | # | ||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  | # See the License for the specific language governing permissions and | ||||
|  | # limitations under the License. | ||||
|  | 
 | ||||
| @ -0,0 +1,41 @@ | |||||
|  | # Copyright 2021 Zilliz. All rights reserved. | ||||
|  | # | ||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  | # you may not use this file except in compliance with the License. | ||||
|  | # You may obtain a copy of the License at | ||||
|  | # | ||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | # | ||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  | # See the License for the specific language governing permissions and | ||||
|  | # limitations under the License. | ||||
|  | import os | ||||
|  | 
 | ||||
|  | import torch | ||||
|  | 
 | ||||
|  | from towhee.models.retina_face.retinaface import RetinaFace | ||||
|  | from towhee.models.retina_face.configs import build_configs | ||||
|  | from towhee.models.utils.pretrained_utils import load_pretrained_weights | ||||
|  | 
 | ||||
|  | class Model: | ||||
|  |     """ | ||||
|  |     Pytorch model class | ||||
|  |     """ | ||||
|  |     def __init__(self): | ||||
|  |         model_name = 'cfg_mnet' | ||||
|  |         cfg = build_configs(model_name) | ||||
|  |         self._model = RetinaFace(cfg=cfg, phase='test') | ||||
|  |         load_pretrained_weights(self._model, 'mnet', None, os.path.dirname(__file__) + '/pytorch_retinaface_mobilenet_widerface.pth') | ||||
|  |         self._model.eval() | ||||
|  | 
 | ||||
|  |     def __call__(self, img_tensor: torch.Tensor): | ||||
|  |         outputs = self._model.inference(img_tensor) | ||||
|  |         return outputs | ||||
|  | 
 | ||||
|  |     def train(self): | ||||
|  |         """ | ||||
|  |         For training model | ||||
|  |         """ | ||||
|  |         pass | ||||
								
									Binary file not shown.
								
							
						
					| @ -0,0 +1 @@ | |||||
|  | torch | ||||
| @ -0,0 +1,73 @@ | |||||
|  | # Copyright 2021 Zilliz. All rights reserved. | ||||
|  | # | ||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  | # you may not use this file except in compliance with the License. | ||||
|  | # You may obtain a copy of the License at | ||||
|  | # | ||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | # | ||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  | # See the License for the specific language governing permissions and | ||||
|  | # limitations under the License. | ||||
|  | 
 | ||||
|  | from typing import NamedTuple, List | ||||
|  | from PIL import Image | ||||
|  | import torch | ||||
|  | from torchvision import transforms | ||||
|  | import sys | ||||
|  | import towhee | ||||
|  | from pathlib import Path | ||||
|  | import numpy | ||||
|  | 
 | ||||
|  | from towhee.operator import Operator | ||||
|  | from towhee.utils.pil_utils import to_pil | ||||
|  | from timm.data import resolve_data_config | ||||
|  | from timm.data.transforms_factory import create_transform | ||||
|  | import os | ||||
|  | 
 | ||||
|  | class RetinafaceFaceDetection(Operator): | ||||
|  |     """ | ||||
|  |     Embedding extractor using efficientnet. | ||||
|  |     Args: | ||||
|  |         model_name (`string`): | ||||
|  |             Model name. | ||||
|  |         weights_path (`string`): | ||||
|  |             Path to local weights. | ||||
|  |     """ | ||||
|  | 
 | ||||
|  |     def __init__(self, need_crop = True, framework: str = 'pytorch') -> None: | ||||
|  |         super().__init__() | ||||
|  |         if framework == 'pytorch': | ||||
|  |             import importlib.util | ||||
|  |             path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') | ||||
|  |             opname = os.path.basename(str(Path(__file__))).split('.')[0] | ||||
|  |             spec = importlib.util.spec_from_file_location(opname, path) | ||||
|  |             module = importlib.util.module_from_spec(spec) | ||||
|  |             spec.loader.exec_module(module) | ||||
|  |         self.need_crop = need_crop | ||||
|  |         self.model = module.Model() | ||||
|  | 
 | ||||
|  |     def __call__(self, image: 'towhee.types.Image') -> List[NamedTuple('Outputs', [('boxes', numpy.ndarray), | ||||
|  |                                                                             ('keypoints', numpy.ndarray), | ||||
|  |                                                                            ('cropped_imgs', numpy.ndarray)])]: | ||||
|  |         Outputs = NamedTuple('Outputs', [('boxes', numpy.ndarray), ('keypoints', numpy.ndarray), ('cropped_imgs', numpy.ndarray)]) | ||||
|  |         img = torch.FloatTensor(numpy.asarray(to_pil(image))) | ||||
|  |         bboxes, keypoints = self.model(img) | ||||
|  |         croppeds = [] | ||||
|  |         if self.need_crop is True: | ||||
|  |             h, w, _ = img.shape | ||||
|  |             for bbox in bboxes: | ||||
|  |                 x1, y1, x2, y2, _ = bbox | ||||
|  |                 x1 = max(int(x1), 0) | ||||
|  |                 y1 = max(int(y1), 0) | ||||
|  |                 x2 = min(int(x2), w) | ||||
|  |                 y2 = min(int(y2), h) | ||||
|  |                 croppeds.append(img[y1:y2, x1:x2, :].numpy()) | ||||
|  |         outputs = [] | ||||
|  | 
 | ||||
|  |         for i in range(len(croppeds)): | ||||
|  |             output = Outputs(bboxes[i], keypoints[i,:], croppeds[i]) | ||||
|  |             outputs.append(output) | ||||
|  |         return outputs | ||||
					Loading…
					
					
				
		Reference in new issue
	
	