import os
import torch
import spaces
import psycopg2
import gradio as gr
from threading import Thread
from collections.abc import Iterator
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_MAX_NEW_TOKENS = 4096
MAX_INPUT_TOKEN_LENGTH = 4096
DEFAULT_MAX_NEW_TOKENS = 2048
HF_TOKEN = os.environ["HF_TOKEN"]
model_id = "ai4bharat/IndicTrans3-beta"
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
LANGUAGES = [
"Hindi",
"Bengali",
"Telugu",
"Marathi",
"Tamil",
"Urdu",
"Gujarati",
"Kannada",
"Odia",
"Malayalam",
"Punjabi",
"Assamese",
"Maithili",
"Santali",
"Kashmiri",
"Nepali",
"Sindhi",
"Konkani",
"Dogri",
"Manipuri",
"Bodo",
]
def format_message_for_translation(message, target_lang):
return f"Translate the following text to {target_lang}: {message}"
def store_feedback(rating, feedback_text, chat_history, tgt_lang):
try:
if not rating:
gr.Warning("Please select a rating before submitting feedback.", duration=5)
return None
if not feedback_text or feedback_text.strip() == "":
gr.Warning("Please provide some feedback before submitting.", duration=5)
return None
if not chat_history:
gr.Warning(
"Please provide the input text before submitting feedback.", duration=5
)
return None
if len(chat_history[0]) < 2:
gr.Warning(
"Please translate the input text before submitting feedback.",
duration=5,
)
return None
conn = psycopg2.connect(
host=os.getenv("DB_HOST"),
database=os.getenv("DB_NAME"),
user=os.getenv("DB_USER"),
password=os.getenv("DB_PASSWORD"),
port=os.getenv("DB_PORT"),
)
cursor = conn.cursor()
insert_query = """
INSERT INTO feedback
(tgt_lang, rating, feedback_txt, chat_history)
VALUES (%s, %s, %s, %s)
"""
cursor.execute(
insert_query, (tgt_lang, int(rating), feedback_text, chat_history)
)
conn.commit()
cursor.close()
conn.close()
gr.Info("Thank you for your feedback! 🙏", duration=5)
except:
gr.Error(
"An error occurred while storing feedback. Please try again later.",
duration=5,
)
def store_output(tgt_lang, input_text, output_text):
conn = psycopg2.connect(
host=os.getenv("DB_HOST"),
database=os.getenv("DB_NAME"),
user=os.getenv("DB_USER"),
password=os.getenv("DB_PASSWORD"),
port=os.getenv("DB_PORT"),
)
cursor = conn.cursor()
insert_query = """
INSERT INTO translation
(input_txt, output_txt, tgt_lang)
VALUES (%s, %s, %s)
"""
cursor.execute(insert_query, (input_text, output_text, tgt_lang))
conn.commit()
cursor.close()
@spaces.GPU
def translate_message(
message: str,
chat_history: list[dict],
target_language: str = "Hindi",
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
translation_request = format_message_for_translation(message, target_language)
conversation.append({"role": "user", "content": translation_request})
input_ids = tokenizer.apply_chat_template(
conversation, return_tensors="pt", add_generation_prompt=True
)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(
f"Trimmed input as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
)
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(
tokenizer, timeout=240.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
store_output(target_language, message, "".join(outputs))
css = """
# body {
# background-color: #f7f7f7;
# }
.feedback-section {
margin-top: 30px;
border-top: 1px solid #ddd;
padding-top: 20px;
}
.container {
max-width: 90%;
margin: 0 auto;
}
.language-selector {
margin-bottom: 20px;
padding: 10px;
background-color: #ffffff;
border-radius: 8px;
box-shadow: 0 2px 5px rgba(0,0,0,0.1);
}
.advanced-options {
margin-top: 20px;
}
"""
DESCRIPTION = """\
IndicTrans3 is the latest state-of-the-art (SOTA) translation model from AI4Bharat, designed to handle translations across 22 Indic languages with high accuracy. It supports document-level machine translation (MT) and is built to match the performance of other leading SOTA models.
📢 Training data will be released soon!