FinalProject / app.py
sshenai's picture
Update app.py
11c2804 verified
raw
history blame
3.22 kB
# -*- coding: utf-8 -*-
"""
鸟类知识科普系统(Qwen3优化版) by [你的名字]
ISOM5240 Group Project
"""
import transformers
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import torch
# 初始化模型(兼容Qwen3)
def init_models():
# 鸟类分类模型(保持不变)
classifier = pipeline(
"image-classification",
model="chriamue/bird-species-classifier",
device=0 if torch.cuda.is_available() else -1
)
# 更新为Qwen3模型(官方支持版本)
text_generator = pipeline(
"text-generation",
model="Qwen/Qwen-7B-Chat", # 使用官方维护版本
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True, # 必须开启
model_kwargs={
"revision": "main",
"force_download": True # 替换弃用参数
}
)
# 语音合成模型(保持不变)
tts = pipeline(
"text-to-speech",
model="facebook/mms-tts-eng",
device=0 if torch.cuda.is_available() else -1
)
return classifier, text_generator, tts
# 生成儿童友好的鸟类描述
def generate_child_friendly_text(bird_name):
PROMPT = f"""以6-12岁儿童能理解的方式描述{bird_name}
1. 用比喻手法(如:羽毛像彩虹糖纸)
2. 包含一个趣味冷知识(例如:每天吃相当于自身体重30%的食物)
3. 语句长度不超过15个英文单词
4. 避免使用专业术语"""
response = text_generator(
PROMPT,
max_new_tokens=150,
temperature=0.7,
do_sample=True
)
return response[0]['generated_text'].split('\n')[2]
# 主处理流程
def process_image(image):
try:
classification = classifier(image)
bird_name = classification[0]['label']
description = generate_child_friendly_text(bird_name)
speech = tts(description, forward_params={"speaker_id": 6})
return {
"bird_name": bird_name,
"description": description,
"audio": speech["audio"]
}
except Exception as e:
return f"处理错误: {str(e)}"
# 初始化模型
classifier, text_generator, tts = init_models()
# 创建Gradio界面
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px}") as demo:
gr.Markdown("# 🐦 鸟类知识小课堂(Qwen3版)")
with gr.Row():
image_input = gr.Image(type="pil", label="上传鸟类图片", height=300)
audio_output = gr.Audio(label="语音讲解", autoplay=True)
with gr.Column():
name_output = gr.Textbox(label="识别到的鸟类")
text_output = gr.Textbox(label="趣味知识", lines=4)
examples = gr.Examples(
examples=["eagle.jpg", "penguin.jpg", "peacock.jpg"],
inputs=image_input,
label="示例图片"
)
image_input.change(
process_image,
inputs=image_input,
outputs=[name_output, text_output, audio_output]
)
# 部署配置
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)