faster-whisperでリアルタイム文字起こし

faster-whisperを用いて、マイク入力やスピーカー出力をリアルタイムに文字起こしするPythonツールを作成したので公開する。
faster-whisperはOpenAIが2022年9月に公開したSpeech-to-Text AI Whisperの軽量版。本投稿では2023年11月に公開されたfaster-whisper-large-v3を扱う。

本ツールでは次の機能を実装した。

  • マイク入力からのリアルタイム文字起こし
  • スピーカー出力のリアルタイム文字起こし
  • 音声ファイルの文字起こし
  • 入力音声の保存
  • 生成文章のSRT形式での出力

目次

ソースコード

環境構築

Python 3.12で動作確認済み。

対象環境はWindows。
スピーカー出力を拾うためにWASAPI (Windows Audio Session API)を使っている。

必要パッケージ

pipなどからインストールする。

  • torch
  • pyaudioWPatch
  • pydub

ディレクトリー構成

src/
├── model/
│   └── faster-whisper-large-v3
├── module/
│   ├── audiostream.py
│   ├── vad.py
│   ├── transcribe.py
│   └── util.py
├── ffmpeg.exe
└── ffprobe.exe
  • ffmpegをダウンロードして、ffmpeg.exeとffprobe.exeをルートフォルダーに配置する。
  • faster-whisper-large-v3の学習済モデルファイルをダウンロードして、フォルダーごとmodelフォルダーの下に配置する。
  • moduleフォルダーには下記のスクリプトを配置する。

module/audiostream.py

AudioStreamは音声を入力してchunkと呼ぶ小さい単位ごとに処理するクラス。
マイク入力(InputAudioStream)・スピーカー出力(OutputAudioStream)・音声ファイル(FileAudioStream)の3種類に対応。

SpeechDequeでは後述のVADを用いることで、streamをfaster-whisperへの入力として適切な形になるように無音区間前後で区切り、dequeに保存している。

from collections import deque
import threading
import wave

import numpy as np
import pyaudiowpatch as pyaudio
from pydub import AudioSegment
from pydub.utils import make_chunks


class AudioStream:
    def __init__(self, vad=None, wave_save_path=None, frames_per_chunk=None):
        self.frames_per_chunk = vad.frames_per_chunk if vad is not None else frames_per_chunk
        self.speech_deque = SpeechDeque(vad, self.frames_per_chunk)
        self.frame_count = 0
        self.wave_save_path = wave_save_path

        self.p = pyaudio.PyAudio()

    def stream_callback(self, in_data, frame_count=None, time_info=None, status=None):
        in_data = self.subsample(in_data)
        if self.ww is not None:
            self.ww.writeframes(in_data)
        self.speech_deque.put_data(in_data, self.frame_count)

        self.frame_count += self.frames_per_chunk

        return in_data, pyaudio.paContinue

    def subsample(self, in_data):
        in_data2 = []
        size = len(in_data)/2
        for i in range(self.frames_per_chunk):
            offset = int(i * size / self.frames_per_chunk)
            in_data2.append(in_data[2*offset: 2*offset+2])
        return b''.join(in_data2)

    def get_speech(self):
        if len(self.speech_deque) > 0:
            return {
                **self.speech_deque.popleft(),  # audio, frame_from, frame_to
                'is_buffer': False
            }
        else:  # is buffer
            return {
                'audio': self.speech_deque.get_buffer(),
                'frame_from': self.speech_deque.frame_count,
                'frame_to': self.frame_count,
                'is_buffer': True
            }

    def __enter__(self):
        if self.wave_save_path is not None:
            self.ww = wave.open(self.wave_save_path, 'wb')
            self.ww.setnchannels(1)
            self.ww.setframerate(16000)
            self.ww.setnframes(self.frames_per_chunk)
            self.ww.setsampwidth(2)
        else:
            self.ww = None
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.ww is not None:
            self.ww.close()

    @staticmethod
    def open(mode=None, device_id=None, frames_per_buffer=None, vad=None, wave_save_path=None, path=None):
        if mode == 'input':
            return InputAudioStream(device_id=device_id, vad=vad, wave_save_path=wave_save_path, frames_per_chunk=frames_per_buffer)
        elif mode == 'output':
            return OutputAudioStream(device_id=device_id, vad=vad, wave_save_path=wave_save_path, frames_per_chunk=frames_per_buffer)
        elif mode == 'file':
            return FileAudioStream(file_path=path, vad=vad, wave_save_path=wave_save_path, frames_per_chunk=frames_per_buffer)


class InputAudioStream(AudioStream):
    def __init__(self, device_id=None, vad=None, wave_save_path=None, frames_per_chunk=None):
        super().__init__(vad, wave_save_path, frames_per_chunk)

        if device_id is None:
            device = self.p.get_default_input_device_info()
        else:
            device = self.p.get_device_info_by_index(device_id)

        self.pyaudio_option = {
            'format': pyaudio.paInt16,
            'channels': 1,
            'rate': int(device['defaultSampleRate']),
            'frames_per_buffer': int(self.frames_per_chunk * device['defaultSampleRate'] / 16000),
            'input': True,
            'input_device_index': device['index'],
            'stream_callback': self.stream_callback,
        }

    def __enter__(self):
        super().__enter__()
        self.stream = self.p.open(**self.pyaudio_option)
        self.stream.start_stream()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        super().__exit__(exc_type, exc_val, exc_tb)
        self.stream.stop_stream()
        self.stream.close()


class OutputAudioStream(AudioStream):
    def __init__(self, device_id=None, vad=None, wave_save_path=None, frames_per_chunk=None):
        super().__init__(vad, wave_save_path, frames_per_chunk)

        # use WASAPI
        if device_id is None:
            device_id = self.p.get_host_api_info_by_type(pyaudio.paWASAPI)['defaultOutputDevice']
        device = self.p.get_device_info_by_index(device_id)
        if not device['isLoopbackDevice']:
            for loopback in self.p.get_loopback_device_info_generator():
                if device['name'] in loopback['name']:
                    device = loopback
                    break

        self.pyaudio_option = {
            'format': pyaudio.paInt16,
            'channels': 1,
            'rate': int(device['defaultSampleRate']),
            'frames_per_buffer': int(self.frames_per_chunk * device['defaultSampleRate'] / 16000),
            'input': True,
            'input_device_index': device['index'],
            'stream_callback': self.stream_callback,
        }

    def __enter__(self):
        super().__enter__()
        self.stream = self.p.open(**self.pyaudio_option)
        self.stream.start_stream()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        super().__exit__(exc_type, exc_val, exc_tb)
        self.stream.stop_stream()
        self.stream.close()


class FileAudioStream(AudioStream):
    def __init__(self, file_path, vad=None, wave_save_path=None, frames_per_chunk=None):
        super().__init__(vad, wave_save_path, frames_per_chunk)
        self.seg = AudioSegment.from_file(file_path)
        self.seg = self.seg.set_frame_rate(16000)

    def __enter__(self):
        super().__enter__()
        threading.Thread(target=self.stream_write).start()  # threading
        return self

    def stream_write(self):
        for chunk in make_chunks(self.seg, self.frames_per_chunk * 2 * self.seg.channels / 64):  # 64 = seg unit size
            chunk = chunk.get_array_of_samples()[::self.seg.channels]  # to monaural
            chunk = b''.join([v.to_bytes(2, signed=True, byteorder='little') for v in chunk])
            self.stream_callback(chunk)
        self.speech_deque.append({
            'audio': None,
            'frame_from': 0,
            'frame_to': 0
        })

    def __exit__(self, exc_type, exc_val, exc_tb):
        super().__exit__(exc_type, exc_val, exc_tb)


class SpeechDeque(deque):
    def __init__(self, vad=None, frames_per_chunk=None):
        super().__init__()
        self.vad = vad if vad is not None else lambda _: True
        self.frames_per_chunk = vad.frames_per_chunk if vad is not None else frames_per_chunk

        self.min_silent_length = 2400 // self.frames_per_chunk
        self.max_silent_length = 5760 // self.frames_per_chunk
        self.min_speech_length = 7200 // self.frames_per_chunk
        self.max_speech_length = 288000 // self.frames_per_chunk

        self.speech_buffer = []
        self.silent_length = []
        self.current_silent_length = 0
        self.frame_count = 0

    def put_data(self, in_data, frame_count):
        is_speech = self.vad(in_data)
        if not is_speech and not self.speech_buffer:
            return

        # append buffer
        if self.vad(in_data):
            if not self.speech_buffer:
                self.frame_count = frame_count
            self.speech_buffer.append(np.frombuffer(in_data, dtype=np.int16))
            self.silent_length.append(self.current_silent_length)
            self.current_silent_length = 0
        elif self.speech_buffer:
            self.speech_buffer.append(np.frombuffer(in_data, dtype=np.int16))
            self.silent_length.append(self.current_silent_length)
            self.current_silent_length += 1

        # append speech
        if self.current_silent_length >= self.max_silent_length:  # not is_speech
            self.append_audio(self.speech_buffer[:-self.current_silent_length])
            self.speech_buffer.clear()
            self.silent_length.clear()
            self.current_silent_length = 0
        while (n := len(self.speech_buffer)) >= self.max_speech_length:
            j = n - 1 - np.argmax(self.silent_length[::-1])  # last argmax index
            l = self.silent_length[j]

            if l >= self.min_silent_length:
                # split buffer on the longest silent interval
                self.append_audio(self.speech_buffer[:j - l])
                self.frame_count += self.frames_per_chunk * j
                self.speech_buffer = self.speech_buffer[j:]
                self.silent_length = self.silent_length[j:]
                self.silent_length[0] = 0
            else:
                i = n - self.current_silent_length  # current_silent_length can be 0
                self.append_audio(self.speech_buffer[:i])
                self.frame_count += self.frames_per_chunk * i
                self.speech_buffer = self.speech_buffer[i:]
                self.silent_length = self.silent_length[i:]

    def get_buffer(self):
        if len(self.speech_buffer) >= self.min_speech_length:
            return np.concatenate(self.speech_buffer)
        else:
            return None

    def append_audio(self, buffer):
        if len(buffer) >= self.min_speech_length:
            super().append({
                'audio': np.concatenate(buffer),
                'frame_from': self.frame_count,
                'frame_to': self.frame_count + self.frames_per_chunk * len(buffer)
            })

module/vad.py

音声の無音区間を検出するためのVAD (Voice Activity Detection)という種類のモデル。
ここではfaster-whisperに同梱されているSilero VADを使用。

import numpy as np
import torch

from faster_whisper import vad as silerovad


class SileroVad:
    def __init__(self):
        self.vad = silerovad.get_vad_model()  # faster_whisper/assets/silero_vad.onnx
        self.state, self.context = self.vad.get_initial_states(batch_size=1)
        self.frames_per_chunk = 512

    def __call__(self, in_data):
        chunk = torch.from_numpy(np.frombuffer(in_data, dtype=np.int16).copy()).to(torch.float16)
        speech_prob, self.state, self.context = self.vad(chunk, self.state, self.context, 16000)
        return speech_prob[0][0] > 0.5

module/transcribe.py

faster-wisperの本体。
音声を入力すると日本語文章が出力される。

import functools

import numpy as np

from faster_whisper import WhisperModel


class Model:
    def __init__(self):
        super().__init__()
        self.model = WhisperModel('model/faster-whisper-large-v3')

    @functools.lru_cache(maxsize=5)
    def __call__(self, audio):
        segments, _ = self.model.transcribe(
            audio=np.array(audio).astype(np.float16), beam_size=3, language='ja', without_timestamps=True,
            vad_filter=True,
        )
        return segments

module/util.py

日時フォーマットや入出力などを扱う。

import time


def format_frame_count(frame, srt=False):
    t = frame / 16000
    ms = int((t % 1 * 1000) // 1)
    h = t // 1
    s = int(h % 60)
    h = h // 60
    m = int(h % 60)
    h = int(h // 60)

    return f'{h:02}:{m:02}:{s:02}{',' if srt else '.'}{ms:03}'


def get_timestamp():
    return time.strftime('%Y%m%d_%H%M%S')


def filter_text(text_list, blacklist):
    return [text for text in text_list if text not in blacklist]


def format_log_string(frame_from, frame_to, text, *args):
    return '\t'.join([
        format_frame_count(frame_from),
        format_frame_count(frame_to)
    ] + [str(v) for v in args] + [text])


def console_log(frame_from, frame_to, is_buffer, segments, *debuginfo):
    if not is_buffer:
        for i, text in enumerate(segments):
            text = format_log_string(frame_from, frame_to, text, *debuginfo)
            if i == 0:
                print(f'\r{text}')
            else:
                print(text)
    else:
        text = format_log_string(frame_from, frame_to, '\n'.join(segments), *debuginfo)
        if text is not None:
            print(f'\r{text}', end='')


class SRTLogger:
    def __init__(self, path):
        self.count = 1
        self.path = path

    def format_srt_log(self, frame_from, frame_to, segments):
        log = [
            str(self.count),
            f'{format_frame_count(frame_from, True)} --> {format_frame_count(frame_to, True)}'
        ]
        for text in segments:
            if len(text) > 0:
                log.append(text)
        self.count += 1
        return log

    def write(self, frame_from, frame_to, segments):
        log = self.format_srt_log(frame_from, frame_to, segments)
        for line in log:
            self.f.write(line + '\n')
            self.f.flush()
        self.f.write('\n')

    def __enter__(self):
        self.f = open(self.path, 'w', encoding='utf_8', newline='\n')
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.f.close()

    @staticmethod
    def open(*args):
        return SRTLogger(*args)

実行スクリプト

  • AudioStreamを初期化する際にwave_save_pathを与えることで、そのパスに入力音声が保存される。
  • SRTLoggerを使って出力文章をSRT形式で保存している。
  • 学習データの影響なのか、faster-whisperは「ご視聴ありがとうございました」などの発言していない文章を出力することがある。そのような出力を無効化するために、このスクリプトではブラックリストを使ってフィルタリングを行っている。

マイク入力のリアルタイム文字起こし

import time

import torch

from module.audiostream import AudioStream
from module.vad import SileroVad
from module.transcribe import Model
from module.util import console_log, SRTLogger, filter_text, get_timestamp


if __name__ == '__main__':
    timestamp = get_timestamp()
    print(timestamp)

    args = {
        'mode': 'input',
        'wave_save_path': f'save/{timestamp}_input.wav'
    }
    output_srt_path = f'save/{timestamp}_output.srt'

    model = Model()
    with (AudioStream.open(vad=SileroVad(), **args) as stream,
          SRTLogger.open(output_srt_path) as srt):
        print('main loop started')
        while (audio_data := stream.get_speech())['is_buffer'] or audio_data['audio'] is not None:  # end of stream
            try:
                if (audio := audio_data['audio']) is not None:
                    start = time.time()
                    segments = model(tuple(audio.tolist()))
                    torch.cuda.synchronize()
                    td = time.time() - start

                    segments = [segment.text for segment in segments]
                    segments = filter_text(segments, [
                        'ご視聴ありがとうございました',
                        'ご視聴ありがとうございました。',
                        'エンディング',
                    ])

                    # console log
                    console_log(
                        audio_data['frame_from'], audio_data['frame_to'], audio_data['is_buffer'], segments,
                        len(stream.speech_deque), len(audio), '{:.3f}'.format(td)
                    )

                    # write to SRT
                    if not audio_data['is_buffer']:
                        srt.write(audio_data['frame_from'], audio_data['frame_to'], segments)

                if args['mode'] != 'file':
                    time.sleep(0.2)

            except KeyboardInterrupt:
                break

スピーカー出力のリアルタイム文字起こし

import time

import torch

from module.audiostream import AudioStream
from module.vad import SileroVad
from module.transcribe import Model
from module.util import console_log, SRTLogger, filter_text, get_timestamp


if __name__ == '__main__':
    timestamp = get_timestamp()
    print(timestamp)

    args = {
        'mode': 'output',
        'wave_save_path': f'save/{timestamp}_output.wav'
    }
    output_srt_path = f'save/{timestamp}_output.srt'

    model = Model()
    with (AudioStream.open(vad=SileroVad(), **args) as stream,
          SRTLogger.open(output_srt_path) as srt):
        print('main loop started')
        while (audio_data := stream.get_speech())['is_buffer'] or audio_data['audio'] is not None:  # end of stream
            try:
                if (audio := audio_data['audio']) is not None:
                    start = time.time()
                    segments = model(tuple(audio.tolist()))
                    torch.cuda.synchronize()
                    td = time.time() - start

                    segments = [segment.text for segment in segments]
                    segments = filter_text(segments, [
                        'ご視聴ありがとうございました',
                        'ご視聴ありがとうございました。',
                        'エンディング',
                    ])

                    # console log
                    console_log(
                        audio_data['frame_from'], audio_data['frame_to'], audio_data['is_buffer'], segments,
                        len(stream.speech_deque), len(audio), '{:.3f}'.format(td)
                    )

                    # write to SRT
                    if not audio_data['is_buffer']:
                        srt.write(audio_data['frame_from'], audio_data['frame_to'], segments)

                if args['mode'] != 'file':
                    time.sleep(0.2)

            except KeyboardInterrupt:
                break

音声ファイルの文字起こし

import time

import torch

from module.audiostream import AudioStream
from module.vad import SileroVad
from module.transcribe import Model
from module.util import console_log, SRTLogger, filter_text, get_timestamp


if __name__ == '__main__':
    timestamp = get_timestamp()
    print(timestamp)

    args = {
        'mode': 'file',
        'path': 'save/test.mp4',
    }
    output_srt_path = f'save/{'.'.join(args['path'].split('/')[-1].split('.')[:-1])}.srt'

    model = Model()
    # with (AudioStream.open(vad=WebRTCVAD(), **args) as stream,
    with (AudioStream.open(vad=SileroVad(), **args) as stream,
          SRTLogger.open(output_srt_path) as srt):
        print('main loop started')
        while (audio_data := stream.get_speech())['is_buffer'] or audio_data['audio'] is not None:  # end of stream
            try:
                if (audio := audio_data['audio']) is not None:
                    start = time.time()
                    segments = model(tuple(audio.tolist()))
                    torch.cuda.synchronize()
                    td = time.time() - start

                    segments = [segment.text for segment in segments]
                    segments = filter_text(segments, [
                        'ご視聴ありがとうございました',
                        'ご視聴ありがとうございました。',
                        'エンディング',
                    ])

                    # console log
                    console_log(
                        audio_data['frame_from'], audio_data['frame_to'], audio_data['is_buffer'], segments,
                        len(stream.speech_deque), len(audio), '{:.3f}'.format(td)
                    )

                    # write to SRT
                    if not audio_data['is_buffer']:
                        srt.write(audio_data['frame_from'], audio_data['frame_to'], segments)

                if args['mode'] != 'file':
                    time.sleep(0.2)

            except KeyboardInterrupt:
                break

参考

faster-whisper

PyAudio・pydub

SRT