diff --git a/pytorch/model.py b/pytorch/model.py index 02bff05..7c8cdf6 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from torch import nn import torch -import torch.nn as nn -import numpy as np import sys from pathlib import Path +from towhee.models.vggish.torch_vggish import VGG + sys.path.append(str(Path(__file__).parent)) import vggish_input @@ -26,39 +27,18 @@ class Model(nn.Module): """ PyTorch model class """ - def __init__(self): + def __init__(self, weights_path: str=None): super().__init__() - self.features = nn.Sequential( - nn.Conv2d(1, 64, 3, 1, 1), - nn.ReLU(inplace=True), - nn.MaxPool2d(2, 2), - nn.Conv2d(64, 128, 3, 1, 1), - nn.ReLU(inplace=True), - nn.MaxPool2d(2, 2), - nn.Conv2d(128, 256, 3, 1, 1), - nn.ReLU(inplace=True), - nn.Conv2d(256, 256, 3, 1, 1), - nn.ReLU(inplace=True), - nn.MaxPool2d(2, 2), - nn.Conv2d(256, 512, 3, 1, 1), - nn.ReLU(inplace=True), - nn.Conv2d(512, 512, 3, 1, 1), - nn.ReLU(inplace=True), - nn.MaxPool2d(2, 2)) - self.embeddings = nn.Sequential( - nn.Linear(512 * 24, 4096), - nn.ReLU(inplace=True), - nn.Linear(4096, 4096), - nn.ReLU(inplace=True), - nn.Linear(4096, 128), - #nn.ReLU(inplace=True) - ) + self._model = VGG() + if not weights_path: + path = str(Path(__file__).parent) + weights_path = path + '/vggish.pth' + state_dict = torch.load(weights_path, map_location=torch.device('cpu')) + self._model.load_state_dict(state_dict) + self._model.eval() def forward(self, x): - x = self.features(x).permute(0, 2, 3, 1).contiguous() - x = x.view(x.size(0), -1) - x = self.embeddings(x) - return x + return self._model(x) def preprocess(self, audio_path: str): audio_tensors = vggish_input.wavfile_to_examples(audio_path) diff --git a/torch_vggish.py b/torch_vggish.py index 388880d..f1b67a3 100644 --- a/torch_vggish.py +++ b/torch_vggish.py @@ -29,7 +29,7 @@ class TorchVggish(Operator): """ """ - def __init__(self, framework: str = 'pytorch') -> None: + def __init__(self, framework: str = 'pytorch', weights_path: str=None) -> None: super().__init__() if framework == 'pytorch': import importlib.util @@ -38,13 +38,10 @@ class TorchVggish(Operator): spec = importlib.util.spec_from_file_location(opname, path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - self.model = module.Model() - - path = str(Path(__file__).parent) - self.model.load_state_dict(torch.load(path + '/pytorch/vggish.pth', map_location=torch.device('cpu'))) + self.model = module.Model(weights_path) def __call__(self, audio_path: str) -> NamedTuple('Outputs', [('embs', numpy.ndarray)]): audio_tensors = self.model.preprocess(audio_path) - features = self.model.forward(audio_tensors) + features = self.model._model(audio_tensors) Outputs = NamedTuple('Outputs', [('embs', numpy.ndarray)]) return Outputs(features.detach().numpy())