|
@ -12,12 +12,13 @@ |
|
|
# See the License for the specific language governing permissions and |
|
|
# See the License for the specific language governing permissions and |
|
|
# limitations under the License. |
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
|
|
|
from torch import nn |
|
|
import torch |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
import sys |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
from towhee.models.vggish.torch_vggish import VGG |
|
|
|
|
|
|
|
|
sys.path.append(str(Path(__file__).parent)) |
|
|
sys.path.append(str(Path(__file__).parent)) |
|
|
|
|
|
|
|
|
import vggish_input |
|
|
import vggish_input |
|
@ -26,39 +27,18 @@ class Model(nn.Module): |
|
|
""" |
|
|
""" |
|
|
PyTorch model class |
|
|
PyTorch model class |
|
|
""" |
|
|
""" |
|
|
def __init__(self): |
|
|
|
|
|
|
|
|
def __init__(self, weights_path: str=None): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self.features = nn.Sequential( |
|
|
|
|
|
nn.Conv2d(1, 64, 3, 1, 1), |
|
|
|
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
nn.MaxPool2d(2, 2), |
|
|
|
|
|
nn.Conv2d(64, 128, 3, 1, 1), |
|
|
|
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
nn.MaxPool2d(2, 2), |
|
|
|
|
|
nn.Conv2d(128, 256, 3, 1, 1), |
|
|
|
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
nn.Conv2d(256, 256, 3, 1, 1), |
|
|
|
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
nn.MaxPool2d(2, 2), |
|
|
|
|
|
nn.Conv2d(256, 512, 3, 1, 1), |
|
|
|
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
nn.Conv2d(512, 512, 3, 1, 1), |
|
|
|
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
nn.MaxPool2d(2, 2)) |
|
|
|
|
|
self.embeddings = nn.Sequential( |
|
|
|
|
|
nn.Linear(512 * 24, 4096), |
|
|
|
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
nn.Linear(4096, 4096), |
|
|
|
|
|
nn.ReLU(inplace=True), |
|
|
|
|
|
nn.Linear(4096, 128), |
|
|
|
|
|
#nn.ReLU(inplace=True) |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
self._model = VGG() |
|
|
|
|
|
if not weights_path: |
|
|
|
|
|
path = str(Path(__file__).parent) |
|
|
|
|
|
weights_path = path + '/vggish.pth' |
|
|
|
|
|
state_dict = torch.load(weights_path, map_location=torch.device('cpu')) |
|
|
|
|
|
self._model.load_state_dict(state_dict) |
|
|
|
|
|
self._model.eval() |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
def forward(self, x): |
|
|
x = self.features(x).permute(0, 2, 3, 1).contiguous() |
|
|
|
|
|
x = x.view(x.size(0), -1) |
|
|
|
|
|
x = self.embeddings(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
return self._model(x) |
|
|
|
|
|
|
|
|
def preprocess(self, audio_path: str): |
|
|
def preprocess(self, audio_path: str): |
|
|
audio_tensors = vggish_input.wavfile_to_examples(audio_path) |
|
|
audio_tensors = vggish_input.wavfile_to_examples(audio_path) |
|
|