rahul7star commited on
Commit
401afad
·
verified ·
1 Parent(s): 8c48eed

Update app_low.py

Browse files
Files changed (1) hide show
  1. app_low.py +67 -54
app_low.py CHANGED
@@ -1,75 +1,88 @@
1
  import gradio as gr
 
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- from huggingface_hub import snapshot_download
5
- import os
6
 
7
  # ============================================================
8
- # 1️⃣ Download model efficiently (avoid exceeding space limits)
9
  # ============================================================
10
- MODEL_ID = "Qwen/Qwen2.5-1.5B"
11
 
12
- # Store in /tmp to reduce Space storage pressure
13
- model_dir = snapshot_download(repo_id=MODEL_ID, cache_dir="/tmp/qwen_model")
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # ============================================================
16
- # 2️⃣ Load model with CPU or GPU offload
17
  # ============================================================
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
19
 
20
- model = AutoModelForCausalLM.from_pretrained(
21
- model_dir,
22
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
23
- device_map="auto" if torch.cuda.is_available() else None,
24
- low_cpu_mem_usage=True,
25
- )
 
 
 
26
 
27
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
 
 
28
 
29
  # ============================================================
30
- # 3️⃣ Define chat function
31
  # ============================================================
32
- def chat_with_qwen(message, history):
33
- history = history or []
34
- messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
35
- for human, bot in history:
36
- messages.append({"role": "user", "content": human})
37
- messages.append({"role": "assistant", "content": bot})
38
- messages.append({"role": "user", "content": message})
39
-
40
- # Tokenize input messages
41
- inputs = tokenizer.apply_chat_template(
42
- messages,
43
- add_generation_prompt=True,
44
- tokenize=True,
45
- return_tensors="pt"
46
  )
47
 
48
- inputs = {k: v.to(device) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- with torch.no_grad():
51
- outputs = model.generate(
52
- **inputs,
53
- max_new_tokens=256,
54
- temperature=0.8,
55
- do_sample=True,
56
- pad_token_id=tokenizer.eos_token_id
57
- )
58
 
59
- response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
60
- history.append((message, response))
61
- return history, history
 
 
 
 
 
62
 
63
  # ============================================================
64
- # 4️⃣ Gradio UI
65
  # ============================================================
66
- with gr.Blocks(theme="soft", title="Qwen 2.5 Chatbot") as demo:
67
- gr.Markdown("## 🤖 Qwen 2.5 Chatbot — Optimized for CPU/GPU Offload")
68
- chatbot = gr.Chatbot(height=480, label="Chat with Qwen 2.5", type="messages")
69
- msg = gr.Textbox(placeholder="Type your question here...", label="Your Message")
70
- clear = gr.Button("🧹 Clear Chat")
71
-
72
- msg.submit(chat_with_qwen, [msg, chatbot], [chatbot, chatbot])
73
- clear.click(lambda: None, None, chatbot, queue=False)
74
-
75
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  import torch
 
 
 
4
 
5
  # ============================================================
6
+ # 1️⃣ Load model and tokenizer
7
  # ============================================================
8
+ MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
9
 
10
+ # Use CPU-friendly settings
11
+ device = 0 if torch.cuda.is_available() else -1
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
14
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
15
+
16
+ # Text-generation pipeline
17
+ pipe = pipeline(
18
+ "text-generation",
19
+ model=model,
20
+ tokenizer=tokenizer,
21
+ device=device, # 0 for GPU, -1 for CPU
22
+ )
23
 
24
  # ============================================================
25
+ # 2️⃣ Define the generation function
26
  # ============================================================
27
+ def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
28
+ """Enhance user prompt and maintain chat history."""
29
+ if not user_prompt.strip():
30
+ return chat_history + [["", "⚠️ Please enter a prompt."]]
31
 
32
+ full_prompt = f"Enhance and expand the following prompt with more details and context: {user_prompt}"
33
+
34
+ # Generate output
35
+ output = pipe(
36
+ full_prompt,
37
+ max_new_tokens=int(max_tokens),
38
+ temperature=float(temperature),
39
+ do_sample=True,
40
+ )
41
 
42
+ result = output[0]['generated_text'].strip()
43
+ chat_history = chat_history + [[user_prompt, result]]
44
+ return chat_history
45
 
46
  # ============================================================
47
+ # 3️⃣ Gradio UI
48
  # ============================================================
49
+ with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
50
+ gr.Markdown(
51
+ """
52
+ # Prompt Enhancer (Gemma 3 270M)
53
+ Enter a short prompt, and the model will expand it with extra details, context, and creativity.
54
+ """
 
 
 
 
 
 
 
 
55
  )
56
 
57
+ with gr.Row():
58
+ chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
59
+ with gr.Column(scale=1):
60
+ user_prompt = gr.Textbox(
61
+ placeholder="Enter a short prompt...",
62
+ label="Your Prompt",
63
+ lines=3,
64
+ )
65
+ temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
66
+ max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
67
+ send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
68
+ clear_btn = gr.Button("🧹 Clear Chat")
69
 
70
+ # Bind functions
71
+ send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
72
+ user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
73
+ clear_btn.click(lambda: [], None, chatbot)
 
 
 
 
74
 
75
+ gr.Markdown(
76
+ """
77
+ ---
78
+ 💡 Tips:
79
+ - Works best with short, descriptive prompts (e.g., "A cat sitting on a chair").
80
+ - Adjust temperature for creativity: higher = more diverse output.
81
+ """
82
+ )
83
 
84
  # ============================================================
85
+ # 4️⃣ Launch
86
  # ============================================================
87
+ if __name__ == "__main__":
88
+ demo.launch(show_error=True)