""" 支持 RAG 的多轮对话管理器。 """ from typing import List, Dict, Any, Optional, Tuple from sqlalchemy.orm import Session from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough, RunnableLambda from langchain_core.messages import HumanMessage, AIMessage, SystemMessage from app.db.models import Conversation, Message from app.services.model_manager import ModelManager from app.services.kb_manager import KnowledgeBaseManager from app.config import Settings from app.utils.exceptions import ResourceNotFoundError from app.utils.logger import get_logger logger = get_logger(__name__) class ConversationManager: """ 管理具有可选 RAG 支持的多轮对话。 """ def __init__( self, db_session: Session, model_manager: ModelManager, kb_manager: KnowledgeBaseManager, settings: Settings, ): """ 初始化对话管理器。 Args: db_session: 数据库会话 model_manager: 模型管理器实例 kb_manager: 知识库管理器实例 settings: 应用程序设置 """ self.db = db_session self.model_manager = model_manager self.kb_manager = kb_manager self.settings = settings def create_conversation( self, user_id: Optional[int] = None, title: str = "New Conversation" ) -> Conversation: """ 创建新对话。 Args: user_id: 可选的用户 ID title: 对话标题 Returns: Conversation: 创建的对话 """ logger.info("creating_conversation", user_id=user_id, title=title) conversation = Conversation(user_id=user_id, title=title) self.db.add(conversation) self.db.commit() self.db.refresh(conversation) logger.info("conversation_created", conversation_id=conversation.id) return conversation def get_conversation(self, conversation_id: int) -> Conversation: """ 根据 ID 获取对话。 Args: conversation_id: 对话 ID Returns: Conversation: 对话实例 Raises: ResourceNotFoundError: 如果找不到对话 """ conversation = ( self.db.query(Conversation) .filter(Conversation.id == conversation_id) .first() ) if not conversation: raise ResourceNotFoundError(f"Conversation not found: {conversation_id}") return conversation def add_message( self, conversation_id: int, role: str, content: str, metadata: Optional[Dict[str, Any]] = None, ) -> Message: """ 向对话添加消息。 Args: conversation_id: 对话 ID role: 消息角色('user'、'assistant'、'system') content: 消息内容 metadata: 可选的元数据 Returns: Message: 创建的消息 Raises: ResourceNotFoundError: 如果找不到对话 """ # 验证对话是否存在 self.get_conversation(conversation_id) message = Message( conversation_id=conversation_id, role=role, content=content, msg_metadata=metadata or {}, ) self.db.add(message) self.db.commit() self.db.refresh(message) logger.info( "message_added", conversation_id=conversation_id, role=role, message_id=message.id, ) return message def get_messages( self, conversation_id: int, limit: int = 50, offset: int = 0, ) -> List[Message]: """ 获取对话中的消息。 Args: conversation_id: 对话 ID limit: 最大消息数量 offset: 分页偏移量 Returns: List[Message]: 消息列表 Raises: ResourceNotFoundError: 如果找不到对话 """ # 验证对话是否存在 self.get_conversation(conversation_id) messages = ( self.db.query(Message) .filter(Message.conversation_id == conversation_id) .order_by(Message.created_at.asc()) .offset(offset) .limit(limit) .all() ) return messages def _build_chat_history_messages(self, conversation_id: int, limit: int = 10) -> List: """ 构建对话历史消息列表(LangChain 消息格式)。 Args: conversation_id: 对话 ID limit: 要包含的最大消息数 Returns: List: LangChain 消息对象列表 """ messages = self.get_messages(conversation_id, limit=limit) history_messages = [] for msg in messages[:-1]: # 排除最后一条用户消息(将单独添加) if msg.role == "user": history_messages.append(HumanMessage(content=msg.content)) elif msg.role == "assistant": history_messages.append(AIMessage(content=msg.content)) return history_messages def _retrieve_context(self, kb_id: int, query: str, k: int = 3) -> Tuple[str, List[Dict]]: """ 从知识库检索上下文。 Args: kb_id: 知识库 ID query: 查询文本 k: 返回的文档数量 Returns: Tuple: (格式化的上下文字符串, 来源元数据列表) """ try: results = self.kb_manager.query_kb(kb_id, query, k=k) context_parts = [] sources = [] for i, result in enumerate(results, 1): content = result['content'] if len(content) > 500: content = f"{content[:500]}..." context_parts.append(f"[{i}] {content}") sources.append(result['metadata']) context = "\n\n".join(context_parts) logger.info("rag_context_retrieved", kb_id=kb_id, source_count=len(results)) return context, sources except Exception as e: logger.warning("rag_context_failed", error=str(e)) return "(No relevant context found)", [] def _create_rag_chain(self, llm, kb_id: Optional[int] = None): """ 创建带 RAG 的 LCEL 链。 Args: llm: LLM 实例 kb_id: 可选的知识库 ID Returns: Runnable: LCEL 链 """ if kb_id: # 带 RAG 的提示模板 prompt = ChatPromptTemplate.from_messages([ ("system", "You are a helpful AI assistant. Use the following context from the knowledge base to answer the question.\n\nContext:\n{context}"), MessagesPlaceholder(variable_name="chat_history", optional=True), ("human", "{input}"), ]) # 创建检索器函数 def retrieve_context(inputs): context, sources = self._retrieve_context(kb_id, inputs["input"]) return { **inputs, "context": context, "_sources": sources # 保存来源以供后续使用 } # LCEL 链: 检索 -> 提示 -> LLM -> 解析输出 chain = ( RunnableLambda(retrieve_context) | RunnablePassthrough.assign( response=prompt | llm | StrOutputParser() ) ) else: # 无 RAG 的简单提示模板 prompt = ChatPromptTemplate.from_messages([ ("system", "You are a helpful AI assistant."), MessagesPlaceholder(variable_name="chat_history", optional=True), ("human", "{input}"), ]) # LCEL 链: 提示 -> LLM -> 解析输出 chain = prompt | llm | StrOutputParser() return chain async def chat( self, conversation_id: int, user_input: str, kb_id: Optional[int] = None, llm_name: Optional[str] = None, use_rag: bool = True, ) -> Message: """ 处理用户消息并生成助手回复(使用 LCEL)。 Args: conversation_id: 对话 ID user_input: 用户消息 kb_id: 可选的知识库 ID(用于 RAG) llm_name: 可选的 LLM 模型名称 use_rag: 是否使用 RAG(默认 True) Returns: Message: 助手的回复消息 Raises: ResourceNotFoundError: 如果找不到对话 """ logger.info( "processing_chat", conversation_id=conversation_id, use_rag=use_rag, kb_id=kb_id, ) # 验证对话是否存在 conversation = self.get_conversation(conversation_id) # 保存用户消息 self.add_message(conversation_id, "user", user_input) # 获取 LLM llm = self.model_manager.get_llm(llm_name) # 获取对话历史(LangChain 消息格式) chat_history = self._build_chat_history_messages(conversation_id, limit=10) # 创建 LCEL 链 kb_id_for_rag = kb_id if use_rag else None chain = self._create_rag_chain(llm, kb_id_for_rag) # 准备输入 chain_input = { "input": user_input, "chat_history": chat_history, } # 执行链 try: if kb_id_for_rag: # 带 RAG 的链返回字典 result = await chain.ainvoke(chain_input) response = result.get("response", "") sources = result.get("_sources", []) else: # 简单链直接返回字符串 response = await chain.ainvoke(chain_input) sources = [] logger.info("llm_response_generated", conversation_id=conversation_id) except Exception as e: logger.error("llm_generation_failed", error=str(e)) response = "I apologize, but I encountered an error generating a response. Please try again." sources = [] # 保存助手消息 metadata = {} if sources: metadata["sources"] = sources if llm_name: metadata["model"] = llm_name assistant_message = self.add_message( conversation_id, "assistant", response, metadata ) logger.info("chat_complete", conversation_id=conversation_id) return assistant_message def delete_conversation(self, conversation_id: int) -> bool: """ 删除对话及其所有消息。 Args: conversation_id: 对话 ID Returns: bool: 如果已删除则为 True Raises: ResourceNotFoundError: 如果找不到对话 """ conversation = self.get_conversation(conversation_id) logger.info("deleting_conversation", conversation_id=conversation_id) self.db.delete(conversation) self.db.commit() logger.info("conversation_deleted", conversation_id=conversation_id) return True