langchain-learning-kit/app/services/conv_manager.py

312 lines
8.7 KiB
Python
Raw Normal View History

"""
支持 RAG 的多轮对话管理器
"""
from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session
from app.db.models import Conversation, Message
from app.services.model_manager import ModelManager
from app.services.kb_manager import KnowledgeBaseManager
from app.config import Settings
from app.utils.exceptions import ResourceNotFoundError
from app.utils.logger import get_logger
logger = get_logger(__name__)
class ConversationManager:
"""
管理具有可选 RAG 支持的多轮对话
"""
def __init__(
self,
db_session: Session,
model_manager: ModelManager,
kb_manager: KnowledgeBaseManager,
settings: Settings,
):
"""
初始化对话管理器
Args:
db_session: 数据库会话
model_manager: 模型管理器实例
kb_manager: 知识库管理器实例
settings: 应用程序设置
"""
self.db = db_session
self.model_manager = model_manager
self.kb_manager = kb_manager
self.settings = settings
def create_conversation(
self, user_id: Optional[int] = None, title: str = "New Conversation"
) -> Conversation:
"""
创建新对话
Args:
user_id: 可选的用户 ID
title: 对话标题
Returns:
Conversation: 创建的对话
"""
logger.info("creating_conversation", user_id=user_id, title=title)
conversation = Conversation(user_id=user_id, title=title)
self.db.add(conversation)
self.db.commit()
self.db.refresh(conversation)
logger.info("conversation_created", conversation_id=conversation.id)
return conversation
def get_conversation(self, conversation_id: int) -> Conversation:
"""
根据 ID 获取对话
Args:
conversation_id: 对话 ID
Returns:
Conversation: 对话实例
Raises:
ResourceNotFoundError: 如果找不到对话
"""
conversation = (
self.db.query(Conversation)
.filter(Conversation.id == conversation_id)
.first()
)
if not conversation:
raise ResourceNotFoundError(f"Conversation not found: {conversation_id}")
return conversation
def add_message(
self,
conversation_id: int,
role: str,
content: str,
metadata: Optional[Dict[str, Any]] = None,
) -> Message:
"""
向对话添加消息
Args:
conversation_id: 对话 ID
role: 消息角色'user''assistant''system'
content: 消息内容
metadata: 可选的元数据
Returns:
Message: 创建的消息
Raises:
ResourceNotFoundError: 如果找不到对话
"""
# 验证对话是否存在
self.get_conversation(conversation_id)
message = Message(
conversation_id=conversation_id,
role=role,
content=content,
msg_metadata=metadata or {},
)
self.db.add(message)
self.db.commit()
self.db.refresh(message)
logger.info(
"message_added",
conversation_id=conversation_id,
role=role,
message_id=message.id,
)
return message
def get_messages(
self,
conversation_id: int,
limit: int = 50,
offset: int = 0,
) -> List[Message]:
"""
获取对话中的消息
Args:
conversation_id: 对话 ID
limit: 最大消息数量
offset: 分页偏移量
Returns:
List[Message]: 消息列表
Raises:
ResourceNotFoundError: 如果找不到对话
"""
# 验证对话是否存在
self.get_conversation(conversation_id)
messages = (
self.db.query(Message)
.filter(Message.conversation_id == conversation_id)
.order_by(Message.created_at.asc())
.offset(offset)
.limit(limit)
.all()
)
return messages
async def chat(
self,
conversation_id: int,
user_input: str,
kb_id: Optional[int] = None,
llm_name: Optional[str] = None,
use_rag: bool = True,
) -> Message:
"""
处理用户消息并生成助手回复
Args:
conversation_id: 对话 ID
user_input: 用户消息
kb_id: 可选的知识库 ID用于 RAG
llm_name: 可选的 LLM 模型名称
use_rag: 是否使用 RAG默认 True
Returns:
Message: 助手的回复消息
Raises:
ResourceNotFoundError: 如果找不到对话
"""
logger.info(
"processing_chat",
conversation_id=conversation_id,
use_rag=use_rag,
kb_id=kb_id,
)
# 验证对话是否存在
conversation = self.get_conversation(conversation_id)
# 保存用户消息
self.add_message(conversation_id, "user", user_input)
# 获取对话历史
history = self._build_history(conversation_id, limit=10)
# 获取 LLM
llm = self.model_manager.get_llm(llm_name)
# 如果启用了 RAG从知识库构建上下文
context = ""
sources = []
if use_rag and kb_id:
try:
results = self.kb_manager.query_kb(kb_id, user_input, k=3)
context_parts = []
for i, result in enumerate(results, 1):
context_parts.append(
f"[{i}] {result['content'][:500]}..."
if len(result['content']) > 500
else f"[{i}] {result['content']}"
)
sources.append(result['metadata'])
context = "\n\n".join(context_parts)
logger.info("rag_context_retrieved", kb_id=kb_id, source_count=len(results))
except Exception as e:
logger.warning("rag_context_failed", error=str(e))
context = "(No relevant context found)"
# 构建提示
if context:
prompt = f"""You are a helpful AI assistant. Use the following context from the knowledge base to answer the question.
Context:
{context}
Conversation history:
{history}
User: {user_input}
Assistant:"""
else:
prompt = f"""{history}
User: {user_input}
Assistant:"""
# 生成回复
try:
response = llm.predict(prompt)
logger.info("llm_response_generated", conversation_id=conversation_id)
except Exception as e:
logger.error("llm_generation_failed", error=str(e))
response = "I apologize, but I encountered an error generating a response. Please try again."
# 保存助手消息
metadata = {}
if sources:
metadata["sources"] = sources
if llm_name:
metadata["model"] = llm_name
assistant_message = self.add_message(
conversation_id, "assistant", response, metadata
)
logger.info("chat_complete", conversation_id=conversation_id)
return assistant_message
def _build_history(self, conversation_id: int, limit: int = 10) -> str:
"""
构建对话历史字符串
Args:
conversation_id: 对话 ID
limit: 要包含的最大消息数
Returns:
str: 格式化的对话历史
"""
messages = self.get_messages(conversation_id, limit=limit)
history_parts = []
for msg in messages[:-1]: # 排除最后一条用户消息(已在提示中)
if msg.role == "user":
history_parts.append(f"User: {msg.content}")
elif msg.role == "assistant":
history_parts.append(f"Assistant: {msg.content}")
return "\n".join(history_parts) if history_parts else "(No previous messages)"
def delete_conversation(self, conversation_id: int) -> bool:
"""
删除对话及其所有消息
Args:
conversation_id: 对话 ID
Returns:
bool: 如果已删除则为 True
Raises:
ResourceNotFoundError: 如果找不到对话
"""
conversation = self.get_conversation(conversation_id)
logger.info("deleting_conversation", conversation_id=conversation_id)
self.db.delete(conversation)
self.db.commit()
logger.info("conversation_deleted", conversation_id=conversation_id)
return True