towhee
/
            
              retinaface-face-detection
              
                 
                
            
          copied
				 6 changed files with 145 additions and 0 deletions
			
			
		| @ -0,0 +1,13 @@ | |||
| # Copyright 2021 Zilliz. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| #     http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| @ -0,0 +1,14 @@ | |||
| # Copyright 2021 Zilliz. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| #     http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| 
 | |||
| @ -0,0 +1,41 @@ | |||
| # Copyright 2021 Zilliz. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| #     http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import os | |||
| 
 | |||
| import torch | |||
| 
 | |||
| from towhee.models.retina_face.retinaface import RetinaFace | |||
| from towhee.models.retina_face.configs import build_configs | |||
| from towhee.models.utils.pretrained_utils import load_pretrained_weights | |||
| 
 | |||
| class Model: | |||
|     """ | |||
|     Pytorch model class | |||
|     """ | |||
|     def __init__(self): | |||
|         model_name = 'cfg_mnet' | |||
|         cfg = build_configs(model_name) | |||
|         self._model = RetinaFace(cfg=cfg, phase='test') | |||
|         load_pretrained_weights(self._model, 'mnet', None, os.path.dirname(__file__) + '/pytorch_retinaface_mobilenet_widerface.pth') | |||
|         self._model.eval() | |||
| 
 | |||
|     def __call__(self, img_tensor: torch.Tensor): | |||
|         outputs = self._model.inference(img_tensor) | |||
|         return outputs | |||
| 
 | |||
|     def train(self): | |||
|         """ | |||
|         For training model | |||
|         """ | |||
|         pass | |||
								
									Binary file not shown.
								
							
						
					| @ -0,0 +1 @@ | |||
| torch | |||
| @ -0,0 +1,73 @@ | |||
| # Copyright 2021 Zilliz. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| #     http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| 
 | |||
| from typing import NamedTuple, List | |||
| from PIL import Image | |||
| import torch | |||
| from torchvision import transforms | |||
| import sys | |||
| import towhee | |||
| from pathlib import Path | |||
| import numpy | |||
| 
 | |||
| from towhee.operator import Operator | |||
| from towhee.utils.pil_utils import to_pil | |||
| from timm.data import resolve_data_config | |||
| from timm.data.transforms_factory import create_transform | |||
| import os | |||
| 
 | |||
| class RetinafaceFaceDetection(Operator): | |||
|     """ | |||
|     Embedding extractor using efficientnet. | |||
|     Args: | |||
|         model_name (`string`): | |||
|             Model name. | |||
|         weights_path (`string`): | |||
|             Path to local weights. | |||
|     """ | |||
| 
 | |||
|     def __init__(self, need_crop = True, framework: str = 'pytorch') -> None: | |||
|         super().__init__() | |||
|         if framework == 'pytorch': | |||
|             import importlib.util | |||
|             path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') | |||
|             opname = os.path.basename(str(Path(__file__))).split('.')[0] | |||
|             spec = importlib.util.spec_from_file_location(opname, path) | |||
|             module = importlib.util.module_from_spec(spec) | |||
|             spec.loader.exec_module(module) | |||
|         self.need_crop = need_crop | |||
|         self.model = module.Model() | |||
| 
 | |||
|     def __call__(self, image: 'towhee.types.Image') -> List[NamedTuple('Outputs', [('boxes', numpy.ndarray), | |||
|                                                                             ('keypoints', numpy.ndarray), | |||
|                                                                            ('cropped_imgs', numpy.ndarray)])]: | |||
|         Outputs = NamedTuple('Outputs', [('boxes', numpy.ndarray), ('keypoints', numpy.ndarray), ('cropped_imgs', numpy.ndarray)]) | |||
|         img = torch.FloatTensor(numpy.asarray(to_pil(image))) | |||
|         bboxes, keypoints = self.model(img) | |||
|         croppeds = [] | |||
|         if self.need_crop is True: | |||
|             h, w, _ = img.shape | |||
|             for bbox in bboxes: | |||
|                 x1, y1, x2, y2, _ = bbox | |||
|                 x1 = max(int(x1), 0) | |||
|                 y1 = max(int(y1), 0) | |||
|                 x2 = min(int(x2), w) | |||
|                 y2 = min(int(y2), h) | |||
|                 croppeds.append(img[y1:y2, x1:x2, :].numpy()) | |||
|         outputs = [] | |||
| 
 | |||
|         for i in range(len(croppeds)): | |||
|             output = Outputs(bboxes[i], keypoints[i,:], croppeds[i]) | |||
|             outputs.append(output) | |||
|         return outputs | |||
					Loading…
					
					
				
		Reference in new issue
	
	