sshenai commited on
Commit
4d8788c
·
verified ·
1 Parent(s): 11c2804

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -91
app.py CHANGED
@@ -1,107 +1,97 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- 鸟类知识科普系统(Qwen3优化版) by [你的名字]
4
- ISOM5240 Group Project
5
  """
6
- import transformers
7
- import gradio as gr
8
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
9
  from PIL import Image
 
 
10
  import torch
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
-
14
- # 初始化模型(兼容Qwen3)
15
- def init_models():
16
- # 鸟类分类模型(保持不变)
17
- classifier = pipeline(
18
- "image-classification",
19
- model="chriamue/bird-species-classifier",
20
- device=0 if torch.cuda.is_available() else -1
21
- )
22
-
23
- # 更新为Qwen3模型(官方支持版本)
24
- text_generator = pipeline(
25
- "text-generation",
26
- model="Qwen/Qwen-7B-Chat", # 使用官方维护版本
27
- device_map="auto",
28
- torch_dtype=torch.bfloat16,
29
- trust_remote_code=True, # 必须开启
30
- model_kwargs={
31
- "revision": "main",
32
- "force_download": True # 替换弃用参数
33
- }
34
- )
35
-
36
- # 语音合成模型(保持不变)
37
- tts = pipeline(
38
- "text-to-speech",
39
- model="facebook/mms-tts-eng",
40
- device=0 if torch.cuda.is_available() else -1
41
- )
42
-
43
- return classifier, text_generator, tts
44
-
45
- # 生成儿童友好的鸟类描述
46
- def generate_child_friendly_text(bird_name):
47
- PROMPT = f"""以6-12岁儿童能理解的方式描述{bird_name}:
48
- 1. 用比喻手法(如:羽毛像彩虹糖纸)
49
- 2. 包含一个趣味冷知识(例如:每天吃相当于自身体重30%的食物)
50
- 3. 语句长度不超过15个英文单词
51
- 4. 避免使用专业术语"""
52
-
53
- response = text_generator(
54
- PROMPT,
55
- max_new_tokens=150,
56
- temperature=0.7,
57
- do_sample=True
58
- )
59
-
60
- return response[0]['generated_text'].split('\n')[2]
61
-
62
- # 主处理流程
63
- def process_image(image):
64
  try:
65
- classification = classifier(image)
66
- bird_name = classification[0]['label']
67
- description = generate_child_friendly_text(bird_name)
68
- speech = tts(description, forward_params={"speaker_id": 6})
 
 
 
 
 
 
 
 
 
69
 
70
- return {
71
- "bird_name": bird_name,
72
- "description": description,
73
- "audio": speech["audio"]
74
- }
 
 
 
75
  except Exception as e:
76
- return f"处理错误: {str(e)}"
 
77
 
78
- # 初始化模型
79
- classifier, text_generator, tts = init_models()
 
 
 
80
 
81
- # 创建Gradio界面
82
- with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px}") as demo:
83
- gr.Markdown("# 🐦 鸟类知识小课堂(Qwen3版)")
84
-
85
- with gr.Row():
86
- image_input = gr.Image(type="pil", label="上传鸟类图片", height=300)
87
- audio_output = gr.Audio(label="语音讲解", autoplay=True)
88
-
89
- with gr.Column():
90
- name_output = gr.Textbox(label="识别到的鸟类")
91
- text_output = gr.Textbox(label="趣味知识", lines=4)
92
 
93
- examples = gr.Examples(
94
- examples=["eagle.jpg", "penguin.jpg", "peacock.jpg"],
95
- inputs=image_input,
96
- label="示例图片"
97
- )
98
 
99
- image_input.change(
100
- process_image,
101
- inputs=image_input,
102
- outputs=[name_output, text_output, audio_output]
103
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- # 部署配置
106
  if __name__ == "__main__":
107
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ 鸟类知识智能科普系统
 
4
  """
5
+
6
+ import streamlit as st
 
7
  from PIL import Image
8
+ import tempfile
9
+ from transformers import pipeline, AutoConfig
10
  import torch
11
 
12
+ # ========== 模型配置 ==========
13
+ MODEL_CONFIG = {
14
+ "image_to_text": {
15
+ "model": "chriamue/bird-species-classifier",
16
+ "config": {"use_fast": True} # 强制启用快速处理器
17
+ },
18
+ "text_generation": {
19
+ "model": "Qwen/Qwen-7B-Chat",
20
+ "config": AutoConfig.from_pretrained("Qwen/Qwen-7B-Chat", revision="main")
21
+ },
22
+ "text_to_speech": {
23
+ "model": "facebook/mms-tts-eng",
24
+ "config": {"speaker_id": 6} # 儿童音色
25
+ }
26
+ }
27
 
28
+ # ========== 模型初始化 ==========
29
+ @st.cache_resource
30
+ def init_pipelines():
31
+ """缓存模型加载结果避免重复初始化"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  try:
33
+ img_pipeline = pipeline(
34
+ "image-classification",
35
+ model=MODEL_CONFIG["image_to_text"]["model"],
36
+ **MODEL_CONFIG["image_to_text"]["config"]
37
+ )
38
+
39
+ text_pipeline = pipeline(
40
+ "text-generation",
41
+ model=MODEL_CONFIG["text_generation"]["model"],
42
+ config=MODEL_CONFIG["text_generation"]["config"],
43
+ torch_dtype=torch.bfloat16,
44
+ device_map="auto"
45
+ )
46
 
47
+ tts_pipeline = pipeline(
48
+ "text-to-speech",
49
+ model=MODEL_CONFIG["text_to_speech"]["model"],
50
+ **MODEL_CONFIG["text_to_speech"]["config"]
51
+ )
52
+
53
+ return img_pipeline, text_pipeline, tts_pipeline
54
+
55
  except Exception as e:
56
+ st.error(f"模型加载失败: {str(e)}")
57
+ st.stop()
58
 
59
+ # ========== 核心功能 ==========
60
+ def generate_description(_pipe, bird_name):
61
+ """生成儿童友好型描述"""
62
+ prompt = f"用6-12岁儿童能理解的语言描述{bird_name},使用比喻和趣味知识:"
63
+ return _pipe(prompt, max_new_tokens=120)[0]['generated_text'].split(":")[-1]
64
 
65
+ # ========== 界面设计 ==========
66
+ st.set_page_config(page_title="鸟类知识百科", page_icon="🐦")
67
+ st.title("🐦 智能鸟类科普系统")
68
+ st.markdown("上传鸟类图片,获取趣味知识讲解")
69
+
70
+ # 主流程
71
+ def main():
72
+ img_pipe, text_pipe, tts_pipe = init_pipelines()
 
 
 
73
 
74
+ uploaded_file = st.file_uploader("选择图片文件", type=["jpg", "png", "jpeg"])
 
 
 
 
75
 
76
+ if uploaded_file:
77
+ with tempfile.NamedTemporaryFile(suffix=".jpg") as tmp_file:
78
+ # 保存临时文件
79
+ tmp_file.write(uploaded_file.getvalue())
80
+
81
+ with st.spinner("识别中..."):
82
+ # 识别鸟类
83
+ result = img_pipe(Image.open(tmp_file.name))
84
+ bird_name = result[0]['label']
85
+ st.success(f"识别结果:{bird_name}")
86
+
87
+ # 生成描述
88
+ desc = generate_description(text_pipe, bird_name)
89
+ st.subheader("趣味知识")
90
+ st.write(desc)
91
+
92
+ # 语音合成
93
+ audio = tts_pipe(desc[:1000]) # 限制文本长度
94
+ st.audio(audio["audio"], sample_rate=audio["sampling_rate"])
95
 
 
96
  if __name__ == "__main__":
97
+ main()