312 lines
8.7 KiB
Python
312 lines
8.7 KiB
Python
"""
|
||
支持 RAG 的多轮对话管理器。
|
||
"""
|
||
from typing import List, Dict, Any, Optional
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.db.models import Conversation, Message
|
||
from app.services.model_manager import ModelManager
|
||
from app.services.kb_manager import KnowledgeBaseManager
|
||
from app.config import Settings
|
||
from app.utils.exceptions import ResourceNotFoundError
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class ConversationManager:
|
||
"""
|
||
管理具有可选 RAG 支持的多轮对话。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
db_session: Session,
|
||
model_manager: ModelManager,
|
||
kb_manager: KnowledgeBaseManager,
|
||
settings: Settings,
|
||
):
|
||
"""
|
||
初始化对话管理器。
|
||
|
||
Args:
|
||
db_session: 数据库会话
|
||
model_manager: 模型管理器实例
|
||
kb_manager: 知识库管理器实例
|
||
settings: 应用程序设置
|
||
"""
|
||
self.db = db_session
|
||
self.model_manager = model_manager
|
||
self.kb_manager = kb_manager
|
||
self.settings = settings
|
||
|
||
def create_conversation(
|
||
self, user_id: Optional[int] = None, title: str = "New Conversation"
|
||
) -> Conversation:
|
||
"""
|
||
创建新对话。
|
||
|
||
Args:
|
||
user_id: 可选的用户 ID
|
||
title: 对话标题
|
||
|
||
Returns:
|
||
Conversation: 创建的对话
|
||
"""
|
||
logger.info("creating_conversation", user_id=user_id, title=title)
|
||
|
||
conversation = Conversation(user_id=user_id, title=title)
|
||
self.db.add(conversation)
|
||
self.db.commit()
|
||
self.db.refresh(conversation)
|
||
|
||
logger.info("conversation_created", conversation_id=conversation.id)
|
||
return conversation
|
||
|
||
def get_conversation(self, conversation_id: int) -> Conversation:
|
||
"""
|
||
根据 ID 获取对话。
|
||
|
||
Args:
|
||
conversation_id: 对话 ID
|
||
|
||
Returns:
|
||
Conversation: 对话实例
|
||
|
||
Raises:
|
||
ResourceNotFoundError: 如果找不到对话
|
||
"""
|
||
conversation = (
|
||
self.db.query(Conversation)
|
||
.filter(Conversation.id == conversation_id)
|
||
.first()
|
||
)
|
||
if not conversation:
|
||
raise ResourceNotFoundError(f"Conversation not found: {conversation_id}")
|
||
return conversation
|
||
|
||
def add_message(
|
||
self,
|
||
conversation_id: int,
|
||
role: str,
|
||
content: str,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> Message:
|
||
"""
|
||
向对话添加消息。
|
||
|
||
Args:
|
||
conversation_id: 对话 ID
|
||
role: 消息角色('user'、'assistant'、'system')
|
||
content: 消息内容
|
||
metadata: 可选的元数据
|
||
|
||
Returns:
|
||
Message: 创建的消息
|
||
|
||
Raises:
|
||
ResourceNotFoundError: 如果找不到对话
|
||
"""
|
||
# 验证对话是否存在
|
||
self.get_conversation(conversation_id)
|
||
|
||
message = Message(
|
||
conversation_id=conversation_id,
|
||
role=role,
|
||
content=content,
|
||
msg_metadata=metadata or {},
|
||
)
|
||
self.db.add(message)
|
||
self.db.commit()
|
||
self.db.refresh(message)
|
||
|
||
logger.info(
|
||
"message_added",
|
||
conversation_id=conversation_id,
|
||
role=role,
|
||
message_id=message.id,
|
||
)
|
||
return message
|
||
|
||
def get_messages(
|
||
self,
|
||
conversation_id: int,
|
||
limit: int = 50,
|
||
offset: int = 0,
|
||
) -> List[Message]:
|
||
"""
|
||
获取对话中的消息。
|
||
|
||
Args:
|
||
conversation_id: 对话 ID
|
||
limit: 最大消息数量
|
||
offset: 分页偏移量
|
||
|
||
Returns:
|
||
List[Message]: 消息列表
|
||
|
||
Raises:
|
||
ResourceNotFoundError: 如果找不到对话
|
||
"""
|
||
# 验证对话是否存在
|
||
self.get_conversation(conversation_id)
|
||
|
||
messages = (
|
||
self.db.query(Message)
|
||
.filter(Message.conversation_id == conversation_id)
|
||
.order_by(Message.created_at.asc())
|
||
.offset(offset)
|
||
.limit(limit)
|
||
.all()
|
||
)
|
||
|
||
return messages
|
||
|
||
async def chat(
|
||
self,
|
||
conversation_id: int,
|
||
user_input: str,
|
||
kb_id: Optional[int] = None,
|
||
llm_name: Optional[str] = None,
|
||
use_rag: bool = True,
|
||
) -> Message:
|
||
"""
|
||
处理用户消息并生成助手回复。
|
||
|
||
Args:
|
||
conversation_id: 对话 ID
|
||
user_input: 用户消息
|
||
kb_id: 可选的知识库 ID(用于 RAG)
|
||
llm_name: 可选的 LLM 模型名称
|
||
use_rag: 是否使用 RAG(默认 True)
|
||
|
||
Returns:
|
||
Message: 助手的回复消息
|
||
|
||
Raises:
|
||
ResourceNotFoundError: 如果找不到对话
|
||
"""
|
||
logger.info(
|
||
"processing_chat",
|
||
conversation_id=conversation_id,
|
||
use_rag=use_rag,
|
||
kb_id=kb_id,
|
||
)
|
||
|
||
# 验证对话是否存在
|
||
conversation = self.get_conversation(conversation_id)
|
||
|
||
# 保存用户消息
|
||
self.add_message(conversation_id, "user", user_input)
|
||
|
||
# 获取对话历史
|
||
history = self._build_history(conversation_id, limit=10)
|
||
|
||
# 获取 LLM
|
||
llm = self.model_manager.get_llm(llm_name)
|
||
|
||
# 如果启用了 RAG,从知识库构建上下文
|
||
context = ""
|
||
sources = []
|
||
if use_rag and kb_id:
|
||
try:
|
||
results = self.kb_manager.query_kb(kb_id, user_input, k=3)
|
||
context_parts = []
|
||
for i, result in enumerate(results, 1):
|
||
context_parts.append(
|
||
f"[{i}] {result['content'][:500]}..."
|
||
if len(result['content']) > 500
|
||
else f"[{i}] {result['content']}"
|
||
)
|
||
sources.append(result['metadata'])
|
||
context = "\n\n".join(context_parts)
|
||
logger.info("rag_context_retrieved", kb_id=kb_id, source_count=len(results))
|
||
except Exception as e:
|
||
logger.warning("rag_context_failed", error=str(e))
|
||
context = "(No relevant context found)"
|
||
|
||
# 构建提示
|
||
if context:
|
||
prompt = f"""You are a helpful AI assistant. Use the following context from the knowledge base to answer the question.
|
||
|
||
Context:
|
||
{context}
|
||
|
||
Conversation history:
|
||
{history}
|
||
|
||
User: {user_input}
|
||
Assistant:"""
|
||
else:
|
||
prompt = f"""{history}
|
||
|
||
User: {user_input}
|
||
Assistant:"""
|
||
|
||
# 生成回复
|
||
try:
|
||
response = llm.predict(prompt)
|
||
logger.info("llm_response_generated", conversation_id=conversation_id)
|
||
except Exception as e:
|
||
logger.error("llm_generation_failed", error=str(e))
|
||
response = "I apologize, but I encountered an error generating a response. Please try again."
|
||
|
||
# 保存助手消息
|
||
metadata = {}
|
||
if sources:
|
||
metadata["sources"] = sources
|
||
if llm_name:
|
||
metadata["model"] = llm_name
|
||
|
||
assistant_message = self.add_message(
|
||
conversation_id, "assistant", response, metadata
|
||
)
|
||
|
||
logger.info("chat_complete", conversation_id=conversation_id)
|
||
return assistant_message
|
||
|
||
def _build_history(self, conversation_id: int, limit: int = 10) -> str:
|
||
"""
|
||
构建对话历史字符串。
|
||
|
||
Args:
|
||
conversation_id: 对话 ID
|
||
limit: 要包含的最大消息数
|
||
|
||
Returns:
|
||
str: 格式化的对话历史
|
||
"""
|
||
messages = self.get_messages(conversation_id, limit=limit)
|
||
|
||
history_parts = []
|
||
for msg in messages[:-1]: # 排除最后一条用户消息(已在提示中)
|
||
if msg.role == "user":
|
||
history_parts.append(f"User: {msg.content}")
|
||
elif msg.role == "assistant":
|
||
history_parts.append(f"Assistant: {msg.content}")
|
||
|
||
return "\n".join(history_parts) if history_parts else "(No previous messages)"
|
||
|
||
def delete_conversation(self, conversation_id: int) -> bool:
|
||
"""
|
||
删除对话及其所有消息。
|
||
|
||
Args:
|
||
conversation_id: 对话 ID
|
||
|
||
Returns:
|
||
bool: 如果已删除则为 True
|
||
|
||
Raises:
|
||
ResourceNotFoundError: 如果找不到对话
|
||
"""
|
||
conversation = self.get_conversation(conversation_id)
|
||
|
||
logger.info("deleting_conversation", conversation_id=conversation_id)
|
||
|
||
self.db.delete(conversation)
|
||
self.db.commit()
|
||
|
||
logger.info("conversation_deleted", conversation_id=conversation_id)
|
||
return True
|