langchain-learning-kit/app/services/conv_manager.py

377 lines
11 KiB
Python
Raw Permalink Normal View History

"""
支持 RAG 的多轮对话管理器
"""
from typing import List, Dict, Any, Optional, Tuple
from sqlalchemy.orm import Session
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from app.db.models import Conversation, Message
from app.services.model_manager import ModelManager
from app.services.kb_manager import KnowledgeBaseManager
from app.config import Settings
from app.utils.exceptions import ResourceNotFoundError
from app.utils.logger import get_logger
logger = get_logger(__name__)
class ConversationManager:
"""
管理具有可选 RAG 支持的多轮对话
"""
def __init__(
self,
db_session: Session,
model_manager: ModelManager,
kb_manager: KnowledgeBaseManager,
settings: Settings,
):
"""
初始化对话管理器
Args:
db_session: 数据库会话
model_manager: 模型管理器实例
kb_manager: 知识库管理器实例
settings: 应用程序设置
"""
self.db = db_session
self.model_manager = model_manager
self.kb_manager = kb_manager
self.settings = settings
def create_conversation(
self, user_id: Optional[int] = None, title: str = "New Conversation"
) -> Conversation:
"""
创建新对话
Args:
user_id: 可选的用户 ID
title: 对话标题
Returns:
Conversation: 创建的对话
"""
logger.info("creating_conversation", user_id=user_id, title=title)
conversation = Conversation(user_id=user_id, title=title)
self.db.add(conversation)
self.db.commit()
self.db.refresh(conversation)
logger.info("conversation_created", conversation_id=conversation.id)
return conversation
def get_conversation(self, conversation_id: int) -> Conversation:
"""
根据 ID 获取对话
Args:
conversation_id: 对话 ID
Returns:
Conversation: 对话实例
Raises:
ResourceNotFoundError: 如果找不到对话
"""
conversation = (
self.db.query(Conversation)
.filter(Conversation.id == conversation_id)
.first()
)
if not conversation:
raise ResourceNotFoundError(f"Conversation not found: {conversation_id}")
return conversation
def add_message(
self,
conversation_id: int,
role: str,
content: str,
metadata: Optional[Dict[str, Any]] = None,
) -> Message:
"""
向对话添加消息
Args:
conversation_id: 对话 ID
role: 消息角色'user''assistant''system'
content: 消息内容
metadata: 可选的元数据
Returns:
Message: 创建的消息
Raises:
ResourceNotFoundError: 如果找不到对话
"""
# 验证对话是否存在
self.get_conversation(conversation_id)
message = Message(
conversation_id=conversation_id,
role=role,
content=content,
msg_metadata=metadata or {},
)
self.db.add(message)
self.db.commit()
self.db.refresh(message)
logger.info(
"message_added",
conversation_id=conversation_id,
role=role,
message_id=message.id,
)
return message
def get_messages(
self,
conversation_id: int,
limit: int = 50,
offset: int = 0,
) -> List[Message]:
"""
获取对话中的消息
Args:
conversation_id: 对话 ID
limit: 最大消息数量
offset: 分页偏移量
Returns:
List[Message]: 消息列表
Raises:
ResourceNotFoundError: 如果找不到对话
"""
# 验证对话是否存在
self.get_conversation(conversation_id)
messages = (
self.db.query(Message)
.filter(Message.conversation_id == conversation_id)
.order_by(Message.created_at.asc())
.offset(offset)
.limit(limit)
.all()
)
return messages
def _build_chat_history_messages(self, conversation_id: int, limit: int = 10) -> List:
"""
构建对话历史消息列表LangChain 消息格式
Args:
conversation_id: 对话 ID
limit: 要包含的最大消息数
Returns:
List: LangChain 消息对象列表
"""
messages = self.get_messages(conversation_id, limit=limit)
history_messages = []
for msg in messages[:-1]: # 排除最后一条用户消息(将单独添加)
if msg.role == "user":
history_messages.append(HumanMessage(content=msg.content))
elif msg.role == "assistant":
history_messages.append(AIMessage(content=msg.content))
return history_messages
def _retrieve_context(self, kb_id: int, query: str, k: int = 3) -> Tuple[str, List[Dict]]:
"""
从知识库检索上下文
Args:
kb_id: 知识库 ID
query: 查询文本
k: 返回的文档数量
Returns:
Tuple: (格式化的上下文字符串, 来源元数据列表)
"""
try:
results = self.kb_manager.query_kb(kb_id, query, k=k)
context_parts = []
sources = []
for i, result in enumerate(results, 1):
content = result['content']
if len(content) > 500:
content = f"{content[:500]}..."
context_parts.append(f"[{i}] {content}")
sources.append(result['metadata'])
context = "\n\n".join(context_parts)
logger.info("rag_context_retrieved", kb_id=kb_id, source_count=len(results))
return context, sources
except Exception as e:
logger.warning("rag_context_failed", error=str(e))
return "(No relevant context found)", []
def _create_rag_chain(self, llm, kb_id: Optional[int] = None):
"""
创建带 RAG LCEL
Args:
llm: LLM 实例
kb_id: 可选的知识库 ID
Returns:
Runnable: LCEL
"""
if kb_id:
# 带 RAG 的提示模板
prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI assistant. Use the following context from the knowledge base to answer the question.\n\nContext:\n{context}"),
MessagesPlaceholder(variable_name="chat_history", optional=True),
("human", "{input}"),
])
# 创建检索器函数
def retrieve_context(inputs):
context, sources = self._retrieve_context(kb_id, inputs["input"])
return {
**inputs,
"context": context,
"_sources": sources # 保存来源以供后续使用
}
# LCEL 链: 检索 -> 提示 -> LLM -> 解析输出
chain = (
RunnableLambda(retrieve_context)
| RunnablePassthrough.assign(
response=prompt | llm | StrOutputParser()
)
)
else:
# 无 RAG 的简单提示模板
prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI assistant."),
MessagesPlaceholder(variable_name="chat_history", optional=True),
("human", "{input}"),
])
# LCEL 链: 提示 -> LLM -> 解析输出
chain = prompt | llm | StrOutputParser()
return chain
async def chat(
self,
conversation_id: int,
user_input: str,
kb_id: Optional[int] = None,
llm_name: Optional[str] = None,
use_rag: bool = True,
) -> Message:
"""
处理用户消息并生成助手回复使用 LCEL
Args:
conversation_id: 对话 ID
user_input: 用户消息
kb_id: 可选的知识库 ID用于 RAG
llm_name: 可选的 LLM 模型名称
use_rag: 是否使用 RAG默认 True
Returns:
Message: 助手的回复消息
Raises:
ResourceNotFoundError: 如果找不到对话
"""
logger.info(
"processing_chat",
conversation_id=conversation_id,
use_rag=use_rag,
kb_id=kb_id,
)
# 验证对话是否存在
conversation = self.get_conversation(conversation_id)
# 保存用户消息
self.add_message(conversation_id, "user", user_input)
# 获取 LLM
llm = self.model_manager.get_llm(llm_name)
# 获取对话历史LangChain 消息格式)
chat_history = self._build_chat_history_messages(conversation_id, limit=10)
# 创建 LCEL 链
kb_id_for_rag = kb_id if use_rag else None
chain = self._create_rag_chain(llm, kb_id_for_rag)
# 准备输入
chain_input = {
"input": user_input,
"chat_history": chat_history,
}
# 执行链
try:
if kb_id_for_rag:
# 带 RAG 的链返回字典
result = await chain.ainvoke(chain_input)
response = result.get("response", "")
sources = result.get("_sources", [])
else:
# 简单链直接返回字符串
response = await chain.ainvoke(chain_input)
sources = []
logger.info("llm_response_generated", conversation_id=conversation_id)
except Exception as e:
logger.error("llm_generation_failed", error=str(e))
response = "I apologize, but I encountered an error generating a response. Please try again."
sources = []
# 保存助手消息
metadata = {}
if sources:
metadata["sources"] = sources
if llm_name:
metadata["model"] = llm_name
assistant_message = self.add_message(
conversation_id, "assistant", response, metadata
)
logger.info("chat_complete", conversation_id=conversation_id)
return assistant_message
def delete_conversation(self, conversation_id: int) -> bool:
"""
删除对话及其所有消息
Args:
conversation_id: 对话 ID
Returns:
bool: 如果已删除则为 True
Raises:
ResourceNotFoundError: 如果找不到对话
"""
conversation = self.get_conversation(conversation_id)
logger.info("deleting_conversation", conversation_id=conversation_id)
self.db.delete(conversation)
self.db.commit()
logger.info("conversation_deleted", conversation_id=conversation_id)
return True