Spaces:
Runtime error
Runtime error
Commit
·
384005b
1
Parent(s):
b4fc999
Set pad_token to eos_token and exclude user query from response
Browse files
app.py
CHANGED
|
@@ -6,6 +6,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
| 6 |
import threading
|
| 7 |
import torch
|
| 8 |
import os
|
|
|
|
| 9 |
|
| 10 |
# Define the API URL to use the internal server
|
| 11 |
API_URL = "http://localhost:5000/chat"
|
|
@@ -43,7 +44,7 @@ def chat():
|
|
| 43 |
inputs = tokenizer(user_input, return_tensors='pt', padding=True, truncation=True)
|
| 44 |
input_ids = inputs['input_ids']
|
| 45 |
attention_mask = inputs['attention_mask']
|
| 46 |
-
outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=
|
| 47 |
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(user_input, '').strip()
|
| 48 |
|
| 49 |
return jsonify({"response": response_text})
|
|
@@ -72,6 +73,9 @@ def messages_to_history(messages: Messages) -> History:
|
|
| 72 |
history.append((q['content'], r['content']))
|
| 73 |
return history
|
| 74 |
|
|
|
|
|
|
|
|
|
|
| 75 |
def model_chat(query: str, history: History) -> Tuple[str, History]:
|
| 76 |
if not query.strip():
|
| 77 |
return '', history
|
|
@@ -163,7 +167,10 @@ with gr.Blocks(css='''
|
|
| 163 |
print(f"Query: {query}") # Debug print statement
|
| 164 |
response, history = model_chat(query, history)
|
| 165 |
print(f"Response: {response}") # Debug print statement
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
demo_state = gr.State([])
|
| 169 |
|
|
|
|
| 6 |
import threading
|
| 7 |
import torch
|
| 8 |
import os
|
| 9 |
+
import re
|
| 10 |
|
| 11 |
# Define the API URL to use the internal server
|
| 12 |
API_URL = "http://localhost:5000/chat"
|
|
|
|
| 44 |
inputs = tokenizer(user_input, return_tensors='pt', padding=True, truncation=True)
|
| 45 |
input_ids = inputs['input_ids']
|
| 46 |
attention_mask = inputs['attention_mask']
|
| 47 |
+
outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=1000, pad_token_id=tokenizer.eos_token_id)
|
| 48 |
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(user_input, '').strip()
|
| 49 |
|
| 50 |
return jsonify({"response": response_text})
|
|
|
|
| 73 |
history.append((q['content'], r['content']))
|
| 74 |
return history
|
| 75 |
|
| 76 |
+
def is_hebrew(text: str) -> bool:
|
| 77 |
+
return bool(re.search(r'[\u0590-\u05FF]', text))
|
| 78 |
+
|
| 79 |
def model_chat(query: str, history: History) -> Tuple[str, History]:
|
| 80 |
if not query.strip():
|
| 81 |
return '', history
|
|
|
|
| 167 |
print(f"Query: {query}") # Debug print statement
|
| 168 |
response, history = model_chat(query, history)
|
| 169 |
print(f"Response: {response}") # Debug print statement
|
| 170 |
+
if is_hebrew(response):
|
| 171 |
+
return history, gr.update(value="", interactive=True, lines=2, rtl=True), history
|
| 172 |
+
else:
|
| 173 |
+
return history, gr.update(value="", interactive=True, lines=2, rtl=False), history
|
| 174 |
|
| 175 |
demo_state = gr.State([])
|
| 176 |
|