Switch to side-by-side view

--- a
+++ b/alignment/gentle_alignment.py
@@ -0,0 +1,125 @@
+import json
+import logging
+import multiprocessing
+import os
+
+import gentle
+import scipy.io.wavfile as sciwav
+
+DISFLUENCIES = {'uh', 'um'}  # set of disfluencies
+RESOURCES = gentle.Resources()
+N_THREADS = multiprocessing.cpu_count()
+
+logging.getLogger().setLevel("INFO")
+
+
+def _on_progress(p):
+    for k, v in p.items():
+        logging.debug("%s: %s" % (k, v))
+
+
+def _get_key_val_pair(line):
+    line_split = line[:-1].split()
+    word = line_split[0]
+    if word[-1] == ')':
+        word = word.split('(')[0]
+
+    word = word.lower()
+    key = [word]
+    val = []
+    for phoneme in line_split[1:]:
+        val.append(phoneme.lower())
+        if phoneme[-1].isdigit():
+            phoneme = phoneme[:-1]
+
+        phoneme = phoneme.lower()
+        key.append(phoneme)
+
+    key = " ".join(key)
+    val = tuple(val)
+    return key, val
+
+
+def _create_dict():
+    phoneme_alignment_dict = dict()
+
+    cmu_file = open('cmudict-0.7b.txt', 'r')
+    for line in cmu_file:
+        key, val = _get_key_val_pair(line)
+        phoneme_alignment_dict[key] = val
+
+    return phoneme_alignment_dict
+
+
+def align_audio(wav_path, transcript):
+    with gentle.resampled(wav_path) as wavfile:
+        print("starting alignment {}".format(wav_path))
+        aligner = gentle.ForcedAligner(RESOURCES, transcript, nthreads=N_THREADS, disfluency=False,
+                                       conservative=False, disfluencies=DISFLUENCIES)
+        result = aligner.transcribe(wavfile, progress_cb=_on_progress, logging=logging)
+        result_json = json.loads(result.to_json())
+
+    return result_json
+
+
+def main(input_csv, phoneme_path, output_csv, wav_root):
+    alignment_dict = _create_dict()
+
+    in_file = open(input_csv, 'r')
+    out_file = open(output_csv, 'w')
+
+    for line in in_file:
+        id_, wav_file, transcript = line.split('\t')
+        wav_file = wav_root + '/' + wav_file
+        sr, signal = sciwav.read(wav_file)
+        alignment = align_audio(wav_file, transcript)
+
+        for word in alignment['words']:
+            if word['case'] != 'success':
+                continue
+
+            start_time, end_time = word['start'], word['end']
+            aligned_word = word['alignedWord']
+            key = [aligned_word.lower()]
+            for phoneme in word['phones']:
+                phone = phoneme['phone']
+                key.append(phone.split('_')[0])
+
+            key = ' '.join(key)
+            phoneme_tuple = alignment_dict.get(key, ())
+
+            if len(phoneme_tuple) == 0:
+                print('word: {} not in dict, skipping...'.format(word))
+                continue
+
+            if len(phoneme_tuple) != len(word['phones']):
+                print('word: {} not aligned properly, skipping...'.format(word))
+                continue
+
+            # now map phonemes and slice wav
+            for i, phoneme in enumerate(word['phones']):
+                phone_start = start_time
+                phone_end = phone_start + phoneme['duration']
+                # check if vowel phoneme
+                if phoneme_tuple[i][-1].isdigit():
+
+                    file_name = id_ + '_' + aligned_word + '_' + phoneme_tuple[i] + '_' + \
+                                str(int(phone_start * 1000)) + '_' + str(int(phone_end * 1000)) + '.wav'
+
+                    start_frame, end_frame = int(phone_start * sr), int(phone_end * sr)
+                    sciwav.write(phoneme_path + '/' + file_name, sr, signal[start_frame:end_frame])
+                    out_file.write(file_name + '\t' + id_ + '\t' + aligned_word + '\t' + phoneme_tuple[i] + '\n')
+
+                start_time = phone_end
+
+        print('done alignment and slicing for file: {}'.format(wav_file))
+
+    in_file.close()
+    out_file.close()
+
+
+if __name__ == '__main__':
+    main(input_csv=os.getenv('input_csv'),
+         phoneme_path=os.getenv('phoneme_path'),
+         output_csv=os.getenv('output_csv'),
+         wav_root=os.getenv('wav_root'))