# -*- coding: utf-8 -*- """ 鸟类知识智能科普系统 """ import streamlit as st from PIL import Image import tempfile from transformers import pipeline, AutoConfig import torch # ========== 模型配置 ========== MODEL_CONFIG = { "image_to_text": { "model": "chriamue/bird-species-classifier", "config": {"use_fast": True} # 强制启用快速处理器 }, "text_generation": { "model": "Qwen/Qwen-7B-Chat", "config": AutoConfig.from_pretrained("Qwen/Qwen-7B-Chat", revision="main") }, "text_to_speech": { "model": "facebook/mms-tts-eng", "config": {"speaker_id": 6} # 儿童音色 } } # ========== 模型初始化 ========== @st.cache_resource def init_pipelines(): """缓存模型加载结果避免重复初始化""" try: img_pipeline = pipeline( "image-classification", model=MODEL_CONFIG["image_to_text"]["model"], **MODEL_CONFIG["image_to_text"]["config"] ) text_pipeline = pipeline( "text-generation", model=MODEL_CONFIG["text_generation"]["model"], config=MODEL_CONFIG["text_generation"]["config"], torch_dtype=torch.bfloat16, device_map="auto" ) tts_pipeline = pipeline( "text-to-speech", model=MODEL_CONFIG["text_to_speech"]["model"], **MODEL_CONFIG["text_to_speech"]["config"] ) return img_pipeline, text_pipeline, tts_pipeline except Exception as e: st.error(f"模型加载失败: {str(e)}") st.stop() # ========== 核心功能 ========== def generate_description(_pipe, bird_name): """生成儿童友好型描述""" prompt = f"用6-12岁儿童能理解的语言描述{bird_name},使用比喻和趣味知识:" return _pipe(prompt, max_new_tokens=120)[0]['generated_text'].split(":")[-1] # ========== 界面设计 ========== st.set_page_config(page_title="鸟类知识百科", page_icon="🐦") st.title("🐦 智能鸟类科普系统") st.markdown("上传鸟类图片,获取趣味知识讲解") # 主流程 def main(): img_pipe, text_pipe, tts_pipe = init_pipelines() uploaded_file = st.file_uploader("选择图片文件", type=["jpg", "png", "jpeg"]) if uploaded_file: with tempfile.NamedTemporaryFile(suffix=".jpg") as tmp_file: # 保存临时文件 tmp_file.write(uploaded_file.getvalue()) with st.spinner("识别中..."): # 识别鸟类 result = img_pipe(Image.open(tmp_file.name)) bird_name = result[0]['label'] st.success(f"识别结果:{bird_name}") # 生成描述 desc = generate_description(text_pipe, bird_name) st.subheader("趣味知识") st.write(desc) # 语音合成 audio = tts_pipe(desc[:1000]) # 限制文本长度 st.audio(audio["audio"], sample_rate=audio["sampling_rate"]) if __name__ == "__main__": main()