SUSE_audio_assistant/test/fast_whisper.py
Alex Lau (AvengerMoJo) 7ba9d8d3db
Just MVT demo for the concept of AI audio assistant.
Signed-off-by: Alex Lau (AvengerMoJo) <alau@suse.com>
2023-11-10 02:26:16 +08:00

112 lines
4.0 KiB
Python

import io
import tempfile
import pyaudio
from pydub import AudioSegment
import wave
import re
import queue
from transformers import pipeline
from datasets import load_dataset
from faster_whisper import WhisperModel
import torch
from TTS.api import TTS
from audio_utils import AudioSplit
CHUNK = 1024
FORMAT = pyaudio.paInt16 # 16-bit resolution
CHANNELS = 1
RATE = 16000 # sample rate
DURATION = 2
SUSE = r"s*u*s*e"
p = pyaudio.PyAudio()
def record(stream):
print("Recording started...")
while True:
audio_data = stream.read(CHUNK)
audio_queue.put(audio_data)
stream.stop_stream()
stream.close()
audio.terminate()
def play(play_stream, filename):
wave_file = wave.open(filename, 'rb')
print(f"Wave: rate={wave_file.getframerate()} channels={wave_file.getnchannels()} width={wave_file.getsampwidth()}")
out_data = wave_file.readframes(CHUNK)
while out_data:
play_stream.write(out_data)
out_data = wave_file.readframes(CHUNK)
# Get device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Init TTS
# tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
tts = TTS("tts_models/en/blizzard2013/capacitron-t2-c150_v2").to(device)
# Create a queue to share audio data between threads
audio_queue = queue.Queue()
# model_size = "large-v2"
model_size = "small.en"
# model_size = "tiny.en"
# Run on GPU with FP16
model = WhisperModel(model_size, device="cpu", compute_type="int8")
device = ""
for i in range(p.get_device_count()):
device = p.get_device_info_by_index(i)
if device['name']=="default":
print(device)
break
playback_stream = p.open(format=p.get_format_from_width(2),
channels=1,
rate=24000,
# output_device_index = device['index'],
output=True)
print(f"Stream: playback->{playback_stream.get_write_available()}")
# generator = pipeline(task="automatic-speech-recognition", model="microsoft/speecht5_asr")
while True:
tf = tempfile.NamedTemporaryFile(suffix=".wav", delete=True, mode='wb')
mp3_tf = tempfile.NamedTemporaryFile(suffix=".mp3", delete=True, mode='wb')
temp_filename = tf.name
mp3_tf_filename = mp3_tf.name
with wave.open(temp_filename, 'wb') as wav_file:
wav_file.setnchannels(CHANNELS)
wav_file.setsampwidth(p.get_sample_size(FORMAT))
wav_file.setframerate(RATE)
stream = p.open(format=FORMAT, channels=CHANNELS, rate=RATE, input=True, frames_per_buffer=CHUNK)
print("Listening...")
frames = []
for i in range(0, RATE // CHUNK * DURATION):
# Read audio data from the stream for the specified duration
audio_data = stream.read(CHUNK)
frames.append(audio_data)
wav_file.writeframes(audio_data)
#print(f"{DURATION} sec recording done.")
stream.close()
audio = AudioSegment.from_wav(temp_filename)
audio.export(mp3_tf_filename, format="mp3")
# segments, info = model.transcribe(mp3_tf_filename, beam_size=5)
# print("Detected language '%s' with probability %f" % (info.language, info.language_probability))
segments, _ = model.transcribe(mp3_tf_filename)
for segment in segments:
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
# out_wav = tempfile.NamedTemporaryFile(suffix=".mp3", delete=True, mode='wb')
text_input = segment.text.lower()
if text_input.find("hey") != -1:
if re.search(SUSE, text_input):
# wav = tts.tts(text=segment.text, speaker_wav=speak_wav.name, language="en")
# tts.tts_to_file(text="This is SUSE assistant what can I do for you today?", language="en", file_path=out_wav.name)
# tts.tts_to_file(text="This is SUSE assistant what can I do for you today?", file_path=out_wav.name)
# wave_file = wave.open(out_wav.name, 'rb')
play(playback_stream, "data/audio/suse_intro.wav")
p.terminate()