refactor(conv):重构对话管理器以使用 LCEL 链

- 更新了 LangChain 导入路径以使用 langchain_core
- 实现了带 RAG 的 LCEL 链用于对话处理
- 添加了构建对话历史消息的方法
- 实现了从知识库检索上下文的逻辑
- 移除了旧的对话历史构建方法- 更新了聊天方法以使用新的 LCEL 链- 修复了代码以适应新的导入结构
This commit is contained in:
杨煜 2025-10-02 18:10:53 +08:00
parent 9bd715d080
commit 4bd343c237
9 changed files with 142 additions and 75 deletions

View File

@ -11,7 +11,9 @@
"Bash(git init:*)", "Bash(git init:*)",
"Bash(git remote add:*)", "Bash(git remote add:*)",
"Bash(git add:*)", "Bash(git add:*)",
"Bash(rm:*)" "Bash(rm:*)",
"Bash(git commit:*)",
"Bash(git push:*)"
], ],
"deny": [], "deny": [],
"ask": [] "ask": []

View File

@ -5,7 +5,7 @@ from typing import List, Dict, Any, Optional
from datetime import datetime from datetime import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from langchain.agents import AgentExecutor, create_openai_tools_agent from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from app.db.models import ToolCall from app.db.models import ToolCall
from app.services.model_manager import ModelManager from app.services.model_manager import ModelManager

View File

@ -1,8 +1,12 @@
""" """
支持 RAG 的多轮对话管理器 支持 RAG 的多轮对话管理器
""" """
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional, Tuple
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from app.db.models import Conversation, Message from app.db.models import Conversation, Message
from app.services.model_manager import ModelManager from app.services.model_manager import ModelManager
@ -162,6 +166,107 @@ class ConversationManager:
return messages return messages
def _build_chat_history_messages(self, conversation_id: int, limit: int = 10) -> List:
"""
构建对话历史消息列表LangChain 消息格式
Args:
conversation_id: 对话 ID
limit: 要包含的最大消息数
Returns:
List: LangChain 消息对象列表
"""
messages = self.get_messages(conversation_id, limit=limit)
history_messages = []
for msg in messages[:-1]: # 排除最后一条用户消息(将单独添加)
if msg.role == "user":
history_messages.append(HumanMessage(content=msg.content))
elif msg.role == "assistant":
history_messages.append(AIMessage(content=msg.content))
return history_messages
def _retrieve_context(self, kb_id: int, query: str, k: int = 3) -> Tuple[str, List[Dict]]:
"""
从知识库检索上下文
Args:
kb_id: 知识库 ID
query: 查询文本
k: 返回的文档数量
Returns:
Tuple: (格式化的上下文字符串, 来源元数据列表)
"""
try:
results = self.kb_manager.query_kb(kb_id, query, k=k)
context_parts = []
sources = []
for i, result in enumerate(results, 1):
content = result['content']
if len(content) > 500:
content = f"{content[:500]}..."
context_parts.append(f"[{i}] {content}")
sources.append(result['metadata'])
context = "\n\n".join(context_parts)
logger.info("rag_context_retrieved", kb_id=kb_id, source_count=len(results))
return context, sources
except Exception as e:
logger.warning("rag_context_failed", error=str(e))
return "(No relevant context found)", []
def _create_rag_chain(self, llm, kb_id: Optional[int] = None):
"""
创建带 RAG LCEL
Args:
llm: LLM 实例
kb_id: 可选的知识库 ID
Returns:
Runnable: LCEL
"""
if kb_id:
# 带 RAG 的提示模板
prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI assistant. Use the following context from the knowledge base to answer the question.\n\nContext:\n{context}"),
MessagesPlaceholder(variable_name="chat_history", optional=True),
("human", "{input}"),
])
# 创建检索器函数
def retrieve_context(inputs):
context, sources = self._retrieve_context(kb_id, inputs["input"])
return {
**inputs,
"context": context,
"_sources": sources # 保存来源以供后续使用
}
# LCEL 链: 检索 -> 提示 -> LLM -> 解析输出
chain = (
RunnableLambda(retrieve_context)
| RunnablePassthrough.assign(
response=prompt | llm | StrOutputParser()
)
)
else:
# 无 RAG 的简单提示模板
prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI assistant."),
MessagesPlaceholder(variable_name="chat_history", optional=True),
("human", "{input}"),
])
# LCEL 链: 提示 -> LLM -> 解析输出
chain = prompt | llm | StrOutputParser()
return chain
async def chat( async def chat(
self, self,
conversation_id: int, conversation_id: int,
@ -171,7 +276,7 @@ class ConversationManager:
use_rag: bool = True, use_rag: bool = True,
) -> Message: ) -> Message:
""" """
处理用户消息并生成助手回复 处理用户消息并生成助手回复使用 LCEL
Args: Args:
conversation_id: 对话 ID conversation_id: 对话 ID
@ -199,57 +304,39 @@ class ConversationManager:
# 保存用户消息 # 保存用户消息
self.add_message(conversation_id, "user", user_input) self.add_message(conversation_id, "user", user_input)
# 获取对话历史
history = self._build_history(conversation_id, limit=10)
# 获取 LLM # 获取 LLM
llm = self.model_manager.get_llm(llm_name) llm = self.model_manager.get_llm(llm_name)
# 如果启用了 RAG从知识库构建上下文 # 获取对话历史LangChain 消息格式)
context = "" chat_history = self._build_chat_history_messages(conversation_id, limit=10)
sources = []
if use_rag and kb_id:
try:
results = self.kb_manager.query_kb(kb_id, user_input, k=3)
context_parts = []
for i, result in enumerate(results, 1):
context_parts.append(
f"[{i}] {result['content'][:500]}..."
if len(result['content']) > 500
else f"[{i}] {result['content']}"
)
sources.append(result['metadata'])
context = "\n\n".join(context_parts)
logger.info("rag_context_retrieved", kb_id=kb_id, source_count=len(results))
except Exception as e:
logger.warning("rag_context_failed", error=str(e))
context = "(No relevant context found)"
# 构建提示 # 创建 LCEL 链
if context: kb_id_for_rag = kb_id if use_rag else None
prompt = f"""You are a helpful AI assistant. Use the following context from the knowledge base to answer the question. chain = self._create_rag_chain(llm, kb_id_for_rag)
Context: # 准备输入
{context} chain_input = {
"input": user_input,
"chat_history": chat_history,
}
Conversation history: # 执行链
{history}
User: {user_input}
Assistant:"""
else:
prompt = f"""{history}
User: {user_input}
Assistant:"""
# 生成回复
try: try:
response = llm.predict(prompt) if kb_id_for_rag:
# 带 RAG 的链返回字典
result = await chain.ainvoke(chain_input)
response = result.get("response", "")
sources = result.get("_sources", [])
else:
# 简单链直接返回字符串
response = await chain.ainvoke(chain_input)
sources = []
logger.info("llm_response_generated", conversation_id=conversation_id) logger.info("llm_response_generated", conversation_id=conversation_id)
except Exception as e: except Exception as e:
logger.error("llm_generation_failed", error=str(e)) logger.error("llm_generation_failed", error=str(e))
response = "I apologize, but I encountered an error generating a response. Please try again." response = "I apologize, but I encountered an error generating a response. Please try again."
sources = []
# 保存助手消息 # 保存助手消息
metadata = {} metadata = {}
@ -265,28 +352,6 @@ Assistant:"""
logger.info("chat_complete", conversation_id=conversation_id) logger.info("chat_complete", conversation_id=conversation_id)
return assistant_message return assistant_message
def _build_history(self, conversation_id: int, limit: int = 10) -> str:
"""
构建对话历史字符串
Args:
conversation_id: 对话 ID
limit: 要包含的最大消息数
Returns:
str: 格式化的对话历史
"""
messages = self.get_messages(conversation_id, limit=limit)
history_parts = []
for msg in messages[:-1]: # 排除最后一条用户消息(已在提示中)
if msg.role == "user":
history_parts.append(f"User: {msg.content}")
elif msg.role == "assistant":
history_parts.append(f"Assistant: {msg.content}")
return "\n".join(history_parts) if history_parts else "(No previous messages)"
def delete_conversation(self, conversation_id: int) -> bool: def delete_conversation(self, conversation_id: int) -> bool:
""" """
删除对话及其所有消息 删除对话及其所有消息

View File

@ -3,7 +3,7 @@
""" """
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from langchain.schema import Document as LangChainDocument from langchain_core.documents import Document as LangChainDocument
from app.db.models import KnowledgeBase, Document from app.db.models import KnowledgeBase, Document
from app.services.model_manager import ModelManager from app.services.model_manager import ModelManager

View File

@ -6,8 +6,8 @@ import re
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from langchain.llms.base import BaseLLM from langchain_core.language_models.llms import BaseLLM
from langchain.embeddings.base import Embeddings from langchain_core.embeddings import Embeddings
from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from app.db.models import Model from app.db.models import Model

View File

@ -3,7 +3,7 @@
""" """
from typing import Type from typing import Type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.tools import BaseTool from langchain_core.tools import BaseTool
from app.utils.logger import get_logger from app.utils.logger import get_logger

View File

@ -3,7 +3,7 @@
""" """
from typing import Optional, Type from typing import Optional, Type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.tools import BaseTool from langchain_core.tools import BaseTool
from app.services.kb_manager import KnowledgeBaseManager from app.services.kb_manager import KnowledgeBaseManager
from app.utils.logger import get_logger from app.utils.logger import get_logger

View File

@ -5,8 +5,8 @@ import os
from typing import List, Optional from typing import List, Optional
from langchain_community.vectorstores import FAISS from langchain_community.vectorstores import FAISS
from langchain.schema import Document as LangChainDocument from langchain_core.documents import Document as LangChainDocument
from langchain.embeddings.base import Embeddings from langchain_core.embeddings import Embeddings
from app.utils.exceptions import VectorStoreError from app.utils.exceptions import VectorStoreError
from app.utils.logger import get_logger from app.utils.logger import get_logger

View File

@ -3,8 +3,8 @@
""" """
from typing import List from typing import List
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.schema import Document as LangChainDocument from langchain_core.documents import Document as LangChainDocument
class TextSplitter: class TextSplitter: