Luigi commited on
Commit
4804dfb
·
1 Parent(s): 1ec5603

fix tts and improve rrs-related ergonimicity

Browse files
Files changed (2) hide show
  1. app.py +78 -14
  2. requirements.txt +4 -1
app.py CHANGED
@@ -11,8 +11,15 @@ from pathlib import Path
11
  from faster_whisper import WhisperModel
12
  from huggingface_hub import hf_hub_download
13
 
 
 
 
 
 
 
 
14
  llm_repo_id = "Qwen/Qwen2.5-7B-Instruct-GGUF"
15
- llm_filename="qwen2.5-7b-instruct-q2_k.gguf"
16
  asr_repo_id = "Luigi/whisper-small-zh_tw-ct2"
17
 
18
  llm_model_path = hf_hub_download(repo_id=llm_repo_id, filename=llm_filename)
@@ -40,7 +47,8 @@ def load_transformers_model(model_id):
40
 
41
  @st.cache_resource
42
  def load_outlines_model():
43
- model = outlines_llama_cpp(model_path=llm_model_path,
 
44
  n_ctx=1024,
45
  n_threads=2,
46
  n_threads_batch=2,
@@ -48,36 +56,33 @@ def load_outlines_model():
48
  n_gpu_layers=0,
49
  use_mlock=False,
50
  use_mmap=True,
51
- verbose=False,)
 
52
  return model
53
 
54
  def predict_with_llm(text):
55
  model = load_outlines_model()
56
-
57
  prompt = f"""
58
- You are an expert in classification of restautant customers' message.
59
 
60
- I'm going to provide you with a message from a restautant customer.
61
- You have to classify it in one of the follwing two intents:
62
 
63
  RESERVATION: Inquiries and requests highly related to table reservations and seating 與訂位與座位安排相關的詢問與請求
64
  NOT_RESERVATION: All other messages that do not involve table booking or reservations 所有非訂位或預約類的其他留言
65
 
66
  Please reply with *only* the name of the intent labels in a JSON object like:
67
- {{\"result\": \"RESERVATION\"}} or {{\"result\": \"NOT_RESERVATION\"}}
68
 
69
  Here is the message to classify: {text}
70
  """.strip()
71
-
72
  classifier = choice(model, ["RESERVATION", "NOT_RESERVATION"])
73
  prediction = classifier(prompt)
74
-
75
  if prediction == "RESERVATION":
76
  return "📞 訂位意圖 (Reservation intent)"
77
  elif prediction == "NOT_RESERVATION":
78
  return "❌ 無訂位意圖 (Not Reservation intent)"
79
 
80
- # Standard Transformers classifier
81
  def predict_intent(text, model_id):
82
  tokenizer, model = load_transformers_model(model_id)
83
  inputs = tokenizer(text, return_tensors="pt")
@@ -90,14 +95,65 @@ def predict_intent(text, model_id):
90
  else:
91
  return f"❌ 無訂位意圖 (Not Reservation intent)(訂位信心度 Confidence: {confidence:.2%})"
92
 
93
- # Clean README
94
  def load_clean_readme(path="README.md"):
95
  text = Path(path).read_text(encoding="utf-8")
96
  text = re.sub(r"(?s)^---.*?---", "", text).strip()
97
  text = re.sub(r"^# .*?\n+", "", text)
98
  return text
99
 
100
- # App UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  st.title("🍽️ 餐廳訂位意圖識別")
102
  st.markdown("���音或輸入文字,自動判斷是否具有訂位意圖。")
103
 
@@ -110,7 +166,6 @@ audio = mic_recorder(start_prompt="開始錄音", stop_prompt="停止錄音", ju
110
  if audio:
111
  st.success("錄音完成!")
112
  st.audio(audio["bytes"], format="audio/wav")
113
-
114
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmpfile:
115
  tmpfile.write(audio["bytes"])
116
  tmpfile_path = tmpfile.name
@@ -131,6 +186,11 @@ if audio:
131
  else:
132
  result = predict_intent(transcription, model_id)
133
  st.success(result)
 
 
 
 
 
134
 
135
  text_input = st.text_input("✍️ 或手動輸入語句")
136
 
@@ -141,6 +201,10 @@ if text_input and st.button("🚀 送出"):
141
  else:
142
  result = predict_intent(text_input, model_id)
143
  st.success(result)
 
 
 
 
144
 
145
  with st.expander("ℹ️ 說明文件 / 使用說明 (README)", expanded=False):
146
  readme_md = load_clean_readme()
 
11
  from faster_whisper import WhisperModel
12
  from huggingface_hub import hf_hub_download
13
 
14
+ # Additional imports for TTS and audio auto-play
15
+ import numpy as np
16
+ import io
17
+ import soundfile as sf
18
+ from kokoro import KPipeline
19
+ import base64
20
+
21
  llm_repo_id = "Qwen/Qwen2.5-7B-Instruct-GGUF"
22
+ llm_filename = "qwen2.5-7b-instruct-q2_k.gguf"
23
  asr_repo_id = "Luigi/whisper-small-zh_tw-ct2"
24
 
25
  llm_model_path = hf_hub_download(repo_id=llm_repo_id, filename=llm_filename)
 
47
 
48
  @st.cache_resource
49
  def load_outlines_model():
50
+ model = outlines_llama_cpp(
51
+ model_path=llm_model_path,
52
  n_ctx=1024,
53
  n_threads=2,
54
  n_threads_batch=2,
 
56
  n_gpu_layers=0,
57
  use_mlock=False,
58
  use_mmap=True,
59
+ verbose=False,
60
+ )
61
  return model
62
 
63
  def predict_with_llm(text):
64
  model = load_outlines_model()
 
65
  prompt = f"""
66
+ You are an expert in classification of restaurant customers' messages.
67
 
68
+ I'm going to provide you with a message from a restaurant customer.
69
+ You have to classify it in one of the following two intents:
70
 
71
  RESERVATION: Inquiries and requests highly related to table reservations and seating 與訂位與座位安排相關的詢問與請求
72
  NOT_RESERVATION: All other messages that do not involve table booking or reservations 所有非訂位或預約類的其他留言
73
 
74
  Please reply with *only* the name of the intent labels in a JSON object like:
75
+ {{"result": "RESERVATION"}} or {{"result": "NOT_RESERVATION"}}
76
 
77
  Here is the message to classify: {text}
78
  """.strip()
 
79
  classifier = choice(model, ["RESERVATION", "NOT_RESERVATION"])
80
  prediction = classifier(prompt)
 
81
  if prediction == "RESERVATION":
82
  return "📞 訂位意圖 (Reservation intent)"
83
  elif prediction == "NOT_RESERVATION":
84
  return "❌ 無訂位意圖 (Not Reservation intent)"
85
 
 
86
  def predict_intent(text, model_id):
87
  tokenizer, model = load_transformers_model(model_id)
88
  inputs = tokenizer(text, return_tensors="pt")
 
95
  else:
96
  return f"❌ 無訂位意圖 (Not Reservation intent)(訂位信心度 Confidence: {confidence:.2%})"
97
 
 
98
  def load_clean_readme(path="README.md"):
99
  text = Path(path).read_text(encoding="utf-8")
100
  text = re.sub(r"(?s)^---.*?---", "", text).strip()
101
  text = re.sub(r"^# .*?\n+", "", text)
102
  return text
103
 
104
+ # ---- TTS Integration using kokoro KPipeline ----
105
+
106
+ @st.cache_resource
107
+ def get_tts_pipeline():
108
+ # Instantiate and cache the KPipeline for TTS.
109
+ # Adjust lang_code as needed; here we set it to "zh" for Chinese.
110
+ return KPipeline(lang_code="zh")
111
+
112
+ def get_tts_message(intent_result):
113
+ """
114
+ Determine the TTS message based on the classification result.
115
+ Reservation intent returns one message; all others, another.
116
+ """
117
+ if "訂位意圖" in intent_result and "無" not in intent_result:
118
+ return "稍後您將會從簡訊收到訂位連結"
119
+ else:
120
+ return "我們將會將您的回饋傳達給負責人,謝謝您"
121
+
122
+ def play_tts_message(message, voice='af_heart'):
123
+ """
124
+ Synthesize speech using kokoro's KPipeline and return audio bytes in WAV format.
125
+ The pipeline returns a generator yielding tuples; the audio chunks are concatenated.
126
+ """
127
+ pipeline = get_tts_pipeline()
128
+ generator = pipeline(message, voice=voice)
129
+ audio_chunks = []
130
+ for i, (gs, ps, audio) in enumerate(generator):
131
+ audio_chunks.append(audio)
132
+ if audio_chunks:
133
+ audio_concat = np.concatenate(audio_chunks)
134
+ else:
135
+ audio_concat = np.array([])
136
+ wav_buffer = io.BytesIO()
137
+ # Using a sample rate of 24000 as in the example.
138
+ sf.write(wav_buffer, audio_concat, 24000, format="WAV")
139
+ wav_buffer.seek(0)
140
+ return wav_buffer.read()
141
+
142
+ def play_audio_auto(audio_data, mime="audio/wav"):
143
+ """
144
+ Auto-plays the audio by creating an HTML audio element with the autoplay attribute.
145
+ """
146
+ audio_base64 = base64.b64encode(audio_data).decode()
147
+ audio_html = f'''
148
+ <audio controls autoplay style="width: 100%;">
149
+ <source src="data:{mime};base64,{audio_base64}" type="{mime}">
150
+ Your browser does not support the audio element.
151
+ </audio>
152
+ '''
153
+ st.markdown(audio_html, unsafe_allow_html=True)
154
+
155
+ # ---- App UI ----
156
+
157
  st.title("🍽️ 餐廳訂位意圖識別")
158
  st.markdown("���音或輸入文字,自動判斷是否具有訂位意圖。")
159
 
 
166
  if audio:
167
  st.success("錄音完成!")
168
  st.audio(audio["bytes"], format="audio/wav")
 
169
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmpfile:
170
  tmpfile.write(audio["bytes"])
171
  tmpfile_path = tmpfile.name
 
186
  else:
187
  result = predict_intent(transcription, model_id)
188
  st.success(result)
189
+ tts_text = get_tts_message(result)
190
+ # Show the TTS message text on the page
191
+ st.info(f"TTS 語音內容: {tts_text}")
192
+ audio_message = play_tts_message(tts_text)
193
+ play_audio_auto(audio_message, mime="audio/wav")
194
 
195
  text_input = st.text_input("✍️ 或手動輸入語句")
196
 
 
201
  else:
202
  result = predict_intent(text_input, model_id)
203
  st.success(result)
204
+ tts_text = get_tts_message(result)
205
+ st.info(f"TTS 語音內容: {tts_text}")
206
+ audio_message = play_tts_message(tts_text)
207
+ play_audio_auto(audio_message, mime="audio/wav")
208
 
209
  with st.expander("ℹ️ 說明文件 / 使用說明 (README)", expanded=False):
210
  readme_md = load_clean_readme()
requirements.txt CHANGED
@@ -7,4 +7,7 @@ faster-whisper
7
  soundfile
8
  outlines[llamacpp]==0.0.36 # issue beyond 0.0.36 https://github.com/dottxt-ai/outlines/issues/820
9
  numpy>=1.24,<2.0
10
- llama-cpp-python
 
 
 
 
7
  soundfile
8
  outlines[llamacpp]==0.0.36 # issue beyond 0.0.36 https://github.com/dottxt-ai/outlines/issues/820
9
  numpy>=1.24,<2.0
10
+ llama-cpp-python
11
+ kokoro
12
+ ordered-set
13
+ cn2an