rahul7star commited on
Commit
f2fce5e
·
verified ·
1 Parent(s): 6f21ce1

Create app_cpu.py

Browse files
Files changed (1) hide show
  1. app_cpu.py +182 -0
app_cpu.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+ import re
5
+ import gradio as gr
6
+ from huggingface_hub import snapshot_download
7
+
8
+ # ============================================================
9
+ # 1️⃣ Model auto-download during app load
10
+ # ============================================================
11
+
12
+ DEFAULT_MODEL_PATH = os.environ.get("MODEL_OUTPUT_PATH", "PromptEnhancer/PromptEnhancer-32B")
13
+
14
+ print(f"🔄 Checking local model at startup: {DEFAULT_MODEL_PATH}")
15
+ local_model_dir = snapshot_download(repo_id=DEFAULT_MODEL_PATH)
16
+ print(f"✅ Model downloaded and cached at: {local_model_dir}")
17
+
18
+ # ============================================================
19
+ # 2️⃣ Helper utils
20
+ # ============================================================
21
+
22
+ try:
23
+ from qwen_vl_utils import process_vision_info
24
+ except Exception:
25
+ def process_vision_info(messages):
26
+ return None, None
27
+
28
+ def replace_single_quotes(text):
29
+ pattern = r"\B'([^']*)'\B"
30
+ replaced_text = re.sub(pattern, r'"\1"', text)
31
+ replaced_text = replaced_text.replace("’", "”").replace("‘", "“")
32
+ return replaced_text
33
+
34
+ def _str_to_dtype(dtype_str):
35
+ if dtype_str in ("bfloat16", "float16", "float32"):
36
+ return dtype_str
37
+ return "float32"
38
+
39
+ # ============================================================
40
+ # 3️⃣ CPU inference function
41
+ # ============================================================
42
+
43
+ def cpu_predict(model_path, torch_dtype, prompt_cot, sys_prompt, temperature, max_new_tokens):
44
+ import torch
45
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
46
+
47
+ if not logging.getLogger(__name__).handlers:
48
+ logging.basicConfig(level=logging.INFO)
49
+ logger = logging.getLogger(__name__)
50
+
51
+ dtype = {
52
+ "bfloat16": torch.bfloat16,
53
+ "float16": torch.float16,
54
+ "float32": torch.float32,
55
+ }.get(torch_dtype, torch.float32)
56
+
57
+ # Force CPU
58
+ device = "cpu"
59
+
60
+ logger.info("🔧 Loading model to CPU...")
61
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
62
+ model_path,
63
+ torch_dtype=dtype,
64
+ device_map={"": device}, # CPU-only mapping
65
+ attn_implementation="sdpa",
66
+ )
67
+ processor = AutoProcessor.from_pretrained(model_path)
68
+
69
+ org_prompt_cot = prompt_cot
70
+ user_prompt_format = sys_prompt + "\n" + org_prompt_cot
71
+ messages = [{"role": "user", "content": [{"type": "text", "text": user_prompt_format}]}]
72
+
73
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
74
+ image_inputs, video_inputs = process_vision_info(messages)
75
+
76
+ inputs = processor(
77
+ text=[text],
78
+ images=image_inputs,
79
+ videos=video_inputs,
80
+ padding=True,
81
+ return_tensors="pt",
82
+ ).to(device)
83
+
84
+ logger.info("🧠 Running generation on CPU...")
85
+ generated_ids = model.generate(
86
+ **inputs,
87
+ max_new_tokens=int(max_new_tokens),
88
+ temperature=float(temperature),
89
+ do_sample=False,
90
+ top_k=5,
91
+ top_p=0.9,
92
+ )
93
+
94
+ generated_ids_trimmed = [
95
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
96
+ ]
97
+
98
+ output_text = processor.batch_decode(
99
+ generated_ids_trimmed,
100
+ skip_special_tokens=True,
101
+ clean_up_tokenization_spaces=False,
102
+ )
103
+ output_res = output_text[0]
104
+
105
+ try:
106
+ assert output_res.count("think>") == 2
107
+ new_prompt = output_res.split("think>")[-1].lstrip("\n")
108
+ new_prompt = replace_single_quotes(new_prompt)
109
+ except Exception:
110
+ new_prompt = org_prompt_cot
111
+
112
+ return new_prompt, ""
113
+
114
+ # ============================================================
115
+ # 4️⃣ Gradio interface
116
+ # ============================================================
117
+
118
+ def run_single(prompt, sys_prompt, temperature, max_new_tokens, torch_dtype, state):
119
+ if not prompt.strip():
120
+ return "", "请先输入提示词。", state
121
+
122
+ t0 = time.time()
123
+ try:
124
+ new_prompt, err = cpu_predict(
125
+ model_path=local_model_dir,
126
+ torch_dtype=_str_to_dtype(torch_dtype),
127
+ prompt_cot=prompt,
128
+ sys_prompt=sys_prompt,
129
+ temperature=temperature,
130
+ max_new_tokens=max_new_tokens,
131
+ )
132
+ dt = time.time() - t0
133
+ msg = f"耗时:{dt:.2f}s"
134
+ if err:
135
+ msg = f"{err}({msg})"
136
+ return new_prompt, msg, state
137
+ except Exception as e:
138
+ return "", f"调用失败:{e}", state
139
+
140
+ # ============================================================
141
+ # 5️⃣ UI
142
+ # ============================================================
143
+
144
+ test_list_zh = [
145
+ "第三人称视角,赛车在城市赛道上飞驰,左上角是小地图,地图下面是当前名次,右下角仪表盘显示当前速度。",
146
+ ]
147
+ test_list_en = [
148
+ "Create a painting depicting a 30-year-old white-collar worker on a business trip by plane.",
149
+ ]
150
+
151
+ with gr.Blocks(title="Prompt Enhancer (CPU Mode)") as demo:
152
+ gr.Markdown("## 🧩 Prompt Enhancer (CPU Mode — model preloaded)")
153
+ with gr.Row():
154
+ sys_prompt = gr.Textbox(
155
+ label="系统提示词",
156
+ value="请根据用户的输入,生成思考过程的思维链并改写提示词:",
157
+ lines=3
158
+ )
159
+ temperature = gr.Slider(0, 1, value=0.1, step=0.05, label="Temperature")
160
+ max_new_tokens = gr.Slider(16, 4096, value=2048, step=16, label="Max New Tokens")
161
+ torch_dtype = gr.Dropdown(["float32", "float16", "bfloat16"], value="float32", label="torch_dtype")
162
+
163
+ state = gr.State(value=None)
164
+
165
+ with gr.Tab("推理"):
166
+ with gr.Row():
167
+ with gr.Column(scale=2):
168
+ prompt = gr.Textbox(label="输入提示词", lines=6, placeholder="在此粘贴要改写的提示词...")
169
+ run_btn = gr.Button("生成重写", variant="primary")
170
+ gr.Examples(examples=test_list_zh + test_list_en, inputs=prompt)
171
+ with gr.Column(scale=3):
172
+ out_text = gr.Textbox(label="重写结果", lines=10)
173
+ out_info = gr.Markdown("✅ 模型已在CPU加载。")
174
+
175
+ run_btn.click(
176
+ run_single,
177
+ inputs=[prompt, sys_prompt, temperature, max_new_tokens, torch_dtype, state],
178
+ outputs=[out_text, out_info, state]
179
+ )
180
+
181
+ if __name__ == "__main__":
182
+ demo.launch(show_error=True, share=True)