refactor(conv):重构对话管理器以使用 LCEL 链

- 更新了 LangChain 导入路径以使用 langchain_core
- 实现了带 RAG 的 LCEL 链用于对话处理
- 添加了构建对话历史消息的方法
- 实现了从知识库检索上下文的逻辑
- 移除了旧的对话历史构建方法- 更新了聊天方法以使用新的 LCEL 链- 修复了代码以适应新的导入结构
This commit is contained in:
杨煜 2025-10-02 18:10:53 +08:00
parent 9bd715d080
commit 4bd343c237
9 changed files with 142 additions and 75 deletions

View File

@ -11,7 +11,9 @@
"Bash(git init:*)",
"Bash(git remote add:*)",
"Bash(git add:*)",
"Bash(rm:*)"
"Bash(rm:*)",
"Bash(git commit:*)",
"Bash(git push:*)"
],
"deny": [],
"ask": []

View File

@ -5,7 +5,7 @@ from typing import List, Dict, Any, Optional
from datetime import datetime
from sqlalchemy.orm import Session
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from app.db.models import ToolCall
from app.services.model_manager import ModelManager

View File

@ -1,8 +1,12 @@
"""
支持 RAG 的多轮对话管理器
"""
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any, Optional, Tuple
from sqlalchemy.orm import Session
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from app.db.models import Conversation, Message
from app.services.model_manager import ModelManager
@ -162,6 +166,107 @@ class ConversationManager:
return messages
def _build_chat_history_messages(self, conversation_id: int, limit: int = 10) -> List:
"""
构建对话历史消息列表LangChain 消息格式
Args:
conversation_id: 对话 ID
limit: 要包含的最大消息数
Returns:
List: LangChain 消息对象列表
"""
messages = self.get_messages(conversation_id, limit=limit)
history_messages = []
for msg in messages[:-1]: # 排除最后一条用户消息(将单独添加)
if msg.role == "user":
history_messages.append(HumanMessage(content=msg.content))
elif msg.role == "assistant":
history_messages.append(AIMessage(content=msg.content))
return history_messages
def _retrieve_context(self, kb_id: int, query: str, k: int = 3) -> Tuple[str, List[Dict]]:
"""
从知识库检索上下文
Args:
kb_id: 知识库 ID
query: 查询文本
k: 返回的文档数量
Returns:
Tuple: (格式化的上下文字符串, 来源元数据列表)
"""
try:
results = self.kb_manager.query_kb(kb_id, query, k=k)
context_parts = []
sources = []
for i, result in enumerate(results, 1):
content = result['content']
if len(content) > 500:
content = f"{content[:500]}..."
context_parts.append(f"[{i}] {content}")
sources.append(result['metadata'])
context = "\n\n".join(context_parts)
logger.info("rag_context_retrieved", kb_id=kb_id, source_count=len(results))
return context, sources
except Exception as e:
logger.warning("rag_context_failed", error=str(e))
return "(No relevant context found)", []
def _create_rag_chain(self, llm, kb_id: Optional[int] = None):
"""
创建带 RAG LCEL
Args:
llm: LLM 实例
kb_id: 可选的知识库 ID
Returns:
Runnable: LCEL
"""
if kb_id:
# 带 RAG 的提示模板
prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI assistant. Use the following context from the knowledge base to answer the question.\n\nContext:\n{context}"),
MessagesPlaceholder(variable_name="chat_history", optional=True),
("human", "{input}"),
])
# 创建检索器函数
def retrieve_context(inputs):
context, sources = self._retrieve_context(kb_id, inputs["input"])
return {
**inputs,
"context": context,
"_sources": sources # 保存来源以供后续使用
}
# LCEL 链: 检索 -> 提示 -> LLM -> 解析输出
chain = (
RunnableLambda(retrieve_context)
| RunnablePassthrough.assign(
response=prompt | llm | StrOutputParser()
)
)
else:
# 无 RAG 的简单提示模板
prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI assistant."),
MessagesPlaceholder(variable_name="chat_history", optional=True),
("human", "{input}"),
])
# LCEL 链: 提示 -> LLM -> 解析输出
chain = prompt | llm | StrOutputParser()
return chain
async def chat(
self,
conversation_id: int,
@ -171,7 +276,7 @@ class ConversationManager:
use_rag: bool = True,
) -> Message:
"""
处理用户消息并生成助手回复
处理用户消息并生成助手回复使用 LCEL
Args:
conversation_id: 对话 ID
@ -199,57 +304,39 @@ class ConversationManager:
# 保存用户消息
self.add_message(conversation_id, "user", user_input)
# 获取对话历史
history = self._build_history(conversation_id, limit=10)
# 获取 LLM
llm = self.model_manager.get_llm(llm_name)
# 如果启用了 RAG从知识库构建上下文
context = ""
sources = []
if use_rag and kb_id:
try:
results = self.kb_manager.query_kb(kb_id, user_input, k=3)
context_parts = []
for i, result in enumerate(results, 1):
context_parts.append(
f"[{i}] {result['content'][:500]}..."
if len(result['content']) > 500
else f"[{i}] {result['content']}"
)
sources.append(result['metadata'])
context = "\n\n".join(context_parts)
logger.info("rag_context_retrieved", kb_id=kb_id, source_count=len(results))
except Exception as e:
logger.warning("rag_context_failed", error=str(e))
context = "(No relevant context found)"
# 获取对话历史LangChain 消息格式)
chat_history = self._build_chat_history_messages(conversation_id, limit=10)
# 构建提示
if context:
prompt = f"""You are a helpful AI assistant. Use the following context from the knowledge base to answer the question.
# 创建 LCEL 链
kb_id_for_rag = kb_id if use_rag else None
chain = self._create_rag_chain(llm, kb_id_for_rag)
Context:
{context}
# 准备输入
chain_input = {
"input": user_input,
"chat_history": chat_history,
}
Conversation history:
{history}
User: {user_input}
Assistant:"""
else:
prompt = f"""{history}
User: {user_input}
Assistant:"""
# 生成回复
# 执行链
try:
response = llm.predict(prompt)
if kb_id_for_rag:
# 带 RAG 的链返回字典
result = await chain.ainvoke(chain_input)
response = result.get("response", "")
sources = result.get("_sources", [])
else:
# 简单链直接返回字符串
response = await chain.ainvoke(chain_input)
sources = []
logger.info("llm_response_generated", conversation_id=conversation_id)
except Exception as e:
logger.error("llm_generation_failed", error=str(e))
response = "I apologize, but I encountered an error generating a response. Please try again."
sources = []
# 保存助手消息
metadata = {}
@ -265,28 +352,6 @@ Assistant:"""
logger.info("chat_complete", conversation_id=conversation_id)
return assistant_message
def _build_history(self, conversation_id: int, limit: int = 10) -> str:
"""
构建对话历史字符串
Args:
conversation_id: 对话 ID
limit: 要包含的最大消息数
Returns:
str: 格式化的对话历史
"""
messages = self.get_messages(conversation_id, limit=limit)
history_parts = []
for msg in messages[:-1]: # 排除最后一条用户消息(已在提示中)
if msg.role == "user":
history_parts.append(f"User: {msg.content}")
elif msg.role == "assistant":
history_parts.append(f"Assistant: {msg.content}")
return "\n".join(history_parts) if history_parts else "(No previous messages)"
def delete_conversation(self, conversation_id: int) -> bool:
"""
删除对话及其所有消息

View File

@ -3,7 +3,7 @@
"""
from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session
from langchain.schema import Document as LangChainDocument
from langchain_core.documents import Document as LangChainDocument
from app.db.models import KnowledgeBase, Document
from app.services.model_manager import ModelManager

View File

@ -6,8 +6,8 @@ import re
from typing import Optional, List, Dict, Any
from sqlalchemy.orm import Session
from langchain.llms.base import BaseLLM
from langchain.embeddings.base import Embeddings
from langchain_core.language_models.llms import BaseLLM
from langchain_core.embeddings import Embeddings
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from app.db.models import Model

View File

@ -3,7 +3,7 @@
"""
from typing import Type
from pydantic import BaseModel, Field
from langchain.tools import BaseTool
from langchain_core.tools import BaseTool
from app.utils.logger import get_logger

View File

@ -3,7 +3,7 @@
"""
from typing import Optional, Type
from pydantic import BaseModel, Field
from langchain.tools import BaseTool
from langchain_core.tools import BaseTool
from app.services.kb_manager import KnowledgeBaseManager
from app.utils.logger import get_logger

View File

@ -5,8 +5,8 @@ import os
from typing import List, Optional
from langchain_community.vectorstores import FAISS
from langchain.schema import Document as LangChainDocument
from langchain.embeddings.base import Embeddings
from langchain_core.documents import Document as LangChainDocument
from langchain_core.embeddings import Embeddings
from app.utils.exceptions import VectorStoreError
from app.utils.logger import get_logger

View File

@ -3,8 +3,8 @@
"""
from typing import List
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document as LangChainDocument
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document as LangChainDocument
class TextSplitter: