Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| from typing import Annotated, Optional | |
| from uuid import UUID | |
| from components.services.dataset import DatasetService | |
| from components.services.entity import EntityService | |
| from fastapi import APIRouter, Depends, HTTPException | |
| import common.dependencies as DI | |
| from common.configuration import Configuration, Query | |
| from components.llm.common import ChatRequest, LlmParams, LlmPredictParams, Message | |
| from components.llm.deepinfra_api import DeepInfraApi | |
| from components.llm.utils import append_llm_response_to_history | |
| from components.services.llm_config import LLMConfigService | |
| from components.services.llm_prompt import LlmPromptService | |
| router = APIRouter(prefix='/llm') | |
| logger = logging.getLogger(__name__) | |
| conf = DI.get_config() | |
| llm_params = LlmParams( | |
| **{ | |
| "url": conf.llm_config.base_url, | |
| "model": conf.llm_config.model, | |
| "tokenizer": "unsloth/Llama-3.3-70B-Instruct", | |
| "type": "deepinfra", | |
| "default": True, | |
| "predict_params": LlmPredictParams( | |
| temperature=0.15, | |
| top_p=0.95, | |
| min_p=0.05, | |
| seed=42, | |
| repetition_penalty=1.2, | |
| presence_penalty=1.1, | |
| n_predict=2000, | |
| ), | |
| "api_key": os.environ.get(conf.llm_config.api_key_env), | |
| "context_length": 128000, | |
| } | |
| ) | |
| # TODO: унести в DI | |
| llm_api = DeepInfraApi(params=llm_params) | |
| async def chat( | |
| request: ChatRequest, | |
| config: Annotated[Configuration, Depends(DI.get_config)], | |
| llm_api: Annotated[DeepInfraApi, Depends(DI.get_llm_service)], | |
| prompt_service: Annotated[LlmPromptService, Depends(DI.get_llm_prompt_service)], | |
| llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)], | |
| entity_service: Annotated[EntityService, Depends(DI.get_entity_service)], | |
| dataset_service: Annotated[DatasetService, Depends(DI.get_dataset_service)], | |
| ): | |
| try: | |
| p = llm_config_service.get_default() | |
| system_prompt = prompt_service.get_default() | |
| predict_params = LlmPredictParams( | |
| temperature=p.temperature, | |
| top_p=p.top_p, | |
| min_p=p.min_p, | |
| seed=p.seed, | |
| frequency_penalty=p.frequency_penalty, | |
| presence_penalty=p.presence_penalty, | |
| n_predict=p.n_predict, | |
| stop=[], | |
| ) | |
| # TODO: Вынести | |
| def get_last_user_message(chat_request: ChatRequest) -> Optional[Message]: | |
| return next( | |
| ( | |
| msg | |
| for msg in reversed(chat_request.history) | |
| if msg.role == "user" | |
| and (msg.searchResults is None or not msg.searchResults) | |
| ), | |
| None, | |
| ) | |
| def insert_search_results_to_message( | |
| chat_request: ChatRequest, new_content: str | |
| ) -> bool: | |
| for msg in reversed(chat_request.history): | |
| if msg.role == "user" and ( | |
| msg.searchResults is None or not msg.searchResults | |
| ): | |
| msg.content = new_content | |
| return True | |
| return False | |
| last_query = get_last_user_message(request) | |
| search_result = None | |
| logger.info(f"last_query: {last_query}") | |
| if last_query: | |
| dataset = dataset_service.get_current_dataset() | |
| if dataset is None: | |
| raise HTTPException(status_code=400, detail="Dataset not found") | |
| logger.info(f"last_query: {last_query.content}") | |
| _, scores, chunk_ids = entity_service.search_similar(last_query.content, dataset.id) | |
| chunks = entity_service.chunk_repository.get_chunks_by_ids(chunk_ids) | |
| logger.info(f"chunk_ids: {chunk_ids[:3]}...{chunk_ids[-3:]}") | |
| logger.info(f"scores: {scores[:3]}...{scores[-3:]}") | |
| text_chunks = entity_service.build_text(chunks, scores) | |
| logger.info(f"text_chunks: {text_chunks[:3]}...{text_chunks[-3:]}") | |
| new_message = f'{last_query.content} /n<search-results>/n{text_chunks}/n</search-results>' | |
| insert_search_results_to_message(request, new_message) | |
| logger.info(f"request: {request}") | |
| response = await llm_api.predict_chat_stream( | |
| request, system_prompt.text, predict_params | |
| ) | |
| result = append_llm_response_to_history(request, response) | |
| return result | |
| except Exception as e: | |
| logger.error( | |
| f"Error processing LLM request: {str(e)}", stack_info=True, stacklevel=10 | |
| ) | |
| return {"error": str(e)} | |