refactor(conv):重构对话管理器以使用 LCEL 链
- 更新了 LangChain 导入路径以使用 langchain_core - 实现了带 RAG 的 LCEL 链用于对话处理 - 添加了构建对话历史消息的方法 - 实现了从知识库检索上下文的逻辑 - 移除了旧的对话历史构建方法- 更新了聊天方法以使用新的 LCEL 链- 修复了代码以适应新的导入结构
This commit is contained in:
parent
9bd715d080
commit
4bd343c237
|
|
@ -11,7 +11,9 @@
|
|||
"Bash(git init:*)",
|
||||
"Bash(git remote add:*)",
|
||||
"Bash(git add:*)",
|
||||
"Bash(rm:*)"
|
||||
"Bash(rm:*)",
|
||||
"Bash(git commit:*)",
|
||||
"Bash(git push:*)"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": []
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import List, Dict, Any, Optional
|
|||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
from app.db.models import ToolCall
|
||||
from app.services.model_manager import ModelManager
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
"""
|
||||
支持 RAG 的多轮对话管理器。
|
||||
"""
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
||||
|
||||
from app.db.models import Conversation, Message
|
||||
from app.services.model_manager import ModelManager
|
||||
|
|
@ -162,6 +166,107 @@ class ConversationManager:
|
|||
|
||||
return messages
|
||||
|
||||
def _build_chat_history_messages(self, conversation_id: int, limit: int = 10) -> List:
|
||||
"""
|
||||
构建对话历史消息列表(LangChain 消息格式)。
|
||||
|
||||
Args:
|
||||
conversation_id: 对话 ID
|
||||
limit: 要包含的最大消息数
|
||||
|
||||
Returns:
|
||||
List: LangChain 消息对象列表
|
||||
"""
|
||||
messages = self.get_messages(conversation_id, limit=limit)
|
||||
|
||||
history_messages = []
|
||||
for msg in messages[:-1]: # 排除最后一条用户消息(将单独添加)
|
||||
if msg.role == "user":
|
||||
history_messages.append(HumanMessage(content=msg.content))
|
||||
elif msg.role == "assistant":
|
||||
history_messages.append(AIMessage(content=msg.content))
|
||||
|
||||
return history_messages
|
||||
|
||||
def _retrieve_context(self, kb_id: int, query: str, k: int = 3) -> Tuple[str, List[Dict]]:
|
||||
"""
|
||||
从知识库检索上下文。
|
||||
|
||||
Args:
|
||||
kb_id: 知识库 ID
|
||||
query: 查询文本
|
||||
k: 返回的文档数量
|
||||
|
||||
Returns:
|
||||
Tuple: (格式化的上下文字符串, 来源元数据列表)
|
||||
"""
|
||||
try:
|
||||
results = self.kb_manager.query_kb(kb_id, query, k=k)
|
||||
context_parts = []
|
||||
sources = []
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
content = result['content']
|
||||
if len(content) > 500:
|
||||
content = f"{content[:500]}..."
|
||||
context_parts.append(f"[{i}] {content}")
|
||||
sources.append(result['metadata'])
|
||||
|
||||
context = "\n\n".join(context_parts)
|
||||
logger.info("rag_context_retrieved", kb_id=kb_id, source_count=len(results))
|
||||
return context, sources
|
||||
except Exception as e:
|
||||
logger.warning("rag_context_failed", error=str(e))
|
||||
return "(No relevant context found)", []
|
||||
|
||||
def _create_rag_chain(self, llm, kb_id: Optional[int] = None):
|
||||
"""
|
||||
创建带 RAG 的 LCEL 链。
|
||||
|
||||
Args:
|
||||
llm: LLM 实例
|
||||
kb_id: 可选的知识库 ID
|
||||
|
||||
Returns:
|
||||
Runnable: LCEL 链
|
||||
"""
|
||||
if kb_id:
|
||||
# 带 RAG 的提示模板
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", "You are a helpful AI assistant. Use the following context from the knowledge base to answer the question.\n\nContext:\n{context}"),
|
||||
MessagesPlaceholder(variable_name="chat_history", optional=True),
|
||||
("human", "{input}"),
|
||||
])
|
||||
|
||||
# 创建检索器函数
|
||||
def retrieve_context(inputs):
|
||||
context, sources = self._retrieve_context(kb_id, inputs["input"])
|
||||
return {
|
||||
**inputs,
|
||||
"context": context,
|
||||
"_sources": sources # 保存来源以供后续使用
|
||||
}
|
||||
|
||||
# LCEL 链: 检索 -> 提示 -> LLM -> 解析输出
|
||||
chain = (
|
||||
RunnableLambda(retrieve_context)
|
||||
| RunnablePassthrough.assign(
|
||||
response=prompt | llm | StrOutputParser()
|
||||
)
|
||||
)
|
||||
else:
|
||||
# 无 RAG 的简单提示模板
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", "You are a helpful AI assistant."),
|
||||
MessagesPlaceholder(variable_name="chat_history", optional=True),
|
||||
("human", "{input}"),
|
||||
])
|
||||
|
||||
# LCEL 链: 提示 -> LLM -> 解析输出
|
||||
chain = prompt | llm | StrOutputParser()
|
||||
|
||||
return chain
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
conversation_id: int,
|
||||
|
|
@ -171,7 +276,7 @@ class ConversationManager:
|
|||
use_rag: bool = True,
|
||||
) -> Message:
|
||||
"""
|
||||
处理用户消息并生成助手回复。
|
||||
处理用户消息并生成助手回复(使用 LCEL)。
|
||||
|
||||
Args:
|
||||
conversation_id: 对话 ID
|
||||
|
|
@ -199,57 +304,39 @@ class ConversationManager:
|
|||
# 保存用户消息
|
||||
self.add_message(conversation_id, "user", user_input)
|
||||
|
||||
# 获取对话历史
|
||||
history = self._build_history(conversation_id, limit=10)
|
||||
|
||||
# 获取 LLM
|
||||
llm = self.model_manager.get_llm(llm_name)
|
||||
|
||||
# 如果启用了 RAG,从知识库构建上下文
|
||||
context = ""
|
||||
sources = []
|
||||
if use_rag and kb_id:
|
||||
# 获取对话历史(LangChain 消息格式)
|
||||
chat_history = self._build_chat_history_messages(conversation_id, limit=10)
|
||||
|
||||
# 创建 LCEL 链
|
||||
kb_id_for_rag = kb_id if use_rag else None
|
||||
chain = self._create_rag_chain(llm, kb_id_for_rag)
|
||||
|
||||
# 准备输入
|
||||
chain_input = {
|
||||
"input": user_input,
|
||||
"chat_history": chat_history,
|
||||
}
|
||||
|
||||
# 执行链
|
||||
try:
|
||||
results = self.kb_manager.query_kb(kb_id, user_input, k=3)
|
||||
context_parts = []
|
||||
for i, result in enumerate(results, 1):
|
||||
context_parts.append(
|
||||
f"[{i}] {result['content'][:500]}..."
|
||||
if len(result['content']) > 500
|
||||
else f"[{i}] {result['content']}"
|
||||
)
|
||||
sources.append(result['metadata'])
|
||||
context = "\n\n".join(context_parts)
|
||||
logger.info("rag_context_retrieved", kb_id=kb_id, source_count=len(results))
|
||||
except Exception as e:
|
||||
logger.warning("rag_context_failed", error=str(e))
|
||||
context = "(No relevant context found)"
|
||||
|
||||
# 构建提示
|
||||
if context:
|
||||
prompt = f"""You are a helpful AI assistant. Use the following context from the knowledge base to answer the question.
|
||||
|
||||
Context:
|
||||
{context}
|
||||
|
||||
Conversation history:
|
||||
{history}
|
||||
|
||||
User: {user_input}
|
||||
Assistant:"""
|
||||
if kb_id_for_rag:
|
||||
# 带 RAG 的链返回字典
|
||||
result = await chain.ainvoke(chain_input)
|
||||
response = result.get("response", "")
|
||||
sources = result.get("_sources", [])
|
||||
else:
|
||||
prompt = f"""{history}
|
||||
# 简单链直接返回字符串
|
||||
response = await chain.ainvoke(chain_input)
|
||||
sources = []
|
||||
|
||||
User: {user_input}
|
||||
Assistant:"""
|
||||
|
||||
# 生成回复
|
||||
try:
|
||||
response = llm.predict(prompt)
|
||||
logger.info("llm_response_generated", conversation_id=conversation_id)
|
||||
except Exception as e:
|
||||
logger.error("llm_generation_failed", error=str(e))
|
||||
response = "I apologize, but I encountered an error generating a response. Please try again."
|
||||
sources = []
|
||||
|
||||
# 保存助手消息
|
||||
metadata = {}
|
||||
|
|
@ -265,28 +352,6 @@ Assistant:"""
|
|||
logger.info("chat_complete", conversation_id=conversation_id)
|
||||
return assistant_message
|
||||
|
||||
def _build_history(self, conversation_id: int, limit: int = 10) -> str:
|
||||
"""
|
||||
构建对话历史字符串。
|
||||
|
||||
Args:
|
||||
conversation_id: 对话 ID
|
||||
limit: 要包含的最大消息数
|
||||
|
||||
Returns:
|
||||
str: 格式化的对话历史
|
||||
"""
|
||||
messages = self.get_messages(conversation_id, limit=limit)
|
||||
|
||||
history_parts = []
|
||||
for msg in messages[:-1]: # 排除最后一条用户消息(已在提示中)
|
||||
if msg.role == "user":
|
||||
history_parts.append(f"User: {msg.content}")
|
||||
elif msg.role == "assistant":
|
||||
history_parts.append(f"Assistant: {msg.content}")
|
||||
|
||||
return "\n".join(history_parts) if history_parts else "(No previous messages)"
|
||||
|
||||
def delete_conversation(self, conversation_id: int) -> bool:
|
||||
"""
|
||||
删除对话及其所有消息。
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
"""
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from langchain.schema import Document as LangChainDocument
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
|
||||
from app.db.models import KnowledgeBase, Document
|
||||
from app.services.model_manager import ModelManager
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ import re
|
|||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
|
||||
from app.db.models import Model
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
"""
|
||||
from typing import Type
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.tools import BaseTool
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.utils.logger import get_logger
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
"""
|
||||
from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.tools import BaseTool
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.services.kb_manager import KnowledgeBaseManager
|
||||
from app.utils.logger import get_logger
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ import os
|
|||
from typing import List, Optional
|
||||
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain.schema import Document as LangChainDocument
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from app.utils.exceptions import VectorStoreError
|
||||
from app.utils.logger import get_logger
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.schema import Document as LangChainDocument
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_core.documents import Document as LangChainDocument
|
||||
|
||||
|
||||
class TextSplitter:
|
||||
|
|
|
|||
Loading…
Reference in New Issue