logo
Browse Source

Implement with towhee.models.vggish

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
e3f1b9d8a8
  1. 44
      pytorch/model.py
  2. 9
      torch_vggish.py

44
pytorch/model.py

@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torch import nn
import torch
import torch.nn as nn
import numpy as np
import sys
from pathlib import Path
from towhee.models.vggish.torch_vggish import VGG
sys.path.append(str(Path(__file__).parent))
import vggish_input
@ -26,39 +27,18 @@ class Model(nn.Module):
"""
PyTorch model class
"""
def __init__(self):
def __init__(self, weights_path: str=None):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 64, 3, 1, 1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, 1, 1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(256, 512, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, 3, 1, 1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2))
self.embeddings = nn.Sequential(
nn.Linear(512 * 24, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, 128),
#nn.ReLU(inplace=True)
)
self._model = VGG()
if not weights_path:
path = str(Path(__file__).parent)
weights_path = path + '/vggish.pth'
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
self._model.load_state_dict(state_dict)
self._model.eval()
def forward(self, x):
x = self.features(x).permute(0, 2, 3, 1).contiguous()
x = x.view(x.size(0), -1)
x = self.embeddings(x)
return x
return self._model(x)
def preprocess(self, audio_path: str):
audio_tensors = vggish_input.wavfile_to_examples(audio_path)

9
torch_vggish.py

@ -29,7 +29,7 @@ class TorchVggish(Operator):
"""
"""
def __init__(self, framework: str = 'pytorch') -> None:
def __init__(self, framework: str = 'pytorch', weights_path: str=None) -> None:
super().__init__()
if framework == 'pytorch':
import importlib.util
@ -38,13 +38,10 @@ class TorchVggish(Operator):
spec = importlib.util.spec_from_file_location(opname, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
self.model = module.Model()
path = str(Path(__file__).parent)
self.model.load_state_dict(torch.load(path + '/pytorch/vggish.pth', map_location=torch.device('cpu')))
self.model = module.Model(weights_path)
def __call__(self, audio_path: str) -> NamedTuple('Outputs', [('embs', numpy.ndarray)]):
audio_tensors = self.model.preprocess(audio_path)
features = self.model.forward(audio_tensors)
features = self.model._model(audio_tensors)
Outputs = NamedTuple('Outputs', [('embs', numpy.ndarray)])
return Outputs(features.detach().numpy())

Loading…
Cancel
Save