langchain-learning-kit/app/services/conv_manager.py

377 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
支持 RAG 的多轮对话管理器。
"""
from typing import List, Dict, Any, Optional, Tuple
from sqlalchemy.orm import Session
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from app.db.models import Conversation, Message
from app.services.model_manager import ModelManager
from app.services.kb_manager import KnowledgeBaseManager
from app.config import Settings
from app.utils.exceptions import ResourceNotFoundError
from app.utils.logger import get_logger
logger = get_logger(__name__)
class ConversationManager:
"""
管理具有可选 RAG 支持的多轮对话。
"""
def __init__(
self,
db_session: Session,
model_manager: ModelManager,
kb_manager: KnowledgeBaseManager,
settings: Settings,
):
"""
初始化对话管理器。
Args:
db_session: 数据库会话
model_manager: 模型管理器实例
kb_manager: 知识库管理器实例
settings: 应用程序设置
"""
self.db = db_session
self.model_manager = model_manager
self.kb_manager = kb_manager
self.settings = settings
def create_conversation(
self, user_id: Optional[int] = None, title: str = "New Conversation"
) -> Conversation:
"""
创建新对话。
Args:
user_id: 可选的用户 ID
title: 对话标题
Returns:
Conversation: 创建的对话
"""
logger.info("creating_conversation", user_id=user_id, title=title)
conversation = Conversation(user_id=user_id, title=title)
self.db.add(conversation)
self.db.commit()
self.db.refresh(conversation)
logger.info("conversation_created", conversation_id=conversation.id)
return conversation
def get_conversation(self, conversation_id: int) -> Conversation:
"""
根据 ID 获取对话。
Args:
conversation_id: 对话 ID
Returns:
Conversation: 对话实例
Raises:
ResourceNotFoundError: 如果找不到对话
"""
conversation = (
self.db.query(Conversation)
.filter(Conversation.id == conversation_id)
.first()
)
if not conversation:
raise ResourceNotFoundError(f"Conversation not found: {conversation_id}")
return conversation
def add_message(
self,
conversation_id: int,
role: str,
content: str,
metadata: Optional[Dict[str, Any]] = None,
) -> Message:
"""
向对话添加消息。
Args:
conversation_id: 对话 ID
role: 消息角色('user''assistant''system'
content: 消息内容
metadata: 可选的元数据
Returns:
Message: 创建的消息
Raises:
ResourceNotFoundError: 如果找不到对话
"""
# 验证对话是否存在
self.get_conversation(conversation_id)
message = Message(
conversation_id=conversation_id,
role=role,
content=content,
msg_metadata=metadata or {},
)
self.db.add(message)
self.db.commit()
self.db.refresh(message)
logger.info(
"message_added",
conversation_id=conversation_id,
role=role,
message_id=message.id,
)
return message
def get_messages(
self,
conversation_id: int,
limit: int = 50,
offset: int = 0,
) -> List[Message]:
"""
获取对话中的消息。
Args:
conversation_id: 对话 ID
limit: 最大消息数量
offset: 分页偏移量
Returns:
List[Message]: 消息列表
Raises:
ResourceNotFoundError: 如果找不到对话
"""
# 验证对话是否存在
self.get_conversation(conversation_id)
messages = (
self.db.query(Message)
.filter(Message.conversation_id == conversation_id)
.order_by(Message.created_at.asc())
.offset(offset)
.limit(limit)
.all()
)
return messages
def _build_chat_history_messages(self, conversation_id: int, limit: int = 10) -> List:
"""
构建对话历史消息列表LangChain 消息格式)。
Args:
conversation_id: 对话 ID
limit: 要包含的最大消息数
Returns:
List: LangChain 消息对象列表
"""
messages = self.get_messages(conversation_id, limit=limit)
history_messages = []
for msg in messages[:-1]: # 排除最后一条用户消息(将单独添加)
if msg.role == "user":
history_messages.append(HumanMessage(content=msg.content))
elif msg.role == "assistant":
history_messages.append(AIMessage(content=msg.content))
return history_messages
def _retrieve_context(self, kb_id: int, query: str, k: int = 3) -> Tuple[str, List[Dict]]:
"""
从知识库检索上下文。
Args:
kb_id: 知识库 ID
query: 查询文本
k: 返回的文档数量
Returns:
Tuple: (格式化的上下文字符串, 来源元数据列表)
"""
try:
results = self.kb_manager.query_kb(kb_id, query, k=k)
context_parts = []
sources = []
for i, result in enumerate(results, 1):
content = result['content']
if len(content) > 500:
content = f"{content[:500]}..."
context_parts.append(f"[{i}] {content}")
sources.append(result['metadata'])
context = "\n\n".join(context_parts)
logger.info("rag_context_retrieved", kb_id=kb_id, source_count=len(results))
return context, sources
except Exception as e:
logger.warning("rag_context_failed", error=str(e))
return "(No relevant context found)", []
def _create_rag_chain(self, llm, kb_id: Optional[int] = None):
"""
创建带 RAG 的 LCEL 链。
Args:
llm: LLM 实例
kb_id: 可选的知识库 ID
Returns:
Runnable: LCEL 链
"""
if kb_id:
# 带 RAG 的提示模板
prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI assistant. Use the following context from the knowledge base to answer the question.\n\nContext:\n{context}"),
MessagesPlaceholder(variable_name="chat_history", optional=True),
("human", "{input}"),
])
# 创建检索器函数
def retrieve_context(inputs):
context, sources = self._retrieve_context(kb_id, inputs["input"])
return {
**inputs,
"context": context,
"_sources": sources # 保存来源以供后续使用
}
# LCEL 链: 检索 -> 提示 -> LLM -> 解析输出
chain = (
RunnableLambda(retrieve_context)
| RunnablePassthrough.assign(
response=prompt | llm | StrOutputParser()
)
)
else:
# 无 RAG 的简单提示模板
prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI assistant."),
MessagesPlaceholder(variable_name="chat_history", optional=True),
("human", "{input}"),
])
# LCEL 链: 提示 -> LLM -> 解析输出
chain = prompt | llm | StrOutputParser()
return chain
async def chat(
self,
conversation_id: int,
user_input: str,
kb_id: Optional[int] = None,
llm_name: Optional[str] = None,
use_rag: bool = True,
) -> Message:
"""
处理用户消息并生成助手回复(使用 LCEL
Args:
conversation_id: 对话 ID
user_input: 用户消息
kb_id: 可选的知识库 ID用于 RAG
llm_name: 可选的 LLM 模型名称
use_rag: 是否使用 RAG默认 True
Returns:
Message: 助手的回复消息
Raises:
ResourceNotFoundError: 如果找不到对话
"""
logger.info(
"processing_chat",
conversation_id=conversation_id,
use_rag=use_rag,
kb_id=kb_id,
)
# 验证对话是否存在
conversation = self.get_conversation(conversation_id)
# 保存用户消息
self.add_message(conversation_id, "user", user_input)
# 获取 LLM
llm = self.model_manager.get_llm(llm_name)
# 获取对话历史LangChain 消息格式)
chat_history = self._build_chat_history_messages(conversation_id, limit=10)
# 创建 LCEL 链
kb_id_for_rag = kb_id if use_rag else None
chain = self._create_rag_chain(llm, kb_id_for_rag)
# 准备输入
chain_input = {
"input": user_input,
"chat_history": chat_history,
}
# 执行链
try:
if kb_id_for_rag:
# 带 RAG 的链返回字典
result = await chain.ainvoke(chain_input)
response = result.get("response", "")
sources = result.get("_sources", [])
else:
# 简单链直接返回字符串
response = await chain.ainvoke(chain_input)
sources = []
logger.info("llm_response_generated", conversation_id=conversation_id)
except Exception as e:
logger.error("llm_generation_failed", error=str(e))
response = "I apologize, but I encountered an error generating a response. Please try again."
sources = []
# 保存助手消息
metadata = {}
if sources:
metadata["sources"] = sources
if llm_name:
metadata["model"] = llm_name
assistant_message = self.add_message(
conversation_id, "assistant", response, metadata
)
logger.info("chat_complete", conversation_id=conversation_id)
return assistant_message
def delete_conversation(self, conversation_id: int) -> bool:
"""
删除对话及其所有消息。
Args:
conversation_id: 对话 ID
Returns:
bool: 如果已删除则为 True
Raises:
ResourceNotFoundError: 如果找不到对话
"""
conversation = self.get_conversation(conversation_id)
logger.info("deleting_conversation", conversation_id=conversation_id)
self.db.delete(conversation)
self.db.commit()
logger.info("conversation_deleted", conversation_id=conversation_id)
return True