# Copyright 2021 Zilliz. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from torch import nn import torch import sys from pathlib import Path from towhee.models.vggish.torch_vggish import VGG sys.path.append(str(Path(__file__).parent)) import vggish_input class Model(nn.Module): """ PyTorch model class """ def __init__(self, weights_path: str=None): super().__init__() self._model = VGG() if not weights_path: path = str(Path(__file__).parent) weights_path = path + '/vggish.pth' state_dict = torch.load(weights_path, map_location=torch.device('cpu')) self._model.load_state_dict(state_dict) self._model.eval() def forward(self, x): return self._model(x) def preprocess(self, audio_path: str): audio_tensors = vggish_input.wavfile_to_examples(audio_path) return audio_tensors def train(self): """ For training model """ pass