sshenai commited on
Commit
c7a77a7
·
verified ·
1 Parent(s): 124507a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -56
app.py CHANGED
@@ -1,70 +1,109 @@
1
- # 导入必要库
2
- from transformers import AutoFeatureExtractor, AutoModelForImageClassification
 
 
 
 
 
 
3
  from PIL import Image
4
- import requests
5
- from io import BytesIO
6
  import torch
7
- from transformers import pipeline
8
- import wikipedia # 用于获取鸟类百科信息
9
- from wikipedia.exceptions import DisambiguationError, PageError
10
 
11
- # 1. 鸟类图片识别(使用指定模型)
12
- def bird_classification(image_url):
13
- model_name = "chriamue/bird-species-classifier"
14
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
15
- model = AutoModelForImageClassification.from_pretrained(model_name)
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # 下载并预处理图片
20
- response = requests.get(image_url)
21
- img = Image.open(BytesIO(response.content)).convert("RGB")
22
- inputs = feature_extractor(img, return_tensors="pt").to(device)
 
 
23
 
24
- # 模型推理
25
- with torch.no_grad():
26
- outputs = model(**inputs)
27
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
 
 
 
 
 
28
 
29
- # 获取前1个预测结果
30
- predicted_id = torch.argmax(probabilities).item()
31
- labels = model.config.id2label
32
- bird_species = labels[predicted_id]
33
- confidence = round(probabilities[predicted_id].item(), 3)
 
34
 
35
- return bird_species, confidence
36
 
37
- # 2. 鸟类信息获取(使用维基百科API)
38
- def get_bird_info(species_name):
39
  try:
40
- # 去除可能的多余标签(如模型输出中的括号内容)
41
- clean_name = species_name.split("(")[0].strip()
42
- # 从维基百科获取摘要(英文转中文)
43
- summary = wikipedia.summary(clean_name, sentences=3, auto_suggest=False)
44
- return summary
45
- except (DisambiguationError, PageError):
46
- return "抱歉,未找到该鸟类的详细信息。"
 
 
 
 
 
47
 
48
- # 3. 文本转语音(使用TTS模型)
49
- def text_to_speech(text, output_file="bird_info.mp3"):
50
- tts = pipeline("text-to-speech", model="tts_models/en_US/tacotron2")
51
- speech = tts(text)
52
- with open(output_file, "wb") as f:
53
- f.write(speech["audio"])
54
- return output_file
55
 
56
- # 主函数
57
- def bird_knowledge_pipeline(image_url):
58
- # 1. 鸟类识别
59
- species, confidence = bird_classification(image_url)
60
- print(f"识别结果:{species}(置信度:{confidence*100:.1f}%)")
 
 
61
 
62
- # 2. 获取详细信息
63
- info = get_bird_info(species)
64
- print(f"鸟类介绍:\n{info}")
65
 
66
- # 3. 生成语音
67
- audio_file = text_to_speech(f"这是{species}的介绍:{info}")
68
- print(f"语音文件已保存:{audio_file}")
 
 
69
 
70
- return species, info, audio_file
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ 鸟类知识科普系统(Qwen3优化版) by [你的名字]
4
+ ISOM5240 Group Project
5
+ """
6
+
7
+ import gradio as gr
8
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
9
  from PIL import Image
 
 
10
  import torch
 
 
 
11
 
12
+ # 强制清理旧版缓存
13
+ from transformers.utils import move_cache
14
+ move_cache()
15
+
16
+ # 初始化模型(兼容Qwen3)
17
+ def init_models():
18
+ # 鸟类分类模型(保持不变)
19
+ classifier = pipeline(
20
+ "image-classification",
21
+ model="chriamue/bird-species-classifier",
22
+ device=0 if torch.cuda.is_available() else -1
23
+ )
24
+
25
+ # 更新为Qwen3模型(官方支持版本)
26
+ text_generator = pipeline(
27
+ "text-generation",
28
+ model="Qwen/Qwen-7B-Chat", # 使用官方维护版本
29
+ device_map="auto",
30
+ torch_dtype=torch.bfloat16,
31
+ trust_remote_code=True, # 必须开启
32
+ model_kwargs={
33
+ "revision": "main",
34
+ "force_download": True # 替换弃用参数
35
+ }
36
+ )
37
 
38
+ # 语音合成模型(保持不变)
39
+ tts = pipeline(
40
+ "text-to-speech",
41
+ model="facebook/mms-tts-eng",
42
+ device=0 if torch.cuda.is_available() else -1
43
+ )
44
 
45
+ return classifier, text_generator, tts
46
+
47
+ # 生成儿童友好的鸟类描述
48
+ def generate_child_friendly_text(bird_name):
49
+ PROMPT = f"""以6-12岁儿童能理解的方式描述{bird_name}:
50
+ 1. 用比喻手法(如:羽毛像彩虹糖纸)
51
+ 2. 包含一个趣味冷知识(例如:每天吃相当于自身体重30%的食物)
52
+ 3. 语句长度不超过15个英文单词
53
+ 4. 避免使用专业术语"""
54
 
55
+ response = text_generator(
56
+ PROMPT,
57
+ max_new_tokens=150,
58
+ temperature=0.7,
59
+ do_sample=True
60
+ )
61
 
62
+ return response[0]['generated_text'].split('\n')[2]
63
 
64
+ # 主处理流程
65
+ def process_image(image):
66
  try:
67
+ classification = classifier(image)
68
+ bird_name = classification[0]['label']
69
+ description = generate_child_friendly_text(bird_name)
70
+ speech = tts(description, forward_params={"speaker_id": 6})
71
+
72
+ return {
73
+ "bird_name": bird_name,
74
+ "description": description,
75
+ "audio": speech["audio"]
76
+ }
77
+ except Exception as e:
78
+ return f"处理错误: {str(e)}"
79
 
80
+ # 初始化模型
81
+ classifier, text_generator, tts = init_models()
 
 
 
 
 
82
 
83
+ # 创建Gradio界面
84
+ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px}") as demo:
85
+ gr.Markdown("# 🐦 鸟类知识小课堂(Qwen3版)")
86
+
87
+ with gr.Row():
88
+ image_input = gr.Image(type="pil", label="上传鸟类图片", height=300)
89
+ audio_output = gr.Audio(label="语音��解", autoplay=True)
90
 
91
+ with gr.Column():
92
+ name_output = gr.Textbox(label="识别到的鸟类")
93
+ text_output = gr.Textbox(label="趣味知识", lines=4)
94
 
95
+ examples = gr.Examples(
96
+ examples=["eagle.jpg", "penguin.jpg", "peacock.jpg"],
97
+ inputs=image_input,
98
+ label="示例图片"
99
+ )
100
 
101
+ image_input.change(
102
+ process_image,
103
+ inputs=image_input,
104
+ outputs=[name_output, text_output, audio_output]
105
+ )
106
+
107
+ # 部署配置
108
+ if __name__ == "__main__":
109
+ demo.launch(server_name="0.0.0.0", server_port=7860)