|
|
from agents import Agent, Runner, ModelSettings |
|
|
from agents import function_tool |
|
|
import transformers |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM |
|
|
from transformers.pipelines import pipeline |
|
|
from pydantic import BaseModel, Field |
|
|
from dotenv import load_dotenv |
|
|
from agents.tool import WebSearchTool |
|
|
from sendgrid import Email, To, Content, Mail |
|
|
import sendgrid |
|
|
from sendgrid import SendGridAPIClient |
|
|
import gradio as gr |
|
|
import os |
|
|
import re |
|
|
import asyncio |
|
|
load_dotenv() |
|
|
from peft import PeftModel |
|
|
|
|
|
from langchain_community.utilities import GoogleSerperAPIWrapper |
|
|
os.environ["SERPER_API_KEY"]="94ba433777f0a7b814a54c8316bb0db52ca265e9" |
|
|
online_search=GoogleSerperAPIWrapper() |
|
|
|
|
|
from huggingface_hub import login |
|
|
hf_token=os.getenv("HF_TOKEN") |
|
|
login(token=hf_token) |
|
|
|
|
|
try: |
|
|
import torch |
|
|
def noop_compile(*args, **kwargs): |
|
|
if not args: |
|
|
def decorator(func): |
|
|
return func |
|
|
return decorator |
|
|
return args[0] if args else lambda x: x |
|
|
try: |
|
|
torch.compile = noop_compile |
|
|
except (RuntimeError, AttributeError): |
|
|
setattr(torch, 'compile', noop_compile) |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
model_name="A-Asif/Gemma3-ft" |
|
|
|
|
|
conversational_model=AutoModelForCausalLM.from_pretrained(model_name) |
|
|
conversational_tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
conversation=pipeline( |
|
|
"text-generation", |
|
|
model=conversational_model, |
|
|
tokenizer=conversational_tokenizer, |
|
|
max_new_tokens=200, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.1, |
|
|
) |
|
|
|
|
|
@function_tool |
|
|
def get_information_from_internet(query: str) -> str: |
|
|
""" Use this tool to get information from internet or online info""" |
|
|
response= online_search.run(query) |
|
|
return response |
|
|
|
|
|
|
|
|
@function_tool |
|
|
def send_alert(message: str): |
|
|
""" Use this tool when you have to alert guadian about patient sensitivity that he can take harmful action""" |
|
|
sg=sendgrid.SendGridAPIClient(api_key=os.environ.get("SENDGRID_API_KEY")) |
|
|
|
|
|
from_email=Email("[email protected]") |
|
|
to_email=To("[email protected]") |
|
|
content=Content("text/plain",message) |
|
|
mail=Mail(from_email,to_email,"Your Patient is Trying to attempt something Harmful",content).get() |
|
|
response=sg.client.mail.send.post(request_body=mail) |
|
|
print(f"Email status: {response.status_code}") |
|
|
return {"status":"success"} |
|
|
|
|
|
@function_tool |
|
|
def conversation_with_personal_assistant(user_input: str) -> str: |
|
|
""" Use this tool to generate friendly response according to user query, and later update response to give intractive and friendly response.""" |
|
|
prompt=f"instruction: {user_input}\\response:" |
|
|
response=conversation(prompt, pad_token_id=conversational_tokenizer.eos_token_id) |
|
|
text=response[0]['generated_text'] |
|
|
match = re.search(r'\\response:(.*)', text, re.DOTALL) |
|
|
if match: |
|
|
extracted_text = match.group(1).strip() |
|
|
else: |
|
|
pre_delimiter, delimiter, extracted_text=text.partition("\response:") |
|
|
return extracted_text |
|
|
|
|
|
@function_tool |
|
|
def get_motivational_quote(query): |
|
|
""" Use this tool to get best motivational quote according the current mental state of patient that is extracting from tool classify_mental_state""" |
|
|
quote=online_search.run(query) |
|
|
pattern = r"[\___—()0-9·/]+" |
|
|
cleaned_quote = re.sub(pattern, "", quote) |
|
|
cleaned_quote = re.sub(r'(\.\.\.)', ' ', cleaned_quote) |
|
|
cleaned_quote = re.sub(r' +', ' ', cleaned_quote) |
|
|
return cleaned_quote.strip() |
|
|
|
|
|
quotetional_agent=Agent( |
|
|
name="quotetional_agent", |
|
|
instructions="""You are expert in giving best motivational quotes according to current patient mental state, \ |
|
|
use tool get_motivational_quote to get quotes. Give different query to this tool for different types of mental state. \\ |
|
|
After getting quotes from this tool, then refine these and present and return in well formatted form.""", |
|
|
model="litellm/gemini/gemini-2.5-flash", |
|
|
tools=[get_motivational_quote], |
|
|
model_settings=ModelSettings( |
|
|
tool_choice="required", |
|
|
max_output_tokens=128 |
|
|
) |
|
|
) |
|
|
quotetional_agent_tool=quotetional_agent.as_tool(tool_name="quotetional_agent_tool", tool_description="Use this tool for Searches for quotes and returns the quotes according to patient mental state that motivates him and use them in main response") |
|
|
|
|
|
root_instructions="""You are an expert Agent in supportive conversation with mental patients. \\ |
|
|
You are provided multiple tools for generating friendly response. \\ |
|
|
After receiving user message if you realize user is asking general question about health or mental health then give \\ |
|
|
give the user best response according to your understanding, if needed use get_information_from_internet for getting infromation from internet or current info. \\ |
|
|
But if you realize user is sharing his current mental state, then: \\ |
|
|
1. then, use conversation_with_personal_assistant tool to get friendly message according to user message if he is sharing his current menatl state. \\ |
|
|
2. if you realize user can take harmful action such as socide or killing someone then use send_alert tool to alert his guardian to take proper action to mitigate risks. \\ |
|
|
3. then you must have touse quotetional_agent_tool agent tool to get motivational quotes according to user current mental state if he is sharing his mental situation. \\ |
|
|
so, well after getting mental state, response from tool and motivational quotes, \\ |
|
|
generate best supportive friendly response for user in format as it is a conversation in AI chatbot, \\ |
|
|
that response should be clear, concise and should mitigate user disturbing condition and should motivates him.""" |
|
|
|
|
|
async def main(): |
|
|
mental_agent=Agent( |
|
|
name="mental_agent", |
|
|
instructions=root_instructions, |
|
|
model="litellm/gemini/gemini-2.5-flash", |
|
|
tools=[get_information_from_internet, conversation_with_personal_assistant, send_alert, quotetional_agent_tool], |
|
|
model_settings=ModelSettings( |
|
|
tool_choice="required", |
|
|
max_output_tokens=1000 |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
async def run_with_retry(agent, message, max_retries=3, initial_delay=2): |
|
|
"""Run the agent with retry logic for transient API errors""" |
|
|
import litellm |
|
|
|
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
response = await Runner.run(agent, message) |
|
|
return response |
|
|
except Exception as e: |
|
|
|
|
|
error_str = str(e).lower() |
|
|
is_retryable = ( |
|
|
"503" in error_str or |
|
|
"overloaded" in error_str or |
|
|
"unavailable" in error_str or |
|
|
"429" in error_str or |
|
|
"rate limit" in error_str |
|
|
) |
|
|
|
|
|
if is_retryable and attempt < max_retries - 1: |
|
|
delay = initial_delay * (2 ** attempt) |
|
|
print(f"⚠️ API error (attempt {attempt + 1}/{max_retries}): {str(e)[:100]}...") |
|
|
print(f" Retrying in {delay} seconds...") |
|
|
await asyncio.sleep(delay) |
|
|
continue |
|
|
else: |
|
|
|
|
|
raise e |
|
|
|
|
|
raise Exception("Max retries reached") |
|
|
|
|
|
|
|
|
async def chat(message: str, history): |
|
|
response = await run_with_retry( |
|
|
mental_agent, |
|
|
message |
|
|
) |
|
|
ai_reply = response.final_output |
|
|
history = history + [ |
|
|
{"role": "user", "content": message}, |
|
|
{"role": "assistant", "content": ai_reply} |
|
|
] |
|
|
|
|
|
return history |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Default(primary_hue="emerald")) as demo: |
|
|
with gr.Row(): |
|
|
chatbot = gr.Chatbot(label="Personalized Ai companion", height=500, type="messages") |
|
|
with gr.Row(): |
|
|
message = gr.Textbox(show_label=False, placeholder="Feel free to share your thoughts, I`m your personal AI companion") |
|
|
with gr.Row(): |
|
|
button = gr.Button("Share!", variant="primary") |
|
|
button.click(chat, [message, chatbot], [chatbot]) |
|
|
|
|
|
demo.launch(share=False, debug=True, inbrowser=True) |
|
|
|
|
|
|
|
|
if __name__=="__main__": |
|
|
asyncio.run(main()) |