--- a +++ b/alignment/gentle_alignment.py @@ -0,0 +1,125 @@ +import json +import logging +import multiprocessing +import os + +import gentle +import scipy.io.wavfile as sciwav + +DISFLUENCIES = {'uh', 'um'} # set of disfluencies +RESOURCES = gentle.Resources() +N_THREADS = multiprocessing.cpu_count() + +logging.getLogger().setLevel("INFO") + + +def _on_progress(p): + for k, v in p.items(): + logging.debug("%s: %s" % (k, v)) + + +def _get_key_val_pair(line): + line_split = line[:-1].split() + word = line_split[0] + if word[-1] == ')': + word = word.split('(')[0] + + word = word.lower() + key = [word] + val = [] + for phoneme in line_split[1:]: + val.append(phoneme.lower()) + if phoneme[-1].isdigit(): + phoneme = phoneme[:-1] + + phoneme = phoneme.lower() + key.append(phoneme) + + key = " ".join(key) + val = tuple(val) + return key, val + + +def _create_dict(): + phoneme_alignment_dict = dict() + + cmu_file = open('cmudict-0.7b.txt', 'r') + for line in cmu_file: + key, val = _get_key_val_pair(line) + phoneme_alignment_dict[key] = val + + return phoneme_alignment_dict + + +def align_audio(wav_path, transcript): + with gentle.resampled(wav_path) as wavfile: + print("starting alignment {}".format(wav_path)) + aligner = gentle.ForcedAligner(RESOURCES, transcript, nthreads=N_THREADS, disfluency=False, + conservative=False, disfluencies=DISFLUENCIES) + result = aligner.transcribe(wavfile, progress_cb=_on_progress, logging=logging) + result_json = json.loads(result.to_json()) + + return result_json + + +def main(input_csv, phoneme_path, output_csv, wav_root): + alignment_dict = _create_dict() + + in_file = open(input_csv, 'r') + out_file = open(output_csv, 'w') + + for line in in_file: + id_, wav_file, transcript = line.split('\t') + wav_file = wav_root + '/' + wav_file + sr, signal = sciwav.read(wav_file) + alignment = align_audio(wav_file, transcript) + + for word in alignment['words']: + if word['case'] != 'success': + continue + + start_time, end_time = word['start'], word['end'] + aligned_word = word['alignedWord'] + key = [aligned_word.lower()] + for phoneme in word['phones']: + phone = phoneme['phone'] + key.append(phone.split('_')[0]) + + key = ' '.join(key) + phoneme_tuple = alignment_dict.get(key, ()) + + if len(phoneme_tuple) == 0: + print('word: {} not in dict, skipping...'.format(word)) + continue + + if len(phoneme_tuple) != len(word['phones']): + print('word: {} not aligned properly, skipping...'.format(word)) + continue + + # now map phonemes and slice wav + for i, phoneme in enumerate(word['phones']): + phone_start = start_time + phone_end = phone_start + phoneme['duration'] + # check if vowel phoneme + if phoneme_tuple[i][-1].isdigit(): + + file_name = id_ + '_' + aligned_word + '_' + phoneme_tuple[i] + '_' + \ + str(int(phone_start * 1000)) + '_' + str(int(phone_end * 1000)) + '.wav' + + start_frame, end_frame = int(phone_start * sr), int(phone_end * sr) + sciwav.write(phoneme_path + '/' + file_name, sr, signal[start_frame:end_frame]) + out_file.write(file_name + '\t' + id_ + '\t' + aligned_word + '\t' + phoneme_tuple[i] + '\n') + + start_time = phone_end + + print('done alignment and slicing for file: {}'.format(wav_file)) + + in_file.close() + out_file.close() + + +if __name__ == '__main__': + main(input_csv=os.getenv('input_csv'), + phoneme_path=os.getenv('phoneme_path'), + output_csv=os.getenv('output_csv'), + wav_root=os.getenv('wav_root'))