towhee
/
retinaface-face-detection
copied
6 changed files with 145 additions and 0 deletions
@ -0,0 +1,13 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
@ -0,0 +1,14 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
@ -0,0 +1,41 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
import os |
|||
|
|||
import torch |
|||
|
|||
from towhee.models.retina_face.retinaface import RetinaFace |
|||
from towhee.models.retina_face.configs import build_configs |
|||
from towhee.models.utils.pretrained_utils import load_pretrained_weights |
|||
|
|||
class Model: |
|||
""" |
|||
Pytorch model class |
|||
""" |
|||
def __init__(self): |
|||
model_name = 'cfg_mnet' |
|||
cfg = build_configs(model_name) |
|||
self._model = RetinaFace(cfg=cfg, phase='test') |
|||
load_pretrained_weights(self._model, 'mnet', None, os.path.dirname(__file__) + '/pytorch_retinaface_mobilenet_widerface.pth') |
|||
self._model.eval() |
|||
|
|||
def __call__(self, img_tensor: torch.Tensor): |
|||
outputs = self._model.inference(img_tensor) |
|||
return outputs |
|||
|
|||
def train(self): |
|||
""" |
|||
For training model |
|||
""" |
|||
pass |
Binary file not shown.
@ -0,0 +1 @@ |
|||
torch |
@ -0,0 +1,73 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
from typing import NamedTuple, List |
|||
from PIL import Image |
|||
import torch |
|||
from torchvision import transforms |
|||
import sys |
|||
import towhee |
|||
from pathlib import Path |
|||
import numpy |
|||
|
|||
from towhee.operator import Operator |
|||
from towhee.utils.pil_utils import to_pil |
|||
from timm.data import resolve_data_config |
|||
from timm.data.transforms_factory import create_transform |
|||
import os |
|||
|
|||
class RetinafaceFaceDetection(Operator): |
|||
""" |
|||
Embedding extractor using efficientnet. |
|||
Args: |
|||
model_name (`string`): |
|||
Model name. |
|||
weights_path (`string`): |
|||
Path to local weights. |
|||
""" |
|||
|
|||
def __init__(self, need_crop = True, framework: str = 'pytorch') -> None: |
|||
super().__init__() |
|||
if framework == 'pytorch': |
|||
import importlib.util |
|||
path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') |
|||
opname = os.path.basename(str(Path(__file__))).split('.')[0] |
|||
spec = importlib.util.spec_from_file_location(opname, path) |
|||
module = importlib.util.module_from_spec(spec) |
|||
spec.loader.exec_module(module) |
|||
self.need_crop = need_crop |
|||
self.model = module.Model() |
|||
|
|||
def __call__(self, image: 'towhee.types.Image') -> List[NamedTuple('Outputs', [('boxes', numpy.ndarray), |
|||
('keypoints', numpy.ndarray), |
|||
('cropped_imgs', numpy.ndarray)])]: |
|||
Outputs = NamedTuple('Outputs', [('boxes', numpy.ndarray), ('keypoints', numpy.ndarray), ('cropped_imgs', numpy.ndarray)]) |
|||
img = torch.FloatTensor(numpy.asarray(to_pil(image))) |
|||
bboxes, keypoints = self.model(img) |
|||
croppeds = [] |
|||
if self.need_crop is True: |
|||
h, w, _ = img.shape |
|||
for bbox in bboxes: |
|||
x1, y1, x2, y2, _ = bbox |
|||
x1 = max(int(x1), 0) |
|||
y1 = max(int(y1), 0) |
|||
x2 = min(int(x2), w) |
|||
y2 = min(int(y2), h) |
|||
croppeds.append(img[y1:y2, x1:x2, :].numpy()) |
|||
outputs = [] |
|||
|
|||
for i in range(len(croppeds)): |
|||
output = Outputs(bboxes[i], keypoints[i,:], croppeds[i]) |
|||
outputs.append(output) |
|||
return outputs |
Loading…
Reference in new issue