logo
Browse Source

Update to support pipeline with time-window

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
885930423a
  1. 6
      __init__.py
  2. 0
      mel_features.py
  3. 1
      pytorch/.gitattributes
  4. 24
      pytorch/__init__.py
  5. 51
      pytorch/model.py
  6. 84
      vggish.bak
  7. 0
      vggish.pth
  8. 83
      vggish.py
  9. 10
      vggish_input.py
  10. 0
      vggish_params.py

6
__init__.py

@ -11,3 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .vggish import Vggish
def vggish(weights_path: str = None, framework: str = 'pytorch'):
return Vggish(weights_path, framework)

0
pytorch/mel_features.py → mel_features.py

1
pytorch/.gitattributes

@ -1 +0,0 @@
vggish.pth filter=lfs diff=lfs merge=lfs -text

24
pytorch/__init__.py

@ -1,24 +0,0 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
# For requirements.
try:
import timm
except ModuleNotFoundError:
os.system('pip install timm')
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

51
pytorch/model.py

@ -1,51 +0,0 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch import nn
import torch
import sys
from pathlib import Path
from towhee.models.vggish.torch_vggish import VGG
sys.path.append(str(Path(__file__).parent))
import vggish_input
class Model(nn.Module):
"""
PyTorch model class
"""
def __init__(self, weights_path: str=None):
super().__init__()
self._model = VGG()
if not weights_path:
path = str(Path(__file__).parent)
weights_path = path + '/vggish.pth'
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
self._model.load_state_dict(state_dict)
self._model.eval()
def forward(self, x):
return self._model(x)
def preprocess(self, audio_path: str):
audio_tensors = vggish_input.wavfile_to_examples(audio_path)
return audio_tensors
def train(self):
"""
For training model
"""
pass

84
vggish.bak

@ -0,0 +1,84 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import warnings
import os
import sys
import numpy
from pathlib import Path
from typing import Union
import torch
from towhee.operator.base import NNOperator
from towhee.models.vggish.torch_vggish import VGG
from towhee import register
sys.path.append(str(Path(__file__).parent))
import vggish_input
warnings.filterwarnings('ignore')
log = logging.getLogger()
@register(output_schema=['vec'])
class Vggish(NNOperator):
"""
"""
def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None:
super().__init__(framework=framework)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = VGG()
if not weights_path:
path = str(Path(__file__).parent)
weights_path = os.path.join(path, 'vggish.pth')
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
self.model.load_state_dict(state_dict)
self.model.eval()
self.model.to(self.device)
def __call__(self, audio: Union[str, numpy.ndarray], sr: int = None) -> numpy.ndarray:
audio_tensors = self.preprocess(audio, sr).to(self.device)
features = self.model(audio_tensors)
outs = features.to("cpu")
return outs.detach().numpy()
def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None):
if isinstance(audio, str):
audio_tensors = vggish_input.wavfile_to_examples(audio)
elif isinstance(audio, numpy.ndarray):
try:
audio = audio.transpose()
audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True)
except Exception as e:
log.error("Fail to load audio data.")
raise e
else:
log.error(f"Invalid input audio: {type(audio)}")
return audio_tensors
# if __name__ == '__main__':
# encoder = Vggish()
#
# # audio_path = '/path/to/audio'
# # vec = encoder(audio_path)
#
# audio_data = numpy.zeros((2, 441344))
# sample_rate = 44100
# vec = encoder(audio_data, sample_rate)
# print(vec)

0
pytorch/vggish.pth → vggish.pth

83
vggish.py

@ -0,0 +1,83 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import warnings
import os
import sys
import numpy
from pathlib import Path
from typing import Union, List, NamedTuple
import torch
from towhee.operator.base import NNOperator
from towhee.models.vggish.torch_vggish import VGG
from towhee import register
sys.path.append(str(Path(__file__).parent))
import vggish_input
warnings.filterwarnings('ignore')
log = logging.getLogger()
AudioOutput = NamedTuple('AudioOutput', [('vec', 'ndarray')])
class Vggish(NNOperator):
"""
"""
def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None:
super().__init__(framework=framework)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = VGG()
if not weights_path:
path = str(Path(__file__).parent)
weights_path = os.path.join(path, 'vggish.pth')
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
self.model.load_state_dict(state_dict)
self.model.eval()
self.model.to(self.device)
def __call__(self, datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int')])]) -> numpy.ndarray:
audios = numpy.stack([item.audio for item in datas])
sr = datas[0].sample_rate
audio_array = numpy.reshape(audios, (-1, 2))
audio_tensors = self.preprocess(audio_array, sr).to(self.device)
features = self.model(audio_tensors)
outs = features.to("cpu")
return [AudioOutput(outs.detach().numpy())]
def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None):
if audio.dtype == numpy.int32:
samples = audio / 2147483648.0
elif audio.dtype == numpy.int16:
samples = audio / 32768.0
return vggish_input.waveform_to_examples(samples, sr, return_tensor=True)
# if __name__ == '__main__':
# encoder = Vggish()
#
# # audio_path = '/path/to/audio'
# # vec = encoder(audio_path)
#
# audio_data = numpy.zeros((2, 441344))
# sample_rate = 44100
# vec = encoder(audio_data, sample_rate)
# print(vec)

10
pytorch/vggish_input.py → vggish_input.py

@ -17,14 +17,13 @@
# Modification: Return torch tensors rather than numpy arrays
import torch
import numpy as np
import resampy
import mel_features
import vggish_params
import soundfile as sf
import torchaudio
def waveform_to_examples(data, sample_rate, return_tensor=True):
@ -92,7 +91,6 @@ def wavfile_to_examples(wav_file, return_tensor=True):
Returns:
See waveform_to_examples.
"""
wav_data, sr = sf.read(wav_file, dtype='int16')
assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
return waveform_to_examples(samples, sr, return_tensor)
data, sr = torchaudio.load(wav_file)
wav_data = data.detach().numpy().transpose()
return waveform_to_examples(wav_data, sr, return_tensor)

0
pytorch/vggish_params.py → vggish_params.py

Loading…
Cancel
Save