Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| """ | |
| 鸟类知识科普系统(Qwen3优化版) by [你的名字] | |
| ISOM5240 Group Project | |
| """ | |
| import transformers | |
| import gradio as gr | |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
| from PIL import Image | |
| import torch | |
| # 初始化模型(兼容Qwen3) | |
| def init_models(): | |
| # 鸟类分类模型(保持不变) | |
| classifier = pipeline( | |
| "image-classification", | |
| model="chriamue/bird-species-classifier", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| # 更新为Qwen3模型(官方支持版本) | |
| text_generator = pipeline( | |
| "text-generation", | |
| model="Qwen/Qwen-7B-Chat", # 使用官方维护版本 | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, # 必须开启 | |
| model_kwargs={ | |
| "revision": "main", | |
| "force_download": True # 替换弃用参数 | |
| } | |
| ) | |
| # 语音合成模型(保持不变) | |
| tts = pipeline( | |
| "text-to-speech", | |
| model="facebook/mms-tts-eng", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| return classifier, text_generator, tts | |
| # 生成儿童友好的鸟类描述 | |
| def generate_child_friendly_text(bird_name): | |
| PROMPT = f"""以6-12岁儿童能理解的方式描述{bird_name}: | |
| 1. 用比喻手法(如:羽毛像彩虹糖纸) | |
| 2. 包含一个趣味冷知识(例如:每天吃相当于自身体重30%的食物) | |
| 3. 语句长度不超过15个英文单词 | |
| 4. 避免使用专业术语""" | |
| response = text_generator( | |
| PROMPT, | |
| max_new_tokens=150, | |
| temperature=0.7, | |
| do_sample=True | |
| ) | |
| return response[0]['generated_text'].split('\n')[2] | |
| # 主处理流程 | |
| def process_image(image): | |
| try: | |
| classification = classifier(image) | |
| bird_name = classification[0]['label'] | |
| description = generate_child_friendly_text(bird_name) | |
| speech = tts(description, forward_params={"speaker_id": 6}) | |
| return { | |
| "bird_name": bird_name, | |
| "description": description, | |
| "audio": speech["audio"] | |
| } | |
| except Exception as e: | |
| return f"处理错误: {str(e)}" | |
| # 初始化模型 | |
| classifier, text_generator, tts = init_models() | |
| # 创建Gradio界面 | |
| with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 800px}") as demo: | |
| gr.Markdown("# 🐦 鸟类知识小课堂(Qwen3版)") | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="上传鸟类图片", height=300) | |
| audio_output = gr.Audio(label="语音讲解", autoplay=True) | |
| with gr.Column(): | |
| name_output = gr.Textbox(label="识别到的鸟类") | |
| text_output = gr.Textbox(label="趣味知识", lines=4) | |
| examples = gr.Examples( | |
| examples=["eagle.jpg", "penguin.jpg", "peacock.jpg"], | |
| inputs=image_input, | |
| label="示例图片" | |
| ) | |
| image_input.change( | |
| process_image, | |
| inputs=image_input, | |
| outputs=[name_output, text_output, audio_output] | |
| ) | |
| # 部署配置 | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |