Switch to side-by-side view

--- a
+++ b/tools/data/build_audio_features.py
@@ -0,0 +1,315 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import glob
+import os
+import os.path as osp
+import sys
+from multiprocessing import Pool
+
+import mmcv
+import numpy as np
+from scipy.io import wavfile
+
+try:
+    import librosa
+    import lws
+except ImportError:
+    print('Please import librosa, lws first.')
+
+sys.path.append('..')
+
+SILENCE_THRESHOLD = 2
+FMIN = 125
+FMAX = 7600
+FRAME_SHIFT_MS = None
+MIN_LEVEL_DB = -100
+REF_LEVEL_DB = 20
+RESCALING = True
+RESCALING_MAX = 0.999
+ALLOW_CLIPPING_IN_NORMALIZATION = True
+LOG_SCALE_MIN = -32.23619130191664
+NORM_AUDIO = True
+
+
+class AudioTools:
+    """All methods related to audio feature extraction. Code Reference:
+
+            <https://github.com/r9y9/deepvoice3_pytorch>`_,
+            <https://pypi.org/project/lws/1.2.6/>`_.
+
+    Args:
+        frame_rate (int): The frame rate per second of the video. Default: 30.
+        sample_rate (int): The sample rate for audio sampling. Default: 16000.
+        num_mels (int): Number of channels of the melspectrogram. Default: 80.
+        fft_size (int): fft_size / sample_rate is window size. Default: 1280.
+        hop_size (int): hop_size / sample_rate is step size. Default: 320.
+    """
+
+    def __init__(self,
+                 frame_rate=30,
+                 sample_rate=16000,
+                 num_mels=80,
+                 fft_size=1280,
+                 hop_size=320,
+                 spectrogram_type='lws'):
+        self.frame_rate = frame_rate
+        self.sample_rate = sample_rate
+        self.silence_threshold = SILENCE_THRESHOLD
+        self.num_mels = num_mels
+        self.fmin = FMIN
+        self.fmax = FMAX
+        self.fft_size = fft_size
+        self.hop_size = hop_size
+        self.frame_shift_ms = FRAME_SHIFT_MS
+        self.min_level_db = MIN_LEVEL_DB
+        self.ref_level_db = REF_LEVEL_DB
+        self.rescaling = RESCALING
+        self.rescaling_max = RESCALING_MAX
+        self.allow_clipping_in_normalization = ALLOW_CLIPPING_IN_NORMALIZATION
+        self.log_scale_min = LOG_SCALE_MIN
+        self.norm_audio = NORM_AUDIO
+        self.spectrogram_type = spectrogram_type
+        assert spectrogram_type in ['lws', 'librosa']
+
+    def load_wav(self, path):
+        """Load an audio file into numpy array."""
+        return librosa.core.load(path, sr=self.sample_rate)[0]
+
+    @staticmethod
+    def audio_normalize(samples, desired_rms=0.1, eps=1e-4):
+        """RMS normalize the audio data."""
+        rms = np.maximum(eps, np.sqrt(np.mean(samples**2)))
+        samples = samples * (desired_rms / rms)
+        return samples
+
+    def generate_spectrogram_magphase(self, audio, with_phase=False):
+        """Separate a complex-valued spectrogram D into its magnitude (S)
+
+            and phase (P) components, so that D = S * P.
+
+        Args:
+            audio (np.ndarray): The input audio signal.
+            with_phase (bool): Determines whether to output the
+                phase components. Default: False.
+
+        Returns:
+            np.ndarray: magnitude and phase component of the complex-valued
+                spectrogram.
+        """
+        spectro = librosa.core.stft(
+            audio,
+            hop_length=self.get_hop_size(),
+            n_fft=self.fft_size,
+            center=True)
+        spectro_mag, spectro_phase = librosa.core.magphase(spectro)
+        spectro_mag = np.expand_dims(spectro_mag, axis=0)
+        if with_phase:
+            spectro_phase = np.expand_dims(np.angle(spectro_phase), axis=0)
+            return spectro_mag, spectro_phase
+
+        return spectro_mag
+
+    def save_wav(self, wav, path):
+        """Save the wav to disk."""
+        # 32767 = (2 ^ 15 - 1) maximum of int16
+        wav *= 32767 / max(0.01, np.max(np.abs(wav)))
+        wavfile.write(path, self.sample_rate, wav.astype(np.int16))
+
+    def trim(self, quantized):
+        """Trim the audio wavfile."""
+        start, end = self.start_and_end_indices(quantized,
+                                                self.silence_threshold)
+        return quantized[start:end]
+
+    def adjust_time_resolution(self, quantized, mel):
+        """Adjust time resolution by repeating features.
+
+        Args:
+            quantized (np.ndarray): (T,)
+            mel (np.ndarray): (N, D)
+
+        Returns:
+            tuple: Tuple of (T,) and (T, D)
+        """
+        assert quantized.ndim == 1
+        assert mel.ndim == 2
+
+        upsample_factor = quantized.size // mel.shape[0]
+        mel = np.repeat(mel, upsample_factor, axis=0)
+        n_pad = quantized.size - mel.shape[0]
+        if n_pad != 0:
+            assert n_pad > 0
+            mel = np.pad(
+                mel, [(0, n_pad), (0, 0)], mode='constant', constant_values=0)
+
+        # trim
+        start, end = self.start_and_end_indices(quantized,
+                                                self.silence_threshold)
+
+        return quantized[start:end], mel[start:end, :]
+
+    @staticmethod
+    def start_and_end_indices(quantized, silence_threshold=2):
+        """Trim the audio file when reaches the silence threshold."""
+        for start in range(quantized.size):
+            if abs(quantized[start] - 127) > silence_threshold:
+                break
+        for end in range(quantized.size - 1, 1, -1):
+            if abs(quantized[end] - 127) > silence_threshold:
+                break
+
+        assert abs(quantized[start] - 127) > silence_threshold
+        assert abs(quantized[end] - 127) > silence_threshold
+
+        return start, end
+
+    def melspectrogram(self, y):
+        """Generate the melspectrogram."""
+        D = self._lws_processor().stft(y).T
+        S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.ref_level_db
+        if not self.allow_clipping_in_normalization:
+            assert S.max() <= 0 and S.min() - self.min_level_db >= 0
+        return self._normalize(S)
+
+    def get_hop_size(self):
+        """Calculate the hop size."""
+        hop_size = self.hop_size
+        if hop_size is None:
+            assert self.frame_shift_ms is not None
+            hop_size = int(self.frame_shift_ms / 1000 * self.sample_rate)
+        return hop_size
+
+    def _lws_processor(self):
+        """Perform local weighted sum.
+
+        Please refer to <https://pypi.org/project/lws/1.2.6/>`_.
+        """
+        return lws.lws(self.fft_size, self.get_hop_size(), mode='speech')
+
+    @staticmethod
+    def lws_num_frames(length, fsize, fshift):
+        """Compute number of time frames of lws spectrogram.
+
+        Please refer to <https://pypi.org/project/lws/1.2.6/>`_.
+        """
+        pad = (fsize - fshift)
+        if length % fshift == 0:
+            M = (length + pad * 2 - fsize) // fshift + 1
+        else:
+            M = (length + pad * 2 - fsize) // fshift + 2
+        return M
+
+    def lws_pad_lr(self, x, fsize, fshift):
+        """Compute left and right padding lws internally uses.
+
+        Please refer to <https://pypi.org/project/lws/1.2.6/>`_.
+        """
+        M = self.lws_num_frames(len(x), fsize, fshift)
+        pad = (fsize - fshift)
+        T = len(x) + 2 * pad
+        r = (M - 1) * fshift + fsize - T
+        return pad, pad + r
+
+    def _linear_to_mel(self, spectrogram):
+        """Warp linear scale spectrograms to the mel scale.
+
+        Please refer to <https://github.com/r9y9/deepvoice3_pytorch>`_
+        """
+        global _mel_basis
+        _mel_basis = self._build_mel_basis()
+        return np.dot(_mel_basis, spectrogram)
+
+    def _build_mel_basis(self):
+        """Build mel filters.
+
+        Please refer to <https://github.com/r9y9/deepvoice3_pytorch>`_
+        """
+        assert self.fmax <= self.sample_rate // 2
+        return librosa.filters.mel(
+            self.sample_rate,
+            self.fft_size,
+            fmin=self.fmin,
+            fmax=self.fmax,
+            n_mels=self.num_mels)
+
+    def _amp_to_db(self, x):
+        min_level = np.exp(self.min_level_db / 20 * np.log(10))
+        return 20 * np.log10(np.maximum(min_level, x))
+
+    @staticmethod
+    def _db_to_amp(x):
+        return np.power(10.0, x * 0.05)
+
+    def _normalize(self, S):
+        return np.clip((S - self.min_level_db) / -self.min_level_db, 0, 1)
+
+    def _denormalize(self, S):
+        return (np.clip(S, 0, 1) * -self.min_level_db) + self.min_level_db
+
+    def read_audio(self, audio_path):
+        wav = self.load_wav(audio_path)
+        if self.norm_audio:
+            wav = self.audio_normalize(wav)
+        else:
+            wav = wav / np.abs(wav).max()
+
+        return wav
+
+    def audio_to_spectrogram(self, wav):
+        if self.spectrogram_type == 'lws':
+            spectrogram = self.melspectrogram(wav).astype(np.float32).T
+        elif self.spectrogram_type == 'librosa':
+            spectrogram = self.generate_spectrogram_magphase(wav)
+        return spectrogram
+
+
+def extract_audio_feature(wav_path, audio_tools, mel_out_dir):
+    file_name, _ = osp.splitext(osp.basename(wav_path))
+    # Write the spectrograms to disk:
+    mel_filename = os.path.join(mel_out_dir, file_name + '.npy')
+    if not os.path.exists(mel_filename):
+        try:
+            wav = audio_tools.read_audio(wav_path)
+
+            spectrogram = audio_tools.audio_to_spectrogram(wav)
+
+            np.save(
+                mel_filename,
+                spectrogram.astype(np.float32),
+                allow_pickle=False)
+
+        except BaseException:
+            print(f'Read audio [{wav_path}] failed.')
+
+
+if __name__ == '__main__':
+    audio_tools = AudioTools(
+        fft_size=512, hop_size=256)  # window_size:32ms hop_size:16ms
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('audio_home_path', type=str)
+    parser.add_argument('spectrogram_save_path', type=str)
+    parser.add_argument('--level', type=int, default=1)
+    parser.add_argument('--ext', default='.m4a')
+    parser.add_argument('--num-workers', type=int, default=4)
+    parser.add_argument('--part', type=str, default='1/1')
+    args = parser.parse_args()
+
+    mmcv.mkdir_or_exist(args.spectrogram_save_path)
+
+    files = glob.glob(
+        osp.join(args.audio_home_path, '*/' * args.level, '*' + args.ext))
+    print(f'found {len(files)} files.')
+    files = sorted(files)
+    if args.part is not None:
+        [this_part, num_parts] = [int(i) for i in args.part.split('/')]
+        part_len = len(files) // num_parts
+
+    p = Pool(args.num_workers)
+    for file in files[part_len * (this_part - 1):(
+            part_len * this_part) if this_part != num_parts else len(files)]:
+        p.apply_async(
+            extract_audio_feature,
+            args=(file, audio_tools, args.spectrogram_save_path))
+    p.close()
+    p.join()