""" 支持 RAG 的多轮对话管理器。 """ from typing import List, Dict, Any, Optional from sqlalchemy.orm import Session from app.db.models import Conversation, Message from app.services.model_manager import ModelManager from app.services.kb_manager import KnowledgeBaseManager from app.config import Settings from app.utils.exceptions import ResourceNotFoundError from app.utils.logger import get_logger logger = get_logger(__name__) class ConversationManager: """ 管理具有可选 RAG 支持的多轮对话。 """ def __init__( self, db_session: Session, model_manager: ModelManager, kb_manager: KnowledgeBaseManager, settings: Settings, ): """ 初始化对话管理器。 Args: db_session: 数据库会话 model_manager: 模型管理器实例 kb_manager: 知识库管理器实例 settings: 应用程序设置 """ self.db = db_session self.model_manager = model_manager self.kb_manager = kb_manager self.settings = settings def create_conversation( self, user_id: Optional[int] = None, title: str = "New Conversation" ) -> Conversation: """ 创建新对话。 Args: user_id: 可选的用户 ID title: 对话标题 Returns: Conversation: 创建的对话 """ logger.info("creating_conversation", user_id=user_id, title=title) conversation = Conversation(user_id=user_id, title=title) self.db.add(conversation) self.db.commit() self.db.refresh(conversation) logger.info("conversation_created", conversation_id=conversation.id) return conversation def get_conversation(self, conversation_id: int) -> Conversation: """ 根据 ID 获取对话。 Args: conversation_id: 对话 ID Returns: Conversation: 对话实例 Raises: ResourceNotFoundError: 如果找不到对话 """ conversation = ( self.db.query(Conversation) .filter(Conversation.id == conversation_id) .first() ) if not conversation: raise ResourceNotFoundError(f"Conversation not found: {conversation_id}") return conversation def add_message( self, conversation_id: int, role: str, content: str, metadata: Optional[Dict[str, Any]] = None, ) -> Message: """ 向对话添加消息。 Args: conversation_id: 对话 ID role: 消息角色('user'、'assistant'、'system') content: 消息内容 metadata: 可选的元数据 Returns: Message: 创建的消息 Raises: ResourceNotFoundError: 如果找不到对话 """ # 验证对话是否存在 self.get_conversation(conversation_id) message = Message( conversation_id=conversation_id, role=role, content=content, msg_metadata=metadata or {}, ) self.db.add(message) self.db.commit() self.db.refresh(message) logger.info( "message_added", conversation_id=conversation_id, role=role, message_id=message.id, ) return message def get_messages( self, conversation_id: int, limit: int = 50, offset: int = 0, ) -> List[Message]: """ 获取对话中的消息。 Args: conversation_id: 对话 ID limit: 最大消息数量 offset: 分页偏移量 Returns: List[Message]: 消息列表 Raises: ResourceNotFoundError: 如果找不到对话 """ # 验证对话是否存在 self.get_conversation(conversation_id) messages = ( self.db.query(Message) .filter(Message.conversation_id == conversation_id) .order_by(Message.created_at.asc()) .offset(offset) .limit(limit) .all() ) return messages async def chat( self, conversation_id: int, user_input: str, kb_id: Optional[int] = None, llm_name: Optional[str] = None, use_rag: bool = True, ) -> Message: """ 处理用户消息并生成助手回复。 Args: conversation_id: 对话 ID user_input: 用户消息 kb_id: 可选的知识库 ID(用于 RAG) llm_name: 可选的 LLM 模型名称 use_rag: 是否使用 RAG(默认 True) Returns: Message: 助手的回复消息 Raises: ResourceNotFoundError: 如果找不到对话 """ logger.info( "processing_chat", conversation_id=conversation_id, use_rag=use_rag, kb_id=kb_id, ) # 验证对话是否存在 conversation = self.get_conversation(conversation_id) # 保存用户消息 self.add_message(conversation_id, "user", user_input) # 获取对话历史 history = self._build_history(conversation_id, limit=10) # 获取 LLM llm = self.model_manager.get_llm(llm_name) # 如果启用了 RAG,从知识库构建上下文 context = "" sources = [] if use_rag and kb_id: try: results = self.kb_manager.query_kb(kb_id, user_input, k=3) context_parts = [] for i, result in enumerate(results, 1): context_parts.append( f"[{i}] {result['content'][:500]}..." if len(result['content']) > 500 else f"[{i}] {result['content']}" ) sources.append(result['metadata']) context = "\n\n".join(context_parts) logger.info("rag_context_retrieved", kb_id=kb_id, source_count=len(results)) except Exception as e: logger.warning("rag_context_failed", error=str(e)) context = "(No relevant context found)" # 构建提示 if context: prompt = f"""You are a helpful AI assistant. Use the following context from the knowledge base to answer the question. Context: {context} Conversation history: {history} User: {user_input} Assistant:""" else: prompt = f"""{history} User: {user_input} Assistant:""" # 生成回复 try: response = llm.predict(prompt) logger.info("llm_response_generated", conversation_id=conversation_id) except Exception as e: logger.error("llm_generation_failed", error=str(e)) response = "I apologize, but I encountered an error generating a response. Please try again." # 保存助手消息 metadata = {} if sources: metadata["sources"] = sources if llm_name: metadata["model"] = llm_name assistant_message = self.add_message( conversation_id, "assistant", response, metadata ) logger.info("chat_complete", conversation_id=conversation_id) return assistant_message def _build_history(self, conversation_id: int, limit: int = 10) -> str: """ 构建对话历史字符串。 Args: conversation_id: 对话 ID limit: 要包含的最大消息数 Returns: str: 格式化的对话历史 """ messages = self.get_messages(conversation_id, limit=limit) history_parts = [] for msg in messages[:-1]: # 排除最后一条用户消息(已在提示中) if msg.role == "user": history_parts.append(f"User: {msg.content}") elif msg.role == "assistant": history_parts.append(f"Assistant: {msg.content}") return "\n".join(history_parts) if history_parts else "(No previous messages)" def delete_conversation(self, conversation_id: int) -> bool: """ 删除对话及其所有消息。 Args: conversation_id: 对话 ID Returns: bool: 如果已删除则为 True Raises: ResourceNotFoundError: 如果找不到对话 """ conversation = self.get_conversation(conversation_id) logger.info("deleting_conversation", conversation_id=conversation_id) self.db.delete(conversation) self.db.commit() logger.info("conversation_deleted", conversation_id=conversation_id) return True