logo
Browse Source

Implement with towhee.models.vggish

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
e3f1b9d8a8
  1. 44
      pytorch/model.py
  2. 9
      torch_vggish.py

44
pytorch/model.py

@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from torch import nn
import torch import torch
import torch.nn as nn
import numpy as np
import sys import sys
from pathlib import Path from pathlib import Path
from towhee.models.vggish.torch_vggish import VGG
sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent))
import vggish_input import vggish_input
@ -26,39 +27,18 @@ class Model(nn.Module):
""" """
PyTorch model class PyTorch model class
""" """
def __init__(self):
def __init__(self, weights_path: str=None):
super().__init__() super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 64, 3, 1, 1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, 1, 1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(256, 512, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, 3, 1, 1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2))
self.embeddings = nn.Sequential(
nn.Linear(512 * 24, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, 128),
#nn.ReLU(inplace=True)
)
self._model = VGG()
if not weights_path:
path = str(Path(__file__).parent)
weights_path = path + '/vggish.pth'
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
self._model.load_state_dict(state_dict)
self._model.eval()
def forward(self, x): def forward(self, x):
x = self.features(x).permute(0, 2, 3, 1).contiguous()
x = x.view(x.size(0), -1)
x = self.embeddings(x)
return x
return self._model(x)
def preprocess(self, audio_path: str): def preprocess(self, audio_path: str):
audio_tensors = vggish_input.wavfile_to_examples(audio_path) audio_tensors = vggish_input.wavfile_to_examples(audio_path)

9
torch_vggish.py

@ -29,7 +29,7 @@ class TorchVggish(Operator):
""" """
""" """
def __init__(self, framework: str = 'pytorch') -> None:
def __init__(self, framework: str = 'pytorch', weights_path: str=None) -> None:
super().__init__() super().__init__()
if framework == 'pytorch': if framework == 'pytorch':
import importlib.util import importlib.util
@ -38,13 +38,10 @@ class TorchVggish(Operator):
spec = importlib.util.spec_from_file_location(opname, path) spec = importlib.util.spec_from_file_location(opname, path)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) spec.loader.exec_module(module)
self.model = module.Model()
path = str(Path(__file__).parent)
self.model.load_state_dict(torch.load(path + '/pytorch/vggish.pth', map_location=torch.device('cpu')))
self.model = module.Model(weights_path)
def __call__(self, audio_path: str) -> NamedTuple('Outputs', [('embs', numpy.ndarray)]): def __call__(self, audio_path: str) -> NamedTuple('Outputs', [('embs', numpy.ndarray)]):
audio_tensors = self.model.preprocess(audio_path) audio_tensors = self.model.preprocess(audio_path)
features = self.model.forward(audio_tensors)
features = self.model._model(audio_tensors)
Outputs = NamedTuple('Outputs', [('embs', numpy.ndarray)]) Outputs = NamedTuple('Outputs', [('embs', numpy.ndarray)])
return Outputs(features.detach().numpy()) return Outputs(features.detach().numpy())

Loading…
Cancel
Save